Browse Source

Add type hints to synapse.events.*. (#11066)

Except `synapse/events/__init__.py`, which will be done in a follow-up.
Patrick Cloke 2 years ago
parent
commit
1f9d0b8a7a

+ 1 - 0
changelog.d/11066.misc

@@ -0,0 +1 @@
+Add type hints to `synapse.events`.

+ 6 - 0
mypy.ini

@@ -22,8 +22,11 @@ files =
   synapse/crypto,
   synapse/crypto,
   synapse/event_auth.py,
   synapse/event_auth.py,
   synapse/events/builder.py,
   synapse/events/builder.py,
+  synapse/events/presence_router.py,
+  synapse/events/snapshot.py,
   synapse/events/spamcheck.py,
   synapse/events/spamcheck.py,
   synapse/events/third_party_rules.py,
   synapse/events/third_party_rules.py,
+  synapse/events/utils.py,
   synapse/events/validator.py,
   synapse/events/validator.py,
   synapse/federation,
   synapse/federation,
   synapse/groups,
   synapse/groups,
@@ -96,6 +99,9 @@ files =
   tests/util/test_itertools.py,
   tests/util/test_itertools.py,
   tests/util/test_stream_change_cache.py
   tests/util/test_stream_change_cache.py
 
 
+[mypy-synapse.events.*]
+disallow_untyped_defs = True
+
 [mypy-synapse.handlers.*]
 [mypy-synapse.handlers.*]
 disallow_untyped_defs = True
 disallow_untyped_defs = True
 
 

+ 2 - 2
synapse/events/builder.py

@@ -90,13 +90,13 @@ class EventBuilder:
     )
     )
 
 
     @property
     @property
-    def state_key(self):
+    def state_key(self) -> str:
         if self._state_key is not None:
         if self._state_key is not None:
             return self._state_key
             return self._state_key
 
 
         raise AttributeError("state_key")
         raise AttributeError("state_key")
 
 
-    def is_state(self):
+    def is_state(self) -> bool:
         return self._state_key is not None
         return self._state_key is not None
 
 
     async def build(
     async def build(

+ 11 - 10
synapse/events/presence_router.py

@@ -14,6 +14,7 @@
 import logging
 import logging
 from typing import (
 from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
+    Any,
     Awaitable,
     Awaitable,
     Callable,
     Callable,
     Dict,
     Dict,
@@ -33,14 +34,13 @@ if TYPE_CHECKING:
 GET_USERS_FOR_STATES_CALLBACK = Callable[
 GET_USERS_FOR_STATES_CALLBACK = Callable[
     [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
     [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
 ]
 ]
-GET_INTERESTED_USERS_CALLBACK = Callable[
-    [str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
-]
+# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
+GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-def load_legacy_presence_router(hs: "HomeServer"):
+def load_legacy_presence_router(hs: "HomeServer") -> None:
     """Wrapper that loads a presence router module configured using the old
     """Wrapper that loads a presence router module configured using the old
     configuration, and registers the hooks they implement.
     configuration, and registers the hooks they implement.
     """
     """
@@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"):
         if f is None:
         if f is None:
             return None
             return None
 
 
-        def run(*args, **kwargs):
-            # mypy doesn't do well across function boundaries so we need to tell it
-            # f is definitely not None.
+        def run(*args: Any, **kwargs: Any) -> Awaitable:
+            # Assertion required because mypy can't prove we won't change `f`
+            # back to `None`. See
+            # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
             assert f is not None
             assert f is not None
 
 
             return maybe_awaitable(f(*args, **kwargs))
             return maybe_awaitable(f(*args, **kwargs))
@@ -104,7 +105,7 @@ class PresenceRouter:
         self,
         self,
         get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
         get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
         get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
         get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
-    ):
+    ) -> None:
         # PresenceRouter modules are required to implement both of these methods
         # PresenceRouter modules are required to implement both of these methods
         # or neither of them as they are assumed to act in a complementary manner
         # or neither of them as they are assumed to act in a complementary manner
         paired_methods = [get_users_for_states, get_interested_users]
         paired_methods = [get_users_for_states, get_interested_users]
@@ -142,7 +143,7 @@ class PresenceRouter:
             # Don't include any extra destinations for presence updates
             # Don't include any extra destinations for presence updates
             return {}
             return {}
 
 
-        users_for_states = {}
+        users_for_states: Dict[str, Set[UserPresenceState]] = {}
         # run all the callbacks for get_users_for_states and combine the results
         # run all the callbacks for get_users_for_states and combine the results
         for callback in self._get_users_for_states_callbacks:
         for callback in self._get_users_for_states_callbacks:
             try:
             try:
@@ -171,7 +172,7 @@ class PresenceRouter:
 
 
         return users_for_states
         return users_for_states
 
 
-    async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
         """
         """
         Retrieve a list of users that `user_id` is interested in receiving the
         Retrieve a list of users that `user_id` is interested in receiving the
         presence of. This will be in addition to those they share a room with.
         presence of. This will be in addition to those they share a room with.

+ 60 - 50
synapse/events/snapshot.py

@@ -11,17 +11,20 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 
 import attr
 import attr
 from frozendict import frozendict
 from frozendict import frozendict
 
 
+from twisted.internet.defer import Deferred
+
 from synapse.appservice import ApplicationService
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.context import make_deferred_yieldable, run_in_background
-from synapse.types import StateMap
+from synapse.types import JsonDict, StateMap
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
+    from synapse.storage import Storage
     from synapse.storage.databases.main import DataStore
     from synapse.storage.databases.main import DataStore
 
 
 
 
@@ -112,13 +115,13 @@ class EventContext:
 
 
     @staticmethod
     @staticmethod
     def with_state(
     def with_state(
-        state_group,
-        state_group_before_event,
-        current_state_ids,
-        prev_state_ids,
-        prev_group=None,
-        delta_ids=None,
-    ):
+        state_group: Optional[int],
+        state_group_before_event: Optional[int],
+        current_state_ids: Optional[StateMap[str]],
+        prev_state_ids: Optional[StateMap[str]],
+        prev_group: Optional[int] = None,
+        delta_ids: Optional[StateMap[str]] = None,
+    ) -> "EventContext":
         return EventContext(
         return EventContext(
             current_state_ids=current_state_ids,
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
             prev_state_ids=prev_state_ids,
@@ -129,22 +132,22 @@ class EventContext:
         )
         )
 
 
     @staticmethod
     @staticmethod
-    def for_outlier():
+    def for_outlier() -> "EventContext":
         """Return an EventContext instance suitable for persisting an outlier event"""
         """Return an EventContext instance suitable for persisting an outlier event"""
         return EventContext(
         return EventContext(
             current_state_ids={},
             current_state_ids={},
             prev_state_ids={},
             prev_state_ids={},
         )
         )
 
 
-    async def serialize(self, event: EventBase, store: "DataStore") -> dict:
+    async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
         deserialized by `deserialize`
 
 
         Args:
         Args:
-            event (FrozenEvent): The event that this context relates to
+            event: The event that this context relates to
 
 
         Returns:
         Returns:
-            dict
+            The serialized event.
         """
         """
 
 
         # We don't serialize the full state dicts, instead they get pulled out
         # We don't serialize the full state dicts, instead they get pulled out
@@ -170,17 +173,16 @@ class EventContext:
         }
         }
 
 
     @staticmethod
     @staticmethod
-    def deserialize(storage, input):
+    def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
         """Converts a dict that was produced by `serialize` back into a
         """Converts a dict that was produced by `serialize` back into a
         EventContext.
         EventContext.
 
 
         Args:
         Args:
-            storage (Storage): Used to convert AS ID to AS object and fetch
-                state.
-            input (dict): A dict produced by `serialize`
+            storage: Used to convert AS ID to AS object and fetch state.
+            input: A dict produced by `serialize`
 
 
         Returns:
         Returns:
-            EventContext
+            The event context.
         """
         """
         context = _AsyncEventContextImpl(
         context = _AsyncEventContextImpl(
             # We use the state_group and prev_state_id stuff to pull the
             # We use the state_group and prev_state_id stuff to pull the
@@ -241,22 +243,25 @@ class EventContext:
         await self._ensure_fetched()
         await self._ensure_fetched()
         return self._current_state_ids
         return self._current_state_ids
 
 
-    async def get_prev_state_ids(self):
+    async def get_prev_state_ids(self) -> StateMap[str]:
         """
         """
         Gets the room state map, excluding this event.
         Gets the room state map, excluding this event.
 
 
         For a non-state event, this will be the same as get_current_state_ids().
         For a non-state event, this will be the same as get_current_state_ids().
 
 
         Returns:
         Returns:
-            dict[(str, str), str]|None: Returns None if state_group
-                is None, which happens when the associated event is an outlier.
-                Maps a (type, state_key) to the event ID of the state event matching
-                this tuple.
+            Returns {} if state_group is None, which happens when the associated
+            event is an outlier.
+
+            Maps a (type, state_key) to the event ID of the state event matching
+            this tuple.
         """
         """
         await self._ensure_fetched()
         await self._ensure_fetched()
+        # There *should* be previous state IDs now.
+        assert self._prev_state_ids is not None
         return self._prev_state_ids
         return self._prev_state_ids
 
 
-    def get_cached_current_state_ids(self):
+    def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
         """Gets the current state IDs if we have them already cached.
         """Gets the current state IDs if we have them already cached.
 
 
         It is an error to access this for a rejected event, since rejected state should
         It is an error to access this for a rejected event, since rejected state should
@@ -264,16 +269,17 @@ class EventContext:
         ``rejected`` is set.
         ``rejected`` is set.
 
 
         Returns:
         Returns:
-            dict[(str, str), str]|None: Returns None if we haven't cached the
-            state or if state_group is None, which happens when the associated
-            event is an outlier.
+            Returns None if we haven't cached the state or if state_group is None
+            (which happens when the associated event is an outlier).
+
+            Otherwise, returns the the current state IDs.
         """
         """
         if self.rejected:
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
 
         return self._current_state_ids
         return self._current_state_ids
 
 
-    async def _ensure_fetched(self):
+    async def _ensure_fetched(self) -> None:
         return None
         return None
 
 
 
 
@@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext):
 
 
     Attributes:
     Attributes:
 
 
-        _storage (Storage)
+        _storage
 
 
-        _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
-            been calculated. None if we haven't started calculating yet
+        _fetching_state_deferred: Resolves when *_state_ids have been calculated.
+            None if we haven't started calculating yet
 
 
-        _event_type (str): The type of the event the context is associated with.
+        _event_type: The type of the event the context is associated with.
 
 
-        _event_state_key (str): The state_key of the event the context is
-            associated with.
+        _event_state_key: The state_key of the event the context is associated with.
 
 
-        _prev_state_id (str|None): If the event associated with the context is
-            a state event, then `_prev_state_id` is the event_id of the state
-            that was replaced.
+        _prev_state_id: If the event associated with the context is a state event,
+            then `_prev_state_id` is the event_id of the state that was replaced.
     """
     """
 
 
     # This needs to have a default as we're inheriting
     # This needs to have a default as we're inheriting
-    _storage = attr.ib(default=None)
-    _prev_state_id = attr.ib(default=None)
-    _event_type = attr.ib(default=None)
-    _event_state_key = attr.ib(default=None)
-    _fetching_state_deferred = attr.ib(default=None)
+    _storage: "Storage" = attr.ib(default=None)
+    _prev_state_id: Optional[str] = attr.ib(default=None)
+    _event_type: str = attr.ib(default=None)
+    _event_state_key: Optional[str] = attr.ib(default=None)
+    _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
 
 
-    async def _ensure_fetched(self):
+    async def _ensure_fetched(self) -> None:
         if not self._fetching_state_deferred:
         if not self._fetching_state_deferred:
             self._fetching_state_deferred = run_in_background(self._fill_out_state)
             self._fetching_state_deferred = run_in_background(self._fill_out_state)
 
 
-        return await make_deferred_yieldable(self._fetching_state_deferred)
+        await make_deferred_yieldable(self._fetching_state_deferred)
 
 
-    async def _fill_out_state(self):
+    async def _fill_out_state(self) -> None:
         """Called to populate the _current_state_ids and _prev_state_ids
         """Called to populate the _current_state_ids and _prev_state_ids
         attributes by loading from the database.
         attributes by loading from the database.
         """
         """
         if self.state_group is None:
         if self.state_group is None:
             return
             return
 
 
-        self._current_state_ids = await self._storage.state.get_state_ids_for_group(
+        current_state_ids = await self._storage.state.get_state_ids_for_group(
             self.state_group
             self.state_group
         )
         )
+        # Set this separately so mypy knows current_state_ids is not None.
+        self._current_state_ids = current_state_ids
         if self._event_state_key is not None:
         if self._event_state_key is not None:
-            self._prev_state_ids = dict(self._current_state_ids)
+            self._prev_state_ids = dict(current_state_ids)
 
 
             key = (self._event_type, self._event_state_key)
             key = (self._event_type, self._event_state_key)
             if self._prev_state_id:
             if self._prev_state_id:
@@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext):
             else:
             else:
                 self._prev_state_ids.pop(key, None)
                 self._prev_state_ids.pop(key, None)
         else:
         else:
-            self._prev_state_ids = self._current_state_ids
+            self._prev_state_ids = current_state_ids
 
 
 
 
-def _encode_state_dict(state_dict):
+def _encode_state_dict(
+    state_dict: Optional[StateMap[str]],
+) -> Optional[List[Tuple[str, str, str]]]:
     """Since dicts of (type, state_key) -> event_id cannot be serialized in
     """Since dicts of (type, state_key) -> event_id cannot be serialized in
     JSON we need to convert them to a form that can.
     JSON we need to convert them to a form that can.
     """
     """
@@ -345,7 +353,9 @@ def _encode_state_dict(state_dict):
     return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
     return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
 
 
 
 
-def _decode_state_dict(input):
+def _decode_state_dict(
+    input: Optional[List[Tuple[str, str, str]]]
+) -> Optional[StateMap[str]]:
     """Decodes a state dict encoded using `_encode_state_dict` above"""
     """Decodes a state dict encoded using `_encode_state_dict` above"""
     if input is None:
     if input is None:
         return None
         return None

+ 14 - 11
synapse/events/spamcheck.py

@@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
 ]
 ]
 
 
 
 
-def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
+def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
     """Wrapper that loads spam checkers configured using the old configuration, and
     """Wrapper that loads spam checkers configured using the old configuration, and
     registers the spam checker hooks they implement.
     registers the spam checker hooks they implement.
     """
     """
@@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
                         request_info: Collection[Tuple[str, str]],
                         request_info: Collection[Tuple[str, str]],
                         auth_provider_id: Optional[str],
                         auth_provider_id: Optional[str],
                     ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
                     ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
-                        # We've already made sure f is not None above, but mypy doesn't
-                        # do well across function boundaries so we need to tell it f is
-                        # definitely not None.
+                        # Assertion required because mypy can't prove we won't
+                        # change `f` back to `None`. See
+                        # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                         assert f is not None
                         assert f is not None
 
 
                         return f(
                         return f(
@@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
                         "Bad signature for callback check_registration_for_spam",
                         "Bad signature for callback check_registration_for_spam",
                     )
                     )
 
 
-            def run(*args, **kwargs):
-                # mypy doesn't do well across function boundaries so we need to tell it
-                # wrapped_func is definitely not None.
+            def run(*args: Any, **kwargs: Any) -> Awaitable:
+                # Assertion required because mypy can't prove we won't change `f`
+                # back to `None`. See
+                # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                 assert wrapped_func is not None
                 assert wrapped_func is not None
 
 
                 return maybe_awaitable(wrapped_func(*args, **kwargs))
                 return maybe_awaitable(wrapped_func(*args, **kwargs))
@@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
 
 
 
 
 class SpamChecker:
 class SpamChecker:
-    def __init__(self):
+    def __init__(self) -> None:
         self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
         self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
         self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
         self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
         self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
         self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
@@ -209,7 +210,7 @@ class SpamChecker:
             CHECK_REGISTRATION_FOR_SPAM_CALLBACK
             CHECK_REGISTRATION_FOR_SPAM_CALLBACK
         ] = None,
         ] = None,
         check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
         check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
-    ):
+    ) -> None:
         """Register callbacks from module for each hook."""
         """Register callbacks from module for each hook."""
         if check_event_for_spam is not None:
         if check_event_for_spam is not None:
             self._check_event_for_spam_callbacks.append(check_event_for_spam)
             self._check_event_for_spam_callbacks.append(check_event_for_spam)
@@ -275,7 +276,9 @@ class SpamChecker:
 
 
         return False
         return False
 
 
-    async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
+    async def user_may_join_room(
+        self, user_id: str, room_id: str, is_invited: bool
+    ) -> bool:
         """Checks if a given users is allowed to join a room.
         """Checks if a given users is allowed to join a room.
         Not called when a user creates a room.
         Not called when a user creates a room.
 
 
@@ -285,7 +288,7 @@ class SpamChecker:
             is_invited: Whether the user is invited into the room
             is_invited: Whether the user is invited into the room
 
 
         Returns:
         Returns:
-            bool: Whether the user may join the room
+            Whether the user may join the room
         """
         """
         for callback in self._user_may_join_room_callbacks:
         for callback in self._user_may_join_room_callbacks:
             if await callback(user_id, room_id, is_invited) is False:
             if await callback(user_id, room_id, is_invited) is False:

+ 13 - 12
synapse/events/third_party_rules.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 import logging
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
 
 
 from synapse.api.errors import SynapseError
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events import EventBase
@@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
 ]
 ]
 
 
 
 
-def load_legacy_third_party_event_rules(hs: "HomeServer"):
+def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
     """Wrapper that loads a third party event rules module configured using the old
     """Wrapper that loads a third party event rules module configured using the old
     configuration, and registers the hooks they implement.
     configuration, and registers the hooks they implement.
     """
     """
@@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
                 event: EventBase,
                 event: EventBase,
                 state_events: StateMap[EventBase],
                 state_events: StateMap[EventBase],
             ) -> Tuple[bool, Optional[dict]]:
             ) -> Tuple[bool, Optional[dict]]:
-                # We've already made sure f is not None above, but mypy doesn't do well
-                # across function boundaries so we need to tell it f is definitely not
-                # None.
+                # Assertion required because mypy can't prove we won't change
+                # `f` back to `None`. See
+                # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                 assert f is not None
                 assert f is not None
 
 
                 res = await f(event, state_events)
                 res = await f(event, state_events)
@@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
             async def wrap_on_create_room(
             async def wrap_on_create_room(
                 requester: Requester, config: dict, is_requester_admin: bool
                 requester: Requester, config: dict, is_requester_admin: bool
             ) -> None:
             ) -> None:
-                # We've already made sure f is not None above, but mypy doesn't do well
-                # across function boundaries so we need to tell it f is definitely not
-                # None.
+                # Assertion required because mypy can't prove we won't change
+                # `f` back to `None`. See
+                # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                 assert f is not None
                 assert f is not None
 
 
                 res = await f(requester, config, is_requester_admin)
                 res = await f(requester, config, is_requester_admin)
@@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
 
 
             return wrap_on_create_room
             return wrap_on_create_room
 
 
-        def run(*args, **kwargs):
-            # mypy doesn't do well across function boundaries so we need to tell it
-            # f is definitely not None.
+        def run(*args: Any, **kwargs: Any) -> Awaitable:
+            # Assertion required because mypy can't prove we won't change  `f`
+            # back to `None`. See
+            # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
             assert f is not None
             assert f is not None
 
 
             return maybe_awaitable(f(*args, **kwargs))
             return maybe_awaitable(f(*args, **kwargs))
@@ -162,7 +163,7 @@ class ThirdPartyEventRules:
         check_visibility_can_be_modified: Optional[
         check_visibility_can_be_modified: Optional[
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = None,
         ] = None,
-    ):
+    ) -> None:
         """Register callbacks from modules for each hook."""
         """Register callbacks from modules for each hook."""
         if check_event_allowed is not None:
         if check_event_allowed is not None:
             self._check_event_allowed_callbacks.append(check_event_allowed)
             self._check_event_allowed_callbacks.append(check_event_allowed)

+ 67 - 46
synapse/events/utils.py

@@ -13,18 +13,32 @@
 # limitations under the License.
 # limitations under the License.
 import collections.abc
 import collections.abc
 import re
 import re
-from typing import Any, Mapping, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Union,
+)
 
 
 from frozendict import frozendict
 from frozendict import frozendict
 
 
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.api.room_versions import RoomVersion
+from synapse.types import JsonDict
 from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.frozenutils import unfreeze
 from synapse.util.frozenutils import unfreeze
 
 
 from . import EventBase
 from . import EventBase
 
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
 # (?<!stuff) matches if the current position in the string is not preceded
 # (?<!stuff) matches if the current position in the string is not preceded
 # by a match for 'stuff'.
 # by a match for 'stuff'.
@@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
     return pruned_event
     return pruned_event
 
 
 
 
-def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
+def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
     """Redacts the event_dict in the same way as `prune_event`, except it
     """Redacts the event_dict in the same way as `prune_event`, except it
     operates on dicts rather than event objects
     operates on dicts rather than event objects
 
 
@@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
 
 
     new_content = {}
     new_content = {}
 
 
-    def add_fields(*fields):
+    def add_fields(*fields: str) -> None:
         for field in fields:
         for field in fields:
             if field in event_dict["content"]:
             if field in event_dict["content"]:
                 new_content[field] = event_dict["content"][field]
                 new_content[field] = event_dict["content"][field]
@@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
 
 
     allowed_fields["content"] = new_content
     allowed_fields["content"] = new_content
 
 
-    unsigned = {}
+    unsigned: JsonDict = {}
     allowed_fields["unsigned"] = unsigned
     allowed_fields["unsigned"] = unsigned
 
 
     event_unsigned = event_dict.get("unsigned", {})
     event_unsigned = event_dict.get("unsigned", {})
@@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
     return allowed_fields
     return allowed_fields
 
 
 
 
-def _copy_field(src, dst, field):
+def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
     """Copy the field in 'src' to 'dst'.
     """Copy the field in 'src' to 'dst'.
 
 
     For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
     For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
     then dst={"foo":{"bar":5}}.
     then dst={"foo":{"bar":5}}.
 
 
     Args:
     Args:
-        src(dict): The dict to read from.
-        dst(dict): The dict to modify.
-        field(list<str>): List of keys to drill down to in 'src'.
+        src: The dict to read from.
+        dst: The dict to modify.
+        field: List of keys to drill down to in 'src'.
     """
     """
     if len(field) == 0:  # this should be impossible
     if len(field) == 0:  # this should be impossible
         return
         return
@@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
     sub_out_dict[key_to_move] = sub_dict[key_to_move]
     sub_out_dict[key_to_move] = sub_dict[key_to_move]
 
 
 
 
-def only_fields(dictionary, fields):
+def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
     """Return a new dict with only the fields in 'dictionary' which are present
     """Return a new dict with only the fields in 'dictionary' which are present
     in 'fields'.
     in 'fields'.
 
 
@@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
     A literal '.' character in a field name may be escaped using a '\'.
     A literal '.' character in a field name may be escaped using a '\'.
 
 
     Args:
     Args:
-        dictionary(dict): The dictionary to read from.
-        fields(list<str>): A list of fields to copy over. Only shallow refs are
+        dictionary: The dictionary to read from.
+        fields: A list of fields to copy over. Only shallow refs are
         taken.
         taken.
     Returns:
     Returns:
-        dict: A new dictionary with only the given fields. If fields was empty,
+        A new dictionary with only the given fields. If fields was empty,
         the same dictionary is returned.
         the same dictionary is returned.
     """
     """
     if len(fields) == 0:
     if len(fields) == 0:
@@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
         [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
         [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
     ]
     ]
 
 
-    output = {}
+    output: JsonDict = {}
     for field_array in split_fields:
     for field_array in split_fields:
         _copy_field(dictionary, output, field_array)
         _copy_field(dictionary, output, field_array)
     return output
     return output
 
 
 
 
-def format_event_raw(d):
+def format_event_raw(d: JsonDict) -> JsonDict:
     return d
     return d
 
 
 
 
-def format_event_for_client_v1(d):
+def format_event_for_client_v1(d: JsonDict) -> JsonDict:
     d = format_event_for_client_v2(d)
     d = format_event_for_client_v2(d)
 
 
     sender = d.get("sender")
     sender = d.get("sender")
@@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
     return d
     return d
 
 
 
 
-def format_event_for_client_v2(d):
+def format_event_for_client_v2(d: JsonDict) -> JsonDict:
     drop_keys = (
     drop_keys = (
         "auth_events",
         "auth_events",
         "prev_events",
         "prev_events",
@@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
     return d
     return d
 
 
 
 
-def format_event_for_client_v2_without_room_id(d):
+def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
     d = format_event_for_client_v2(d)
     d = format_event_for_client_v2(d)
     d.pop("room_id", None)
     d.pop("room_id", None)
     return d
     return d
 
 
 
 
 def serialize_event(
 def serialize_event(
-    e,
-    time_now_ms,
-    as_client_event=True,
-    event_format=format_event_for_client_v1,
-    token_id=None,
-    only_event_fields=None,
-    include_stripped_room_state=False,
-):
+    e: Union[JsonDict, EventBase],
+    time_now_ms: int,
+    as_client_event: bool = True,
+    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
+    token_id: Optional[str] = None,
+    only_event_fields: Optional[List[str]] = None,
+    include_stripped_room_state: bool = False,
+) -> JsonDict:
     """Serialize event for clients
     """Serialize event for clients
 
 
     Args:
     Args:
-        e (EventBase)
-        time_now_ms (int)
-        as_client_event (bool)
+        e
+        time_now_ms
+        as_client_event
         event_format
         event_format
         token_id
         token_id
         only_event_fields
         only_event_fields
-        include_stripped_room_state (bool): Some events can have stripped room state
+        include_stripped_room_state: Some events can have stripped room state
             stored in the `unsigned` field. This is required for invite and knock
             stored in the `unsigned` field. This is required for invite and knock
             functionality. If this option is False, that state will be removed from the
             functionality. If this option is False, that state will be removed from the
             event before it is returned. Otherwise, it will be kept.
             event before it is returned. Otherwise, it will be kept.
 
 
     Returns:
     Returns:
-        dict
+        The serialized event dictionary.
     """
     """
 
 
     # FIXME(erikj): To handle the case of presence events and the like
     # FIXME(erikj): To handle the case of presence events and the like
@@ -369,25 +383,29 @@ class EventClientSerializer:
     clients.
     clients.
     """
     """
 
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.store = hs.get_datastore()
         self.experimental_msc1849_support_enabled = (
         self.experimental_msc1849_support_enabled = (
             hs.config.server.experimental_msc1849_support_enabled
             hs.config.server.experimental_msc1849_support_enabled
         )
         )
 
 
     async def serialize_event(
     async def serialize_event(
-        self, event, time_now, bundle_aggregations=True, **kwargs
-    ):
+        self,
+        event: Union[JsonDict, EventBase],
+        time_now: int,
+        bundle_aggregations: bool = True,
+        **kwargs: Any,
+    ) -> JsonDict:
         """Serializes a single event.
         """Serializes a single event.
 
 
         Args:
         Args:
-            event (EventBase)
-            time_now (int): The current time in milliseconds
-            bundle_aggregations (bool): Whether to bundle in related events
+            event
+            time_now: The current time in milliseconds
+            bundle_aggregations: Whether to bundle in related events
             **kwargs: Arguments to pass to `serialize_event`
             **kwargs: Arguments to pass to `serialize_event`
 
 
         Returns:
         Returns:
-            dict: The serialized event
+            The serialized event
         """
         """
         # To handle the case of presence events and the like
         # To handle the case of presence events and the like
         if not isinstance(event, EventBase):
         if not isinstance(event, EventBase):
@@ -448,25 +466,27 @@ class EventClientSerializer:
 
 
         return serialized_event
         return serialized_event
 
 
-    def serialize_events(self, events, time_now, **kwargs):
+    async def serialize_events(
+        self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
+    ) -> List[JsonDict]:
         """Serializes multiple events.
         """Serializes multiple events.
 
 
         Args:
         Args:
-            event (iter[EventBase])
-            time_now (int): The current time in milliseconds
+            event
+            time_now: The current time in milliseconds
             **kwargs: Arguments to pass to `serialize_event`
             **kwargs: Arguments to pass to `serialize_event`
 
 
         Returns:
         Returns:
-            Deferred[list[dict]]: The list of serialized events
+            The list of serialized events
         """
         """
-        return yieldable_gather_results(
+        return await yieldable_gather_results(
             self.serialize_event, events, time_now=time_now, **kwargs
             self.serialize_event, events, time_now=time_now, **kwargs
         )
         )
 
 
 
 
 def copy_power_levels_contents(
 def copy_power_levels_contents(
     old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
     old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
-):
+) -> Dict[str, Union[int, Dict[str, int]]]:
     """Copy the content of a power_levels event, unfreezing frozendicts along the way
     """Copy the content of a power_levels event, unfreezing frozendicts along the way
 
 
     Raises:
     Raises:
@@ -475,7 +495,7 @@ def copy_power_levels_contents(
     if not isinstance(old_power_levels, collections.abc.Mapping):
     if not isinstance(old_power_levels, collections.abc.Mapping):
         raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
         raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
 
 
-    power_levels = {}
+    power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
     for k, v in old_power_levels.items():
     for k, v in old_power_levels.items():
 
 
         if isinstance(v, int):
         if isinstance(v, int):
@@ -483,7 +503,8 @@ def copy_power_levels_contents(
             continue
             continue
 
 
         if isinstance(v, collections.abc.Mapping):
         if isinstance(v, collections.abc.Mapping):
-            power_levels[k] = h = {}
+            h: Dict[str, int] = {}
+            power_levels[k] = h
             for k1, v1 in v.items():
             for k1, v1 in v.items():
                 # we should only have one level of nesting
                 # we should only have one level of nesting
                 if not isinstance(v1, int):
                 if not isinstance(v1, int):
@@ -498,7 +519,7 @@ def copy_power_levels_contents(
     return power_levels
     return power_levels
 
 
 
 
-def validate_canonicaljson(value: Any):
+def validate_canonicaljson(value: Any) -> None:
     """
     """
     Ensure that the JSON object is valid according to the rules of canonical JSON.
     Ensure that the JSON object is valid according to the rules of canonical JSON.
 
 

+ 10 - 8
synapse/events/validator.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 import collections.abc
 import collections.abc
-from typing import Union
+from typing import Iterable, Union
 
 
 import jsonschema
 import jsonschema
 
 
@@ -28,11 +28,11 @@ from synapse.events.utils import (
     validate_canonicaljson,
     validate_canonicaljson,
 )
 )
 from synapse.federation.federation_server import server_matches_acl_event
 from synapse.federation.federation_server import server_matches_acl_event
-from synapse.types import EventID, RoomID, UserID
+from synapse.types import EventID, JsonDict, RoomID, UserID
 
 
 
 
 class EventValidator:
 class EventValidator:
-    def validate_new(self, event: EventBase, config: HomeServerConfig):
+    def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
         """Validates the event has roughly the right format
         """Validates the event has roughly the right format
 
 
         Args:
         Args:
@@ -116,7 +116,7 @@ class EventValidator:
                     errcode=Codes.BAD_JSON,
                     errcode=Codes.BAD_JSON,
                 )
                 )
 
 
-    def _validate_retention(self, event: EventBase):
+    def _validate_retention(self, event: EventBase) -> None:
         """Checks that an event that defines the retention policy for a room respects the
         """Checks that an event that defines the retention policy for a room respects the
         format enforced by the spec.
         format enforced by the spec.
 
 
@@ -156,7 +156,7 @@ class EventValidator:
                 errcode=Codes.BAD_JSON,
                 errcode=Codes.BAD_JSON,
             )
             )
 
 
-    def validate_builder(self, event: Union[EventBase, EventBuilder]):
+    def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:
         """Validates that the builder/event has roughly the right format. Only
         """Validates that the builder/event has roughly the right format. Only
         checks values that we expect a proto event to have, rather than all the
         checks values that we expect a proto event to have, rather than all the
         fields an event would have
         fields an event would have
@@ -204,14 +204,14 @@ class EventValidator:
 
 
             self._ensure_state_event(event)
             self._ensure_state_event(event)
 
 
-    def _ensure_strings(self, d, keys):
+    def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
         for s in keys:
         for s in keys:
             if s not in d:
             if s not in d:
                 raise SynapseError(400, "'%s' not in content" % (s,))
                 raise SynapseError(400, "'%s' not in content" % (s,))
             if not isinstance(d[s], str):
             if not isinstance(d[s], str):
                 raise SynapseError(400, "'%s' not a string type" % (s,))
                 raise SynapseError(400, "'%s' not a string type" % (s,))
 
 
-    def _ensure_state_event(self, event):
+    def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None:
         if not event.is_state():
         if not event.is_state():
             raise SynapseError(400, "'%s' must be state events" % (event.type,))
             raise SynapseError(400, "'%s' must be state events" % (event.type,))
 
 
@@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = {
 }
 }
 
 
 
 
-def _create_power_level_validator():
+# This could return something newer than Draft 7, but that's the current "latest"
+# validator.
+def _create_power_level_validator() -> jsonschema.Draft7Validator:
     validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
     validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
 
 
     # by default jsonschema does not consider a frozendict to be an object so
     # by default jsonschema does not consider a frozendict to be an object so

+ 20 - 2
synapse/handlers/room.py

@@ -465,17 +465,35 @@ class RoomCreationHandler:
         # the room has been created
         # the room has been created
         # Calculate the minimum power level needed to clone the room
         # Calculate the minimum power level needed to clone the room
         event_power_levels = power_levels.get("events", {})
         event_power_levels = power_levels.get("events", {})
+        if not isinstance(event_power_levels, dict):
+            event_power_levels = {}
         state_default = power_levels.get("state_default", 50)
         state_default = power_levels.get("state_default", 50)
+        try:
+            state_default_int = int(state_default)  # type: ignore[arg-type]
+        except (TypeError, ValueError):
+            state_default_int = 50
         ban = power_levels.get("ban", 50)
         ban = power_levels.get("ban", 50)
-        needed_power_level = max(state_default, ban, max(event_power_levels.values()))
+        try:
+            ban = int(ban)  # type: ignore[arg-type]
+        except (TypeError, ValueError):
+            ban = 50
+        needed_power_level = max(
+            state_default_int, ban, max(event_power_levels.values())
+        )
 
 
         # Get the user's current power level, this matches the logic in get_user_power_level,
         # Get the user's current power level, this matches the logic in get_user_power_level,
         # but without the entire state map.
         # but without the entire state map.
         user_power_levels = power_levels.setdefault("users", {})
         user_power_levels = power_levels.setdefault("users", {})
+        if not isinstance(user_power_levels, dict):
+            user_power_levels = {}
         users_default = power_levels.get("users_default", 0)
         users_default = power_levels.get("users_default", 0)
         current_power_level = user_power_levels.get(user_id, users_default)
         current_power_level = user_power_levels.get(user_id, users_default)
+        try:
+            current_power_level_int = int(current_power_level)  # type: ignore[arg-type]
+        except (TypeError, ValueError):
+            current_power_level_int = 0
         # Raise the requester's power level in the new room if necessary
         # Raise the requester's power level in the new room if necessary
-        if current_power_level < needed_power_level:
+        if current_power_level_int < needed_power_level:
             user_power_levels[user_id] = needed_power_level
             user_power_levels[user_id] = needed_power_level
 
 
         await self._send_events_for_new_room(
         await self._send_events_for_new_room(

+ 4 - 4
synapse/rest/client/relations.py

@@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet):
         # Similarly, we don't allow relations to be applied to relations, so we
         # Similarly, we don't allow relations to be applied to relations, so we
         # return the original relations without any aggregations on top of them
         # return the original relations without any aggregations on top of them
         # here.
         # here.
-        events = await self._event_serializer.serialize_events(
+        serialized_events = await self._event_serializer.serialize_events(
             events, now, bundle_aggregations=False
             events, now, bundle_aggregations=False
         )
         )
 
 
         return_value = pagination_chunk.to_dict()
         return_value = pagination_chunk.to_dict()
-        return_value["chunk"] = events
+        return_value["chunk"] = serialized_events
         return_value["original_event"] = original_event
         return_value["original_event"] = original_event
 
 
         return 200, return_value
         return 200, return_value
@@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         )
         )
 
 
         now = self.clock.time_msec()
         now = self.clock.time_msec()
-        events = await self._event_serializer.serialize_events(events, now)
+        serialized_events = await self._event_serializer.serialize_events(events, now)
 
 
         return_value = result.to_dict()
         return_value = result.to_dict()
-        return_value["chunk"] = events
+        return_value["chunk"] = serialized_events
 
 
         return 200, return_value
         return 200, return_value