Browse Source

Add more missing type hints to tests. (#15028)

Patrick Cloke 1 year ago
parent
commit
30509a1010

+ 1 - 0
changelog.d/15028.misc

@@ -0,0 +1 @@
+Improve type hints.

+ 0 - 18
mypy.ini

@@ -60,24 +60,6 @@ disallow_untyped_defs = False
 [mypy-synapse.storage.database]
 disallow_untyped_defs = False
 
-[mypy-tests.scripts.test_new_matrix_user]
-disallow_untyped_defs = False
-
-[mypy-tests.server_notices.test_consent]
-disallow_untyped_defs = False
-
-[mypy-tests.server_notices.test_resource_limits_server_notices]
-disallow_untyped_defs = False
-
-[mypy-tests.test_federation]
-disallow_untyped_defs = False
-
-[mypy-tests.test_utils.*]
-disallow_untyped_defs = False
-
-[mypy-tests.test_visibility]
-disallow_untyped_defs = False
-
 [mypy-tests.unittest]
 disallow_untyped_defs = False
 

+ 2 - 2
tests/handlers/test_oidc.py

@@ -150,7 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         hs = self.setup_test_homeserver()
         self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
-        self.hs_patcher.start()
+        self.hs_patcher.start()  # type: ignore[attr-defined]
 
         self.handler = hs.get_oidc_handler()
         self.provider = self.handler._providers["oidc"]
@@ -170,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def tearDown(self) -> None:
-        self.hs_patcher.stop()
+        self.hs_patcher.stop()  # type: ignore[attr-defined]
         return super().tearDown()
 
     def reset_mocks(self) -> None:

+ 16 - 9
tests/scripts/test_new_matrix_user.py

@@ -12,29 +12,33 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List
+from typing import List, Optional
 from unittest.mock import Mock, patch
 
 from synapse._scripts.register_new_matrix_user import request_registration
+from synapse.types import JsonDict
 
 from tests.unittest import TestCase
 
 
 class RegisterTestCase(TestCase):
-    def test_success(self):
+    def test_success(self) -> None:
         """
         The script will fetch a nonce, and then generate a MAC with it, and then
         post that MAC.
         """
 
-        def get(url, verify=None):
+        def get(url: str, verify: Optional[bool] = None) -> Mock:
             r = Mock()
             r.status_code = 200
             r.json = lambda: {"nonce": "a"}
             return r
 
-        def post(url, json=None, verify=None):
+        def post(
+            url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+        ) -> Mock:
             # Make sure we are sent the correct info
+            assert json is not None
             self.assertEqual(json["username"], "user")
             self.assertEqual(json["password"], "pass")
             self.assertEqual(json["nonce"], "a")
@@ -70,12 +74,12 @@ class RegisterTestCase(TestCase):
         # sys.exit shouldn't have been called.
         self.assertEqual(err_code, [])
 
-    def test_failure_nonce(self):
+    def test_failure_nonce(self) -> None:
         """
         If the script fails to fetch a nonce, it throws an error and quits.
         """
 
-        def get(url, verify=None):
+        def get(url: str, verify: Optional[bool] = None) -> Mock:
             r = Mock()
             r.status_code = 404
             r.reason = "Not Found"
@@ -107,20 +111,23 @@ class RegisterTestCase(TestCase):
         self.assertIn("ERROR! Received 404 Not Found", out)
         self.assertNotIn("Success!", out)
 
-    def test_failure_post(self):
+    def test_failure_post(self) -> None:
         """
         The script will fetch a nonce, and then if the final POST fails, will
         report an error and quit.
         """
 
-        def get(url, verify=None):
+        def get(url: str, verify: Optional[bool] = None) -> Mock:
             r = Mock()
             r.status_code = 200
             r.json = lambda: {"nonce": "a"}
             return r
 
-        def post(url, json=None, verify=None):
+        def post(
+            url: str, json: Optional[JsonDict] = None, verify: Optional[bool] = None
+        ) -> Mock:
             # Make sure we are sent the correct info
+            assert json is not None
             self.assertEqual(json["username"], "user")
             self.assertEqual(json["password"], "pass")
             self.assertEqual(json["nonce"], "a")

+ 8 - 6
tests/server_notices/test_consent.py

@@ -14,8 +14,12 @@
 
 import os
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.rest.client import login, room, sync
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -29,7 +33,7 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         tmpdir = self.mktemp()
         os.mkdir(tmpdir)
@@ -53,15 +57,13 @@ class ConsentNoticesTests(unittest.HomeserverTestCase):
             "room_name": "Server Notices",
         }
 
-        hs = self.setup_test_homeserver(config=config)
-
-        return hs
+        return self.setup_test_homeserver(config=config)
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.user_id = self.register_user("bob", "abc123")
         self.access_token = self.login("bob", "abc123")
 
-    def test_get_sync_message(self):
+    def test_get_sync_message(self) -> None:
         """
         When user consent server notices are enabled, a sync will cause a notice
         to fire (in a room which the user is invited to). The notice contains

+ 20 - 15
tests/server_notices/test_resource_limits_server_notices.py

@@ -24,6 +24,7 @@ from synapse.server import HomeServer
 from synapse.server_notices.resource_limits_server_notices import (
     ResourceLimitsServerNotices,
 )
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -33,7 +34,7 @@ from tests.utils import default_config
 
 
 class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         config = default_config("test")
 
         config.update(
@@ -86,18 +87,18 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))  # type: ignore[assignment]
 
     @override_config({"hs_disabled": True})
-    def test_maybe_send_server_notice_disabled_hs(self):
+    def test_maybe_send_server_notice_disabled_hs(self) -> None:
         """If the HS is disabled, we should not send notices"""
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         self._send_notice.assert_not_called()
 
     @override_config({"limit_usage_by_mau": False})
-    def test_maybe_send_server_notice_to_user_flag_off(self):
+    def test_maybe_send_server_notice_to_user_flag_off(self) -> None:
         """If mau limiting is disabled, we should not send notices"""
         self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
         self._send_notice.assert_not_called()
 
-    def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
+    def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None:
         """Test when user has blocked notice, but should have it removed"""
 
         self._rlsn._auth_blocking.check_auth_blocking = Mock(
@@ -114,7 +115,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._rlsn._server_notices_manager.maybe_get_notice_room_for_user.assert_called_once()
         self._send_notice.assert_called_once()
 
-    def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
+    def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self) -> None:
         """
         Test when user has blocked notice, but notice ought to be there (NOOP)
         """
@@ -134,7 +135,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
 
         self._send_notice.assert_not_called()
 
-    def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
+    def test_maybe_send_server_notice_to_user_add_blocked_notice(self) -> None:
         """
         Test when user does not have blocked notice, but should have one
         """
@@ -147,7 +148,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         # Would be better to check contents, but 2 calls == set blocking event
         self.assertEqual(self._send_notice.call_count, 2)
 
-    def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
+    def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self) -> None:
         """
         Test when user does not have blocked notice, nor should they (NOOP)
         """
@@ -159,7 +160,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
 
         self._send_notice.assert_not_called()
 
-    def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
+    def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self) -> None:
         """
         Test when user is not part of the MAU cohort - this should not ever
         happen - but ...
@@ -175,7 +176,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self._send_notice.assert_not_called()
 
     @override_config({"mau_limit_alerting": False})
-    def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
+    def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(
+        self,
+    ) -> None:
         """
         Test that when server is over MAU limit and alerting is suppressed, then
         an alert message is not sent into the room
@@ -191,7 +194,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self.assertEqual(self._send_notice.call_count, 0)
 
     @override_config({"mau_limit_alerting": False})
-    def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
+    def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self) -> None:
         """
         Test that when a server is disabled, that MAU limit alerting is ignored.
         """
@@ -207,7 +210,9 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
         self.assertEqual(self._send_notice.call_count, 2)
 
     @override_config({"mau_limit_alerting": False})
-    def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
+    def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(
+        self,
+    ) -> None:
         """
         When the room is already in a blocked state, test that when alerting
         is suppressed that the room is returned to an unblocked state.
@@ -242,7 +247,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         c = super().default_config()
         c["server_notices"] = {
             "system_mxid_localpart": "server",
@@ -270,7 +275,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         self.user_id = "@user_id:test"
 
-    def test_server_notice_only_sent_once(self):
+    def test_server_notice_only_sent_once(self) -> None:
         self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000))
 
         self.store.user_last_seen_monthly_active = Mock(
@@ -306,7 +311,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         self.assertEqual(count, 1)
 
-    def test_no_invite_without_notice(self):
+    def test_no_invite_without_notice(self) -> None:
         """Tests that a user doesn't get invited to a server notices room without a
         server notice being sent.
 
@@ -328,7 +333,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
 
         m.assert_called_once_with(user_id)
 
-    def test_invite_with_notice(self):
+    def test_invite_with_notice(self) -> None:
         """Tests that, if the MAU limit is hit, the server notices user invites each user
         to a room in which it has sent a notice.
         """

+ 44 - 36
tests/test_federation.py

@@ -12,53 +12,48 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Optional, Union
 from unittest.mock import Mock
 
 from twisted.internet.defer import succeed
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import FederationError
 from synapse.api.room_versions import RoomVersions
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.federation.federation_base import event_from_pdu_json
+from synapse.http.types import QueryParams
 from synapse.logging.context import LoggingContext
-from synapse.types import UserID, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
 from tests import unittest
-from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver
 from tests.test_utils import make_awaitable
 
 
 class MessageAcceptTests(unittest.HomeserverTestCase):
-    def setUp(self):
-
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.http_client = Mock()
-        self.reactor = ThreadedMemoryReactorClock()
-        self.hs_clock = Clock(self.reactor)
-        self.homeserver = setup_test_homeserver(
-            self.addCleanup,
-            federation_http_client=self.http_client,
-            clock=self.hs_clock,
-            reactor=self.reactor,
-        )
+        return self.setup_test_homeserver(federation_http_client=self.http_client)
 
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         user_id = UserID("us", "test")
         our_user = create_requester(user_id)
-        room_creator = self.homeserver.get_room_creation_handler()
+        room_creator = self.hs.get_room_creation_handler()
         self.room_id = self.get_success(
             room_creator.create_room(
                 our_user, room_creator._presets_dict["public_chat"], ratelimit=False
             )
         )[0]["room_id"]
 
-        self.store = self.homeserver.get_datastores().main
+        self.store = self.hs.get_datastores().main
 
         # Figure out what the most recent event is
         most_recent = self.get_success(
-            self.homeserver.get_datastores().main.get_latest_event_ids_in_room(
-                self.room_id
-            )
+            self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )[0]
 
         join_event = make_event_from_dict(
@@ -78,14 +73,16 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             }
         )
 
-        self.handler = self.homeserver.get_federation_handler()
-        federation_event_handler = self.homeserver.get_federation_event_handler()
+        self.handler = self.hs.get_federation_handler()
+        federation_event_handler = self.hs.get_federation_event_handler()
 
-        async def _check_event_auth(origin, event, context):
+        async def _check_event_auth(
+            origin: Optional[str], event: EventBase, context: EventContext
+        ) -> None:
             pass
 
         federation_event_handler._check_event_auth = _check_event_auth
-        self.client = self.homeserver.get_federation_client()
+        self.client = self.hs.get_federation_client()
         self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
             lambda dest, pdus, **k: succeed(pdus)
         )
@@ -104,16 +101,25 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             "$join:test.serv",
         )
 
-    def test_cant_hide_direct_ancestors(self):
+    def test_cant_hide_direct_ancestors(self) -> None:
         """
         If you send a message, you must be able to provide the direct
         prev_events that said event references.
         """
 
-        async def post_json(destination, path, data, headers=None, timeout=0):
+        async def post_json(
+            destination: str,
+            path: str,
+            data: Optional[JsonDict] = None,
+            long_retries: bool = False,
+            timeout: Optional[int] = None,
+            ignore_backoff: bool = False,
+            args: Optional[QueryParams] = None,
+        ) -> Union[JsonDict, list]:
             # If it asks us for new missing events, give them NOTHING
             if path.startswith("/_matrix/federation/v1/get_missing_events/"):
                 return {"events": []}
+            return {}
 
         self.http_client.post_json = post_json
 
@@ -138,7 +144,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
             }
         )
 
-        federation_event_handler = self.homeserver.get_federation_event_handler()
+        federation_event_handler = self.hs.get_federation_event_handler()
         with LoggingContext("test-context"):
             failure = self.get_failure(
                 federation_event_handler.on_receive_pdu("test.serv", lying_event),
@@ -158,7 +164,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
         self.assertEqual(extrem[0], "$join:test.serv")
 
-    def test_retry_device_list_resync(self):
+    def test_retry_device_list_resync(self) -> None:
         """Tests that device lists are marked as stale if they couldn't be synced, and
         that stale device lists are retried periodically.
         """
@@ -171,24 +177,26 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         # When this function is called, increment the number of resync attempts (only if
         # we're querying devices for the right user ID), then raise a
         # NotRetryingDestination error to fail the resync gracefully.
-        def query_user_devices(destination, user_id):
+        def query_user_devices(
+            destination: str, user_id: str, timeout: int = 30000
+        ) -> JsonDict:
             if user_id == remote_user_id:
                 self.resync_attempts += 1
 
             raise NotRetryingDestination(0, 0, destination)
 
         # Register the mock on the federation client.
-        federation_client = self.homeserver.get_federation_client()
+        federation_client = self.hs.get_federation_client()
         federation_client.query_user_devices = Mock(side_effect=query_user_devices)
 
         # Register a mock on the store so that the incoming update doesn't fail because
         # we don't share a room with the user.
-        store = self.homeserver.get_datastores().main
+        store = self.hs.get_datastores().main
         store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"]))
 
         # Manually inject a fake device list update. We need this update to include at
         # least one prev_id so that the user's device list will need to be retried.
-        device_list_updater = self.homeserver.get_device_handler().device_list_updater
+        device_list_updater = self.hs.get_device_handler().device_list_updater
         self.get_success(
             device_list_updater.incoming_device_list_update(
                 origin=remote_origin,
@@ -218,7 +226,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         self.reactor.advance(30)
         self.assertEqual(self.resync_attempts, 2)
 
-    def test_cross_signing_keys_retry(self):
+    def test_cross_signing_keys_retry(self) -> None:
         """Tests that resyncing a device list correctly processes cross-signing keys from
         the remote server.
         """
@@ -227,7 +235,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
 
         # Register mock device list retrieval on the federation client.
-        federation_client = self.homeserver.get_federation_client()
+        federation_client = self.hs.get_federation_client()
         federation_client.query_user_devices = Mock(
             return_value=make_awaitable(
                 {
@@ -252,7 +260,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
         )
 
         # Resync the device list.
-        device_handler = self.homeserver.get_device_handler()
+        device_handler = self.hs.get_device_handler()
         self.get_success(
             device_handler.device_list_updater.user_device_resync(remote_user_id),
         )
@@ -279,7 +287,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
 
 class StripUnsignedFromEventsTestCase(unittest.TestCase):
-    def test_strip_unauthorized_unsigned_values(self):
+    def test_strip_unauthorized_unsigned_values(self) -> None:
         event1 = {
             "sender": "@baduser:test.serv",
             "state_key": "@baduser:test.serv",
@@ -296,7 +304,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
         # Make sure unauthorized fields are stripped from unsigned
         self.assertNotIn("more warez", filtered_event.unsigned)
 
-    def test_strip_event_maintains_allowed_fields(self):
+    def test_strip_event_maintains_allowed_fields(self) -> None:
         event2 = {
             "sender": "@baduser:test.serv",
             "state_key": "@baduser:test.serv",
@@ -323,7 +331,7 @@ class StripUnsignedFromEventsTestCase(unittest.TestCase):
         self.assertIn("invite_room_state", filtered_event2.unsigned)
         self.assertEqual([], filtered_event2.unsigned["invite_room_state"])
 
-    def test_strip_event_removes_fields_based_on_event_type(self):
+    def test_strip_event_removes_fields_based_on_event_type(self) -> None:
         event3 = {
             "sender": "@baduser:test.serv",
             "state_key": "@baduser:test.serv",

+ 17 - 9
tests/test_utils/__init__.py

@@ -20,12 +20,13 @@ import sys
 import warnings
 from asyncio import Future
 from binascii import unhexlify
-from typing import Awaitable, Callable, Tuple, TypeVar
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
 from unittest.mock import Mock
 
 import attr
 import zope.interface
 
+from twisted.internet.interfaces import IProtocol
 from twisted.python.failure import Failure
 from twisted.web.client import ResponseDone
 from twisted.web.http import RESPONSES
@@ -34,6 +35,9 @@ from twisted.web.iweb import IResponse
 
 from synapse.types import JsonDict
 
+if TYPE_CHECKING:
+    from sys import UnraisableHookArgs
+
 TV = TypeVar("TV")
 
 
@@ -78,25 +82,29 @@ def setup_awaitable_errors() -> Callable[[], None]:
     unraisable_exceptions = []
     orig_unraisablehook = sys.unraisablehook
 
-    def unraisablehook(unraisable):
+    def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
         unraisable_exceptions.append(unraisable.exc_value)
 
-    def cleanup():
+    def cleanup() -> None:
         """
         A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
         """
         sys.unraisablehook = orig_unraisablehook
         if unraisable_exceptions:
-            raise unraisable_exceptions.pop()
+            exc = unraisable_exceptions.pop()
+            assert exc is not None
+            raise exc
 
     sys.unraisablehook = unraisablehook
 
     return cleanup
 
 
-def simple_async_mock(return_value=None, raises=None) -> Mock:
+def simple_async_mock(
+    return_value: Optional[TV] = None, raises: Optional[Exception] = None
+) -> Mock:
     # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args, **kwargs):
+    async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
         if raises:
             raise raises
         return return_value
@@ -125,14 +133,14 @@ class FakeResponse:  # type: ignore[misc]
     headers: Headers = attr.Factory(Headers)
 
     @property
-    def phrase(self):
+    def phrase(self) -> bytes:
         return RESPONSES.get(self.code, b"Unknown Status")
 
     @property
-    def length(self):
+    def length(self) -> int:
         return len(self.body)
 
-    def deliverBody(self, protocol):
+    def deliverBody(self, protocol: IProtocol) -> None:
         protocol.dataReceived(self.body)
         protocol.connectionLost(Failure(ResponseDone()))
 

+ 4 - 4
tests/test_utils/event_injection.py

@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
 
 import synapse.server
 from synapse.api.constants import EventTypes
@@ -32,7 +32,7 @@ async def inject_member_event(
     membership: str,
     target: Optional[str] = None,
     extra_content: Optional[dict] = None,
-    **kwargs,
+    **kwargs: Any,
 ) -> EventBase:
     """Inject a membership event into a room."""
     if target is None:
@@ -57,7 +57,7 @@ async def inject_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[List[str]] = None,
-    **kwargs,
+    **kwargs: Any,
 ) -> EventBase:
     """Inject a generic event into a room
 
@@ -82,7 +82,7 @@ async def create_event(
     hs: synapse.server.HomeServer,
     room_version: Optional[str] = None,
     prev_event_ids: Optional[List[str]] = None,
-    **kwargs,
+    **kwargs: Any,
 ) -> Tuple[EventBase, EventContext]:
     if room_version is None:
         room_version = await hs.get_datastores().main.get_room_version_id(

+ 3 - 3
tests/test_utils/html_parsers.py

@@ -13,13 +13,13 @@
 # limitations under the License.
 
 from html.parser import HTMLParser
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Dict, Iterable, List, NoReturn, Optional, Tuple
 
 
 class TestHtmlParser(HTMLParser):
     """A generic HTML page parser which extracts useful things from the HTML"""
 
-    def __init__(self):
+    def __init__(self) -> None:
         super().__init__()
 
         # a list of links found in the doc
@@ -48,5 +48,5 @@ class TestHtmlParser(HTMLParser):
                 assert input_name
                 self.hiddens[input_name] = attr_dict["value"]
 
-    def error(_, message):
+    def error(self, message: str) -> NoReturn:
         raise AssertionError(message)

+ 2 - 2
tests/test_utils/logging_setup.py

@@ -25,7 +25,7 @@ class ToTwistedHandler(logging.Handler):
 
     tx_log = twisted.logger.Logger()
 
-    def emit(self, record):
+    def emit(self, record: logging.LogRecord) -> None:
         log_entry = self.format(record)
         log_level = record.levelname.lower().replace("warning", "warn")
         self.tx_log.emit(
@@ -33,7 +33,7 @@ class ToTwistedHandler(logging.Handler):
         )
 
 
-def setup_logging():
+def setup_logging() -> None:
     """Configure the python logging appropriately for the tests.
 
     (Logs will end up in _trial_temp.)

+ 5 - 5
tests/test_utils/oidc.py

@@ -14,7 +14,7 @@
 
 
 import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
 from unittest.mock import Mock, patch
 from urllib.parse import parse_qs
 
@@ -77,14 +77,14 @@ class FakeOidcServer:
 
         self._id_token_overrides: Dict[str, Any] = {}
 
-    def reset_mocks(self):
+    def reset_mocks(self) -> None:
         self.request.reset_mock()
         self.get_jwks_handler.reset_mock()
         self.get_metadata_handler.reset_mock()
         self.get_userinfo_handler.reset_mock()
         self.post_token_handler.reset_mock()
 
-    def patch_homeserver(self, hs: HomeServer):
+    def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
         """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
 
         This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
 
         return self._sign(logout_token)
 
-    def id_token_override(self, overrides: dict):
+    def id_token_override(self, overrides: dict) -> ContextManager[dict]:
         """Temporarily patch the ID token generated by the token endpoint."""
         return patch.object(self, "_id_token_overrides", overrides)
 
@@ -247,7 +247,7 @@ class FakeOidcServer:
         metadata: bool = False,
         token: bool = False,
         userinfo: bool = False,
-    ):
+    ) -> ContextManager[Dict[str, Mock]]:
         """A context which makes a set of endpoints return a 500 error.
 
         Args:

+ 1 - 1
tests/test_visibility.py

@@ -258,7 +258,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
 
 class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
-    def test_out_of_band_invite_rejection(self):
+    def test_out_of_band_invite_rejection(self) -> None:
         # this is where we have received an invite event over federation, and then
         # rejected it.
         invite_pdu = {

+ 1 - 1
tests/unittest.py

@@ -315,7 +315,7 @@ class HomeserverTestCase(TestCase):
 
                 # This has to be a function and not just a Mock, because
                 # `self.helper.auth_user_id` is temporarily reassigned in some tests
-                async def get_requester(*args, **kwargs) -> Requester:
+                async def get_requester(*args: Any, **kwargs: Any) -> Requester:
                     assert self.helper.auth_user_id is not None
                     return create_requester(
                         user_id=UserID.from_string(self.helper.auth_user_id),