Răsfoiți Sursa

Add type hints to some tests/handlers files. (#12224)

Dirk Klimpel 2 ani în urmă
părinte
comite
5dd949bee6

+ 1 - 0
changelog.d/12224.misc

@@ -0,0 +1 @@
+Add type hints to tests files.

+ 0 - 5
mypy.ini

@@ -67,13 +67,8 @@ exclude = (?x)
    |tests/federation/transport/test_knocking.py
    |tests/federation/transport/test_knocking.py
    |tests/federation/transport/test_server.py
    |tests/federation/transport/test_server.py
    |tests/handlers/test_cas.py
    |tests/handlers/test_cas.py
-   |tests/handlers/test_directory.py
-   |tests/handlers/test_e2e_keys.py
    |tests/handlers/test_federation.py
    |tests/handlers/test_federation.py
-   |tests/handlers/test_oidc.py
    |tests/handlers/test_presence.py
    |tests/handlers/test_presence.py
-   |tests/handlers/test_profile.py
-   |tests/handlers/test_saml.py
    |tests/handlers/test_typing.py
    |tests/handlers/test_typing.py
    |tests/http/federation/test_matrix_federation_agent.py
    |tests/http/federation/test_matrix_federation_agent.py
    |tests/http/federation/test_srv_resolver.py
    |tests/http/federation/test_srv_resolver.py

+ 47 - 37
tests/handlers/test_directory.py

@@ -12,14 +12,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
-
+from typing import Any, Awaitable, Callable, Dict
 from unittest.mock import Mock
 from unittest.mock import Mock
 
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.api.errors
 import synapse.api.errors
 import synapse.rest.admin
 import synapse.rest.admin
 from synapse.api.constants import EventTypes
 from synapse.api.constants import EventTypes
 from synapse.rest.client import directory, login, room
 from synapse.rest.client import directory, login, room
-from synapse.types import RoomAlias, create_requester
+from synapse.server import HomeServer
+from synapse.types import JsonDict, RoomAlias, create_requester
+from synapse.util import Clock
 
 
 from tests import unittest
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.test_utils import make_awaitable
@@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable
 class DirectoryTestCase(unittest.HomeserverTestCase):
 class DirectoryTestCase(unittest.HomeserverTestCase):
     """Tests the directory service."""
     """Tests the directory service."""
 
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.mock_federation = Mock()
         self.mock_federation = Mock()
         self.mock_registry = Mock()
         self.mock_registry = Mock()
 
 
-        self.query_handlers = {}
+        self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
 
 
-        def register_query_handler(query_type, handler):
+        def register_query_handler(
+            query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+        ) -> None:
             self.query_handlers[query_type] = handler
             self.query_handlers[query_type] = handler
 
 
         self.mock_registry.register_query_handler = register_query_handler
         self.mock_registry.register_query_handler = register_query_handler
@@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
 
         return hs
         return hs
 
 
-    def test_get_local_association(self):
+    def test_get_local_association(self) -> None:
         self.get_success(
         self.get_success(
             self.store.create_room_alias_association(
             self.store.create_room_alias_association(
                 self.my_room, "!8765qwer:test", ["test"]
                 self.my_room, "!8765qwer:test", ["test"]
@@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
 
 
         self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
         self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
 
 
-    def test_get_remote_association(self):
+    def test_get_remote_association(self) -> None:
         self.mock_federation.make_query.return_value = make_awaitable(
         self.mock_federation.make_query.return_value = make_awaitable(
             {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
             {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
         )
         )
@@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
             ignore_backoff=True,
             ignore_backoff=True,
         )
         )
 
 
-    def test_incoming_fed_query(self):
+    def test_incoming_fed_query(self) -> None:
         self.get_success(
         self.get_success(
             self.store.create_room_alias_association(
             self.store.create_room_alias_association(
                 self.your_room, "!8765asdf:test", ["test"]
                 self.your_room, "!8765asdf:test", ["test"]
@@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
         directory.register_servlets,
         directory.register_servlets,
     ]
     ]
 
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_directory_handler()
         self.handler = hs.get_directory_handler()
 
 
         # Create user
         # Create user
@@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
         self.test_user_tok = self.login("user", "pass")
         self.test_user_tok = self.login("user", "pass")
         self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
         self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
 
 
-    def test_create_alias_joined_room(self):
+    def test_create_alias_joined_room(self) -> None:
         """A user can create an alias for a room they're in."""
         """A user can create an alias for a room they're in."""
         self.get_success(
         self.get_success(
             self.handler.create_association(
             self.handler.create_association(
@@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
             )
             )
         )
         )
 
 
-    def test_create_alias_other_room(self):
+    def test_create_alias_other_room(self) -> None:
         """A user cannot create an alias for a room they're NOT in."""
         """A user cannot create an alias for a room they're NOT in."""
         other_room_id = self.helper.create_room_as(
         other_room_id = self.helper.create_room_as(
             self.admin_user, tok=self.admin_user_tok
             self.admin_user, tok=self.admin_user_tok
@@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
             synapse.api.errors.SynapseError,
             synapse.api.errors.SynapseError,
         )
         )
 
 
-    def test_create_alias_admin(self):
+    def test_create_alias_admin(self) -> None:
         """An admin can create an alias for a room they're NOT in."""
         """An admin can create an alias for a room they're NOT in."""
         other_room_id = self.helper.create_room_as(
         other_room_id = self.helper.create_room_as(
             self.test_user, tok=self.test_user_tok
             self.test_user, tok=self.test_user_tok
@@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
         directory.register_servlets,
         directory.register_servlets,
     ]
     ]
 
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
         self.state_handler = hs.get_state_handler()
@@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
         self.test_user_tok = self.login("user", "pass")
         self.test_user_tok = self.login("user", "pass")
         self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
         self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok)
 
 
-    def _create_alias(self, user):
+    def _create_alias(self, user) -> None:
         # Create a new alias to this room.
         # Create a new alias to this room.
         self.get_success(
         self.get_success(
             self.store.create_room_alias_association(
             self.store.create_room_alias_association(
@@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             )
             )
         )
         )
 
 
-    def test_delete_alias_not_allowed(self):
+    def test_delete_alias_not_allowed(self) -> None:
         """A user that doesn't meet the expected guidelines cannot delete an alias."""
         """A user that doesn't meet the expected guidelines cannot delete an alias."""
         self._create_alias(self.admin_user)
         self._create_alias(self.admin_user)
         self.get_failure(
         self.get_failure(
@@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             synapse.api.errors.AuthError,
             synapse.api.errors.AuthError,
         )
         )
 
 
-    def test_delete_alias_creator(self):
+    def test_delete_alias_creator(self) -> None:
         """An alias creator can delete their own alias."""
         """An alias creator can delete their own alias."""
         # Create an alias from a different user.
         # Create an alias from a different user.
         self._create_alias(self.test_user)
         self._create_alias(self.test_user)
@@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             synapse.api.errors.SynapseError,
             synapse.api.errors.SynapseError,
         )
         )
 
 
-    def test_delete_alias_admin(self):
+    def test_delete_alias_admin(self) -> None:
         """A server admin can delete an alias created by another user."""
         """A server admin can delete an alias created by another user."""
         # Create an alias from a different user.
         # Create an alias from a different user.
         self._create_alias(self.test_user)
         self._create_alias(self.test_user)
@@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
             synapse.api.errors.SynapseError,
             synapse.api.errors.SynapseError,
         )
         )
 
 
-    def test_delete_alias_sufficient_power(self):
+    def test_delete_alias_sufficient_power(self) -> None:
         """A user with a sufficient power level should be able to delete an alias."""
         """A user with a sufficient power level should be able to delete an alias."""
         self._create_alias(self.admin_user)
         self._create_alias(self.admin_user)
 
 
@@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         directory.register_servlets,
         directory.register_servlets,
     ]
     ]
 
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.store = hs.get_datastores().main
         self.handler = hs.get_directory_handler()
         self.handler = hs.get_directory_handler()
         self.state_handler = hs.get_state_handler()
         self.state_handler = hs.get_state_handler()
@@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         )
         )
         return room_alias
         return room_alias
 
 
-    def _set_canonical_alias(self, content):
+    def _set_canonical_alias(self, content) -> None:
         """Configure the canonical alias state on the room."""
         """Configure the canonical alias state on the room."""
         self.helper.send_state(
         self.helper.send_state(
             self.room_id,
             self.room_id,
@@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
             )
             )
         )
         )
 
 
-    def test_remove_alias(self):
+    def test_remove_alias(self) -> None:
         """Removing an alias that is the canonical alias should remove it there too."""
         """Removing an alias that is the canonical alias should remove it there too."""
         # Set this new alias as the canonical alias for this room
         # Set this new alias as the canonical alias for this room
         self._set_canonical_alias(
         self._set_canonical_alias(
@@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
         self.assertNotIn("alias", data["content"])
         self.assertNotIn("alias", data["content"])
         self.assertNotIn("alt_aliases", data["content"])
         self.assertNotIn("alt_aliases", data["content"])
 
 
-    def test_remove_other_alias(self):
+    def test_remove_other_alias(self) -> None:
         """Removing an alias listed as in alt_aliases should remove it there too."""
         """Removing an alias listed as in alt_aliases should remove it there too."""
         # Create a second alias.
         # Create a second alias.
         other_test_alias = "#test2:test"
         other_test_alias = "#test2:test"
@@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
 
 
     servlets = [directory.register_servlets, room.register_servlets]
     servlets = [directory.register_servlets, room.register_servlets]
 
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config = super().default_config()
 
 
         # Add custom alias creation rules to the config.
         # Add custom alias creation rules to the config.
@@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
 
 
         return config
         return config
 
 
-    def test_denied(self):
+    def test_denied(self) -> None:
         room_id = self.helper.create_room_as(self.user_id)
         room_id = self.helper.create_room_as(self.user_id)
 
 
         channel = self.make_request(
         channel = self.make_request(
@@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         )
         )
         self.assertEqual(403, channel.code, channel.result)
         self.assertEqual(403, channel.code, channel.result)
 
 
-    def test_allowed(self):
+    def test_allowed(self) -> None:
         room_id = self.helper.create_room_as(self.user_id)
         room_id = self.helper.create_room_as(self.user_id)
 
 
         channel = self.make_request(
         channel = self.make_request(
@@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
         )
         )
         self.assertEqual(200, channel.code, channel.result)
         self.assertEqual(200, channel.code, channel.result)
 
 
-    def test_denied_during_creation(self):
+    def test_denied_during_creation(self) -> None:
         """A room alias that is not allowed should be rejected during creation."""
         """A room alias that is not allowed should be rejected during creation."""
         # Invalid room alias.
         # Invalid room alias.
         self.helper.create_room_as(
         self.helper.create_room_as(
@@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
             extra_content={"room_alias_name": "foo"},
             extra_content={"room_alias_name": "foo"},
         )
         )
 
 
-    def test_allowed_during_creation(self):
+    def test_allowed_during_creation(self) -> None:
         """A valid room alias should be allowed during creation."""
         """A valid room alias should be allowed during creation."""
         room_id = self.helper.create_room_as(
         room_id = self.helper.create_room_as(
             self.user_id,
             self.user_id,
@@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
     data = {"room_alias_name": "unofficial_test"}
     data = {"room_alias_name": "unofficial_test"}
     allowed_localpart = "allowed"
     allowed_localpart = "allowed"
 
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config = super().default_config()
 
 
         # Add custom room list publication rules to the config.
         # Add custom room list publication rules to the config.
@@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
 
 
         return config
         return config
 
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+    ) -> HomeServer:
         self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
         self.allowed_user_id = self.register_user(self.allowed_localpart, "pass")
         self.allowed_access_token = self.login(self.allowed_localpart, "pass")
         self.allowed_access_token = self.login(self.allowed_localpart, "pass")
 
 
@@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
 
 
         return hs
         return hs
 
 
-    def test_denied_without_publication_permission(self):
+    def test_denied_without_publication_permission(self) -> None:
         """
         """
         Try to create a room, register an alias for it, and publish it,
         Try to create a room, register an alias for it, and publish it,
         as a user without permission to publish rooms.
         as a user without permission to publish rooms.
@@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=403,
             expect_code=403,
         )
         )
 
 
-    def test_allowed_when_creating_private_room(self):
+    def test_allowed_when_creating_private_room(self) -> None:
         """
         """
         Try to create a room, register an alias for it, and NOT publish it,
         Try to create a room, register an alias for it, and NOT publish it,
         as a user without permission to publish rooms.
         as a user without permission to publish rooms.
@@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=200,
             expect_code=200,
         )
         )
 
 
-    def test_allowed_with_publication_permission(self):
+    def test_allowed_with_publication_permission(self) -> None:
         """
         """
         Try to create a room, register an alias for it, and publish it,
         Try to create a room, register an alias for it, and publish it,
         as a user WITH permission to publish rooms.
         as a user WITH permission to publish rooms.
@@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=200,
             expect_code=200,
         )
         )
 
 
-    def test_denied_publication_with_invalid_alias(self):
+    def test_denied_publication_with_invalid_alias(self) -> None:
         """
         """
         Try to create a room, register an alias for it, and publish it,
         Try to create a room, register an alias for it, and publish it,
         as a user WITH permission to publish rooms.
         as a user WITH permission to publish rooms.
@@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
             expect_code=403,
             expect_code=403,
         )
         )
 
 
-    def test_can_create_as_private_room_after_rejection(self):
+    def test_can_create_as_private_room_after_rejection(self) -> None:
         """
         """
         After failing to publish a room with an alias as a user without publish permission,
         After failing to publish a room with an alias as a user without publish permission,
         retry as the same user, but without publishing the room.
         retry as the same user, but without publishing the room.
@@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
         self.test_denied_without_publication_permission()
         self.test_denied_without_publication_permission()
         self.test_allowed_when_creating_private_room()
         self.test_allowed_when_creating_private_room()
 
 
-    def test_can_create_with_permission_after_rejection(self):
+    def test_can_create_with_permission_after_rejection(self) -> None:
         """
         """
         After failing to publish a room with an alias as a user without publish permission,
         After failing to publish a room with an alias as a user without publish permission,
         retry as someone with permission, using the same alias.
         retry as someone with permission, using the same alias.
@@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
 
 
     servlets = [directory.register_servlets, room.register_servlets]
     servlets = [directory.register_servlets, room.register_servlets]
 
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, hs: HomeServer
+    ) -> HomeServer:
         room_id = self.helper.create_room_as(self.user_id)
         room_id = self.helper.create_room_as(self.user_id)
 
 
         channel = self.make_request(
         channel = self.make_request(
@@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
 
 
         return hs
         return hs
 
 
-    def test_disabling_room_list(self):
+    def test_disabling_room_list(self) -> None:
         self.room_list_handler.enable_room_list_search = True
         self.room_list_handler.enable_room_list_search = True
         self.directory_handler.enable_room_list_search = True
         self.directory_handler.enable_room_list_search = True
 
 

+ 20 - 16
tests/handlers/test_e2e_keys.py

@@ -20,33 +20,37 @@ from parameterized import parameterized
 from signedjson import key as key, sign as sign
 from signedjson import key as key, sign as sign
 
 
 from twisted.internet import defer
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.errors import Codes, SynapseError
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 
 from tests import unittest
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.test_utils import make_awaitable
 
 
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         return self.setup_test_homeserver(federation_client=mock.Mock())
         return self.setup_test_homeserver(federation_client=mock.Mock())
 
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = hs.get_e2e_keys_handler()
         self.handler = hs.get_e2e_keys_handler()
         self.store = self.hs.get_datastores().main
         self.store = self.hs.get_datastores().main
 
 
-    def test_query_local_devices_no_devices(self):
+    def test_query_local_devices_no_devices(self) -> None:
         """If the user has no devices, we expect an empty list."""
         """If the user has no devices, we expect an empty list."""
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
         self.assertDictEqual(res, {local_user: {}})
 
 
-    def test_reupload_one_time_keys(self):
+    def test_reupload_one_time_keys(self) -> None:
         """we should be able to re-upload the same keys"""
         """we should be able to re-upload the same keys"""
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         device_id = "xyz"
-        keys = {
+        keys: JsonDict = {
             "alg1:k1": "key1",
             "alg1:k1": "key1",
             "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
             "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}},
             "alg2:k3": {"key": "key3"},
             "alg2:k3": {"key": "key3"},
@@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
             res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}}
         )
         )
 
 
-    def test_change_one_time_keys(self):
+    def test_change_one_time_keys(self) -> None:
         """attempts to change one-time-keys should be rejected"""
         """attempts to change one-time-keys should be rejected"""
 
 
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
@@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             SynapseError,
             SynapseError,
         )
         )
 
 
-    def test_claim_one_time_key(self):
+    def test_claim_one_time_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         device_id = "xyz"
         keys = {"alg1:k1": "key1"}
         keys = {"alg1:k1": "key1"}
@@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
             },
         )
         )
 
 
-    def test_fallback_key(self):
+    def test_fallback_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         device_id = "xyz"
         fallback_key = {"alg1:k1": "fallback_key1"}
         fallback_key = {"alg1:k1": "fallback_key1"}
@@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
         )
         )
 
 
-    def test_replace_master_key(self):
+    def test_replace_master_key(self) -> None:
         """uploading a new signing key should make the old signing key unavailable"""
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
         keys1 = {
         keys1 = {
@@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         )
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
 
-    def test_reupload_signatures(self):
+    def test_reupload_signatures(self) -> None:
         """re-uploading a signature should not fail"""
         """re-uploading a signature should not fail"""
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
         keys1 = {
         keys1 = {
@@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
         self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
         self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
         self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
 
 
-    def test_self_signing_key_doesnt_show_up_as_device(self):
+    def test_self_signing_key_doesnt_show_up_as_device(self) -> None:
         """signing keys should be hidden when fetching a user's devices"""
         """signing keys should be hidden when fetching a user's devices"""
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
         keys1 = {
         keys1 = {
@@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
         res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
         self.assertDictEqual(res, {local_user: {}})
 
 
-    def test_upload_signatures(self):
+    def test_upload_signatures(self) -> None:
         """should check signatures that are uploaded"""
         """should check signatures that are uploaded"""
         # set up a user with cross-signing keys and a device.  This user will
         # set up a user with cross-signing keys and a device.  This user will
         # try uploading signatures
         # try uploading signatures
@@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
             other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
         )
         )
 
 
-    def test_query_devices_remote_no_sync(self):
+    def test_query_devices_remote_no_sync(self) -> None:
         """Tests that querying keys for a remote user that we don't share a room
         """Tests that querying keys for a remote user that we don't share a room
         with returns the cross signing keys correctly.
         with returns the cross signing keys correctly.
         """
         """
@@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             },
             },
         )
         )
 
 
-    def test_query_devices_remote_sync(self):
+    def test_query_devices_remote_sync(self) -> None:
         """Tests that querying keys for a remote user that we share a room with,
         """Tests that querying keys for a remote user that we share a room with,
         but haven't yet fetched the keys for, returns the cross signing keys
         but haven't yet fetched the keys for, returns the cross signing keys
         correctly.
         correctly.
@@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             (["device_1", "device_2"],),
             (["device_1", "device_2"],),
         ]
         ]
     )
     )
-    def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+    def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None:
         """Test that requests for all of a remote user's devices are cached.
         """Test that requests for all of a remote user's devices are cached.
 
 
         We do this by asserting that only one call over federation was made, and that
         We do this by asserting that only one call over federation was made, and that
@@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         """
         """
         local_user_id = "@test:test"
         local_user_id = "@test:test"
         remote_user_id = "@test:other"
         remote_user_id = "@test:other"
-        request_body = {"device_keys": {remote_user_id: []}}
+        request_body: JsonDict = {"device_keys": {remote_user_id: []}}
 
 
         response_devices = [
         response_devices = [
             {
             {

+ 50 - 44
tests/handlers/test_oidc.py

@@ -13,14 +13,18 @@
 # limitations under the License.
 # limitations under the License.
 import json
 import json
 import os
 import os
+from typing import Any, Dict
 from unittest.mock import ANY, Mock, patch
 from unittest.mock import ANY, Mock, patch
 from urllib.parse import parse_qs, urlparse
 from urllib.parse import parse_qs, urlparse
 
 
 import pymacaroons
 import pymacaroons
 
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.handlers.sso import MappingException
 from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 from synapse.util.macaroons import get_value_from_macaroon
 from synapse.util.macaroons import get_value_from_macaroon
 
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
         }
 
 
 
 
-async def get_json(url):
+async def get_json(url: str) -> JsonDict:
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     if url == WELL_KNOWN:
     if url == WELL_KNOWN:
         # Minimal discovery document, as defined in OpenID.Discovery
         # Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
     elif url == JWKS_URI:
     elif url == JWKS_URI:
         return {"keys": []}
         return {"keys": []}
 
 
+    return {}
+
 
 
 def _key_file_path() -> str:
 def _key_file_path() -> str:
     """path to a file containing the private half of a test key"""
     """path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
     if not HAS_OIDC:
     if not HAS_OIDC:
         skip = "requires OIDC"
         skip = "requires OIDC"
 
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
         config["public_baseurl"] = BASE_URL
         return config
         return config
 
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.http_client = Mock(spec=["get_json"])
         self.http_client = Mock(spec=["get_json"])
         self.http_client.get_json.side_effect = get_json
         self.http_client.get_json.side_effect = get_json
         self.http_client.user_agent = b"Synapse Test"
         self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         sso_handler = hs.get_sso_handler()
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
         self.render_error = Mock(return_value=None)
-        sso_handler.render_error = self.render_error
+        sso_handler.render_error = self.render_error  # type: ignore[assignment]
 
 
         # Reduce the number of attempts when generating MXIDs.
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
         sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return args
         return args
 
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_config(self):
+    def test_config(self) -> None:
         """Basic config correctly sets up the callback URL and client auth correctly."""
         """Basic config correctly sets up the callback URL and client auth correctly."""
         self.assertEqual(self.provider._callback_url, CALLBACK_URL)
         self.assertEqual(self.provider._callback_url, CALLBACK_URL)
         self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
         self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
         self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
         self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
-    def test_discovery(self):
+    def test_discovery(self) -> None:
         """The handler should discover the endpoints from OIDC discovery document."""
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
         # This would throw if some metadata were invalid
         metadata = self.get_success(self.provider.load_metadata())
         metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_not_called()
         self.http_client.get_json.assert_not_called()
 
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
-    def test_no_discovery(self):
+    def test_no_discovery(self) -> None:
         """When discovery is disabled, it should not try to load from discovery document."""
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
         self.http_client.get_json.assert_not_called()
 
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
-    def test_load_jwks(self):
+    def test_load_jwks(self) -> None:
         """JWKS loading is done once (then cached) if used."""
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
         jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_validate_config(self):
+    def test_validate_config(self) -> None:
         """Provider metadatas are extensively validated."""
         """Provider metadatas are extensively validated."""
         h = self.provider
         h = self.provider
 
 
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             force_load_metadata()
             force_load_metadata()
 
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
-    def test_skip_verification(self):
+    def test_skip_verification(self) -> None:
         """Provider metadata validation can be disabled by config."""
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
             # This should not throw
             get_awaitable_result(self.provider.load_metadata())
             get_awaitable_result(self.provider.load_metadata())
 
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_redirect_request(self):
+    def test_redirect_request(self) -> None:
         """The redirect request has the right arguments & generates a valid session cookie."""
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["cookies"])
         req = Mock(spec=["cookies"])
         req.cookies = []
         req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(redirect, "http://client/redirect")
         self.assertEqual(redirect, "http://client/redirect")
 
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback_error(self):
+    def test_callback_error(self) -> None:
         """Errors from the provider returned in the callback are displayed."""
         """Errors from the provider returned in the callback are displayed."""
         request = Mock(args={})
         request = Mock(args={})
         request.args[b"error"] = [b"invalid_client"]
         request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_client", "some description")
         self.assertRenderedError("invalid_client", "some description")
 
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback(self):
+    def test_callback(self) -> None:
         """Code callback works and display errors if something went wrong.
         """Code callback works and display errors if something went wrong.
 
 
         A lot of scenarios are tested here:
         A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": username,
             "username": username,
         }
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.provider._exchange_code = simple_async_mock(return_value=token)
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
         auth_handler = self.hs.get_auth_handler()
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         auth_handler.complete_sso_login = simple_async_mock()
 
 
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
             self.assertRenderedError("mapping_error")
 
 
         # Handle ID token errors
         # Handle ID token errors
-        self.provider._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
         self.assertRenderedError("invalid_token")
 
 
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "type": "bearer",
             "type": "bearer",
             "access_token": "access_token",
             "access_token": "access_token",
         }
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
         self.get_success(self.handler.handle_oidc_callback(request))
 
 
         auth_handler.complete_sso_login.assert_called_once_with(
         auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         id_token = {
         id_token = {
             "sid": "abcdefgh",
             "sid": "abcdefgh",
         }
         }
-        self.provider._parse_id_token = simple_async_mock(return_value=id_token)
-        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=id_token)  # type: ignore[assignment]
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
         auth_handler.complete_sso_login.reset_mock()
         auth_handler.complete_sso_login.reset_mock()
         self.provider._fetch_userinfo.reset_mock()
         self.provider._fetch_userinfo.reset_mock()
         self.get_success(self.handler.handle_oidc_callback(request))
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.render_error.assert_not_called()
         self.render_error.assert_not_called()
 
 
         # Handle userinfo fetching error
         # Handle userinfo fetching error
-        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
         self.assertRenderedError("fetch_error")
 
 
         # Handle code exchange failure
         # Handle code exchange failure
         from synapse.handlers.oidc import OidcError
         from synapse.handlers.oidc import OidcError
 
 
-        self.provider._exchange_code = simple_async_mock(
+        self.provider._exchange_code = simple_async_mock(  # type: ignore[assignment]
             raises=OidcError("invalid_request")
             raises=OidcError("invalid_request")
         )
         )
         self.get_success(self.handler.handle_oidc_callback(request))
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
         self.assertRenderedError("invalid_request")
 
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback_session(self):
+    def test_callback_session(self) -> None:
         """The callback verifies the session presence and validity"""
         """The callback verifies the session presence and validity"""
         request = Mock(spec=["args", "getCookie", "cookies"])
         request = Mock(spec=["args", "getCookie", "cookies"])
 
 
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
     @override_config(
         {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
         {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
     )
     )
-    def test_exchange_code(self):
+    def test_exchange_code(self) -> None:
         """Code exchange behaves correctly and handles various error scenarios."""
         """Code exchange behaves correctly and handles various error scenarios."""
         token = {"type": "bearer"}
         token = {"type": "bearer"}
         token_json = json.dumps(token).encode("utf-8")
         token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
             }
         }
         }
     )
     )
-    def test_exchange_code_jwt_key(self):
+    def test_exchange_code_jwt_key(self) -> None:
         """Test that code exchange works with a JWK client secret."""
         """Test that code exchange works with a JWK client secret."""
         from authlib.jose import jwt
         from authlib.jose import jwt
 
 
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
             }
         }
         }
     )
     )
-    def test_exchange_code_no_auth(self):
+    def test_exchange_code_no_auth(self) -> None:
         """Test that code exchange works with no client secret."""
         """Test that code exchange works with no client secret."""
         token = {"type": "bearer"}
         token = {"type": "bearer"}
         self.http_client.request = simple_async_mock(
         self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
             }
         }
         }
     )
     )
-    def test_extra_attributes(self):
+    def test_extra_attributes(self) -> None:
         """
         """
         Login while using a mapping provider that implements get_extra_attributes.
         Login while using a mapping provider that implements get_extra_attributes.
         """
         """
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "foo",
             "username": "foo",
             "phone": "1234567",
             "phone": "1234567",
         }
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
         auth_handler = self.hs.get_auth_handler()
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         auth_handler.complete_sso_login = simple_async_mock()
 
 
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         )
 
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_map_userinfo_to_user(self):
+    def test_map_userinfo_to_user(self) -> None:
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         auth_handler.complete_sso_login = simple_async_mock()
 
 
-        userinfo = {
+        userinfo: dict = {
             "sub": "test_user",
             "sub": "test_user",
             "username": "test_user",
             "username": "test_user",
         }
         }
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         )
 
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
-    def test_map_userinfo_to_existing_user(self):
+    def test_map_userinfo_to_existing_user(self) -> None:
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
         store = self.hs.get_datastores().main
         store = self.hs.get_datastores().main
         user = UserID.from_string("@test_user:test")
         user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         )
 
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_map_userinfo_to_invalid_localpart(self):
+    def test_map_userinfo_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
         """If the mapping provider generates an invalid localpart it should be rejected."""
         self.get_success(
         self.get_success(
             _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
             _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
             }
         }
         }
     )
     )
-    def test_map_userinfo_to_user_retries(self):
+    def test_map_userinfo_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         )
 
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_empty_localpart(self):
+    def test_empty_localpart(self) -> None:
         """Attempts to map onto an empty localpart should be rejected."""
         """Attempts to map onto an empty localpart should be rejected."""
         userinfo = {
         userinfo = {
             "sub": "tester",
             "sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
             }
         }
         }
     )
     )
-    def test_null_localpart(self):
+    def test_null_localpart(self) -> None:
         """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
         """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
         userinfo = {
         userinfo = {
             "sub": "tester",
             "sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
             }
         }
         }
     )
     )
-    def test_attribute_requirements(self):
+    def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the OIDC userinfo response."""
         """The required attributes must be met from the OIDC userinfo response."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
             }
         }
         }
     )
     )
-    def test_attribute_requirements_contains(self):
+    def test_attribute_requirements_contains(self) -> None:
         """Test that auth succeeds if userinfo attribute CONTAINS required value"""
         """Test that auth succeeds if userinfo attribute CONTAINS required value"""
         auth_handler = self.hs.get_auth_handler()
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
             }
         }
         }
     )
     )
-    def test_attribute_requirements_mismatch(self):
+    def test_attribute_requirements_mismatch(self) -> None:
         """
         """
         Test that auth fails if attributes exist but don't match,
         Test that auth fails if attributes exist but don't match,
         or are non-string values.
         or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         auth_handler = self.hs.get_auth_handler()
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": "not_foobar" attribute should fail
         # userinfo with "test": "not_foobar" attribute should fail
-        userinfo = {
+        userinfo: dict = {
             "sub": "tester",
             "sub": "tester",
             "username": "tester",
             "username": "tester",
             "test": "not_foobar",
             "test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
 
 
     handler = hs.get_oidc_handler()
     handler = hs.get_oidc_handler()
     provider = handler._providers["oidc"]
     provider = handler._providers["oidc"]
-    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
-    provider._parse_id_token = simple_async_mock(return_value=userinfo)
-    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})  # type: ignore[assignment]
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
 
 
     state = "state"
     state = "state"
     session = handler._token_generator.generate_oidc_session_token(
     session = handler._token_generator.generate_oidc_session_token(

+ 24 - 19
tests/handlers/test_profile.py

@@ -11,14 +11,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
-from typing import Any, Dict
+from typing import Any, Awaitable, Callable, Dict
 from unittest.mock import Mock
 from unittest.mock import Mock
 
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.types
 import synapse.types
 from synapse.api.errors import AuthError, SynapseError
 from synapse.api.errors import AuthError, SynapseError
 from synapse.rest import admin
 from synapse.rest import admin
 from synapse.server import HomeServer
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 
 
 from tests import unittest
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.test_utils import make_awaitable
@@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
 
     servlets = [admin.register_servlets]
     servlets = [admin.register_servlets]
 
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.mock_federation = Mock()
         self.mock_federation = Mock()
         self.mock_registry = Mock()
         self.mock_registry = Mock()
 
 
-        self.query_handlers = {}
+        self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
 
 
-        def register_query_handler(query_type, handler):
+        def register_query_handler(
+            query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
+        ) -> None:
             self.query_handlers[query_type] = handler
             self.query_handlers[query_type] = handler
 
 
         self.mock_registry.register_query_handler = register_query_handler
         self.mock_registry.register_query_handler = register_query_handler
@@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         )
         )
         return hs
         return hs
 
 
-    def prepare(self, reactor, clock, hs: HomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.store = hs.get_datastores().main
 
 
         self.frank = UserID.from_string("@1234abcd:test")
         self.frank = UserID.from_string("@1234abcd:test")
@@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
 
         self.handler = hs.get_profile_handler()
         self.handler = hs.get_profile_handler()
 
 
-    def test_get_my_name(self):
+    def test_get_my_name(self) -> None:
         self.get_success(
         self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
         )
@@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
 
         self.assertEqual("Frank", displayname)
         self.assertEqual("Frank", displayname)
 
 
-    def test_set_my_name(self):
+    def test_set_my_name(self) -> None:
         self.get_success(
         self.get_success(
             self.handler.set_displayname(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
@@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             self.get_success(self.store.get_profile_displayname(self.frank.localpart))
             self.get_success(self.store.get_profile_displayname(self.frank.localpart))
         )
         )
 
 
-    def test_set_my_name_if_disabled(self):
+    def test_set_my_name_if_disabled(self) -> None:
         self.hs.config.registration.enable_set_displayname = False
         self.hs.config.registration.enable_set_displayname = False
 
 
         # Setting displayname for the first time is allowed
         # Setting displayname for the first time is allowed
@@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             SynapseError,
             SynapseError,
         )
         )
 
 
-    def test_set_my_name_noauth(self):
+    def test_set_my_name_noauth(self) -> None:
         self.get_failure(
         self.get_failure(
             self.handler.set_displayname(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
                 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
@@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             AuthError,
             AuthError,
         )
         )
 
 
-    def test_get_other_name(self):
+    def test_get_other_name(self) -> None:
         self.mock_federation.make_query.return_value = make_awaitable(
         self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
             {"displayname": "Alice"}
         )
         )
@@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             ignore_backoff=True,
             ignore_backoff=True,
         )
         )
 
 
-    def test_incoming_fed_query(self):
+    def test_incoming_fed_query(self) -> None:
         self.get_success(self.store.create_profile("caroline"))
         self.get_success(self.store.create_profile("caroline"))
         self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
         self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
 
 
@@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
 
         self.assertEqual({"displayname": "Caroline"}, response)
         self.assertEqual({"displayname": "Caroline"}, response)
 
 
-    def test_get_my_avatar(self):
+    def test_get_my_avatar(self) -> None:
         self.get_success(
         self.get_success(
             self.store.set_profile_avatar_url(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
                 self.frank.localpart, "http://my.server/me.png"
@@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
 
 
         self.assertEqual("http://my.server/me.png", avatar_url)
         self.assertEqual("http://my.server/me.png", avatar_url)
 
 
-    def test_set_my_avatar(self):
+    def test_set_my_avatar(self) -> None:
         self.get_success(
         self.get_success(
             self.handler.set_avatar_url(
             self.handler.set_avatar_url(
                 self.frank,
                 self.frank,
@@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
         )
         )
 
 
-    def test_set_my_avatar_if_disabled(self):
+    def test_set_my_avatar_if_disabled(self) -> None:
         self.hs.config.registration.enable_set_avatar_url = False
         self.hs.config.registration.enable_set_avatar_url = False
 
 
         # Setting displayname for the first time is allowed
         # Setting displayname for the first time is allowed
@@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
             SynapseError,
             SynapseError,
         )
         )
 
 
-    def test_avatar_constraints_no_config(self):
+    def test_avatar_constraints_no_config(self) -> None:
         """Tests that the method to check an avatar against configured constraints skips
         """Tests that the method to check an avatar against configured constraints skips
         all of its check if no constraint is configured.
         all of its check if no constraint is configured.
         """
         """
@@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertTrue(res)
         self.assertTrue(res)
 
 
     @unittest.override_config({"max_avatar_size": 50})
     @unittest.override_config({"max_avatar_size": 50})
-    def test_avatar_constraints_missing(self):
+    def test_avatar_constraints_missing(self) -> None:
         """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
         """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't
         be found.
         be found.
         """
         """
@@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertFalse(res)
         self.assertFalse(res)
 
 
     @unittest.override_config({"max_avatar_size": 50})
     @unittest.override_config({"max_avatar_size": 50})
-    def test_avatar_constraints_file_size(self):
+    def test_avatar_constraints_file_size(self) -> None:
         """Tests that a file that's above the allowed file size is forbidden but one
         """Tests that a file that's above the allowed file size is forbidden but one
         that's below it is allowed.
         that's below it is allowed.
         """
         """
@@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
         self.assertFalse(res)
         self.assertFalse(res)
 
 
     @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
     @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
-    def test_avatar_constraint_mime_type(self):
+    def test_avatar_constraint_mime_type(self) -> None:
         """Tests that a file with an unauthorised MIME type is forbidden but one with
         """Tests that a file with an unauthorised MIME type is forbidden but one with
         an authorised content type is allowed.
         an authorised content type is allowed.
         """
         """

+ 14 - 10
tests/handlers/test_saml.py

@@ -12,12 +12,16 @@
 #  See the License for the specific language governing permissions and
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 #  limitations under the License.
 
 
-from typing import Optional
+from typing import Any, Dict, Optional
 from unittest.mock import Mock
 from unittest.mock import Mock
 
 
 import attr
 import attr
 
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.errors import RedirectException
 from synapse.api.errors import RedirectException
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 
 from tests.test_utils import simple_async_mock
 from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 from tests.unittest import HomeserverTestCase, override_config
@@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider):
 
 
 
 
 class SamlHandlerTestCase(HomeserverTestCase):
 class SamlHandlerTestCase(HomeserverTestCase):
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
         config["public_baseurl"] = BASE_URL
-        saml_config = {
+        saml_config: Dict[str, Any] = {
             "sp_config": {"metadata": {}},
             "sp_config": {"metadata": {}},
             # Disable grandfathering.
             # Disable grandfathering.
             "grandfathered_mxid_source_attribute": None,
             "grandfathered_mxid_source_attribute": None,
@@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
 
         return config
         return config
 
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver()
         hs = self.setup_test_homeserver()
 
 
         self.handler = hs.get_saml_handler()
         self.handler = hs.get_saml_handler()
@@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
     elif not has_xmlsec1:
     elif not has_xmlsec1:
         skip = "Requires xmlsec1"
         skip = "Requires xmlsec1"
 
 
-    def test_map_saml_response_to_user(self):
+    def test_map_saml_response_to_user(self) -> None:
         """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
         """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
 
 
         # stub out the auth handler
         # stub out the auth handler
@@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         )
 
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
-    def test_map_saml_response_to_existing_user(self):
+    def test_map_saml_response_to_existing_user(self) -> None:
         """Existing users can log in with SAML account."""
         """Existing users can log in with SAML account."""
         store = self.hs.get_datastores().main
         store = self.hs.get_datastores().main
         self.get_success(
         self.get_success(
@@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             auth_provider_session_id=None,
             auth_provider_session_id=None,
         )
         )
 
 
-    def test_map_saml_response_to_invalid_localpart(self):
+    def test_map_saml_response_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
         """If the mapping provider generates an invalid localpart it should be rejected."""
 
 
         # stub out the auth handler
         # stub out the auth handler
@@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
         )
         )
         auth_handler.complete_sso_login.assert_not_called()
         auth_handler.complete_sso_login.assert_not_called()
 
 
-    def test_map_saml_response_to_user_retries(self):
+    def test_map_saml_response_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
 
 
         # stub out the auth handler and error renderer
         # stub out the auth handler and error renderer
@@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             }
             }
         }
         }
     )
     )
-    def test_map_saml_response_redirect(self):
+    def test_map_saml_response_redirect(self) -> None:
         """Test a mapping provider that raises a RedirectException"""
         """Test a mapping provider that raises a RedirectException"""
 
 
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
@@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             },
             },
         }
         }
     )
     )
-    def test_attribute_requirements(self):
+    def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the SAML response."""
         """The required attributes must be met from the SAML response."""
 
 
         # stub out the auth handler
         # stub out the auth handler