Browse Source

Type annotations in `synapse.databases.main.devices` (#13025)

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
David Robertson 1 year ago
parent
commit
97e9fbe1b2

+ 1 - 0
changelog.d/13025.misc

@@ -0,0 +1 @@
+Add type annotations to `synapse.storage.databases.main.devices`.

+ 0 - 1
mypy.ini

@@ -27,7 +27,6 @@ exclude = (?x)
   ^(
    |synapse/storage/databases/__init__.py
    |synapse/storage/databases/main/cache.py
-   |synapse/storage/databases/main/devices.py
    |synapse/storage/schema/
 
    |tests/api/test_auth.py

+ 1 - 2
synapse/replication/slave/storage/devices.py

@@ -19,13 +19,12 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
 from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.devices import DeviceWorkerStore
-from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
+class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
     def __init__(
         self,
         database: DatabasePool,

+ 1 - 0
synapse/storage/databases/main/__init__.py

@@ -195,6 +195,7 @@ class DataStore(
         self._min_stream_order_on_start = self.get_room_min_stream_ordering()
 
     def get_device_stream_token(self) -> int:
+        # TODO: shouldn't this be moved to `DeviceWorkerStore`?
         return self._device_list_id_gen.get_current_token()
 
     async def get_users(self) -> List[JsonDict]:

+ 33 - 18
synapse/storage/databases/main/devices.py

@@ -28,6 +28,8 @@ from typing import (
     cast,
 )
 
+from typing_extensions import Literal
+
 from synapse.api.constants import EduTypes
 from synapse.api.errors import Codes, StoreError
 from synapse.logging.opentracing import (
@@ -44,6 +46,8 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.types import Cursor
 from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
 BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
 
 
-class DeviceWorkerStore(SQLBaseStore):
+class DeviceWorkerStore(EndToEndKeyWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        device_list_max = self._device_list_id_gen.get_current_token()
+        # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
+        # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
+        device_list_max = self._device_list_id_gen.get_current_token()  # type: ignore[attr-defined]
         device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
             db_conn,
             "device_lists_stream",
@@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore):
         # following this stream later.
         last_processed_stream_id = from_stream_id
 
-        query_map = {}
-        cross_signing_keys_by_user = {}
+        # A map of (user ID, device ID) to (stream ID, context).
+        query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
+        cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
         for user_id, device_id, update_stream_id, update_context in updates:
             # Calculate the remaining length budget.
             # Note that, for now, each entry in `cross_signing_keys_by_user`
@@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
             txn=txn,
             table="device_lists_outbound_last_success",
             key_names=("destination", "user_id"),
-            key_values=((destination, user_id) for user_id, _ in rows),
+            key_values=[(destination, user_id) for user_id, _ in rows],
             value_names=("stream_id",),
             value_values=((stream_id,) for _, stream_id in rows),
         )
@@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore):
             The new stream ID.
         """
 
-        async with self._device_list_id_gen.get_next() as stream_id:
+        # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
+        #  than DeviceWorkerStore?
+        async with self._device_list_id_gen.get_next() as stream_id:  # type: ignore[attr-defined]
             await self.db_pool.runInteraction(
                 "add_user_sig_change_to_streams",
                 self._add_user_signature_change_txn,
@@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore):
         } - users_needing_resync
         user_ids_not_in_cache = user_ids - user_ids_in_cache
 
-        results = {}
+        results: Dict[str, Dict[str, JsonDict]] = {}
         for user_id, device_id in query_list:
             if user_id not in user_ids_in_cache:
                 continue
@@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore):
     def get_cached_device_list_changes(
         self,
         from_key: int,
-    ) -> Optional[Set[str]]:
+    ) -> Optional[List[str]]:
         """Get set of users whose devices have changed since `from_key`, or None
         if that information is not in our cache.
         """
@@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore):
     async def get_users_whose_devices_changed(
         self,
         from_key: int,
-        user_ids: Optional[Iterable[str]] = None,
+        user_ids: Optional[Collection[str]] = None,
         to_key: Optional[int] = None,
     ) -> Set[str]:
         """Get set of users whose devices have changed since `from_key` that
@@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         # Get set of users who *may* have changed. Users not in the returned
         # list have definitely not changed.
+        user_ids_to_check: Optional[Collection[str]]
         if user_ids is None:
             # Get set of all users that have had device list changes since 'from_key'
             user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
@@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore):
             return set()
 
         def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
-            changes = set()
+            changes: Set[str] = set()
 
             stream_id_where_clause = "stream_id > ?"
             sql_args = [from_key]
@@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore):
             """
 
             # Query device changes with a batch of users at a time
+            # Assertion for mypy's benefit; see also
+            # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
+            assert user_ids_to_check is not None
             for chunk in batch_iter(user_ids_to_check, 100):
                 clause, args = make_in_list_sql_clause(
                     txn.database_engine, "user_id", chunk
@@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def _get_all_device_list_changes_for_remotes(txn):
+        def _get_all_device_list_changes_for_remotes(
+            txn: Cursor,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             # This query Does The Right Thing where it'll correctly apply the
             # bounds to the inner queries.
             sql = """
@@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
             desc="get_device_list_last_stream_id_for_remotes",
         )
 
-        results = {user_id: None for user_id in user_ids}
+        results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
         results.update({row["user_id"]: row["stream_id"] for row in rows})
 
         return results
@@ -1337,9 +1352,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         # Map of (user_id, device_id) -> bool. If there is an entry that implies
         # the device exists.
-        self.device_id_exists_cache = LruCache(
-            cache_name="device_id_exists", max_size=10000
-        )
+        self.device_id_exists_cache: LruCache[
+            Tuple[str, str], Literal[True]
+        ] = LruCache(cache_name="device_id_exists", max_size=10000)
 
     async def store_device(
         self,
@@ -1651,7 +1666,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 context,
             )
 
-        async with self._device_list_id_gen.get_next_mult(
+        async with self._device_list_id_gen.get_next_mult(  # type: ignore[attr-defined]
             len(device_ids)
         ) as stream_ids:
             await self.db_pool.runInteraction(
@@ -1704,7 +1719,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         device_ids: Iterable[str],
         hosts: Collection[str],
         stream_ids: List[int],
-        context: Dict[str, str],
+        context: Optional[Dict[str, str]],
     ) -> None:
         for host in hosts:
             txn.call_after(
@@ -1875,7 +1890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 [],
             )
 
-        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+        async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:  # type: ignore[attr-defined]
             return await self.db_pool.runInteraction(
                 "add_device_list_outbound_pokes",
                 add_device_list_outbound_pokes_txn,