Ver Fonte

Add type hints to `tests/rest/client` (#12108)

* Add type hints to `tests/rest/client`

* newsfile

* fix imports

* add `test_account.py`

* Remove one type hint in `test_report_event.py`

* change `on_create_room` to `async`

* update new functions in `test_third_party_rules.py`

* Add `test_filter.py`

* add `test_rooms.py`

* change to `assertEquals` to `assertEqual`

* lint
Dirk Klimpel há 2 anos atrás
pai
commit
2ffaf30803

+ 1 - 0
changelog.d/12108.misc

@@ -0,0 +1 @@
+Add type hints to `tests/rest/client`.

+ 0 - 6
mypy.ini

@@ -78,13 +78,7 @@ exclude = (?x)
    |tests/push/test_http.py
    |tests/push/test_presentable_names.py
    |tests/push/test_push_rule_evaluator.py
-   |tests/rest/client/test_account.py
-   |tests/rest/client/test_filter.py
-   |tests/rest/client/test_report_event.py
-   |tests/rest/client/test_rooms.py
-   |tests/rest/client/test_third_party_rules.py
    |tests/rest/client/test_transactions.py
-   |tests/rest/client/test_typing.py
    |tests/rest/key/v2/test_remote_key_resource.py
    |tests/rest/media/v1/test_base.py
    |tests/rest/media/v1/test_media_storage.py

+ 157 - 133
tests/rest/client/test_account.py

@@ -15,11 +15,12 @@ import json
 import os
 import re
 from email.parser import Parser
-from typing import Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
 from unittest.mock import Mock
 
 import pkg_resources
 
+from twisted.internet.interfaces import IReactorTCP
 from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
@@ -30,6 +31,7 @@ from synapse.rest import admin
 from synapse.rest.client import account, login, register, room
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
 from synapse.server import HomeServer
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
@@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
 
         # Email config.
@@ -67,20 +69,27 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         hs = self.setup_test_homeserver(config=config)
 
         async def sendmail(
-            reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
-        ):
-            self.email_attempts.append(msg)
-
-        self.email_attempts = []
+            reactor: IReactorTCP,
+            smtphost: str,
+            smtpport: int,
+            from_addr: str,
+            to_addr: str,
+            msg_bytes: bytes,
+            *args: Any,
+            **kwargs: Any,
+        ) -> None:
+            self.email_attempts.append(msg_bytes)
+
+        self.email_attempts: List[bytes] = []
         hs.get_send_email_handler()._sendmail = sendmail
 
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
         self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
 
-    def test_basic_password_reset(self):
+    def test_basic_password_reset(self) -> None:
         """Test basic password reset flow"""
         old_password = "monkey"
         new_password = "kangeroo"
@@ -118,7 +127,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         self.attempt_wrong_password_login("kermit", old_password)
 
     @override_config({"rc_3pid_validation": {"burst_count": 3}})
-    def test_ratelimit_by_email(self):
+    def test_ratelimit_by_email(self) -> None:
         """Test that we ratelimit /requestToken for the same email."""
         old_password = "monkey"
         new_password = "kangeroo"
@@ -139,7 +148,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        def reset(ip):
+        def reset(ip: str) -> None:
             client_secret = "foobar"
             session_id = self._request_token(email, client_secret, ip)
 
@@ -166,7 +175,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(cm.exception.code, 429)
 
-    def test_basic_password_reset_canonicalise_email(self):
+    def test_basic_password_reset_canonicalise_email(self) -> None:
         """Test basic password reset flow
         Request password reset with different spelling
         """
@@ -206,7 +215,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the old password
         self.attempt_wrong_password_login("kermit", old_password)
 
-    def test_cant_reset_password_without_clicking_link(self):
+    def test_cant_reset_password_without_clicking_link(self) -> None:
         """Test that we do actually need to click the link in the email"""
         old_password = "monkey"
         new_password = "kangeroo"
@@ -241,7 +250,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the new password
         self.attempt_wrong_password_login("kermit", new_password)
 
-    def test_no_valid_token(self):
+    def test_no_valid_token(self) -> None:
         """Test that we do actually need to request a token and can't just
         make a session up.
         """
@@ -277,7 +286,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         self.attempt_wrong_password_login("kermit", new_password)
 
     @unittest.override_config({"request_token_inhibit_3pid_errors": True})
-    def test_password_reset_bad_email_inhibit_error(self):
+    def test_password_reset_bad_email_inhibit_error(self) -> None:
         """Test that triggering a password reset with an email address that isn't bound
         to an account doesn't leak the lack of binding for that address if configured
         that way.
@@ -292,7 +301,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         self.assertIsNotNone(session_id)
 
-    def _request_token(self, email, client_secret, ip="127.0.0.1"):
+    def _request_token(
+        self,
+        email: str,
+        client_secret: str,
+        ip: str = "127.0.0.1",
+    ) -> str:
         channel = self.make_request(
             "POST",
             b"account/password/email/requestToken",
@@ -309,7 +323,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         return channel.json_body["sid"]
 
-    def _validate_token(self, link):
+    def _validate_token(self, link: str) -> None:
         # Remove the host
         path = link.replace("https://example.com", "")
 
@@ -339,7 +353,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(200, channel.code, channel.result)
 
-    def _get_link_from_email(self):
+    def _get_link_from_email(self) -> str:
         assert self.email_attempts, "No emails have been sent"
 
         raw_msg = self.email_attempts[-1].decode("UTF-8")
@@ -354,14 +368,19 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         if not text:
             self.fail("Could not find text portion of email to parse")
 
+        assert text is not None
         match = re.search(r"https://example.com\S+", text)
         assert match, "Could not find link in email"
 
         return match.group(0)
 
     def _reset_password(
-        self, new_password, session_id, client_secret, expected_code=200
-    ):
+        self,
+        new_password: str,
+        session_id: str,
+        client_secret: str,
+        expected_code: int = 200,
+    ) -> None:
         channel = self.make_request(
             "POST",
             b"account/password",
@@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.hs = self.setup_test_homeserver()
         return self.hs
 
-    def test_deactivate_account(self):
+    def test_deactivate_account(self) -> None:
         user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test")
 
@@ -407,7 +426,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", "account/whoami", access_token=tok)
         self.assertEqual(channel.code, 401)
 
-    def test_pending_invites(self):
+    def test_pending_invites(self) -> None:
         """Tests that deactivating a user rejects every pending invite for them."""
         store = self.hs.get_datastores().main
 
@@ -448,7 +467,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(memberships), 1, memberships)
         self.assertEqual(memberships[0].room_id, room_id, memberships)
 
-    def deactivate(self, user_id, tok):
+    def deactivate(self, user_id: str, tok: str) -> None:
         request_data = json.dumps(
             {
                 "auth": {
@@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
         register.register_servlets,
     ]
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["allow_guest_access"] = True
         return config
 
-    def test_GET_whoami(self):
+    def test_GET_whoami(self) -> None:
         device_id = "wouldgohere"
         user_id = self.register_user("kermit", "test")
         tok = self.login("kermit", "test", device_id=device_id)
@@ -496,7 +515,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def test_GET_whoami_guests(self):
+    def test_GET_whoami_guests(self) -> None:
         channel = self.make_request(
             b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
         )
@@ -516,7 +535,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
             },
         )
 
-    def test_GET_whoami_appservices(self):
+    def test_GET_whoami_appservices(self) -> None:
         user_id = "@as:test"
         as_token = "i_am_an_app_service"
 
@@ -541,7 +560,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
         )
         self.assertFalse(hasattr(whoami, "device_id"))
 
-    def _whoami(self, tok):
+    def _whoami(self, tok: str) -> JsonDict:
         channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
         self.assertEqual(channel.code, 200)
         return channel.json_body
@@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         synapse.rest.admin.register_servlets_for_client_rest_resource,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
 
         # Email config.
@@ -576,16 +595,23 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.hs = self.setup_test_homeserver(config=config)
 
         async def sendmail(
-            reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
-        ):
-            self.email_attempts.append(msg)
-
-        self.email_attempts = []
+            reactor: IReactorTCP,
+            smtphost: str,
+            smtpport: int,
+            from_addr: str,
+            to_addr: str,
+            msg_bytes: bytes,
+            *args: Any,
+            **kwargs: Any,
+        ) -> None:
+            self.email_attempts.append(msg_bytes)
+
+        self.email_attempts: List[bytes] = []
         self.hs.get_send_email_handler()._sendmail = sendmail
 
         return self.hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastores().main
 
         self.user_id = self.register_user("kermit", "test")
@@ -593,83 +619,73 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         self.email = "test@example.com"
         self.url_3pid = b"account/3pid"
 
-    def test_add_valid_email(self):
-        self.get_success(self._add_email(self.email, self.email))
+    def test_add_valid_email(self) -> None:
+        self._add_email(self.email, self.email)
 
-    def test_add_valid_email_second_time(self):
-        self.get_success(self._add_email(self.email, self.email))
-        self.get_success(
-            self._request_token_invalid_email(
-                self.email,
-                expected_errcode=Codes.THREEPID_IN_USE,
-                expected_error="Email is already in use",
-            )
+    def test_add_valid_email_second_time(self) -> None:
+        self._add_email(self.email, self.email)
+        self._request_token_invalid_email(
+            self.email,
+            expected_errcode=Codes.THREEPID_IN_USE,
+            expected_error="Email is already in use",
         )
 
-    def test_add_valid_email_second_time_canonicalise(self):
-        self.get_success(self._add_email(self.email, self.email))
-        self.get_success(
-            self._request_token_invalid_email(
-                "TEST@EXAMPLE.COM",
-                expected_errcode=Codes.THREEPID_IN_USE,
-                expected_error="Email is already in use",
-            )
+    def test_add_valid_email_second_time_canonicalise(self) -> None:
+        self._add_email(self.email, self.email)
+        self._request_token_invalid_email(
+            "TEST@EXAMPLE.COM",
+            expected_errcode=Codes.THREEPID_IN_USE,
+            expected_error="Email is already in use",
         )
 
-    def test_add_email_no_at(self):
-        self.get_success(
-            self._request_token_invalid_email(
-                "address-without-at.bar",
-                expected_errcode=Codes.UNKNOWN,
-                expected_error="Unable to parse email address",
-            )
+    def test_add_email_no_at(self) -> None:
+        self._request_token_invalid_email(
+            "address-without-at.bar",
+            expected_errcode=Codes.UNKNOWN,
+            expected_error="Unable to parse email address",
         )
 
-    def test_add_email_two_at(self):
-        self.get_success(
-            self._request_token_invalid_email(
-                "foo@foo@test.bar",
-                expected_errcode=Codes.UNKNOWN,
-                expected_error="Unable to parse email address",
-            )
+    def test_add_email_two_at(self) -> None:
+        self._request_token_invalid_email(
+            "foo@foo@test.bar",
+            expected_errcode=Codes.UNKNOWN,
+            expected_error="Unable to parse email address",
         )
 
-    def test_add_email_bad_format(self):
-        self.get_success(
-            self._request_token_invalid_email(
-                "user@bad.example.net@good.example.com",
-                expected_errcode=Codes.UNKNOWN,
-                expected_error="Unable to parse email address",
-            )
+    def test_add_email_bad_format(self) -> None:
+        self._request_token_invalid_email(
+            "user@bad.example.net@good.example.com",
+            expected_errcode=Codes.UNKNOWN,
+            expected_error="Unable to parse email address",
         )
 
-    def test_add_email_domain_to_lower(self):
-        self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar"))
+    def test_add_email_domain_to_lower(self) -> None:
+        self._add_email("foo@TEST.BAR", "foo@test.bar")
 
-    def test_add_email_domain_with_umlaut(self):
-        self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com"))
+    def test_add_email_domain_with_umlaut(self) -> None:
+        self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")
 
-    def test_add_email_address_casefold(self):
-        self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com"))
+    def test_add_email_address_casefold(self) -> None:
+        self._add_email("Strauß@Example.com", "strauss@example.com")
 
-    def test_address_trim(self):
-        self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
+    def test_address_trim(self) -> None:
+        self._add_email(" foo@test.bar ", "foo@test.bar")
 
     @override_config({"rc_3pid_validation": {"burst_count": 3}})
-    def test_ratelimit_by_ip(self):
+    def test_ratelimit_by_ip(self) -> None:
         """Tests that adding emails is ratelimited by IP"""
 
         # We expect to be able to set three emails before getting ratelimited.
-        self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
-        self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
-        self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+        self._add_email("foo1@test.bar", "foo1@test.bar")
+        self._add_email("foo2@test.bar", "foo2@test.bar")
+        self._add_email("foo3@test.bar", "foo3@test.bar")
 
         with self.assertRaises(HttpResponseException) as cm:
-            self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+            self._add_email("foo4@test.bar", "foo4@test.bar")
 
         self.assertEqual(cm.exception.code, 429)
 
-    def test_add_email_if_disabled(self):
+    def test_add_email_if_disabled(self) -> None:
         """Test adding email to profile when doing so is disallowed"""
         self.hs.config.registration.enable_3pid_changes = False
 
@@ -695,7 +711,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             },
             access_token=self.user_id_tok,
         )
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
         # Get user
@@ -705,10 +721,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             access_token=self.user_id_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assertFalse(channel.json_body["threepids"])
 
-    def test_delete_email(self):
+    def test_delete_email(self) -> None:
         """Test deleting an email from profile"""
         # Add a threepid
         self.get_success(
@@ -727,7 +743,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             {"medium": "email", "address": self.email},
             access_token=self.user_id_tok,
         )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         # Get user
         channel = self.make_request(
@@ -736,10 +752,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             access_token=self.user_id_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assertFalse(channel.json_body["threepids"])
 
-    def test_delete_email_if_disabled(self):
+    def test_delete_email_if_disabled(self) -> None:
         """Test deleting an email from profile when disallowed"""
         self.hs.config.registration.enable_3pid_changes = False
 
@@ -761,7 +777,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             access_token=self.user_id_tok,
         )
 
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
         # Get user
@@ -771,11 +787,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             access_token=self.user_id_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
 
-    def test_cant_add_email_without_clicking_link(self):
+    def test_cant_add_email_without_clicking_link(self) -> None:
         """Test that we do actually need to click the link in the email"""
         client_secret = "foobar"
         session_id = self._request_token(self.email, client_secret)
@@ -797,7 +813,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             },
             access_token=self.user_id_tok,
         )
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
         self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
 
         # Get user
@@ -807,10 +823,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             access_token=self.user_id_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assertFalse(channel.json_body["threepids"])
 
-    def test_no_valid_token(self):
+    def test_no_valid_token(self) -> None:
         """Test that we do actually need to request a token and can't just
         make a session up.
         """
@@ -832,7 +848,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             },
             access_token=self.user_id_tok,
         )
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
         self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
 
         # Get user
@@ -842,11 +858,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             access_token=self.user_id_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assertFalse(channel.json_body["threepids"])
 
     @override_config({"next_link_domain_whitelist": None})
-    def test_next_link(self):
+    def test_next_link(self) -> None:
         """Tests a valid next_link parameter value with no whitelist (good case)"""
         self._request_token(
             "something@example.com",
@@ -856,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         )
 
     @override_config({"next_link_domain_whitelist": None})
-    def test_next_link_exotic_protocol(self):
+    def test_next_link_exotic_protocol(self) -> None:
         """Tests using a esoteric protocol as a next_link parameter value.
         Someone may be hosting a client on IPFS etc.
         """
@@ -868,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         )
 
     @override_config({"next_link_domain_whitelist": None})
-    def test_next_link_file_uri(self):
+    def test_next_link_file_uri(self) -> None:
         """Tests next_link parameters cannot be file URI"""
         # Attempt to use a next_link value that points to the local disk
         self._request_token(
@@ -879,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         )
 
     @override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
-    def test_next_link_domain_whitelist(self):
+    def test_next_link_domain_whitelist(self) -> None:
         """Tests next_link parameters must fit the whitelist if provided"""
 
         # Ensure not providing a next_link parameter still works
@@ -912,7 +928,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         )
 
     @override_config({"next_link_domain_whitelist": []})
-    def test_empty_next_link_domain_whitelist(self):
+    def test_empty_next_link_domain_whitelist(self) -> None:
         """Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
         disallowed
         """
@@ -962,28 +978,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
     def _request_token_invalid_email(
         self,
-        email,
-        expected_errcode,
-        expected_error,
-        client_secret="foobar",
-    ):
+        email: str,
+        expected_errcode: str,
+        expected_error: str,
+        client_secret: str = "foobar",
+    ) -> None:
         channel = self.make_request(
             "POST",
             b"account/3pid/email/requestToken",
             {"client_secret": client_secret, "email": email, "send_attempt": 1},
         )
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(400, channel.code, msg=channel.result["body"])
         self.assertEqual(expected_errcode, channel.json_body["errcode"])
         self.assertEqual(expected_error, channel.json_body["error"])
 
-    def _validate_token(self, link):
+    def _validate_token(self, link: str) -> None:
         # Remove the host
         path = link.replace("https://example.com", "")
 
         channel = self.make_request("GET", path, shorthand=False)
         self.assertEqual(200, channel.code, channel.result)
 
-    def _get_link_from_email(self):
+    def _get_link_from_email(self) -> str:
         assert self.email_attempts, "No emails have been sent"
 
         raw_msg = self.email_attempts[-1].decode("UTF-8")
@@ -998,12 +1014,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
         if not text:
             self.fail("Could not find text portion of email to parse")
 
+        assert text is not None
         match = re.search(r"https://example.com\S+", text)
         assert match, "Could not find link in email"
 
         return match.group(0)
 
-    def _add_email(self, request_email, expected_email):
+    def _add_email(self, request_email: str, expected_email: str) -> None:
         """Test adding an email to profile"""
         previous_email_attempts = len(self.email_attempts)
 
@@ -1030,7 +1047,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             access_token=self.user_id_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
 
         # Get user
         channel = self.make_request(
@@ -1039,7 +1056,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             access_token=self.user_id_tok,
         )
 
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
 
         threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
 
     url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["experimental_features"] = {"msc3720_enabled": True}
 
         return self.setup_test_homeserver(config=config)
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.requester = self.register_user("requester", "password")
         self.requester_tok = self.login("requester", "password")
-        self.server_name = homeserver.config.server.server_name
+        self.server_name = hs.config.server.server_name
 
-    def test_missing_mxid(self):
+    def test_missing_mxid(self) -> None:
         """Tests that not providing any MXID raises an error."""
         self._test_status(
             users=None,
@@ -1074,7 +1091,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
             expected_errcode=Codes.MISSING_PARAM,
         )
 
-    def test_invalid_mxid(self):
+    def test_invalid_mxid(self) -> None:
         """Tests that providing an invalid MXID raises an error."""
         self._test_status(
             users=["bad:test"],
@@ -1082,7 +1099,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
             expected_errcode=Codes.INVALID_PARAM,
         )
 
-    def test_local_user_not_exists(self):
+    def test_local_user_not_exists(self) -> None:
         """Tests that the account status endpoints correctly reports that a user doesn't
         exist.
         """
@@ -1098,7 +1115,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
             expected_failures=[],
         )
 
-    def test_local_user_exists(self):
+    def test_local_user_exists(self) -> None:
         """Tests that the account status endpoint correctly reports that a user doesn't
         exist.
         """
@@ -1115,7 +1132,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
             expected_failures=[],
         )
 
-    def test_local_user_deactivated(self):
+    def test_local_user_deactivated(self) -> None:
         """Tests that the account status endpoint correctly reports a deactivated user."""
         user = self.register_user("someuser", "password")
         self.get_success(
@@ -1135,7 +1152,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
             expected_failures=[],
         )
 
-    def test_mixed_local_and_remote_users(self):
+    def test_mixed_local_and_remote_users(self) -> None:
         """Tests that if some users are remote the account status endpoint correctly
         merges the remote responses with the local result.
         """
@@ -1150,7 +1167,13 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
             "@bad:badremote",
         ]
 
-        async def post_json(destination, path, data, *a, **kwa):
+        async def post_json(
+            destination: str,
+            path: str,
+            data: Optional[JsonDict] = None,
+            *a: Any,
+            **kwa: Any,
+        ) -> Union[JsonDict, list]:
             if destination == "remote":
                 return {
                     "account_statuses": {
@@ -1160,9 +1183,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
                         },
                     }
                 }
-            if destination == "otherremote":
-                return {}
-            if destination == "badremote":
+            elif destination == "badremote":
                 # badremote tries to overwrite the status of a user that doesn't belong
                 # to it (i.e. users[1]) with false data, which Synapse is expected to
                 # ignore.
@@ -1176,6 +1197,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
                         },
                     }
                 }
+            # if destination == "otherremote"
+            else:
+                return {}
 
         # Register a mock that will return the expected result depending on the remote.
         self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
@@ -1205,7 +1229,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
         expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
         expected_failures: Optional[List[str]] = None,
         expected_errcode: Optional[str] = None,
-    ):
+    ) -> None:
         """Send a request to the account status endpoint and check that the response
         matches with what's expected.
 

+ 16 - 13
tests/rest/client/test_filter.py

@@ -12,10 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import Codes
 from synapse.rest.client import filter
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -30,11 +32,11 @@ class FilterTestCase(unittest.HomeserverTestCase):
     EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
     servlets = [filter.register_servlets]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.filtering = hs.get_filtering()
         self.store = hs.get_datastores().main
 
-    def test_add_filter(self):
+    def test_add_filter(self) -> None:
         channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % (self.user_id),
@@ -43,11 +45,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertEqual(channel.json_body, {"filter_id": "0"})
-        filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
+        filter = self.get_success(
+            self.store.get_user_filter(user_localpart="apple", filter_id=0)
+        )
         self.pump()
-        self.assertEqual(filter.result, self.EXAMPLE_FILTER)
+        self.assertEqual(filter, self.EXAMPLE_FILTER)
 
-    def test_add_filter_for_other_user(self):
+    def test_add_filter_for_other_user(self) -> None:
         channel = self.make_request(
             "POST",
             "/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
@@ -57,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403")
         self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
 
-    def test_add_filter_non_local_user(self):
+    def test_add_filter_non_local_user(self) -> None:
         _is_mine = self.hs.is_mine
         self.hs.is_mine = lambda target_user: False
         channel = self.make_request(
@@ -70,14 +74,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"403")
         self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
 
-    def test_get_filter(self):
-        filter_id = defer.ensureDeferred(
+    def test_get_filter(self) -> None:
+        filter_id = self.get_success(
             self.filtering.add_user_filter(
                 user_localpart="apple", user_filter=self.EXAMPLE_FILTER
             )
         )
         self.reactor.advance(1)
-        filter_id = filter_id.result
         channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
         )
@@ -85,7 +88,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"200")
         self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
 
-    def test_get_filter_non_existant(self):
+    def test_get_filter_non_existant(self) -> None:
         channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
         )
@@ -95,7 +98,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
 
     # Currently invalid params do not have an appropriate errcode
     # in errors.py
-    def test_get_filter_invalid_id(self):
+    def test_get_filter_invalid_id(self) -> None:
         channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
         )
@@ -103,7 +106,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.result["code"], b"400")
 
     # No ID also returns an invalid_id error
-    def test_get_filter_no_id(self):
+    def test_get_filter_no_id(self) -> None:
         channel = self.make_request(
             "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
         )

+ 2 - 2
tests/rest/client/test_relations.py

@@ -15,7 +15,7 @@
 
 import itertools
 import urllib.parse
-from typing import Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 from unittest.mock import patch
 
 from twisted.test.proto_helpers import MemoryReactor
@@ -45,7 +45,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
     ]
     hijack_auth = False
 
-    def default_config(self) -> dict:
+    def default_config(self) -> Dict[str, Any]:
         # We need to enable msc1849 support for aggregations
         config = super().default_config()
 

+ 15 - 10
tests/rest/client/test_report_event.py

@@ -14,8 +14,13 @@
 
 import json
 
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.rest.client import login, report_event, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
         report_event.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.admin_user = self.register_user("admin", "pass", admin=True)
         self.admin_user_tok = self.login("admin", "pass")
         self.other_user = self.register_user("user", "pass")
@@ -42,35 +47,35 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
         self.event_id = resp["event_id"]
         self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
 
-    def test_reason_str_and_score_int(self):
+    def test_reason_str_and_score_int(self) -> None:
         data = {"reason": "this makes me sad", "score": -100}
         self._assert_status(200, data)
 
-    def test_no_reason(self):
+    def test_no_reason(self) -> None:
         data = {"score": 0}
         self._assert_status(200, data)
 
-    def test_no_score(self):
+    def test_no_score(self) -> None:
         data = {"reason": "this makes me sad"}
         self._assert_status(200, data)
 
-    def test_no_reason_and_no_score(self):
-        data = {}
+    def test_no_reason_and_no_score(self) -> None:
+        data: JsonDict = {}
         self._assert_status(200, data)
 
-    def test_reason_int_and_score_str(self):
+    def test_reason_int_and_score_str(self) -> None:
         data = {"reason": 10, "score": "string"}
         self._assert_status(400, data)
 
-    def test_reason_zero_and_score_blank(self):
+    def test_reason_zero_and_score_blank(self) -> None:
         data = {"reason": 0, "score": ""}
         self._assert_status(400, data)
 
-    def test_reason_and_score_null(self):
+    def test_reason_and_score_null(self) -> None:
         data = {"reason": None, "score": None}
         self._assert_status(400, data)
 
-    def _assert_status(self, response_status, data):
+    def _assert_status(self, response_status: int, data: JsonDict) -> None:
         channel = self.make_request(
             "POST",
             self.report_path,

Diff do ficheiro suprimidas por serem muito extensas
+ 129 - 124
tests/rest/client/test_rooms.py


+ 69 - 39
tests/rest/client/test_third_party_rules.py

@@ -12,16 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import threading
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes, LoginType, Membership
 from synapse.api.errors import SynapseError
+from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.rest import admin
 from synapse.rest.client import account, login, profile, room
+from synapse.server import HomeServer
 from synapse.types import JsonDict, Requester, StateMap
+from synapse.util import Clock
 from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
@@ -34,7 +40,7 @@ thread_local = threading.local()
 
 
 class LegacyThirdPartyRulesTestModule:
-    def __init__(self, config: Dict, module_api: "ModuleApi"):
+    def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
         # keep a record of the "current" rules module, so that the test can patch
         # it if desired.
         thread_local.rules_module = self
@@ -42,32 +48,36 @@ class LegacyThirdPartyRulesTestModule:
 
     async def on_create_room(
         self, requester: Requester, config: dict, is_requester_admin: bool
-    ):
+    ) -> bool:
         return True
 
-    async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+    async def check_event_allowed(
+        self, event: EventBase, state: StateMap[EventBase]
+    ) -> Union[bool, dict]:
         return True
 
     @staticmethod
-    def parse_config(config):
+    def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
         return config
 
 
 class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
-    def __init__(self, config: Dict, module_api: "ModuleApi"):
+    def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
         super().__init__(config, module_api)
 
-    def on_create_room(
+    async def on_create_room(
         self, requester: Requester, config: dict, is_requester_admin: bool
-    ):
+    ) -> bool:
         return False
 
 
 class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
-    def __init__(self, config: Dict, module_api: "ModuleApi"):
+    def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
         super().__init__(config, module_api)
 
-    async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
+    async def check_event_allowed(
+        self, event: EventBase, state: StateMap[EventBase]
+    ) -> JsonDict:
         d = event.get_dict()
         content = unfreeze(event.content)
         content["foo"] = "bar"
@@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         account.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver()
 
         load_legacy_third_party_event_rules(hs)
@@ -94,22 +104,30 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         # Note that these checks are not relevant to this test case.
 
         # Have this homeserver auto-approve all event signature checking.
-        async def approve_all_signature_checking(_, pdu):
+        async def approve_all_signature_checking(
+            _: RoomVersion, pdu: EventBase
+        ) -> EventBase:
             return pdu
 
-        hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
+        hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking  # type: ignore[assignment]
 
         # Have this homeserver skip event auth checks. This is necessary due to
         # event auth checks ensuring that events were signed by the sender's homeserver.
-        async def _check_event_auth(origin, event, context, *args, **kwargs):
+        async def _check_event_auth(
+            origin: str,
+            event: EventBase,
+            context: EventContext,
+            *args: Any,
+            **kwargs: Any,
+        ) -> EventContext:
             return context
 
-        hs.get_federation_event_handler()._check_event_auth = _check_event_auth
+        hs.get_federation_event_handler()._check_event_auth = _check_event_auth  # type: ignore[assignment]
 
         return hs
 
-    def prepare(self, reactor, clock, homeserver):
-        super().prepare(reactor, clock, homeserver)
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        super().prepare(reactor, clock, hs)
         # Create some users and a room to play with during the tests
         self.user_id = self.register_user("kermit", "monkey")
         self.invitee = self.register_user("invitee", "hackme")
@@ -121,13 +139,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         except Exception:
             pass
 
-    def test_third_party_rules(self):
+    def test_third_party_rules(self) -> None:
         """Tests that a forbidden event is forbidden from being sent, but an allowed one
         can be sent.
         """
         # patch the rules module with a Mock which will return False for some event
         # types
-        async def check(ev, state):
+        async def check(
+            ev: EventBase, state: StateMap[EventBase]
+        ) -> Tuple[bool, Optional[JsonDict]]:
             return ev.type != "foo.bar.forbidden", None
 
         callback = Mock(spec=[], side_effect=check)
@@ -161,7 +181,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         )
         self.assertEqual(channel.result["code"], b"403", channel.result)
 
-    def test_third_party_rules_workaround_synapse_errors_pass_through(self):
+    def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
         """
         Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
         is functional: that SynapseErrors are passed through from check_event_allowed
@@ -172,7 +192,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """
 
         class NastyHackException(SynapseError):
-            def error_dict(self):
+            def error_dict(self) -> JsonDict:
                 """
                 This overrides SynapseError's `error_dict` to nastily inject
                 JSON into the error response.
@@ -182,7 +202,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
                 return result
 
         # add a callback that will raise our hacky exception
-        async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
+        async def check(
+            ev: EventBase, state: StateMap[EventBase]
+        ) -> Tuple[bool, Optional[JsonDict]]:
             raise NastyHackException(429, "message")
 
         self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
@@ -202,11 +224,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             {"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
         )
 
-    def test_cannot_modify_event(self):
+    def test_cannot_modify_event(self) -> None:
         """cannot accidentally modify an event before it is persisted"""
 
         # first patch the event checker so that it will try to modify the event
-        async def check(ev: EventBase, state):
+        async def check(
+            ev: EventBase, state: StateMap[EventBase]
+        ) -> Tuple[bool, Optional[JsonDict]]:
             ev.content = {"x": "y"}
             return True, None
 
@@ -223,10 +247,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         # 500 Internal Server Error
         self.assertEqual(channel.code, 500, channel.result)
 
-    def test_modify_event(self):
+    def test_modify_event(self) -> None:
         """The module can return a modified version of the event"""
         # first patch the event checker so that it will modify the event
-        async def check(ev: EventBase, state):
+        async def check(
+            ev: EventBase, state: StateMap[EventBase]
+        ) -> Tuple[bool, Optional[JsonDict]]:
             d = ev.get_dict()
             d["content"] = {"x": "y"}
             return True, d
@@ -253,10 +279,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         ev = channel.json_body
         self.assertEqual(ev["content"]["x"], "y")
 
-    def test_message_edit(self):
+    def test_message_edit(self) -> None:
         """Ensure that the module doesn't cause issues with edited messages."""
         # first patch the event checker so that it will modify the event
-        async def check(ev: EventBase, state):
+        async def check(
+            ev: EventBase, state: StateMap[EventBase]
+        ) -> Tuple[bool, Optional[JsonDict]]:
             d = ev.get_dict()
             d["content"] = {
                 "msgtype": "m.text",
@@ -315,7 +343,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         ev = channel.json_body
         self.assertEqual(ev["content"]["body"], "EDITED BODY")
 
-    def test_send_event(self):
+    def test_send_event(self) -> None:
         """Tests that a module can send an event into a room via the module api"""
         content = {
             "msgtype": "m.text",
@@ -344,7 +372,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             }
         }
     )
-    def test_legacy_check_event_allowed(self):
+    def test_legacy_check_event_allowed(self) -> None:
         """Tests that the wrapper for legacy check_event_allowed callbacks works
         correctly.
         """
@@ -379,13 +407,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             }
         }
     )
-    def test_legacy_on_create_room(self):
+    def test_legacy_on_create_room(self) -> None:
         """Tests that the wrapper for legacy on_create_room callbacks works
         correctly.
         """
         self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
 
-    def test_sent_event_end_up_in_room_state(self):
+    def test_sent_event_end_up_in_room_state(self) -> None:
         """Tests that a state event sent by a module while processing another state event
         doesn't get dropped from the state of the room. This is to guard against a bug
         where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
@@ -400,7 +428,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         api = self.hs.get_module_api()
 
         # Define a callback that sends a custom event on power levels update.
-        async def test_fn(event: EventBase, state_events):
+        async def test_fn(
+            event: EventBase, state_events: StateMap[EventBase]
+        ) -> Tuple[bool, Optional[JsonDict]]:
             if event.is_state and event.type == EventTypes.PowerLevels:
                 await api.create_and_send_event_into_room(
                     {
@@ -436,7 +466,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             self.assertEqual(channel.code, 200)
             self.assertEqual(channel.json_body["i"], i)
 
-    def test_on_new_event(self):
+    def test_on_new_event(self) -> None:
         """Test that the on_new_event callback is called on new events"""
         on_new_event = Mock(make_awaitable(None))
         self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
@@ -501,7 +531,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
         self.assertEqual(channel.code, 200, channel.result)
 
-    def _update_power_levels(self, event_default: int = 0):
+    def _update_power_levels(self, event_default: int = 0) -> None:
         """Updates the room's power levels.
 
         Args:
@@ -533,7 +563,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             tok=self.tok,
         )
 
-    def test_on_profile_update(self):
+    def test_on_profile_update(self) -> None:
         """Tests that the on_profile_update module callback is correctly called on
         profile updates.
         """
@@ -592,7 +622,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         self.assertEqual(profile_info.display_name, displayname)
         self.assertEqual(profile_info.avatar_url, avatar_url)
 
-    def test_on_profile_update_admin(self):
+    def test_on_profile_update_admin(self) -> None:
         """Tests that the on_profile_update module callback is correctly called on
         profile updates triggered by a server admin.
         """
@@ -634,7 +664,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         self.assertEqual(profile_info.display_name, displayname)
         self.assertEqual(profile_info.avatar_url, avatar_url)
 
-    def test_on_user_deactivation_status_changed(self):
+    def test_on_user_deactivation_status_changed(self) -> None:
         """Tests that the on_user_deactivation_status_changed module callback is called
         correctly when processing a user's deactivation.
         """
@@ -691,7 +721,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         args = profile_mock.call_args[0]
         self.assertTrue(args[3])
 
-    def test_on_user_deactivation_status_changed_admin(self):
+    def test_on_user_deactivation_status_changed_admin(self) -> None:
         """Tests that the on_user_deactivation_status_changed module callback is called
         correctly when processing a user's deactivation triggered by a server admin as
         well as a reactivation.

+ 25 - 16
tests/rest/client/test_typing.py

@@ -14,11 +14,16 @@
 # limitations under the License.
 
 """Tests REST events for /rooms paths."""
-
+from typing import Any
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.rest.client import room
+from synapse.server import HomeServer
+from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import UserID
+from synapse.util import Clock
 
 from tests import unittest
 
@@ -33,7 +38,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
     user = UserID.from_string(user_id)
     servlets = [room.register_servlets]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         hs = self.setup_test_homeserver(
             "red",
@@ -43,30 +48,34 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         self.event_source = hs.get_event_sources().sources.typing
 
-        hs.get_federation_handler = Mock()
+        hs.get_federation_handler = Mock()  # type: ignore[assignment]
 
-        async def get_user_by_access_token(token=None, allow_guest=False):
-            return {
-                "user": UserID.from_string(self.auth_user_id),
-                "token_id": 1,
-                "is_guest": False,
-            }
+        async def get_user_by_access_token(
+            token: str,
+            rights: str = "access",
+            allow_expired: bool = False,
+        ) -> TokenLookupResult:
+            return TokenLookupResult(
+                user_id=self.user_id,
+                is_guest=False,
+                token_id=1,
+            )
 
-        hs.get_auth().get_user_by_access_token = get_user_by_access_token
+        hs.get_auth().get_user_by_access_token = get_user_by_access_token  # type: ignore[assignment]
 
-        async def _insert_client_ip(*args, **kwargs):
+        async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
             return None
 
-        hs.get_datastores().main.insert_client_ip = _insert_client_ip
+        hs.get_datastores().main.insert_client_ip = _insert_client_ip  # type: ignore[assignment]
 
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.room_id = self.helper.create_room_as(self.user_id)
         # Need another user to make notifications actually work
         self.helper.join(self.room_id, user="@jim:red")
 
-    def test_set_typing(self):
+    def test_set_typing(self) -> None:
         channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
@@ -95,7 +104,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
             ],
         )
 
-    def test_set_not_typing(self):
+    def test_set_not_typing(self) -> None:
         channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),
@@ -103,7 +112,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(200, channel.code)
 
-    def test_typing_timeout(self):
+    def test_typing_timeout(self) -> None:
         channel = self.make_request(
             "PUT",
             "/rooms/%s/typing/%s" % (self.room_id, self.user_id),

Alguns ficheiros não foram mostrados porque muitos ficheiros mudaram neste diff