Browse Source

Fix-up type hints for tests.push module. (#14816)

Patrick Cloke 1 year ago
parent
commit
7f2cabf271

+ 1 - 0
changelog.d/14816.misc

@@ -0,0 +1 @@
+Add missing type hints.

+ 1 - 4
mypy.ini

@@ -48,9 +48,6 @@ exclude = (?x)
    |tests/logging/__init__.py
    |tests/logging/test_terse_json.py
    |tests/module_api/test_api.py
-   |tests/push/test_email.py
-   |tests/push/test_presentable_names.py
-   |tests/push/test_push_rule_evaluator.py
    |tests/rest/client/test_transactions.py
    |tests/rest/media/v1/test_media_storage.py
    |tests/server.py
@@ -101,7 +98,7 @@ disallow_untyped_defs = True
 [mypy-tests.metrics.*]
 disallow_untyped_defs = True
 
-[mypy-tests.push.test_bulk_push_rule_evaluator]
+[mypy-tests.push.*]
 disallow_untyped_defs = True
 
 [mypy-tests.rest.*]

+ 4 - 1
stubs/synapse/synapse_rust/push.pyi

@@ -1,4 +1,4 @@
-from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Tuple, Union
+from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union
 
 from synapse.types import JsonDict
 
@@ -54,3 +54,6 @@ class PushRuleEvaluator:
         user_id: Optional[str],
         display_name: Optional[str],
     ) -> Collection[Union[Mapping, str]]: ...
+    def matches(
+        self, condition: JsonDict, user_id: Optional[str], display_name: Optional[str]
+    ) -> bool: ...

+ 24 - 22
tests/push/test_email.py

@@ -13,25 +13,28 @@
 # limitations under the License.
 import email.message
 import os
-from typing import Dict, List, Sequence, Tuple
+from typing import Any, Dict, List, Sequence, Tuple
 
 import attr
 import pkg_resources
 
 from twisted.internet.defer import Deferred
+from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.errors import Codes, SynapseError
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
 
-@attr.s
+@attr.s(auto_attribs=True)
 class _User:
     "Helper wrapper for user ID and access token"
-    id = attr.ib()
-    token = attr.ib()
+    id: str
+    token: str
 
 
 class EmailPusherTests(HomeserverTestCase):
@@ -41,10 +44,9 @@ class EmailPusherTests(HomeserverTestCase):
         room.register_servlets,
         login.register_servlets,
     ]
-    user_id = True
     hijack_auth = False
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
 
         config = self.default_config()
         config["email"] = {
@@ -72,17 +74,17 @@ class EmailPusherTests(HomeserverTestCase):
         # List[Tuple[Deferred, args, kwargs]]
         self.email_attempts: List[Tuple[Deferred, Sequence, Dict]] = []
 
-        def sendmail(*args, **kwargs):
+        def sendmail(*args: Any, **kwargs: Any) -> Deferred:
             # This mocks out synapse.reactor.send_email._sendmail.
-            d = Deferred()
+            d: Deferred = Deferred()
             self.email_attempts.append((d, args, kwargs))
             return d
 
-        hs.get_send_email_handler()._sendmail = sendmail
+        hs.get_send_email_handler()._sendmail = sendmail  # type: ignore[assignment]
 
         return hs
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Register the user who gets notified
         self.user_id = self.register_user("user", "pass")
         self.access_token = self.login("user", "pass")
@@ -129,7 +131,7 @@ class EmailPusherTests(HomeserverTestCase):
         self.auth_handler = hs.get_auth_handler()
         self.store = hs.get_datastores().main
 
-    def test_need_validated_email(self):
+    def test_need_validated_email(self) -> None:
         """Test that we can only add an email pusher if the user has validated
         their email.
         """
@@ -151,7 +153,7 @@ class EmailPusherTests(HomeserverTestCase):
         self.assertEqual(400, cm.exception.code)
         self.assertEqual(Codes.THREEPID_NOT_FOUND, cm.exception.errcode)
 
-    def test_simple_sends_email(self):
+    def test_simple_sends_email(self) -> None:
         # Create a simple room with two users
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
         self.helper.invite(
@@ -171,7 +173,7 @@ class EmailPusherTests(HomeserverTestCase):
 
         self._check_for_mail()
 
-    def test_invite_sends_email(self):
+    def test_invite_sends_email(self) -> None:
         # Create a room and invite the user to it
         room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
         self.helper.invite(
@@ -184,7 +186,7 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about the invite
         self._check_for_mail()
 
-    def test_invite_to_empty_room_sends_email(self):
+    def test_invite_to_empty_room_sends_email(self) -> None:
         # Create a room and invite the user to it
         room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token)
         self.helper.invite(
@@ -200,7 +202,7 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about the invite
         self._check_for_mail()
 
-    def test_multiple_members_email(self):
+    def test_multiple_members_email(self) -> None:
         # We want to test multiple notifications, so we pause processing of push
         # while we send messages.
         self.pusher._pause_processing()
@@ -227,7 +229,7 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about those messages
         self._check_for_mail()
 
-    def test_multiple_rooms(self):
+    def test_multiple_rooms(self) -> None:
         # We want to test multiple notifications from multiple rooms, so we pause
         # processing of push while we send messages.
         self.pusher._pause_processing()
@@ -257,7 +259,7 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about those messages
         self._check_for_mail()
 
-    def test_room_notifications_include_avatar(self):
+    def test_room_notifications_include_avatar(self) -> None:
         # Create a room and set its avatar.
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
         self.helper.send_state(
@@ -290,7 +292,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
         self.assertIn("_matrix/media/v1/thumbnail/DUMMY_MEDIA_ID", html)
 
-    def test_empty_room(self):
+    def test_empty_room(self) -> None:
         """All users leaving a room shouldn't cause the pusher to break."""
         # Create a simple room with two users
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
@@ -309,7 +311,7 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about that message
         self._check_for_mail()
 
-    def test_empty_room_multiple_messages(self):
+    def test_empty_room_multiple_messages(self) -> None:
         """All users leaving a room shouldn't cause the pusher to break."""
         # Create a simple room with two users
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
@@ -329,7 +331,7 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about that message
         self._check_for_mail()
 
-    def test_encrypted_message(self):
+    def test_encrypted_message(self) -> None:
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
         self.helper.invite(
             room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
@@ -342,7 +344,7 @@ class EmailPusherTests(HomeserverTestCase):
         # We should get emailed about that message
         self._check_for_mail()
 
-    def test_no_email_sent_after_removed(self):
+    def test_no_email_sent_after_removed(self) -> None:
         # Create a simple room with two users
         room = self.helper.create_room_as(self.user_id, tok=self.access_token)
         self.helper.invite(
@@ -379,7 +381,7 @@ class EmailPusherTests(HomeserverTestCase):
         pushers = list(pushers)
         self.assertEqual(len(pushers), 0)
 
-    def test_remove_unlinked_pushers_background_job(self):
+    def test_remove_unlinked_pushers_background_job(self) -> None:
         """Checks that all existing pushers associated with unlinked email addresses are removed
         upon running the remove_deleted_email_pushers background update.
         """

+ 1 - 1
tests/push/test_http.py

@@ -46,7 +46,7 @@ class HTTPPusherTests(HomeserverTestCase):
 
         m = Mock()
 
-        def post_json_get_json(url, body):
+        def post_json_get_json(url: str, body: JsonDict) -> Deferred:
             d: Deferred = Deferred()
             self.push_attempts.append((d, url, body))
             return make_deferred_yieldable(d)

+ 23 - 21
tests/push/test_presentable_names.py

@@ -12,11 +12,11 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
-from typing import Iterable, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple, cast
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent
+from synapse.events import EventBase, FrozenEvent
 from synapse.push.presentable_names import calculate_room_name
 from synapse.types import StateKey, StateMap
 
@@ -51,13 +51,15 @@ class MockDataStore:
             )
 
     async def get_event(
-        self, event_id: StateKey, allow_none: bool = False
+        self, event_id: str, allow_none: bool = False
     ) -> Optional[FrozenEvent]:
         assert allow_none, "Mock not configured for allow_none = False"
 
-        return self._events.get(event_id)
+        # Decode the state key from the event ID.
+        state_key = cast(Tuple[str, str], tuple(event_id.split("|", 1)))
+        return self._events.get(state_key)
 
-    async def get_events(self, event_ids: Iterable[StateKey]):
+    async def get_events(self, event_ids: Iterable[StateKey]) -> StateMap[EventBase]:
         # This is cheating since it just returns all events.
         return self._events
 
@@ -68,17 +70,17 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
 
     def _calculate_room_name(
         self,
-        events: StateMap[dict],
+        events: Iterable[Tuple[Tuple[str, str], dict]],
         user_id: str = "",
         fallback_to_members: bool = True,
         fallback_to_single_member: bool = True,
-    ):
-        # This isn't 100% accurate, but works with MockDataStore.
-        room_state_ids = {k[0]: k[0] for k in events}
+    ) -> Optional[str]:
+        # Encode the state key into the event ID.
+        room_state_ids = {k[0]: "|".join(k[0]) for k in events}
 
         return self.get_success(
             calculate_room_name(
-                MockDataStore(events),
+                MockDataStore(events),  # type: ignore[arg-type]
                 room_state_ids,
                 user_id or self.USER_ID,
                 fallback_to_members,
@@ -86,9 +88,9 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
             )
         )
 
-    def test_name(self):
+    def test_name(self) -> None:
         """A room name event should be used."""
-        events = [
+        events: List[Tuple[Tuple[str, str], dict]] = [
             ((EventTypes.Name, ""), {"name": "test-name"}),
         ]
         self.assertEqual("test-name", self._calculate_room_name(events))
@@ -100,9 +102,9 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
         events = [((EventTypes.Name, ""), {"name": 1})]
         self.assertEqual(1, self._calculate_room_name(events))
 
-    def test_canonical_alias(self):
+    def test_canonical_alias(self) -> None:
         """An canonical alias should be used."""
-        events = [
+        events: List[Tuple[Tuple[str, str], dict]] = [
             ((EventTypes.CanonicalAlias, ""), {"alias": "#test-name:test"}),
         ]
         self.assertEqual("#test-name:test", self._calculate_room_name(events))
@@ -114,9 +116,9 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
         events = [((EventTypes.CanonicalAlias, ""), {"alias": "test-name"})]
         self.assertEqual("Empty Room", self._calculate_room_name(events))
 
-    def test_invite(self):
+    def test_invite(self) -> None:
         """An invite has special behaviour."""
-        events = [
+        events: List[Tuple[Tuple[str, str], dict]] = [
             ((EventTypes.Member, self.USER_ID), {"membership": Membership.INVITE}),
             ((EventTypes.Member, self.OTHER_USER_ID), {"displayname": "Other User"}),
         ]
@@ -140,9 +142,9 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
         ]
         self.assertEqual("Room Invite", self._calculate_room_name(events))
 
-    def test_no_members(self):
+    def test_no_members(self) -> None:
         """Behaviour of an empty room."""
-        events = []
+        events: List[Tuple[Tuple[str, str], dict]] = []
         self.assertEqual("Empty Room", self._calculate_room_name(events))
 
         # Note that events with invalid (or missing) membership are ignored.
@@ -152,7 +154,7 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
         ]
         self.assertEqual("Empty Room", self._calculate_room_name(events))
 
-    def test_no_other_members(self):
+    def test_no_other_members(self) -> None:
         """Behaviour of a room with no other members in it."""
         events = [
             (
@@ -185,7 +187,7 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
             self._calculate_room_name(events, user_id=self.OTHER_USER_ID),
         )
 
-    def test_one_other_member(self):
+    def test_one_other_member(self) -> None:
         """Behaviour of a room with a single other member."""
         events = [
             ((EventTypes.Member, self.USER_ID), {"membership": Membership.JOIN}),
@@ -209,7 +211,7 @@ class PresentableNamesTestCase(unittest.HomeserverTestCase):
         ]
         self.assertEqual("@user:test", self._calculate_room_name(events))
 
-    def test_other_members(self):
+    def test_other_members(self) -> None:
         """Behaviour of a room with multiple other members."""
         # Two other members.
         events = [

+ 13 - 13
tests/push/test_push_rule_evaluator.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, Optional, Union
+from typing import Dict, List, Optional, Union, cast
 
 import frozendict
 
@@ -30,7 +30,7 @@ from synapse.rest.client import login, register, room
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
 from synapse.synapse_rust.push import PushRuleEvaluator
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, JsonMapping, UserID
 from synapse.util import Clock
 
 from tests import unittest
@@ -39,7 +39,7 @@ from tests.test_utils.event_injection import create_event, inject_member_event
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
     def _get_evaluator(
-        self, content: JsonDict, related_events=None
+        self, content: JsonMapping, related_events: Optional[JsonDict] = None
     ) -> PushRuleEvaluator:
         event = FrozenEvent(
             {
@@ -59,7 +59,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             _flatten_dict(event),
             room_member_count,
             sender_power_level,
-            power_levels.get("notifications", {}),
+            cast(Dict[str, int], power_levels.get("notifications", {})),
             {} if related_events is None else related_events,
             True,
             event.room_version.msc3931_push_features,
@@ -70,9 +70,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         """Check for a matching display name in the body of the event."""
         evaluator = self._get_evaluator({"body": "foo bar baz"})
 
-        condition = {
-            "kind": "contains_display_name",
-        }
+        condition = {"kind": "contains_display_name"}
 
         # Blank names are skipped.
         self.assertFalse(evaluator.matches(condition, "@user:test", ""))
@@ -93,7 +91,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
 
     def _assert_matches(
-        self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
+        self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
     ) -> None:
         evaluator = self._get_evaluator(content)
         self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
@@ -287,7 +285,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
         This tests the behaviour of tweaks_for_actions.
         """
 
-        actions = [
+        actions: List[Union[Dict[str, str], str]] = [
             {"set_tweak": "sound", "value": "default"},
             {"set_tweak": "highlight"},
             "notify",
@@ -298,7 +296,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             {"sound": "default", "highlight": True},
         )
 
-    def test_related_event_match(self):
+    def test_related_event_match(self) -> None:
         evaluator = self._get_evaluator(
             {
                 "m.relates_to": {
@@ -397,7 +395,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             )
         )
 
-    def test_related_event_match_with_fallback(self):
+    def test_related_event_match_with_fallback(self) -> None:
         evaluator = self._get_evaluator(
             {
                 "m.relates_to": {
@@ -469,7 +467,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             )
         )
 
-    def test_related_event_match_no_related_event(self):
+    def test_related_event_match_no_related_event(self) -> None:
         evaluator = self._get_evaluator(
             {"msgtype": "m.text", "body": "Message without related event"}
         )
@@ -518,7 +516,9 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
         # Define an application service so that we can register appservice users
         self._service_token = "some_token"
         self._service = ApplicationService(