|
@@ -28,6 +28,8 @@ from typing import (
|
|
|
cast,
|
|
|
)
|
|
|
|
|
|
+from typing_extensions import Literal
|
|
|
+
|
|
|
from synapse.api.constants import EduTypes
|
|
|
from synapse.api.errors import Codes, StoreError
|
|
|
from synapse.logging.opentracing import (
|
|
@@ -44,6 +46,8 @@ from synapse.storage.database import (
|
|
|
LoggingTransaction,
|
|
|
make_tuple_comparison_clause,
|
|
|
)
|
|
|
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
|
|
|
+from synapse.storage.types import Cursor
|
|
|
from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
|
|
|
from synapse.util import json_decoder, json_encoder
|
|
|
from synapse.util.caches.descriptors import cached, cachedList
|
|
@@ -65,7 +69,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
|
|
|
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
|
|
|
|
|
|
|
|
-class DeviceWorkerStore(SQLBaseStore):
|
|
|
+class DeviceWorkerStore(EndToEndKeyWorkerStore):
|
|
|
def __init__(
|
|
|
self,
|
|
|
database: DatabasePool,
|
|
@@ -74,7 +78,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
):
|
|
|
super().__init__(database, db_conn, hs)
|
|
|
|
|
|
- device_list_max = self._device_list_id_gen.get_current_token()
|
|
|
+ # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
|
|
|
+ # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
|
|
|
+ device_list_max = self._device_list_id_gen.get_current_token() # type: ignore[attr-defined]
|
|
|
device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
|
|
|
db_conn,
|
|
|
"device_lists_stream",
|
|
@@ -339,8 +345,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
# following this stream later.
|
|
|
last_processed_stream_id = from_stream_id
|
|
|
|
|
|
- query_map = {}
|
|
|
- cross_signing_keys_by_user = {}
|
|
|
+ # A map of (user ID, device ID) to (stream ID, context).
|
|
|
+ query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
|
|
|
+ cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
|
|
|
for user_id, device_id, update_stream_id, update_context in updates:
|
|
|
# Calculate the remaining length budget.
|
|
|
# Note that, for now, each entry in `cross_signing_keys_by_user`
|
|
@@ -596,7 +603,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
txn=txn,
|
|
|
table="device_lists_outbound_last_success",
|
|
|
key_names=("destination", "user_id"),
|
|
|
- key_values=((destination, user_id) for user_id, _ in rows),
|
|
|
+ key_values=[(destination, user_id) for user_id, _ in rows],
|
|
|
value_names=("stream_id",),
|
|
|
value_values=((stream_id,) for _, stream_id in rows),
|
|
|
)
|
|
@@ -621,7 +628,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
The new stream ID.
|
|
|
"""
|
|
|
|
|
|
- async with self._device_list_id_gen.get_next() as stream_id:
|
|
|
+ # TODO: this looks like it's _writing_. Should this be on DeviceStore rather
|
|
|
+ # than DeviceWorkerStore?
|
|
|
+ async with self._device_list_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
|
|
|
await self.db_pool.runInteraction(
|
|
|
"add_user_sig_change_to_streams",
|
|
|
self._add_user_signature_change_txn,
|
|
@@ -686,7 +695,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
} - users_needing_resync
|
|
|
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
|
|
|
|
|
- results = {}
|
|
|
+ results: Dict[str, Dict[str, JsonDict]] = {}
|
|
|
for user_id, device_id in query_list:
|
|
|
if user_id not in user_ids_in_cache:
|
|
|
continue
|
|
@@ -727,7 +736,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
def get_cached_device_list_changes(
|
|
|
self,
|
|
|
from_key: int,
|
|
|
- ) -> Optional[Set[str]]:
|
|
|
+ ) -> Optional[List[str]]:
|
|
|
"""Get set of users whose devices have changed since `from_key`, or None
|
|
|
if that information is not in our cache.
|
|
|
"""
|
|
@@ -737,7 +746,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
async def get_users_whose_devices_changed(
|
|
|
self,
|
|
|
from_key: int,
|
|
|
- user_ids: Optional[Iterable[str]] = None,
|
|
|
+ user_ids: Optional[Collection[str]] = None,
|
|
|
to_key: Optional[int] = None,
|
|
|
) -> Set[str]:
|
|
|
"""Get set of users whose devices have changed since `from_key` that
|
|
@@ -757,6 +766,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
"""
|
|
|
# Get set of users who *may* have changed. Users not in the returned
|
|
|
# list have definitely not changed.
|
|
|
+ user_ids_to_check: Optional[Collection[str]]
|
|
|
if user_ids is None:
|
|
|
# Get set of all users that have had device list changes since 'from_key'
|
|
|
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
|
@@ -772,7 +782,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
return set()
|
|
|
|
|
|
def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
|
|
|
- changes = set()
|
|
|
+ changes: Set[str] = set()
|
|
|
|
|
|
stream_id_where_clause = "stream_id > ?"
|
|
|
sql_args = [from_key]
|
|
@@ -788,6 +798,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
"""
|
|
|
|
|
|
# Query device changes with a batch of users at a time
|
|
|
+ # Assertion for mypy's benefit; see also
|
|
|
+ # https://mypy.readthedocs.io/en/stable/common_issues.html#narrowing-and-inner-functions
|
|
|
+ assert user_ids_to_check is not None
|
|
|
for chunk in batch_iter(user_ids_to_check, 100):
|
|
|
clause, args = make_in_list_sql_clause(
|
|
|
txn.database_engine, "user_id", chunk
|
|
@@ -854,7 +867,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
if last_id == current_id:
|
|
|
return [], current_id, False
|
|
|
|
|
|
- def _get_all_device_list_changes_for_remotes(txn):
|
|
|
+ def _get_all_device_list_changes_for_remotes(
|
|
|
+ txn: Cursor,
|
|
|
+ ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
|
|
# This query Does The Right Thing where it'll correctly apply the
|
|
|
# bounds to the inner queries.
|
|
|
sql = """
|
|
@@ -913,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|
|
desc="get_device_list_last_stream_id_for_remotes",
|
|
|
)
|
|
|
|
|
|
- results = {user_id: None for user_id in user_ids}
|
|
|
+ results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
|
|
|
results.update({row["user_id"]: row["stream_id"] for row in rows})
|
|
|
|
|
|
return results
|
|
@@ -1337,9 +1352,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
|
|
|
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
|
|
# the device exists.
|
|
|
- self.device_id_exists_cache = LruCache(
|
|
|
- cache_name="device_id_exists", max_size=10000
|
|
|
- )
|
|
|
+ self.device_id_exists_cache: LruCache[
|
|
|
+ Tuple[str, str], Literal[True]
|
|
|
+ ] = LruCache(cache_name="device_id_exists", max_size=10000)
|
|
|
|
|
|
async def store_device(
|
|
|
self,
|
|
@@ -1651,7 +1666,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
context,
|
|
|
)
|
|
|
|
|
|
- async with self._device_list_id_gen.get_next_mult(
|
|
|
+ async with self._device_list_id_gen.get_next_mult( # type: ignore[attr-defined]
|
|
|
len(device_ids)
|
|
|
) as stream_ids:
|
|
|
await self.db_pool.runInteraction(
|
|
@@ -1704,7 +1719,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
device_ids: Iterable[str],
|
|
|
hosts: Collection[str],
|
|
|
stream_ids: List[int],
|
|
|
- context: Dict[str, str],
|
|
|
+ context: Optional[Dict[str, str]],
|
|
|
) -> None:
|
|
|
for host in hosts:
|
|
|
txn.call_after(
|
|
@@ -1875,7 +1890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|
|
[],
|
|
|
)
|
|
|
|
|
|
- async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
|
|
|
+ async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids: # type: ignore[attr-defined]
|
|
|
return await self.db_pool.runInteraction(
|
|
|
"add_device_list_outbound_pokes",
|
|
|
add_device_list_outbound_pokes_txn,
|