|
@@ -16,6 +16,7 @@
|
|
|
import abc
|
|
|
from typing import (
|
|
|
TYPE_CHECKING,
|
|
|
+ Any,
|
|
|
Collection,
|
|
|
Dict,
|
|
|
Iterable,
|
|
@@ -39,6 +40,7 @@ from synapse.appservice import (
|
|
|
TransactionUnusedFallbackKeys,
|
|
|
)
|
|
|
from synapse.logging.opentracing import log_kv, set_tag, trace
|
|
|
+from synapse.replication.tcp.streams._base import DeviceListsStream
|
|
|
from synapse.storage._base import SQLBaseStore, db_to_json
|
|
|
from synapse.storage.database import (
|
|
|
DatabasePool,
|
|
@@ -104,6 +106,23 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|
|
self.hs.config.federation.allow_device_name_lookup_over_federation
|
|
|
)
|
|
|
|
|
|
+ def process_replication_rows(
|
|
|
+ self,
|
|
|
+ stream_name: str,
|
|
|
+ instance_name: str,
|
|
|
+ token: int,
|
|
|
+ rows: Iterable[Any],
|
|
|
+ ) -> None:
|
|
|
+ if stream_name == DeviceListsStream.NAME:
|
|
|
+ for row in rows:
|
|
|
+ assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
|
|
|
+ if row.entity.startswith("@"):
|
|
|
+ self._get_e2e_device_keys_for_federation_query_inner.invalidate(
|
|
|
+ (row.entity,)
|
|
|
+ )
|
|
|
+
|
|
|
+ super().process_replication_rows(stream_name, instance_name, token, rows)
|
|
|
+
|
|
|
async def get_e2e_device_keys_for_federation_query(
|
|
|
self, user_id: str
|
|
|
) -> Tuple[int, List[JsonDict]]:
|
|
@@ -114,6 +133,50 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|
|
"""
|
|
|
now_stream_id = self.get_device_stream_token()
|
|
|
|
|
|
+ # We need to be careful with the caching here, as we need to always
|
|
|
+ # return *all* persisted devices, however there may be a lag between a
|
|
|
+ # new device being persisted and the cache being invalidated.
|
|
|
+ cached_results = (
|
|
|
+ self._get_e2e_device_keys_for_federation_query_inner.cache.get_immediate(
|
|
|
+ user_id, None
|
|
|
+ )
|
|
|
+ )
|
|
|
+ if cached_results is not None:
|
|
|
+ # Check that there have been no new devices added by another worker
|
|
|
+ # after the cache. This should be quick as there should be few rows
|
|
|
+ # with a higher stream ordering.
|
|
|
+ #
|
|
|
+ # Note that we invalidate based on the device stream, so we only
|
|
|
+ # have to check for potential invalidations after the
|
|
|
+ # `now_stream_id`.
|
|
|
+ sql = """
|
|
|
+ SELECT user_id FROM device_lists_stream
|
|
|
+ WHERE stream_id >= ? AND user_id = ?
|
|
|
+ """
|
|
|
+ rows = await self.db_pool.execute(
|
|
|
+ "get_e2e_device_keys_for_federation_query_check",
|
|
|
+ None,
|
|
|
+ sql,
|
|
|
+ now_stream_id,
|
|
|
+ user_id,
|
|
|
+ )
|
|
|
+ if not rows:
|
|
|
+ # No new rows, so cache is still valid.
|
|
|
+ return now_stream_id, cached_results
|
|
|
+
|
|
|
+ # There has, so let's invalidate the cache and run the query.
|
|
|
+ self._get_e2e_device_keys_for_federation_query_inner.invalidate((user_id,))
|
|
|
+
|
|
|
+ results = await self._get_e2e_device_keys_for_federation_query_inner(user_id)
|
|
|
+
|
|
|
+ return now_stream_id, results
|
|
|
+
|
|
|
+ @cached(iterable=True)
|
|
|
+ async def _get_e2e_device_keys_for_federation_query_inner(
|
|
|
+ self, user_id: str
|
|
|
+ ) -> List[JsonDict]:
|
|
|
+ """Get all devices (with any device keys) for a user"""
|
|
|
+
|
|
|
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
|
|
|
|
|
|
if devices:
|
|
@@ -134,9 +197,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
|
|
|
|
|
results.append(result)
|
|
|
|
|
|
- return now_stream_id, results
|
|
|
+ return results
|
|
|
|
|
|
- return now_stream_id, []
|
|
|
+ return []
|
|
|
|
|
|
@trace
|
|
|
@cancellable
|