Преглед изворни кода

Add type hints to synapse.events.*. (#11066)

Except `synapse/events/__init__.py`, which will be done in a follow-up.
Patrick Cloke пре 2 година
родитељ
комит
1f9d0b8a7a

+ 1 - 0
changelog.d/11066.misc

@@ -0,0 +1 @@
+Add type hints to `synapse.events`.

+ 6 - 0
mypy.ini

@@ -22,8 +22,11 @@ files =
   synapse/crypto,
   synapse/event_auth.py,
   synapse/events/builder.py,
+  synapse/events/presence_router.py,
+  synapse/events/snapshot.py,
   synapse/events/spamcheck.py,
   synapse/events/third_party_rules.py,
+  synapse/events/utils.py,
   synapse/events/validator.py,
   synapse/federation,
   synapse/groups,
@@ -96,6 +99,9 @@ files =
   tests/util/test_itertools.py,
   tests/util/test_stream_change_cache.py
 
+[mypy-synapse.events.*]
+disallow_untyped_defs = True
+
 [mypy-synapse.handlers.*]
 disallow_untyped_defs = True
 

+ 2 - 2
synapse/events/builder.py

@@ -90,13 +90,13 @@ class EventBuilder:
     )
 
     @property
-    def state_key(self):
+    def state_key(self) -> str:
         if self._state_key is not None:
             return self._state_key
 
         raise AttributeError("state_key")
 
-    def is_state(self):
+    def is_state(self) -> bool:
         return self._state_key is not None
 
     async def build(

+ 11 - 10
synapse/events/presence_router.py

@@ -14,6 +14,7 @@
 import logging
 from typing import (
     TYPE_CHECKING,
+    Any,
     Awaitable,
     Callable,
     Dict,
@@ -33,14 +34,13 @@ if TYPE_CHECKING:
 GET_USERS_FOR_STATES_CALLBACK = Callable[
     [Iterable[UserPresenceState]], Awaitable[Dict[str, Set[UserPresenceState]]]
 ]
-GET_INTERESTED_USERS_CALLBACK = Callable[
-    [str], Awaitable[Union[Set[str], "PresenceRouter.ALL_USERS"]]
-]
+# This must either return a set of strings or the constant PresenceRouter.ALL_USERS.
+GET_INTERESTED_USERS_CALLBACK = Callable[[str], Awaitable[Union[Set[str], str]]]
 
 logger = logging.getLogger(__name__)
 
 
-def load_legacy_presence_router(hs: "HomeServer"):
+def load_legacy_presence_router(hs: "HomeServer") -> None:
     """Wrapper that loads a presence router module configured using the old
     configuration, and registers the hooks they implement.
     """
@@ -69,9 +69,10 @@ def load_legacy_presence_router(hs: "HomeServer"):
         if f is None:
             return None
 
-        def run(*args, **kwargs):
-            # mypy doesn't do well across function boundaries so we need to tell it
-            # f is definitely not None.
+        def run(*args: Any, **kwargs: Any) -> Awaitable:
+            # Assertion required because mypy can't prove we won't change `f`
+            # back to `None`. See
+            # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
             assert f is not None
 
             return maybe_awaitable(f(*args, **kwargs))
@@ -104,7 +105,7 @@ class PresenceRouter:
         self,
         get_users_for_states: Optional[GET_USERS_FOR_STATES_CALLBACK] = None,
         get_interested_users: Optional[GET_INTERESTED_USERS_CALLBACK] = None,
-    ):
+    ) -> None:
         # PresenceRouter modules are required to implement both of these methods
         # or neither of them as they are assumed to act in a complementary manner
         paired_methods = [get_users_for_states, get_interested_users]
@@ -142,7 +143,7 @@ class PresenceRouter:
             # Don't include any extra destinations for presence updates
             return {}
 
-        users_for_states = {}
+        users_for_states: Dict[str, Set[UserPresenceState]] = {}
         # run all the callbacks for get_users_for_states and combine the results
         for callback in self._get_users_for_states_callbacks:
             try:
@@ -171,7 +172,7 @@ class PresenceRouter:
 
         return users_for_states
 
-    async def get_interested_users(self, user_id: str) -> Union[Set[str], ALL_USERS]:
+    async def get_interested_users(self, user_id: str) -> Union[Set[str], str]:
         """
         Retrieve a list of users that `user_id` is interested in receiving the
         presence of. This will be in addition to those they share a room with.

+ 60 - 50
synapse/events/snapshot.py

@@ -11,17 +11,20 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 import attr
 from frozendict import frozendict
 
+from twisted.internet.defer import Deferred
+
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
 from synapse.logging.context import make_deferred_yieldable, run_in_background
-from synapse.types import StateMap
+from synapse.types import JsonDict, StateMap
 
 if TYPE_CHECKING:
+    from synapse.storage import Storage
     from synapse.storage.databases.main import DataStore
 
 
@@ -112,13 +115,13 @@ class EventContext:
 
     @staticmethod
     def with_state(
-        state_group,
-        state_group_before_event,
-        current_state_ids,
-        prev_state_ids,
-        prev_group=None,
-        delta_ids=None,
-    ):
+        state_group: Optional[int],
+        state_group_before_event: Optional[int],
+        current_state_ids: Optional[StateMap[str]],
+        prev_state_ids: Optional[StateMap[str]],
+        prev_group: Optional[int] = None,
+        delta_ids: Optional[StateMap[str]] = None,
+    ) -> "EventContext":
         return EventContext(
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
@@ -129,22 +132,22 @@ class EventContext:
         )
 
     @staticmethod
-    def for_outlier():
+    def for_outlier() -> "EventContext":
         """Return an EventContext instance suitable for persisting an outlier event"""
         return EventContext(
             current_state_ids={},
             prev_state_ids={},
         )
 
-    async def serialize(self, event: EventBase, store: "DataStore") -> dict:
+    async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
 
         Args:
-            event (FrozenEvent): The event that this context relates to
+            event: The event that this context relates to
 
         Returns:
-            dict
+            The serialized event.
         """
 
         # We don't serialize the full state dicts, instead they get pulled out
@@ -170,17 +173,16 @@ class EventContext:
         }
 
     @staticmethod
-    def deserialize(storage, input):
+    def deserialize(storage: "Storage", input: JsonDict) -> "EventContext":
         """Converts a dict that was produced by `serialize` back into a
         EventContext.
 
         Args:
-            storage (Storage): Used to convert AS ID to AS object and fetch
-                state.
-            input (dict): A dict produced by `serialize`
+            storage: Used to convert AS ID to AS object and fetch state.
+            input: A dict produced by `serialize`
 
         Returns:
-            EventContext
+            The event context.
         """
         context = _AsyncEventContextImpl(
             # We use the state_group and prev_state_id stuff to pull the
@@ -241,22 +243,25 @@ class EventContext:
         await self._ensure_fetched()
         return self._current_state_ids
 
-    async def get_prev_state_ids(self):
+    async def get_prev_state_ids(self) -> StateMap[str]:
         """
         Gets the room state map, excluding this event.
 
         For a non-state event, this will be the same as get_current_state_ids().
 
         Returns:
-            dict[(str, str), str]|None: Returns None if state_group
-                is None, which happens when the associated event is an outlier.
-                Maps a (type, state_key) to the event ID of the state event matching
-                this tuple.
+            Returns {} if state_group is None, which happens when the associated
+            event is an outlier.
+
+            Maps a (type, state_key) to the event ID of the state event matching
+            this tuple.
         """
         await self._ensure_fetched()
+        # There *should* be previous state IDs now.
+        assert self._prev_state_ids is not None
         return self._prev_state_ids
 
-    def get_cached_current_state_ids(self):
+    def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
         """Gets the current state IDs if we have them already cached.
 
         It is an error to access this for a rejected event, since rejected state should
@@ -264,16 +269,17 @@ class EventContext:
         ``rejected`` is set.
 
         Returns:
-            dict[(str, str), str]|None: Returns None if we haven't cached the
-            state or if state_group is None, which happens when the associated
-            event is an outlier.
+            Returns None if we haven't cached the state or if state_group is None
+            (which happens when the associated event is an outlier).
+
+            Otherwise, returns the the current state IDs.
         """
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
         return self._current_state_ids
 
-    async def _ensure_fetched(self):
+    async def _ensure_fetched(self) -> None:
         return None
 
 
@@ -285,46 +291,46 @@ class _AsyncEventContextImpl(EventContext):
 
     Attributes:
 
-        _storage (Storage)
+        _storage
 
-        _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
-            been calculated. None if we haven't started calculating yet
+        _fetching_state_deferred: Resolves when *_state_ids have been calculated.
+            None if we haven't started calculating yet
 
-        _event_type (str): The type of the event the context is associated with.
+        _event_type: The type of the event the context is associated with.
 
-        _event_state_key (str): The state_key of the event the context is
-            associated with.
+        _event_state_key: The state_key of the event the context is associated with.
 
-        _prev_state_id (str|None): If the event associated with the context is
-            a state event, then `_prev_state_id` is the event_id of the state
-            that was replaced.
+        _prev_state_id: If the event associated with the context is a state event,
+            then `_prev_state_id` is the event_id of the state that was replaced.
     """
 
     # This needs to have a default as we're inheriting
-    _storage = attr.ib(default=None)
-    _prev_state_id = attr.ib(default=None)
-    _event_type = attr.ib(default=None)
-    _event_state_key = attr.ib(default=None)
-    _fetching_state_deferred = attr.ib(default=None)
+    _storage: "Storage" = attr.ib(default=None)
+    _prev_state_id: Optional[str] = attr.ib(default=None)
+    _event_type: str = attr.ib(default=None)
+    _event_state_key: Optional[str] = attr.ib(default=None)
+    _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
 
-    async def _ensure_fetched(self):
+    async def _ensure_fetched(self) -> None:
         if not self._fetching_state_deferred:
             self._fetching_state_deferred = run_in_background(self._fill_out_state)
 
-        return await make_deferred_yieldable(self._fetching_state_deferred)
+        await make_deferred_yieldable(self._fetching_state_deferred)
 
-    async def _fill_out_state(self):
+    async def _fill_out_state(self) -> None:
         """Called to populate the _current_state_ids and _prev_state_ids
         attributes by loading from the database.
         """
         if self.state_group is None:
             return
 
-        self._current_state_ids = await self._storage.state.get_state_ids_for_group(
+        current_state_ids = await self._storage.state.get_state_ids_for_group(
             self.state_group
         )
+        # Set this separately so mypy knows current_state_ids is not None.
+        self._current_state_ids = current_state_ids
         if self._event_state_key is not None:
-            self._prev_state_ids = dict(self._current_state_ids)
+            self._prev_state_ids = dict(current_state_ids)
 
             key = (self._event_type, self._event_state_key)
             if self._prev_state_id:
@@ -332,10 +338,12 @@ class _AsyncEventContextImpl(EventContext):
             else:
                 self._prev_state_ids.pop(key, None)
         else:
-            self._prev_state_ids = self._current_state_ids
+            self._prev_state_ids = current_state_ids
 
 
-def _encode_state_dict(state_dict):
+def _encode_state_dict(
+    state_dict: Optional[StateMap[str]],
+) -> Optional[List[Tuple[str, str, str]]]:
     """Since dicts of (type, state_key) -> event_id cannot be serialized in
     JSON we need to convert them to a form that can.
     """
@@ -345,7 +353,9 @@ def _encode_state_dict(state_dict):
     return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
 
 
-def _decode_state_dict(input):
+def _decode_state_dict(
+    input: Optional[List[Tuple[str, str, str]]]
+) -> Optional[StateMap[str]]:
     """Decodes a state dict encoded using `_encode_state_dict` above"""
     if input is None:
         return None

+ 14 - 11
synapse/events/spamcheck.py

@@ -77,7 +77,7 @@ CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK = Callable[
 ]
 
 
-def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
+def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
     """Wrapper that loads spam checkers configured using the old configuration, and
     registers the spam checker hooks they implement.
     """
@@ -129,9 +129,9 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
                         request_info: Collection[Tuple[str, str]],
                         auth_provider_id: Optional[str],
                     ) -> Union[Awaitable[RegistrationBehaviour], RegistrationBehaviour]:
-                        # We've already made sure f is not None above, but mypy doesn't
-                        # do well across function boundaries so we need to tell it f is
-                        # definitely not None.
+                        # Assertion required because mypy can't prove we won't
+                        # change `f` back to `None`. See
+                        # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                         assert f is not None
 
                         return f(
@@ -146,9 +146,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
                         "Bad signature for callback check_registration_for_spam",
                     )
 
-            def run(*args, **kwargs):
-                # mypy doesn't do well across function boundaries so we need to tell it
-                # wrapped_func is definitely not None.
+            def run(*args: Any, **kwargs: Any) -> Awaitable:
+                # Assertion required because mypy can't prove we won't change `f`
+                # back to `None`. See
+                # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                 assert wrapped_func is not None
 
                 return maybe_awaitable(wrapped_func(*args, **kwargs))
@@ -165,7 +166,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
 
 
 class SpamChecker:
-    def __init__(self):
+    def __init__(self) -> None:
         self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
         self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
         self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
@@ -209,7 +210,7 @@ class SpamChecker:
             CHECK_REGISTRATION_FOR_SPAM_CALLBACK
         ] = None,
         check_media_file_for_spam: Optional[CHECK_MEDIA_FILE_FOR_SPAM_CALLBACK] = None,
-    ):
+    ) -> None:
         """Register callbacks from module for each hook."""
         if check_event_for_spam is not None:
             self._check_event_for_spam_callbacks.append(check_event_for_spam)
@@ -275,7 +276,9 @@ class SpamChecker:
 
         return False
 
-    async def user_may_join_room(self, user_id: str, room_id: str, is_invited: bool):
+    async def user_may_join_room(
+        self, user_id: str, room_id: str, is_invited: bool
+    ) -> bool:
         """Checks if a given users is allowed to join a room.
         Not called when a user creates a room.
 
@@ -285,7 +288,7 @@ class SpamChecker:
             is_invited: Whether the user is invited into the room
 
         Returns:
-            bool: Whether the user may join the room
+            Whether the user may join the room
         """
         for callback in self._user_may_join_room_callbacks:
             if await callback(user_id, room_id, is_invited) is False:

+ 13 - 12
synapse/events/third_party_rules.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
 
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
@@ -38,7 +38,7 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
 ]
 
 
-def load_legacy_third_party_event_rules(hs: "HomeServer"):
+def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
     """Wrapper that loads a third party event rules module configured using the old
     configuration, and registers the hooks they implement.
     """
@@ -77,9 +77,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
                 event: EventBase,
                 state_events: StateMap[EventBase],
             ) -> Tuple[bool, Optional[dict]]:
-                # We've already made sure f is not None above, but mypy doesn't do well
-                # across function boundaries so we need to tell it f is definitely not
-                # None.
+                # Assertion required because mypy can't prove we won't change
+                # `f` back to `None`. See
+                # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                 assert f is not None
 
                 res = await f(event, state_events)
@@ -98,9 +98,9 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
             async def wrap_on_create_room(
                 requester: Requester, config: dict, is_requester_admin: bool
             ) -> None:
-                # We've already made sure f is not None above, but mypy doesn't do well
-                # across function boundaries so we need to tell it f is definitely not
-                # None.
+                # Assertion required because mypy can't prove we won't change
+                # `f` back to `None`. See
+                # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
                 assert f is not None
 
                 res = await f(requester, config, is_requester_admin)
@@ -112,9 +112,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"):
 
             return wrap_on_create_room
 
-        def run(*args, **kwargs):
-            # mypy doesn't do well across function boundaries so we need to tell it
-            # f is definitely not None.
+        def run(*args: Any, **kwargs: Any) -> Awaitable:
+            # Assertion required because mypy can't prove we won't change  `f`
+            # back to `None`. See
+            # https://mypy.readthedocs.io/en/latest/common_issues.html#narrowing-and-inner-functions
             assert f is not None
 
             return maybe_awaitable(f(*args, **kwargs))
@@ -162,7 +163,7 @@ class ThirdPartyEventRules:
         check_visibility_can_be_modified: Optional[
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = None,
-    ):
+    ) -> None:
         """Register callbacks from modules for each hook."""
         if check_event_allowed is not None:
             self._check_event_allowed_callbacks.append(check_event_allowed)

+ 67 - 46
synapse/events/utils.py

@@ -13,18 +13,32 @@
 # limitations under the License.
 import collections.abc
 import re
-from typing import Any, Mapping, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Union,
+)
 
 from frozendict import frozendict
 
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
+from synapse.types import JsonDict
 from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.frozenutils import unfreeze
 
 from . import EventBase
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
 # (?<!stuff) matches if the current position in the string is not preceded
 # by a match for 'stuff'.
@@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
     return pruned_event
 
 
-def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
+def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
     """Redacts the event_dict in the same way as `prune_event`, except it
     operates on dicts rather than event objects
 
@@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
 
     new_content = {}
 
-    def add_fields(*fields):
+    def add_fields(*fields: str) -> None:
         for field in fields:
             if field in event_dict["content"]:
                 new_content[field] = event_dict["content"][field]
@@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
 
     allowed_fields["content"] = new_content
 
-    unsigned = {}
+    unsigned: JsonDict = {}
     allowed_fields["unsigned"] = unsigned
 
     event_unsigned = event_dict.get("unsigned", {})
@@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
     return allowed_fields
 
 
-def _copy_field(src, dst, field):
+def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
     """Copy the field in 'src' to 'dst'.
 
     For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
     then dst={"foo":{"bar":5}}.
 
     Args:
-        src(dict): The dict to read from.
-        dst(dict): The dict to modify.
-        field(list<str>): List of keys to drill down to in 'src'.
+        src: The dict to read from.
+        dst: The dict to modify.
+        field: List of keys to drill down to in 'src'.
     """
     if len(field) == 0:  # this should be impossible
         return
@@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
     sub_out_dict[key_to_move] = sub_dict[key_to_move]
 
 
-def only_fields(dictionary, fields):
+def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
     """Return a new dict with only the fields in 'dictionary' which are present
     in 'fields'.
 
@@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
     A literal '.' character in a field name may be escaped using a '\'.
 
     Args:
-        dictionary(dict): The dictionary to read from.
-        fields(list<str>): A list of fields to copy over. Only shallow refs are
+        dictionary: The dictionary to read from.
+        fields: A list of fields to copy over. Only shallow refs are
         taken.
     Returns:
-        dict: A new dictionary with only the given fields. If fields was empty,
+        A new dictionary with only the given fields. If fields was empty,
         the same dictionary is returned.
     """
     if len(fields) == 0:
@@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
         [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
     ]
 
-    output = {}
+    output: JsonDict = {}
     for field_array in split_fields:
         _copy_field(dictionary, output, field_array)
     return output
 
 
-def format_event_raw(d):
+def format_event_raw(d: JsonDict) -> JsonDict:
     return d
 
 
-def format_event_for_client_v1(d):
+def format_event_for_client_v1(d: JsonDict) -> JsonDict:
     d = format_event_for_client_v2(d)
 
     sender = d.get("sender")
@@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
     return d
 
 
-def format_event_for_client_v2(d):
+def format_event_for_client_v2(d: JsonDict) -> JsonDict:
     drop_keys = (
         "auth_events",
         "prev_events",
@@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
     return d
 
 
-def format_event_for_client_v2_without_room_id(d):
+def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
     d = format_event_for_client_v2(d)
     d.pop("room_id", None)
     return d
 
 
 def serialize_event(
-    e,
-    time_now_ms,
-    as_client_event=True,
-    event_format=format_event_for_client_v1,
-    token_id=None,
-    only_event_fields=None,
-    include_stripped_room_state=False,
-):
+    e: Union[JsonDict, EventBase],
+    time_now_ms: int,
+    as_client_event: bool = True,
+    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
+    token_id: Optional[str] = None,
+    only_event_fields: Optional[List[str]] = None,
+    include_stripped_room_state: bool = False,
+) -> JsonDict:
     """Serialize event for clients
 
     Args:
-        e (EventBase)
-        time_now_ms (int)
-        as_client_event (bool)
+        e
+        time_now_ms
+        as_client_event
         event_format
         token_id
         only_event_fields
-        include_stripped_room_state (bool): Some events can have stripped room state
+        include_stripped_room_state: Some events can have stripped room state
             stored in the `unsigned` field. This is required for invite and knock
             functionality. If this option is False, that state will be removed from the
             event before it is returned. Otherwise, it will be kept.
 
     Returns:
-        dict
+        The serialized event dictionary.
     """
 
     # FIXME(erikj): To handle the case of presence events and the like
@@ -369,25 +383,29 @@ class EventClientSerializer:
     clients.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.experimental_msc1849_support_enabled = (
             hs.config.server.experimental_msc1849_support_enabled
         )
 
     async def serialize_event(
-        self, event, time_now, bundle_aggregations=True, **kwargs
-    ):
+        self,
+        event: Union[JsonDict, EventBase],
+        time_now: int,
+        bundle_aggregations: bool = True,
+        **kwargs: Any,
+    ) -> JsonDict:
         """Serializes a single event.
 
         Args:
-            event (EventBase)
-            time_now (int): The current time in milliseconds
-            bundle_aggregations (bool): Whether to bundle in related events
+            event
+            time_now: The current time in milliseconds
+            bundle_aggregations: Whether to bundle in related events
             **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
-            dict: The serialized event
+            The serialized event
         """
         # To handle the case of presence events and the like
         if not isinstance(event, EventBase):
@@ -448,25 +466,27 @@ class EventClientSerializer:
 
         return serialized_event
 
-    def serialize_events(self, events, time_now, **kwargs):
+    async def serialize_events(
+        self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
+    ) -> List[JsonDict]:
         """Serializes multiple events.
 
         Args:
-            event (iter[EventBase])
-            time_now (int): The current time in milliseconds
+            event
+            time_now: The current time in milliseconds
             **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
-            Deferred[list[dict]]: The list of serialized events
+            The list of serialized events
         """
-        return yieldable_gather_results(
+        return await yieldable_gather_results(
             self.serialize_event, events, time_now=time_now, **kwargs
         )
 
 
 def copy_power_levels_contents(
     old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
-):
+) -> Dict[str, Union[int, Dict[str, int]]]:
     """Copy the content of a power_levels event, unfreezing frozendicts along the way
 
     Raises:
@@ -475,7 +495,7 @@ def copy_power_levels_contents(
     if not isinstance(old_power_levels, collections.abc.Mapping):
         raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
 
-    power_levels = {}
+    power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
     for k, v in old_power_levels.items():
 
         if isinstance(v, int):
@@ -483,7 +503,8 @@ def copy_power_levels_contents(
             continue
 
         if isinstance(v, collections.abc.Mapping):
-            power_levels[k] = h = {}
+            h: Dict[str, int] = {}
+            power_levels[k] = h
             for k1, v1 in v.items():
                 # we should only have one level of nesting
                 if not isinstance(v1, int):
@@ -498,7 +519,7 @@ def copy_power_levels_contents(
     return power_levels
 
 
-def validate_canonicaljson(value: Any):
+def validate_canonicaljson(value: Any) -> None:
     """
     Ensure that the JSON object is valid according to the rules of canonical JSON.
 

+ 10 - 8
synapse/events/validator.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import collections.abc
-from typing import Union
+from typing import Iterable, Union
 
 import jsonschema
 
@@ -28,11 +28,11 @@ from synapse.events.utils import (
     validate_canonicaljson,
 )
 from synapse.federation.federation_server import server_matches_acl_event
-from synapse.types import EventID, RoomID, UserID
+from synapse.types import EventID, JsonDict, RoomID, UserID
 
 
 class EventValidator:
-    def validate_new(self, event: EventBase, config: HomeServerConfig):
+    def validate_new(self, event: EventBase, config: HomeServerConfig) -> None:
         """Validates the event has roughly the right format
 
         Args:
@@ -116,7 +116,7 @@ class EventValidator:
                     errcode=Codes.BAD_JSON,
                 )
 
-    def _validate_retention(self, event: EventBase):
+    def _validate_retention(self, event: EventBase) -> None:
         """Checks that an event that defines the retention policy for a room respects the
         format enforced by the spec.
 
@@ -156,7 +156,7 @@ class EventValidator:
                 errcode=Codes.BAD_JSON,
             )
 
-    def validate_builder(self, event: Union[EventBase, EventBuilder]):
+    def validate_builder(self, event: Union[EventBase, EventBuilder]) -> None:
         """Validates that the builder/event has roughly the right format. Only
         checks values that we expect a proto event to have, rather than all the
         fields an event would have
@@ -204,14 +204,14 @@ class EventValidator:
 
             self._ensure_state_event(event)
 
-    def _ensure_strings(self, d, keys):
+    def _ensure_strings(self, d: JsonDict, keys: Iterable[str]) -> None:
         for s in keys:
             if s not in d:
                 raise SynapseError(400, "'%s' not in content" % (s,))
             if not isinstance(d[s], str):
                 raise SynapseError(400, "'%s' not a string type" % (s,))
 
-    def _ensure_state_event(self, event):
+    def _ensure_state_event(self, event: Union[EventBase, EventBuilder]) -> None:
         if not event.is_state():
             raise SynapseError(400, "'%s' must be state events" % (event.type,))
 
@@ -244,7 +244,9 @@ POWER_LEVELS_SCHEMA = {
 }
 
 
-def _create_power_level_validator():
+# This could return something newer than Draft 7, but that's the current "latest"
+# validator.
+def _create_power_level_validator() -> jsonschema.Draft7Validator:
     validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
 
     # by default jsonschema does not consider a frozendict to be an object so

+ 20 - 2
synapse/handlers/room.py

@@ -465,17 +465,35 @@ class RoomCreationHandler:
         # the room has been created
         # Calculate the minimum power level needed to clone the room
         event_power_levels = power_levels.get("events", {})
+        if not isinstance(event_power_levels, dict):
+            event_power_levels = {}
         state_default = power_levels.get("state_default", 50)
+        try:
+            state_default_int = int(state_default)  # type: ignore[arg-type]
+        except (TypeError, ValueError):
+            state_default_int = 50
         ban = power_levels.get("ban", 50)
-        needed_power_level = max(state_default, ban, max(event_power_levels.values()))
+        try:
+            ban = int(ban)  # type: ignore[arg-type]
+        except (TypeError, ValueError):
+            ban = 50
+        needed_power_level = max(
+            state_default_int, ban, max(event_power_levels.values())
+        )
 
         # Get the user's current power level, this matches the logic in get_user_power_level,
         # but without the entire state map.
         user_power_levels = power_levels.setdefault("users", {})
+        if not isinstance(user_power_levels, dict):
+            user_power_levels = {}
         users_default = power_levels.get("users_default", 0)
         current_power_level = user_power_levels.get(user_id, users_default)
+        try:
+            current_power_level_int = int(current_power_level)  # type: ignore[arg-type]
+        except (TypeError, ValueError):
+            current_power_level_int = 0
         # Raise the requester's power level in the new room if necessary
-        if current_power_level < needed_power_level:
+        if current_power_level_int < needed_power_level:
             user_power_levels[user_id] = needed_power_level
 
         await self._send_events_for_new_room(

+ 4 - 4
synapse/rest/client/relations.py

@@ -232,12 +232,12 @@ class RelationPaginationServlet(RestServlet):
         # Similarly, we don't allow relations to be applied to relations, so we
         # return the original relations without any aggregations on top of them
         # here.
-        events = await self._event_serializer.serialize_events(
+        serialized_events = await self._event_serializer.serialize_events(
             events, now, bundle_aggregations=False
         )
 
         return_value = pagination_chunk.to_dict()
-        return_value["chunk"] = events
+        return_value["chunk"] = serialized_events
         return_value["original_event"] = original_event
 
         return 200, return_value
@@ -416,10 +416,10 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         )
 
         now = self.clock.time_msec()
-        events = await self._event_serializer.serialize_events(events, now)
+        serialized_events = await self._event_serializer.serialize_events(events, now)
 
         return_value = result.to_dict()
-        return_value["chunk"] = events
+        return_value["chunk"] = serialized_events
 
         return 200, return_value