Browse Source

Add type hints for state. (#8140)

Patrick Cloke 3 years ago
parent
commit
5758dcf30c

+ 1 - 0
changelog.d/8140.misc

@@ -0,0 +1 @@
+Add type hints to `synapse.state`.

+ 47 - 0
stubs/frozendict.pyi

@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Stub for frozendict.
+
+from typing import (
+    Any,
+    Hashable,
+    Iterable,
+    Iterator,
+    Mapping,
+    overload,
+    Tuple,
+    TypeVar,
+)
+
+_KT = TypeVar("_KT", bound=Hashable)  # Key type.
+_VT = TypeVar("_VT")  # Value type.
+
+class frozendict(Mapping[_KT, _VT]):
+    @overload
+    def __init__(self, **kwargs: _VT) -> None: ...
+    @overload
+    def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
+    @overload
+    def __init__(
+        self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
+    ) -> None: ...
+    def __getitem__(self, key: _KT) -> _VT: ...
+    def __contains__(self, key: Any) -> bool: ...
+    def copy(self, **add_or_replace: Any) -> frozendict: ...
+    def __iter__(self) -> Iterator[_KT]: ...
+    def __len__(self) -> int: ...
+    def __repr__(self) -> str: ...
+    def __hash__(self) -> int: ...

+ 2 - 2
synapse/federation/sender/__init__.py

@@ -329,10 +329,10 @@ class FederationSender(object):
         room_id = receipt.room_id
 
         # Work out which remote servers should be poked and poke them.
-        domains = await self.state.get_current_hosts_in_room(room_id)
+        domains_set = await self.state.get_current_hosts_in_room(room_id)
         domains = [
             d
-            for d in domains
+            for d in domains_set
             if d != self.server_name
             and self._federation_shard_config.should_handle(self._instance_name, d)
         ]

+ 6 - 4
synapse/handlers/federation.py

@@ -2134,10 +2134,10 @@ class FederationHandler(BaseHandler):
             )
             state_sets = list(state_sets.values())
             state_sets.append(state)
-            current_state_ids = await self.state_handler.resolve_events(
+            current_states = await self.state_handler.resolve_events(
                 room_version, state_sets, event
             )
-            current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
+            current_state_ids = {k: e.event_id for k, e in current_states.items()}
         else:
             current_state_ids = await self.state_handler.get_current_state_ids(
                 event.room_id, latest_event_ids=extrem_ids
@@ -2149,9 +2149,11 @@ class FederationHandler(BaseHandler):
 
         # Now check if event pass auth against said current state
         auth_types = auth_types_for_event(event)
-        current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
+        current_state_ids_list = [
+            e for k, e in current_state_ids.items() if k in auth_types
+        ]
 
-        auth_events_map = await self.store.get_events(current_state_ids)
+        auth_events_map = await self.store.get_events(current_state_ids_list)
         current_auth_events = {
             (e.type, e.state_key): e for e in auth_events_map.values()
         }

+ 3 - 3
synapse/handlers/presence.py

@@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.state import StateHandler
 from synapse.storage.databases.main import DataStore
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.descriptors import cached
 from synapse.util.metrics import Measure
@@ -1318,7 +1318,7 @@ async def get_interested_parties(
 
 async def get_interested_remotes(
     store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
-) -> List[Tuple[List[str], List[UserPresenceState]]]:
+) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
     """Given a list of presence states figure out which remote servers
     should be sent which.
 
@@ -1334,7 +1334,7 @@ async def get_interested_remotes(
         each tuple the list of UserPresenceState should be sent to each
         destination
     """
-    hosts_and_states = []
+    hosts_and_states = []  # type: List[Tuple[Collection[str], List[UserPresenceState]]]
 
     # First we look up the rooms each user is in (as well as any explicit
     # subscriptions), then for each distinct room we look up the remote

+ 12 - 8
synapse/handlers/room_member.py

@@ -17,7 +17,7 @@ import abc
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
 
 from unpaddedbase64 import encode_base64
 
@@ -38,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.storage.roommember import RoomsForUser
-from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
+from synapse.types import (
+    Collection,
+    JsonDict,
+    Requester,
+    RoomAlias,
+    RoomID,
+    StateMap,
+    UserID,
+)
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_joined_room, user_left_room
 
@@ -738,9 +746,7 @@ class RoomMemberHandler(object):
                 if prev_member_event.membership == Membership.JOIN:
                     await self._user_left_room(target_user, room_id)
 
-    async def _can_guest_join(
-        self, current_state_ids: Dict[Tuple[str, str], str]
-    ) -> bool:
+    async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
         """
         Returns whether a guest can join a room based on its current state.
         """
@@ -969,9 +975,7 @@ class RoomMemberHandler(object):
         )
         return stream_id
 
-    async def _is_host_in_room(
-        self, current_state_ids: Dict[Tuple[str, str], str]
-    ) -> bool:
+    async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
         # Have we just created the room, and is this about to be the very
         # first member event?
         create_event_id = current_state_ids.get(("m.room.create", ""))

+ 122 - 70
synapse/state/__init__.py

@@ -16,11 +16,22 @@
 
 import logging
 from collections import namedtuple
-from typing import Awaitable, Dict, Iterable, List, Optional, Set
+from typing import (
+    Awaitable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Union,
+    overload,
+)
 
 import attr
 from frozendict import frozendict
 from prometheus_client import Histogram
+from typing_extensions import Literal
 
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -30,7 +41,7 @@ from synapse.logging.utils import log_function
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
-from synapse.types import StateMap
+from synapse.types import Collection, StateMap
 from synapse.util import Clock
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -68,8 +79,14 @@ def _gen_state_id():
 class _StateCacheEntry(object):
     __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
 
-    def __init__(self, state, state_group, prev_group=None, delta_ids=None):
-        # dict[(str, str), str] map  from (type, state_key) to event_id
+    def __init__(
+        self,
+        state: StateMap[str],
+        state_group: Optional[int],
+        prev_group: Optional[int] = None,
+        delta_ids: Optional[StateMap[str]] = None,
+    ):
+        # A map from (type, state_key) to event_id.
         self.state = frozendict(state)
 
         # the ID of a state group if one and only one is involved.
@@ -107,24 +124,49 @@ class StateHandler(object):
         self.hs = hs
         self._state_resolution_handler = hs.get_state_resolution_handler()
 
+    @overload
     async def get_current_state(
-        self, room_id, event_type=None, state_key="", latest_event_ids=None
-    ):
-        """ Retrieves the current state for the room. This is done by
+        self,
+        room_id: str,
+        event_type: Literal[None] = None,
+        state_key: str = "",
+        latest_event_ids: Optional[List[str]] = None,
+    ) -> StateMap[EventBase]:
+        ...
+
+    @overload
+    async def get_current_state(
+        self,
+        room_id: str,
+        event_type: str,
+        state_key: str = "",
+        latest_event_ids: Optional[List[str]] = None,
+    ) -> Optional[EventBase]:
+        ...
+
+    async def get_current_state(
+        self,
+        room_id: str,
+        event_type: Optional[str] = None,
+        state_key: str = "",
+        latest_event_ids: Optional[List[str]] = None,
+    ) -> Union[Optional[EventBase], StateMap[EventBase]]:
+        """Retrieves the current state for the room. This is done by
         calling `get_latest_events_in_room` to get the leading edges of the
         event graph and then resolving any of the state conflicts.
 
         This is equivalent to getting the state of an event that were to send
         next before receiving any new events.
 
-        If `event_type` is specified, then the method returns only the one
-        event (or None) with that `event_type` and `state_key`.
-
         Returns:
-            map from (type, state_key) to event
+            If `event_type` is specified, then the method returns only the one
+            event (or None) with that `event_type` and `state_key`.
+
+            Otherwise, a map from (type, state_key) to event.
         """
         if not latest_event_ids:
             latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+        assert latest_event_ids is not None
 
         logger.debug("calling resolve_state_groups from get_current_state")
         ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
@@ -140,34 +182,30 @@ class StateHandler(object):
         state_map = await self.store.get_events(
             list(state.values()), get_prev_content=False
         )
-        state = {
+        return {
             key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
         }
 
-        return state
-
-    async def get_current_state_ids(self, room_id, latest_event_ids=None):
+    async def get_current_state_ids(
+        self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
+    ) -> StateMap[str]:
         """Get the current state, or the state at a set of events, for a room
 
         Args:
-            room_id (str):
-
-            latest_event_ids (iterable[str]|None): if given, the forward
-                extremities to resolve. If None, we look them up from the
-                database (via a cache)
+            room_id:
+            latest_event_ids: if given, the forward extremities to resolve. If
+                None, we look them up from the database (via a cache).
 
         Returns:
-            Deferred[dict[(str, str), str)]]: the state dict, mapping from
-                (event_type, state_key) -> event_id
+            the state dict, mapping from (event_type, state_key) -> event_id
         """
         if not latest_event_ids:
             latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+        assert latest_event_ids is not None
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
         ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        state = ret.state
-
-        return state
+        return dict(ret.state)
 
     async def get_current_users_in_room(
         self, room_id: str, latest_event_ids: Optional[List[str]] = None
@@ -183,32 +221,34 @@ class StateHandler(object):
         """
         if not latest_event_ids:
             latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+        assert latest_event_ids is not None
+
         logger.debug("calling resolve_state_groups from get_current_users_in_room")
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
-        joined_users = await self.store.get_joined_users_from_state(room_id, entry)
-        return joined_users
+        return await self.store.get_joined_users_from_state(room_id, entry)
 
-    async def get_current_hosts_in_room(self, room_id):
+    async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
         event_ids = await self.store.get_latest_event_ids_in_room(room_id)
         return await self.get_hosts_in_room_at_events(room_id, event_ids)
 
-    async def get_hosts_in_room_at_events(self, room_id, event_ids):
+    async def get_hosts_in_room_at_events(
+        self, room_id: str, event_ids: List[str]
+    ) -> Set[str]:
         """Get the hosts that were in a room at the given event ids
 
         Args:
-            room_id (str):
-            event_ids (list[str]):
+            room_id:
+            event_ids:
 
         Returns:
-            Deferred[list[str]]: the hosts in the room at the given events
+            The hosts in the room at the given events
         """
         entry = await self.resolve_state_groups_for_events(room_id, event_ids)
-        joined_hosts = await self.store.get_joined_hosts(room_id, entry)
-        return joined_hosts
+        return await self.store.get_joined_hosts(room_id, entry)
 
     async def compute_event_context(
         self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
-    ):
+    ) -> EventContext:
         """Build an EventContext structure for the event.
 
         This works out what the current state should be for the event, and
@@ -221,7 +261,7 @@ class StateHandler(object):
                 when receiving an event from federation where we don't have the
                 prev events for, e.g. when backfilling.
         Returns:
-            synapse.events.snapshot.EventContext:
+            The event context.
         """
 
         if event.internal_metadata.is_outlier():
@@ -275,7 +315,7 @@ class StateHandler(object):
                 event.room_id, event.prev_event_ids()
             )
 
-            state_ids_before_event = entry.state
+            state_ids_before_event = dict(entry.state)
             state_group_before_event = entry.state_group
             state_group_before_event_prev_group = entry.prev_group
             deltas_to_state_group_before_event = entry.delta_ids
@@ -346,19 +386,18 @@ class StateHandler(object):
         )
 
     @measure_func()
-    async def resolve_state_groups_for_events(self, room_id, event_ids):
+    async def resolve_state_groups_for_events(
+        self, room_id: str, event_ids: Iterable[str]
+    ) -> _StateCacheEntry:
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
         Args:
-            room_id (str)
-            event_ids (list[str])
-            explicit_room_version (str|None): If set uses the the given room
-                version to choose the resolution algorithm. If None, then
-                checks the database for room version.
+            room_id
+            event_ids
 
         Returns:
-            Deferred[_StateCacheEntry]: resolved state
+            The resolved state
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
@@ -394,7 +433,12 @@ class StateHandler(object):
         )
         return result
 
-    async def resolve_events(self, room_version, state_sets, event):
+    async def resolve_events(
+        self,
+        room_version: str,
+        state_sets: Collection[Iterable[EventBase]],
+        event: EventBase,
+    ) -> StateMap[EventBase]:
         logger.info(
             "Resolving state for %s with %d groups", event.room_id, len(state_sets)
         )
@@ -414,9 +458,7 @@ class StateHandler(object):
                 state_res_store=StateResolutionStore(self.store),
             )
 
-        new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
-
-        return new_state
+        return {key: state_map[ev_id] for key, ev_id in new_state.items()}
 
 
 class StateResolutionHandler(object):
@@ -444,7 +486,12 @@ class StateResolutionHandler(object):
 
     @log_function
     async def resolve_state_groups(
-        self, room_id, room_version, state_groups_ids, event_map, state_res_store
+        self,
+        room_id: str,
+        room_version: str,
+        state_groups_ids: Dict[int, StateMap[str]],
+        event_map: Optional[Dict[str, EventBase]],
+        state_res_store: "StateResolutionStore",
     ):
         """Resolves conflicts between a set of state groups
 
@@ -452,13 +499,13 @@ class StateResolutionHandler(object):
         not be called for a single state group
 
         Args:
-            room_id (str): room we are resolving for (used for logging and sanity checks)
-            room_version (str): version of the room
-            state_groups_ids (dict[int, dict[(str, str), str]]):
-                 map from state group id to the state in that state group
+            room_id: room we are resolving for (used for logging and sanity checks)
+            room_version: version of the room
+            state_groups_ids:
+                A map from state group id to the state in that state group
                 (where 'state' is a map from state key to event id)
 
-            event_map(dict[str,FrozenEvent]|None):
+            event_map:
                 a dict from event_id to event, for any events that we happen to
                 have in flight (eg, those currently being persisted). This will be
                 used as a starting point fof finding the state we need; any missing
@@ -466,10 +513,10 @@ class StateResolutionHandler(object):
 
                 If None, all events will be fetched via state_res_store.
 
-            state_res_store (StateResolutionStore)
+            state_res_store
 
         Returns:
-            _StateCacheEntry: resolved state
+            The resolved state
         """
         logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
 
@@ -530,21 +577,22 @@ class StateResolutionHandler(object):
             return cache
 
 
-def _make_state_cache_entry(new_state, state_groups_ids):
+def _make_state_cache_entry(
+    new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
+) -> _StateCacheEntry:
     """Given a resolved state, and a set of input state groups, pick one to base
     a new state group on (if any), and return an appropriately-constructed
     _StateCacheEntry.
 
     Args:
-        new_state (dict[(str, str), str]): resolved state map (mapping from
-           (type, state_key) to event_id)
+        new_state: resolved state map (mapping from (type, state_key) to event_id)
 
-        state_groups_ids (dict[int, dict[(str, str), str]]):
-                 map from state group id to the state in that state group
-                (where 'state' is a map from state key to event id)
+        state_groups_ids:
+            map from state group id to the state in that state group (where
+            'state' is a map from state key to event id)
 
     Returns:
-        _StateCacheEntry
+        The cache entry.
     """
     # if the new state matches any of the input state groups, we can
     # use that state group again. Otherwise we will generate a state_id
@@ -585,7 +633,7 @@ def resolve_events_with_store(
     clock: Clock,
     room_id: str,
     room_version: str,
-    state_sets: List[StateMap[str]],
+    state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "StateResolutionStore",
 ) -> Awaitable[StateMap[str]]:
@@ -633,15 +681,17 @@ class StateResolutionStore(object):
 
     store = attr.ib()
 
-    def get_events(self, event_ids, allow_rejected=False):
+    def get_events(
+        self, event_ids: Iterable[str], allow_rejected: bool = False
+    ) -> Awaitable[Dict[str, EventBase]]:
         """Get events from the database
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            allow_rejected (bool): If True return rejected events.
+            event_ids: The event_ids of the events to fetch
+            allow_rejected: If True return rejected events.
 
         Returns:
-            Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
+            An awaitable which resolves to a dict from event_id to event.
         """
 
         return self.store.get_events(
@@ -651,7 +701,9 @@ class StateResolutionStore(object):
             allow_rejected=allow_rejected,
         )
 
-    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+    def get_auth_chain_difference(
+        self, state_sets: List[Set[str]]
+    ) -> Awaitable[Set[str]]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
 
@@ -660,7 +712,7 @@ class StateResolutionStore(object):
         chain.
 
         Returns:
-            Deferred[Set[str]]: Set of event IDs.
+            An awaitable that resolves to a set of event IDs.
         """
 
         return self.store.get_auth_chain_difference(state_sets)

+ 59 - 28
synapse/state/v1.py

@@ -15,7 +15,17 @@
 
 import hashlib
 import logging
-from typing import Awaitable, Callable, Dict, List, Optional
+from typing import (
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 from synapse import event_auth
 from synapse.api.constants import EventTypes
@@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 async def resolve_events_with_store(
     room_id: str,
-    state_sets: List[StateMap[str]],
+    state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_map_factory: Callable[[List[str]], Awaitable],
-):
+    state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+) -> StateMap[str]:
     """
     Args:
         room_id: the room we are working in
@@ -56,8 +66,7 @@ async def resolve_events_with_store(
             an Awaitable that resolves to a dict of event_id to event.
 
     Returns:
-        Deferred[dict[(str, str), str]]:
-            a map from (type, state_key) to event_id.
+        A map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -75,8 +84,8 @@ async def resolve_events_with_store(
         "Asking for %d/%d conflicted events", len(needed_events), needed_event_count
     )
 
-    # dict[str, FrozenEvent]: a map from state event id to event. Only includes
-    # the state events which are in conflict (and those in event_map)
+    # A map from state event id to event. Only includes the state events which
+    # are in conflict (and those in event_map).
     state_map = await state_map_factory(needed_events)
     if event_map is not None:
         state_map.update(event_map)
@@ -91,8 +100,6 @@ async def resolve_events_with_store(
 
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
-    #
-    # dict[(str, str), str]: a map from state key to event id
     auth_events = _create_auth_events_from_maps(
         unconflicted_state, conflicted_state, state_map
     )
@@ -122,29 +129,30 @@ async def resolve_events_with_store(
     )
 
 
-def _seperate(state_sets):
+def _seperate(
+    state_sets: Iterable[StateMap[str]],
+) -> Tuple[StateMap[str], StateMap[Set[str]]]:
     """Takes the state_sets and figures out which keys are conflicted and
     which aren't. i.e., which have multiple different event_ids associated
     with them in different state sets.
 
     Args:
-        state_sets(iterable[dict[(str, str), str]]):
+        state_sets:
             List of dicts of (type, state_key) -> event_id, which are the
             different state groups to resolve.
 
     Returns:
-        (dict[(str, str), str], dict[(str, str), set[str]]):
-            A tuple of (unconflicted_state, conflicted_state), where:
+        A tuple of (unconflicted_state, conflicted_state), where:
 
-            unconflicted_state is a dict mapping (type, state_key)->event_id
-            for unconflicted state keys.
+        unconflicted_state is a dict mapping (type, state_key)->event_id
+        for unconflicted state keys.
 
-            conflicted_state is a dict mapping (type, state_key) to a set of
-            event ids for conflicted state keys.
+        conflicted_state is a dict mapping (type, state_key) to a set of
+        event ids for conflicted state keys.
     """
     state_set_iterator = iter(state_sets)
     unconflicted_state = dict(next(state_set_iterator))
-    conflicted_state = {}
+    conflicted_state = {}  # type: StateMap[Set[str]]
 
     for state_set in state_set_iterator:
         for key, value in state_set.items():
@@ -171,7 +179,21 @@ def _seperate(state_sets):
     return unconflicted_state, conflicted_state
 
 
-def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
+def _create_auth_events_from_maps(
+    unconflicted_state: StateMap[str],
+    conflicted_state: StateMap[Set[str]],
+    state_map: Dict[str, EventBase],
+) -> StateMap[str]:
+    """
+
+    Args:
+        unconflicted_state: The unconflicted state map.
+        conflicted_state: The conflicted state map.
+        state_map:
+
+    Returns:
+        A map from state key to event id.
+    """
     auth_events = {}
     for event_ids in conflicted_state.values():
         for event_id in event_ids:
@@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
                 keys = event_auth.auth_types_for_event(state_map[event_id])
                 for key in keys:
                     if key not in auth_events:
-                        event_id = unconflicted_state.get(key, None)
-                        if event_id:
-                            auth_events[key] = event_id
+                        auth_event_id = unconflicted_state.get(key, None)
+                        if auth_event_id:
+                            auth_events[key] = auth_event_id
     return auth_events
 
 
 def _resolve_with_state(
-    unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
+    unconflicted_state_ids: StateMap[str],
+    conflicted_state_ids: StateMap[Set[str]],
+    auth_event_ids: StateMap[str],
+    state_map: Dict[str, EventBase],
 ):
     conflicted_state = {}
     for key, event_ids in conflicted_state_ids.items():
@@ -215,7 +240,9 @@ def _resolve_with_state(
     return new_state
 
 
-def _resolve_state_events(conflicted_state, auth_events):
+def _resolve_state_events(
+    conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
+) -> StateMap[EventBase]:
     """ This is where we actually decide which of the conflicted state to
     use.
 
@@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
     return resolved_state
 
 
-def _resolve_auth_events(events, auth_events):
+def _resolve_auth_events(
+    events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
     reverse = list(reversed(_ordered_events(events)))
 
     auth_keys = {
@@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
     return event
 
 
-def _resolve_normal_events(events, auth_events):
+def _resolve_normal_events(
+    events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
     for event in _ordered_events(events):
         try:
             # The signatures have already been checked at this point
@@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
     return event
 
 
-def _ordered_events(events):
+def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
     def key_func(e):
         # we have to use utf-8 rather than ascii here because it turns out we allow
         # people to send us events with non-ascii event IDs :/

+ 167 - 88
synapse/state/v2.py

@@ -16,7 +16,21 @@
 import heapq
 import itertools
 import logging
-from typing import Dict, List, Optional
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    overload,
+)
+
+from typing_extensions import Literal
 
 import synapse.state
 from synapse import event_auth
@@ -40,10 +54,10 @@ async def resolve_events_with_store(
     clock: Clock,
     room_id: str,
     room_version: str,
-    state_sets: List[StateMap[str]],
+    state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "synapse.state.StateResolutionStore",
-):
+) -> StateMap[str]:
     """Resolves the state using the v2 state resolution algorithm
 
     Args:
@@ -63,8 +77,7 @@ async def resolve_events_with_store(
         state_res_store:
 
     Returns:
-        Deferred[dict[(str, str), str]]:
-            a map from (type, state_key) to event_id.
+        A map from (type, state_key) to event_id.
     """
 
     logger.debug("Computing conflicted state")
@@ -171,18 +184,23 @@ async def resolve_events_with_store(
     return resolved_state
 
 
-async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
+async def _get_power_level_for_sender(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
     """Return the power level of the sender of the given event according to
     their auth events.
 
     Args:
-        room_id (str)
-        event_id (str)
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        room_id
+        event_id
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[int]
+        The power level.
     """
     event = await _get_event(room_id, event_id, event_map, state_res_store)
 
@@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
         return int(level)
 
 
-async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
+async def _get_auth_chain_difference(
+    state_sets: Sequence[StateMap[str]],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
     that only appear in some but not all of the auth chains.
 
     Args:
-        state_sets (list)
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        state_sets
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[set[str]]: Set of event IDs
+        Set of event IDs
     """
 
     difference = await state_res_store.get_auth_chain_difference(
@@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
     return difference
 
 
-def _seperate(state_sets):
+def _seperate(
+    state_sets: Iterable[StateMap[str]],
+) -> Tuple[StateMap[str], StateMap[Set[str]]]:
     """Return the unconflicted and conflicted state. This is different than in
     the original algorithm, as this defines a key to be conflicted if one of
     the state sets doesn't have that key.
 
     Args:
-        state_sets (list)
+        state_sets
 
     Returns:
-        tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
-        conflicted state dict is a map from type/state_key to set of event IDs
+        A tuple of unconflicted and conflicted state. The conflicted state dict
+        is a map from type/state_key to set of event IDs
     """
     unconflicted_state = {}
     conflicted_state = {}
@@ -260,18 +284,20 @@ def _seperate(state_sets):
             event_ids.discard(None)
             conflicted_state[key] = event_ids
 
-    return unconflicted_state, conflicted_state
+    # mypy doesn't understand that discarding None above means that conflicted
+    # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
+    return unconflicted_state, conflicted_state  # type: ignore
 
 
-def _is_power_event(event):
+def _is_power_event(event: EventBase) -> bool:
     """Return whether or not the event is a "power event", as defined by the
     v2 state resolution algorithm
 
     Args:
-        event (FrozenEvent)
+        event
 
     Returns:
-        boolean
+        True if the event is a power event.
     """
     if (event.type, event.state_key) in (
         (EventTypes.PowerLevels, ""),
@@ -288,19 +314,23 @@ def _is_power_event(event):
 
 
 async def _add_event_and_auth_chain_to_graph(
-    graph, room_id, event_id, event_map, state_res_store, auth_diff
-):
+    graph: Dict[str, Set[str]],
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    auth_diff: Set[str],
+) -> None:
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
 
     Args:
-        graph (dict[str, set[str]]): A map from event ID to the events auth
-            event IDs
-        room_id (str): the room we are working in
-        event_id (str): Event to add to the graph
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
-        auth_diff (set[str]): Set of event IDs that are in the auth difference.
+        graph: A map from event ID to the events auth event IDs
+        room_id: the room we are working in
+        event_id: Event to add to the graph
+        event_map
+        state_res_store
+        auth_diff: Set of event IDs that are in the auth difference.
     """
 
     state = [event_id]
@@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
 
 
 async def _reverse_topological_power_sort(
-    clock, room_id, event_ids, event_map, state_res_store, auth_diff
-):
+    clock: Clock,
+    room_id: str,
+    event_ids: Iterable[str],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    auth_diff: Set[str],
+) -> List[str]:
     """Returns a list of the event_ids sorted by reverse topological ordering,
     and then by power level and origin_server_ts
 
     Args:
-        clock (Clock)
-        room_id (str): the room we are working in
-        event_ids (list[str]): The events to sort
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
-        auth_diff (set[str]): Set of event IDs that are in the auth difference.
+        clock
+        room_id: the room we are working in
+        event_ids: The events to sort
+        event_map
+        state_res_store
+        auth_diff: Set of event IDs that are in the auth difference.
 
     Returns:
-        Deferred[list[str]]: The sorted list
+        The sorted list
     """
 
-    graph = {}
+    graph = {}  # type: Dict[str, Set[str]]
     for idx, event_id in enumerate(event_ids, start=1):
         await _add_event_and_auth_chain_to_graph(
             graph, room_id, event_id, event_map, state_res_store, auth_diff
@@ -372,22 +407,28 @@ async def _reverse_topological_power_sort(
 
 
 async def _iterative_auth_checks(
-    clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
-):
+    clock: Clock,
+    room_id: str,
+    room_version: str,
+    event_ids: List[str],
+    base_state: StateMap[str],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> StateMap[str]:
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
 
     Args:
-        clock (Clock)
-        room_id (str)
-        room_version (str)
-        event_ids (list[str]): Ordered list of events to apply auth checks to
-        base_state (StateMap[str]): The set of state to start with
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        clock
+        room_id
+        room_version
+        event_ids: Ordered list of events to apply auth checks to
+        base_state: The set of state to start with
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[StateMap[str]]: Returns the final updated state
+        Returns the final updated state
     """
     resolved_state = base_state.copy()
     room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -439,21 +480,26 @@ async def _iterative_auth_checks(
 
 
 async def _mainline_sort(
-    clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
-):
+    clock: Clock,
+    room_id: str,
+    event_ids: List[str],
+    resolved_power_event_id: Optional[str],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> List[str]:
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
 
     Args:
-        clock (Clock)
-        room_id (str): room we're working in
-        event_ids (list[str]): Events to sort
-        resolved_power_event_id (str): The final resolved power level event ID
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        clock
+        room_id: room we're working in
+        event_ids: Events to sort
+        resolved_power_event_id: The final resolved power level event ID
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[list[str]]: The sorted list
+        The sorted list
     """
     if not event_ids:
         # It's possible for there to be no event IDs here to sort, so we can
@@ -505,59 +551,90 @@ async def _mainline_sort(
 
 
 async def _get_mainline_depth_for_event(
-    event, mainline_map, event_map, state_res_store
-):
+    event: EventBase,
+    mainline_map: Dict[str, int],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
     """Get the mainline depths for the given event based on the mainline map
 
     Args:
-        event (FrozenEvent)
-        mainline_map (dict[str, int]): Map from event_id to mainline depth for
-            events in the mainline.
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        event
+        mainline_map: Map from event_id to mainline depth for events in the mainline.
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[int]
+        The mainline depth
     """
 
     room_id = event.room_id
+    tmp_event = event  # type: Optional[EventBase]
 
     # We do an iterative search, replacing `event with the power level in its
     # auth events (if any)
-    while event:
+    while tmp_event:
         depth = mainline_map.get(event.event_id)
         if depth is not None:
             return depth
 
-        auth_events = event.auth_event_ids()
-        event = None
+        auth_events = tmp_event.auth_event_ids()
+        tmp_event = None
 
         for aid in auth_events:
             aev = await _get_event(
                 room_id, aid, event_map, state_res_store, allow_none=True
             )
             if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
-                event = aev
+                tmp_event = aev
                 break
 
     # Didn't find a power level auth event, so we just return 0
     return 0
 
 
-async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
+@overload
+async def _get_event(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    allow_none: Literal[False] = False,
+) -> EventBase:
+    ...
+
+
+@overload
+async def _get_event(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    allow_none: Literal[True],
+) -> Optional[EventBase]:
+    ...
+
+
+async def _get_event(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    allow_none: bool = False,
+) -> Optional[EventBase]:
     """Helper function to look up event in event_map, falling back to looking
     it up in the store
 
     Args:
-        room_id (str)
-        event_id (str)
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
-        allow_none (bool): if the event is not found, return None rather than raising
+        room_id
+        event_id
+        event_map
+        state_res_store
+        allow_none: if the event is not found, return None rather than raising
             an exception
 
     Returns:
-        Deferred[Optional[FrozenEvent]]
+        The event, or none if the event does not exist (and allow_none is True).
     """
     if event_id not in event_map:
         events = await state_res_store.get_events([event_id], allow_rejected=True)
@@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
     return event
 
 
-def lexicographical_topological_sort(graph, key):
+def lexicographical_topological_sort(
+    graph: Dict[str, Set[str]], key: Callable[[str], Any]
+) -> Generator[str, None, None]:
     """Performs a lexicographic reverse topological sort on the graph.
 
     This returns a reverse topological sort (i.e. if node A references B then B
@@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
     NOTE: `graph` is modified during the sort.
 
     Args:
-        graph (dict[str, set[str]]): A representation of the graph where each
-            node is a key in the dict and its value are the nodes edges.
-        key (func): A function that takes a node and returns a value that is
-            comparable and used to order nodes
+        graph: A representation of the graph where each node is a key in the
+            dict and its value are the nodes edges.
+        key: A function that takes a node and returns a value that is comparable
+            and used to order nodes
 
     Yields:
-        str: The next node in the topological sort
+        The next node in the topological sort
     """
 
     # Note, this is basically Kahn's algorithm except we look at nodes with no
     # outgoing edges, c.f.
     # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
     outdegree_map = graph
-    reverse_graph = {}
+    reverse_graph = {}  # type: Dict[str, Set[str]]
 
     # Lists of nodes with zero out degree. Is actually a tuple of
     # `(key(node), node)` so that sorting does the right thing

+ 1 - 0
tox.ini

@@ -209,6 +209,7 @@ commands = mypy \
             synapse/server.py \
             synapse/server_notices \
             synapse/spam_checker_api \
+            synapse/state \
             synapse/storage/databases/main/ui_auth.py \
             synapse/storage/database.py \
             synapse/storage/engines \