Browse Source

Add missing type hints to tests. (#15027)

Patrick Cloke 1 year ago
parent
commit
4eed7b2ede
8 changed files with 70 additions and 76 deletions
  1. 1 0
      changelog.d/15027.misc
  2. 0 18
      mypy.ini
  3. 6 6
      tests/test_distributor.py
  4. 17 15
      tests/test_event_auth.py
  5. 22 13
      tests/test_mau.py
  6. 1 1
      tests/test_rust.py
  7. 8 8
      tests/test_test_utils.py
  8. 15 15
      tests/test_types.py

+ 1 - 0
changelog.d/15027.misc

@@ -0,0 +1 @@
+Improve type hints.

+ 0 - 18
mypy.ini

@@ -69,27 +69,9 @@ disallow_untyped_defs = False
 [mypy-tests.server_notices.test_resource_limits_server_notices]
 disallow_untyped_defs = False
 
-[mypy-tests.test_distributor]
-disallow_untyped_defs = False
-
-[mypy-tests.test_event_auth]
-disallow_untyped_defs = False
-
 [mypy-tests.test_federation]
 disallow_untyped_defs = False
 
-[mypy-tests.test_mau]
-disallow_untyped_defs = False
-
-[mypy-tests.test_rust]
-disallow_untyped_defs = False
-
-[mypy-tests.test_test_utils]
-disallow_untyped_defs = False
-
-[mypy-tests.test_types]
-disallow_untyped_defs = False
-
 [mypy-tests.test_utils.*]
 disallow_untyped_defs = False
 

+ 6 - 6
tests/test_distributor.py

@@ -21,10 +21,10 @@ from . import unittest
 
 
 class DistributorTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.dist = Distributor()
 
-    def test_signal_dispatch(self):
+    def test_signal_dispatch(self) -> None:
         self.dist.declare("alert")
 
         observer = Mock()
@@ -33,7 +33,7 @@ class DistributorTestCase(unittest.TestCase):
         self.dist.fire("alert", 1, 2, 3)
         observer.assert_called_with(1, 2, 3)
 
-    def test_signal_catch(self):
+    def test_signal_catch(self) -> None:
         self.dist.declare("alarm")
 
         observers = [Mock() for i in (1, 2)]
@@ -51,7 +51,7 @@ class DistributorTestCase(unittest.TestCase):
             self.assertEqual(mock_logger.warning.call_count, 1)
             self.assertIsInstance(mock_logger.warning.call_args[0][0], str)
 
-    def test_signal_prereg(self):
+    def test_signal_prereg(self) -> None:
         observer = Mock()
         self.dist.observe("flare", observer)
 
@@ -60,8 +60,8 @@ class DistributorTestCase(unittest.TestCase):
 
         observer.assert_called_with(4, 5)
 
-    def test_signal_undeclared(self):
-        def code():
+    def test_signal_undeclared(self) -> None:
+        def code() -> None:
             self.dist.fire("notification")
 
         self.assertRaises(KeyError, code)

+ 17 - 15
tests/test_event_auth.py

@@ -31,13 +31,13 @@ from tests.test_utils import get_awaitable_result
 class _StubEventSourceStore:
     """A stub implementation of the EventSourceStore"""
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._store: Dict[str, EventBase] = {}
 
-    def add_event(self, event: EventBase):
+    def add_event(self, event: EventBase) -> None:
         self._store[event.event_id] = event
 
-    def add_events(self, events: Iterable[EventBase]):
+    def add_events(self, events: Iterable[EventBase]) -> None:
         for event in events:
             self._store[event.event_id] = event
 
@@ -59,7 +59,7 @@ class _StubEventSourceStore:
 
 
 class EventAuthTestCase(unittest.TestCase):
-    def test_rejected_auth_events(self):
+    def test_rejected_auth_events(self) -> None:
         """
         Events that refer to rejected events in their auth events are rejected
         """
@@ -109,7 +109,7 @@ class EventAuthTestCase(unittest.TestCase):
                 )
             )
 
-    def test_create_event_with_prev_events(self):
+    def test_create_event_with_prev_events(self) -> None:
         """A create event with prev_events should be rejected
 
         https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -150,7 +150,7 @@ class EventAuthTestCase(unittest.TestCase):
                 event_auth.check_state_independent_auth_rules(event_store, bad_event)
             )
 
-    def test_duplicate_auth_events(self):
+    def test_duplicate_auth_events(self) -> None:
         """Events with duplicate auth_events should be rejected
 
         https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -196,7 +196,7 @@ class EventAuthTestCase(unittest.TestCase):
                 event_auth.check_state_independent_auth_rules(event_store, bad_event2)
             )
 
-    def test_unexpected_auth_events(self):
+    def test_unexpected_auth_events(self) -> None:
         """Events with excess auth_events should be rejected
 
         https://spec.matrix.org/v1.3/rooms/v9/#authorization-rules
@@ -236,7 +236,7 @@ class EventAuthTestCase(unittest.TestCase):
                 event_auth.check_state_independent_auth_rules(event_store, bad_event)
             )
 
-    def test_random_users_cannot_send_state_before_first_pl(self):
+    def test_random_users_cannot_send_state_before_first_pl(self) -> None:
         """
         Check that, before the first PL lands, the creator is the only user
         that can send a state event.
@@ -263,7 +263,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events,
         )
 
-    def test_state_default_level(self):
+    def test_state_default_level(self) -> None:
         """
         Check that users above the state_default level can send state and
         those below cannot
@@ -298,7 +298,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events,
         )
 
-    def test_alias_event(self):
+    def test_alias_event(self) -> None:
         """Alias events have special behavior up through room version 6."""
         creator = "@creator:example.com"
         other = "@other:example.com"
@@ -333,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events,
         )
 
-    def test_msc2432_alias_event(self):
+    def test_msc2432_alias_event(self) -> None:
         """After MSC2432, alias events have no special behavior."""
         creator = "@creator:example.com"
         other = "@other:example.com"
@@ -366,7 +366,9 @@ class EventAuthTestCase(unittest.TestCase):
             )
 
     @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)])
-    def test_notifications(self, room_version: RoomVersion, allow_modification: bool):
+    def test_notifications(
+        self, room_version: RoomVersion, allow_modification: bool
+    ) -> None:
         """
         Notifications power levels get checked due to MSC2209.
         """
@@ -395,7 +397,7 @@ class EventAuthTestCase(unittest.TestCase):
             with self.assertRaises(AuthError):
                 event_auth.check_state_dependent_auth_rules(pl_event, auth_events)
 
-    def test_join_rules_public(self):
+    def test_join_rules_public(self) -> None:
         """
         Test joining a public room.
         """
@@ -460,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
             auth_events.values(),
         )
 
-    def test_join_rules_invite(self):
+    def test_join_rules_invite(self) -> None:
         """
         Test joining an invite only room.
         """
@@ -835,7 +837,7 @@ def _power_levels_event(
     )
 
 
-def _alias_event(room_version: RoomVersion, sender: str, **kwargs) -> EventBase:
+def _alias_event(room_version: RoomVersion, sender: str, **kwargs: Any) -> EventBase:
     data = {
         "room_id": TEST_ROOM_ID,
         **_maybe_get_event_id_dict_for_room_version(room_version),

+ 22 - 13
tests/test_mau.py

@@ -14,12 +14,17 @@
 
 """Tests REST events for /rooms paths."""
 
-from typing import List
+from typing import List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.rest.client import register, sync
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 from tests.unittest import override_config
@@ -30,7 +35,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
     servlets = [register.register_servlets, sync.register_servlets]
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = default_config("test")
 
         config.update(
@@ -53,10 +58,12 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
         return config
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         self.store = homeserver.get_datastores().main
 
-    def test_simple_deny_mau(self):
+    def test_simple_deny_mau(self) -> None:
         # Create and sync so that the MAU counts get updated
         token1 = self.create_user("kermit1")
         self.do_sync_for_user(token1)
@@ -75,7 +82,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.assertEqual(e.code, 403)
         self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
-    def test_as_ignores_mau(self):
+    def test_as_ignores_mau(self) -> None:
         """Test that application services can still create users when the MAU
         limit has been reached. This only works when application service
         user ip tracking is disabled.
@@ -113,7 +120,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
         self.create_user("as_kermit4", token=as_token, appservice=True)
 
-    def test_allowed_after_a_month_mau(self):
+    def test_allowed_after_a_month_mau(self) -> None:
         # Create and sync so that the MAU counts get updated
         token1 = self.create_user("kermit1")
         self.do_sync_for_user(token1)
@@ -132,7 +139,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.do_sync_for_user(token3)
 
     @override_config({"mau_trial_days": 1})
-    def test_trial_delay(self):
+    def test_trial_delay(self) -> None:
         # We should be able to register more than the limit initially
         token1 = self.create_user("kermit1")
         self.do_sync_for_user(token1)
@@ -165,7 +172,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
     @override_config({"mau_trial_days": 1})
-    def test_trial_users_cant_come_back(self):
+    def test_trial_users_cant_come_back(self) -> None:
         self.hs.config.server.mau_trial_days = 1
 
         # We should be able to register more than the limit initially
@@ -216,7 +223,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
         # max_mau_value should not matter
         {"max_mau_value": 1, "limit_usage_by_mau": False, "mau_stats_only": True}
     )
-    def test_tracked_but_not_limited(self):
+    def test_tracked_but_not_limited(self) -> None:
         # Simply being able to create 2 users indicates that the
         # limit was not reached.
         token1 = self.create_user("kermit1")
@@ -236,10 +243,10 @@ class TestMauLimit(unittest.HomeserverTestCase):
             "mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2},
         }
     )
-    def test_as_trial_days(self):
+    def test_as_trial_days(self) -> None:
         user_tokens: List[str] = []
 
-        def advance_time_and_sync():
+        def advance_time_and_sync() -> None:
             self.reactor.advance(24 * 60 * 61)
             for token in user_tokens:
                 self.do_sync_for_user(token)
@@ -300,7 +307,9 @@ class TestMauLimit(unittest.HomeserverTestCase):
             },
         )
 
-    def create_user(self, localpart, token=None, appservice=False):
+    def create_user(
+        self, localpart: str, token: Optional[str] = None, appservice: bool = False
+    ) -> str:
         request_data = {
             "username": localpart,
             "password": "monkey",
@@ -326,7 +335,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
 
         return access_token
 
-    def do_sync_for_user(self, token):
+    def do_sync_for_user(self, token: str) -> None:
         channel = self.make_request("GET", "/sync", access_token=token)
 
         if channel.code != 200:

+ 1 - 1
tests/test_rust.py

@@ -6,6 +6,6 @@ from tests import unittest
 class RustTestCase(unittest.TestCase):
     """Basic tests to ensure that we can call into Rust code."""
 
-    def test_basic(self):
+    def test_basic(self) -> None:
         result = sum_as_string(1, 2)
         self.assertEqual("3", result)

+ 8 - 8
tests/test_test_utils.py

@@ -17,25 +17,25 @@ from tests.utils import MockClock
 
 
 class MockClockTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.clock = MockClock()
 
-    def test_advance_time(self):
+    def test_advance_time(self) -> None:
         start_time = self.clock.time()
 
         self.clock.advance_time(20)
 
         self.assertEqual(20, self.clock.time() - start_time)
 
-    def test_later(self):
+    def test_later(self) -> None:
         invoked = [0, 0]
 
-        def _cb0():
+        def _cb0() -> None:
             invoked[0] = 1
 
         self.clock.call_later(10, _cb0)
 
-        def _cb1():
+        def _cb1() -> None:
             invoked[1] = 1
 
         self.clock.call_later(20, _cb1)
@@ -51,15 +51,15 @@ class MockClockTestCase(unittest.TestCase):
 
         self.assertTrue(invoked[1])
 
-    def test_cancel_later(self):
+    def test_cancel_later(self) -> None:
         invoked = [0, 0]
 
-        def _cb0():
+        def _cb0() -> None:
             invoked[0] = 1
 
         t0 = self.clock.call_later(10, _cb0)
 
-        def _cb1():
+        def _cb1() -> None:
             invoked[1] = 1
 
         self.clock.call_later(20, _cb1)

+ 15 - 15
tests/test_types.py

@@ -43,34 +43,34 @@ class IsMineIDTests(unittest.HomeserverTestCase):
 
 
 class UserIDTestCase(unittest.HomeserverTestCase):
-    def test_parse(self):
+    def test_parse(self) -> None:
         user = UserID.from_string("@1234abcd:test")
 
         self.assertEqual("1234abcd", user.localpart)
         self.assertEqual("test", user.domain)
         self.assertEqual(True, self.hs.is_mine(user))
 
-    def test_parse_rejects_empty_id(self):
+    def test_parse_rejects_empty_id(self) -> None:
         with self.assertRaises(SynapseError):
             UserID.from_string("")
 
-    def test_parse_rejects_missing_sigil(self):
+    def test_parse_rejects_missing_sigil(self) -> None:
         with self.assertRaises(SynapseError):
             UserID.from_string("alice:example.com")
 
-    def test_parse_rejects_missing_separator(self):
+    def test_parse_rejects_missing_separator(self) -> None:
         with self.assertRaises(SynapseError):
             UserID.from_string("@alice.example.com")
 
-    def test_validation_rejects_missing_domain(self):
+    def test_validation_rejects_missing_domain(self) -> None:
         self.assertFalse(UserID.is_valid("@alice:"))
 
-    def test_build(self):
+    def test_build(self) -> None:
         user = UserID("5678efgh", "my.domain")
 
         self.assertEqual(user.to_string(), "@5678efgh:my.domain")
 
-    def test_compare(self):
+    def test_compare(self) -> None:
         userA = UserID.from_string("@userA:my.domain")
         userAagain = UserID.from_string("@userA:my.domain")
         userB = UserID.from_string("@userB:my.domain")
@@ -80,43 +80,43 @@ class UserIDTestCase(unittest.HomeserverTestCase):
 
 
 class RoomAliasTestCase(unittest.HomeserverTestCase):
-    def test_parse(self):
+    def test_parse(self) -> None:
         room = RoomAlias.from_string("#channel:test")
 
         self.assertEqual("channel", room.localpart)
         self.assertEqual("test", room.domain)
         self.assertEqual(True, self.hs.is_mine(room))
 
-    def test_build(self):
+    def test_build(self) -> None:
         room = RoomAlias("channel", "my.domain")
 
         self.assertEqual(room.to_string(), "#channel:my.domain")
 
-    def test_validate(self):
+    def test_validate(self) -> None:
         id_string = "#test:domain,test"
         self.assertFalse(RoomAlias.is_valid(id_string))
 
 
 class MapUsernameTestCase(unittest.TestCase):
-    def testPassThrough(self):
+    def test_pass_througuh(self) -> None:
         self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
 
-    def testUpperCase(self):
+    def test_upper_case(self) -> None:
         self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
         self.assertEqual(
             map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
             "t_e_s_t__1234",
         )
 
-    def testSymbols(self):
+    def test_symbols(self) -> None:
         self.assertEqual(
             map_username_to_mxid_localpart("test=$?_1234"), "test=3d=24=3f_1234"
         )
 
-    def testLeadingUnderscore(self):
+    def test_leading_underscore(self) -> None:
         self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
 
-    def testNonAscii(self):
+    def test_non_ascii(self) -> None:
         # this should work with either a unicode or a bytes
         self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
         self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")