Browse Source

Batch up replication requests to request the resyncing of remote users's devices. (#14716)

reivilibre 1 year ago
parent
commit
ba4ea7d13f

+ 1 - 0
changelog.d/14716.misc

@@ -0,0 +1 @@
+Batch up replication requests to request the resyncing of remote users's devices.

+ 98 - 26
synapse/handlers/device.py

@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -33,6 +34,7 @@ from synapse.api.errors import (
     Codes,
     FederationDeniedError,
     HttpResponseException,
+    InvalidAPICallError,
     RequestSendFailed,
     SynapseError,
 )
@@ -45,6 +47,7 @@ from synapse.types import (
     JsonDict,
     StreamKeyType,
     StreamToken,
+    UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
@@ -893,12 +896,47 @@ class DeviceListWorkerUpdater:
 
     def __init__(self, hs: "HomeServer"):
         from synapse.replication.http.devices import (
+            ReplicationMultiUserDevicesResyncRestServlet,
             ReplicationUserDevicesResyncRestServlet,
         )
 
         self._user_device_resync_client = (
             ReplicationUserDevicesResyncRestServlet.make_client(hs)
         )
+        self._multi_user_device_resync_client = (
+            ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
+        )
+
+    async def multi_user_device_resync(
+        self, user_ids: List[str], mark_failed_as_stale: bool = True
+    ) -> Dict[str, Optional[JsonDict]]:
+        """
+        Like `user_device_resync` but operates on multiple users **from the same origin**
+        at once.
+
+        Returns:
+            Dict from User ID to the same Dict as `user_device_resync`.
+        """
+        # mark_failed_as_stale is not sent. Ensure this doesn't break expectations.
+        assert mark_failed_as_stale
+
+        if not user_ids:
+            # Shortcut empty requests
+            return {}
+
+        try:
+            return await self._multi_user_device_resync_client(user_ids=user_ids)
+        except SynapseError as err:
+            if not (
+                err.code == HTTPStatus.NOT_FOUND and err.errcode == Codes.UNRECOGNIZED
+            ):
+                raise
+
+            # Fall back to single requests
+            result: Dict[str, Optional[JsonDict]] = {}
+            for user_id in user_ids:
+                result[user_id] = await self._user_device_resync_client(user_id=user_id)
+            return result
 
     async def user_device_resync(
         self, user_id: str, mark_failed_as_stale: bool = True
@@ -913,8 +951,10 @@ class DeviceListWorkerUpdater:
             A dict with device info as under the "devices" in the result of this
             request:
             https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+            None when we weren't able to fetch the device info for some reason,
+            e.g. due to a connection problem.
         """
-        return await self._user_device_resync_client(user_id=user_id)
+        return (await self.multi_user_device_resync([user_id]))[user_id]
 
 
 class DeviceListUpdater(DeviceListWorkerUpdater):
@@ -1160,19 +1200,66 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
             # Allow future calls to retry resyncinc out of sync device lists.
             self._resync_retry_in_progress = False
 
+    async def multi_user_device_resync(
+        self, user_ids: List[str], mark_failed_as_stale: bool = True
+    ) -> Dict[str, Optional[JsonDict]]:
+        """
+        Like `user_device_resync` but operates on multiple users **from the same origin**
+        at once.
+
+        Returns:
+            Dict from User ID to the same Dict as `user_device_resync`.
+        """
+        if not user_ids:
+            return {}
+
+        origins = {UserID.from_string(user_id).domain for user_id in user_ids}
+
+        if len(origins) != 1:
+            raise InvalidAPICallError(f"Only one origin permitted, got {origins!r}")
+
+        result = {}
+        failed = set()
+        # TODO(Perf): Actually batch these up
+        for user_id in user_ids:
+            user_result, user_failed = await self._user_device_resync_returning_failed(
+                user_id
+            )
+            result[user_id] = user_result
+            if user_failed:
+                failed.add(user_id)
+
+        if mark_failed_as_stale:
+            await self.store.mark_remote_users_device_caches_as_stale(failed)
+
+        return result
+
     async def user_device_resync(
         self, user_id: str, mark_failed_as_stale: bool = True
     ) -> Optional[JsonDict]:
+        result, failed = await self._user_device_resync_returning_failed(user_id)
+
+        if failed and mark_failed_as_stale:
+            # Mark the remote user's device list as stale so we know we need to retry
+            # it later.
+            await self.store.mark_remote_users_device_caches_as_stale((user_id,))
+
+        return result
+
+    async def _user_device_resync_returning_failed(
+        self, user_id: str
+    ) -> Tuple[Optional[JsonDict], bool]:
         """Fetches all devices for a user and updates the device cache with them.
 
         Args:
             user_id: The user's id whose device_list will be updated.
-            mark_failed_as_stale: Whether to mark the user's device list as stale
-                if the attempt to resync failed.
         Returns:
-            A dict with device info as under the "devices" in the result of this
-            request:
-            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+            - A dict with device info as under the "devices" in the result of this
+              request:
+              https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+              None when we weren't able to fetch the device info for some reason,
+              e.g. due to a connection problem.
+            - True iff the resync failed and the device list should be marked as stale.
         """
         logger.debug("Attempting to resync the device list for %s", user_id)
         log_kv({"message": "Doing resync to update device list."})
@@ -1181,12 +1268,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
         try:
             result = await self.federation.query_user_devices(origin, user_id)
         except NotRetryingDestination:
-            if mark_failed_as_stale:
-                # Mark the remote user's device list as stale so we know we need to retry
-                # it later.
-                await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
-            return None
+            return None, True
         except (RequestSendFailed, HttpResponseException) as e:
             logger.warning(
                 "Failed to handle device list update for %s: %s",
@@ -1194,23 +1276,18 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
                 e,
             )
 
-            if mark_failed_as_stale:
-                # Mark the remote user's device list as stale so we know we need to retry
-                # it later.
-                await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
             # We abort on exceptions rather than accepting the update
             # as otherwise synapse will 'forget' that its device list
             # is out of date. If we bail then we will retry the resync
             # next time we get a device list update for this user_id.
             # This makes it more likely that the device lists will
             # eventually become consistent.
-            return None
+            return None, True
         except FederationDeniedError as e:
             set_tag("error", True)
             log_kv({"reason": "FederationDeniedError"})
             logger.info(e)
-            return None
+            return None, False
         except Exception as e:
             set_tag("error", True)
             log_kv(
@@ -1218,12 +1295,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
             )
             logger.exception("Failed to handle device list update for %s", user_id)
 
-            if mark_failed_as_stale:
-                # Mark the remote user's device list as stale so we know we need to retry
-                # it later.
-                await self.store.mark_remote_user_device_cache_as_stale(user_id)
-
-            return None
+            return None, True
         log_kv({"result": result})
         stream_id = result["stream_id"]
         devices = result["devices"]
@@ -1305,7 +1377,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater):
         # point.
         self._seen_updates[user_id] = {stream_id}
 
-        return result
+        return result, False
 
     async def process_cross_signing_key_update(
         self,

+ 1 - 1
synapse/handlers/devicemessage.py

@@ -195,7 +195,7 @@ class DeviceMessageHandler:
                 sender_user_id,
                 unknown_devices,
             )
-            await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+            await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
 
             # Immediately attempt a resync in the background
             run_in_background(self._user_device_resync, user_id=sender_user_id)

+ 55 - 38
synapse/handlers/e2e_keys.py

@@ -36,8 +36,8 @@ from synapse.types import (
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
-from synapse.util import json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer, delay_cancellation
+from synapse.util import json_decoder
+from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.cancellation import cancellable
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -238,24 +238,28 @@ class E2eKeysHandler:
             # Now fetch any devices that we don't have in our cache
             # TODO It might make sense to propagate cancellations into the
             #      deferreds which are querying remote homeservers.
-            await make_deferred_yieldable(
-                delay_cancellation(
-                    defer.gatherResults(
-                        [
-                            run_in_background(
-                                self._query_devices_for_destination,
-                                results,
-                                cross_signing_keys,
-                                failures,
-                                destination,
-                                queries,
-                                timeout,
-                            )
-                            for destination, queries in remote_queries_not_in_cache.items()
-                        ],
-                        consumeErrors=True,
-                    ).addErrback(unwrapFirstError)
+            logger.debug(
+                "%d destinations to query devices for", len(remote_queries_not_in_cache)
+            )
+
+            async def _query(
+                destination_queries: Tuple[str, Dict[str, Iterable[str]]]
+            ) -> None:
+                destination, queries = destination_queries
+                return await self._query_devices_for_destination(
+                    results,
+                    cross_signing_keys,
+                    failures,
+                    destination,
+                    queries,
+                    timeout,
                 )
+
+            await concurrently_execute(
+                _query,
+                remote_queries_not_in_cache.items(),
+                10,
+                delay_cancellation=True,
             )
 
             ret = {"device_keys": results, "failures": failures}
@@ -300,28 +304,41 @@ class E2eKeysHandler:
         # queries. We use the more efficient batched query_client_keys for all
         # remaining users
         user_ids_updated = []
-        for (user_id, device_list) in destination_query.items():
-            if user_id in user_ids_updated:
-                continue
 
-            if device_list:
-                continue
+        # Perform a user device resync for each user only once and only as long as:
+        # - they have an empty device_list
+        # - they are in some rooms that this server can see
+        users_to_resync_devices = {
+            user_id
+            for (user_id, device_list) in destination_query.items()
+            if (not device_list) and (await self.store.get_rooms_for_user(user_id))
+        }
 
-            room_ids = await self.store.get_rooms_for_user(user_id)
-            if not room_ids:
-                continue
+        logger.debug(
+            "%d users to resync devices for from destination %s",
+            len(users_to_resync_devices),
+            destination,
+        )
 
-            # We've decided we're sharing a room with this user and should
-            # probably be tracking their device lists. However, we haven't
-            # done an initial sync on the device list so we do it now.
-            try:
-                resync_results = (
-                    await self.device_handler.device_list_updater.user_device_resync(
-                        user_id
-                    )
+        try:
+            user_resync_results = (
+                await self.device_handler.device_list_updater.multi_user_device_resync(
+                    list(users_to_resync_devices)
                 )
+            )
+            for user_id in users_to_resync_devices:
+                resync_results = user_resync_results[user_id]
+
                 if resync_results is None:
-                    raise ValueError("Device resync failed")
+                    # TODO: It's weird that we'll store a failure against a
+                    #       destination, yet continue processing users from that
+                    #       destination.
+                    #       We might want to consider changing this, but for now
+                    #       I'm leaving it as I found it.
+                    failures[destination] = _exception_to_failure(
+                        ValueError(f"Device resync failed for {user_id!r}")
+                    )
+                    continue
 
                 # Add the device keys to the results.
                 user_devices = resync_results["devices"]
@@ -339,8 +356,8 @@ class E2eKeysHandler:
 
                 if self_signing_key:
                     cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
-            except Exception as e:
-                failures[destination] = _exception_to_failure(e)
+        except Exception as e:
+            failures[destination] = _exception_to_failure(e)
 
         if len(destination_query) == len(user_ids_updated):
             # We've updated all the users in the query and we do not need to

+ 1 - 1
synapse/handlers/federation_event.py

@@ -1423,7 +1423,7 @@ class FederationEventHandler:
         """
 
         try:
-            await self._store.mark_remote_user_device_cache_as_stale(sender)
+            await self._store.mark_remote_users_device_caches_as_stale((sender,))
 
             # Immediately attempt a resync in the background
             if self._config.worker.worker_app:

+ 73 - 1
synapse/replication/http/devices.py

@@ -13,12 +13,13 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
 from twisted.web.server import Request
 
 from synapse.http.server import HttpServer
 from synapse.http.servlet import parse_json_object_from_request
+from synapse.logging.opentracing import active_span
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict
 
@@ -84,6 +85,76 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
         return 200, user_devices
 
 
+class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
+    """Ask master to resync the device list for multiple users from the same
+    remote server by contacting their server.
+
+    This must happen on master so that the results can be correctly cached in
+    the database and streamed to workers.
+
+    Request format:
+
+        POST /_synapse/replication/multi_user_device_resync
+
+        {
+            "user_ids": ["@alice:example.org", "@bob:example.org", ...]
+        }
+
+    Response is roughly equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
+    response, but there is a map from user ID to response, e.g.:
+
+        {
+            "@alice:example.org": {
+                "devices": [
+                    {
+                        "device_id": "JLAFKJWSCS",
+                        "keys": { ... },
+                        "device_display_name": "Alice's Mobile Phone"
+                    }
+                ]
+            },
+            ...
+        }
+    """
+
+    NAME = "multi_user_device_resync"
+    PATH_ARGS = ()
+    CACHE = False
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
+        from synapse.handlers.device import DeviceHandler
+
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.device_list_updater = handler.device_list_updater
+
+        self.store = hs.get_datastores().main
+        self.clock = hs.get_clock()
+
+    @staticmethod
+    async def _serialize_payload(user_ids: List[str]) -> JsonDict:  # type: ignore[override]
+        return {"user_ids": user_ids}
+
+    async def _handle_request(  # type: ignore[override]
+        self, request: Request
+    ) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
+        content = parse_json_object_from_request(request)
+        user_ids: List[str] = content["user_ids"]
+
+        logger.info("Resync for %r", user_ids)
+        span = active_span()
+        if span:
+            span.set_tag("user_ids", f"{user_ids!r}")
+
+        multi_user_devices = await self.device_list_updater.multi_user_device_resync(
+            user_ids
+        )
+
+        return 200, multi_user_devices
+
+
 class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
     """Ask master to upload keys for the user and send them out over federation to
     update other servers.
@@ -151,4 +222,5 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
+    ReplicationMultiUserDevicesResyncRestServlet(hs).register(http_server)
     ReplicationUploadKeysForUserRestServlet(hs).register(http_server)

+ 22 - 8
synapse/storage/databases/main/devices.py

@@ -54,7 +54,7 @@ from synapse.storage.util.id_generators import (
     AbstractStreamIdTracker,
     StreamIdGenerator,
 )
-from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
+from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
@@ -1069,16 +1069,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         return {row["user_id"] for row in rows}
 
-    async def mark_remote_user_device_cache_as_stale(self, user_id: str) -> None:
+    async def mark_remote_users_device_caches_as_stale(
+        self, user_ids: StrCollection
+    ) -> None:
         """Records that the server has reason to believe the cache of the devices
         for the remote users is out of date.
         """
-        await self.db_pool.simple_upsert(
-            table="device_lists_remote_resync",
-            keyvalues={"user_id": user_id},
-            values={},
-            insertion_values={"added_ts": self._clock.time_msec()},
-            desc="mark_remote_user_device_cache_as_stale",
+
+        def _mark_remote_users_device_caches_as_stale_txn(
+            txn: LoggingTransaction,
+        ) -> None:
+            # TODO add insertion_values support to simple_upsert_many and use
+            #      that!
+            for user_id in user_ids:
+                self.db_pool.simple_upsert_txn(
+                    txn,
+                    table="device_lists_remote_resync",
+                    keyvalues={"user_id": user_id},
+                    values={},
+                    insertion_values={"added_ts": self._clock.time_msec()},
+                )
+
+        await self.db_pool.runInteraction(
+            "mark_remote_users_device_caches_as_stale",
+            _mark_remote_users_device_caches_as_stale_txn,
         )
 
     async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:

+ 4 - 0
synapse/types/__init__.py

@@ -77,6 +77,10 @@ JsonMapping = Mapping[str, Any]
 # A JSON-serialisable object.
 JsonSerializable = object
 
+# Collection[str] that does not include str itself; str being a Sequence[str]
+# is very misleading and results in bugs.
+StrCollection = Union[Tuple[str, ...], List[str], Set[str]]
+
 
 # Note that this seems to require inheriting *directly* from Interface in order
 # for mypy-zope to realize it is an interface.

+ 51 - 4
synapse/util/async_helpers.py

@@ -205,7 +205,10 @@ T = TypeVar("T")
 
 
 async def concurrently_execute(
-    func: Callable[[T], Any], args: Iterable[T], limit: int
+    func: Callable[[T], Any],
+    args: Iterable[T],
+    limit: int,
+    delay_cancellation: bool = False,
 ) -> None:
     """Executes the function with each argument concurrently while limiting
     the number of concurrent executions.
@@ -215,6 +218,8 @@ async def concurrently_execute(
         args: List of arguments to pass to func, each invocation of func
             gets a single argument.
         limit: Maximum number of conccurent executions.
+        delay_cancellation: Whether to delay cancellation until after the invocations
+            have finished.
 
     Returns:
         None, when all function invocations have finished. The return values
@@ -233,9 +238,16 @@ async def concurrently_execute(
     # We use `itertools.islice` to handle the case where the number of args is
     # less than the limit, avoiding needlessly spawning unnecessary background
     # tasks.
-    await yieldable_gather_results(
-        _concurrently_execute_inner, (value for value in itertools.islice(it, limit))
-    )
+    if delay_cancellation:
+        await yieldable_gather_results_delaying_cancellation(
+            _concurrently_execute_inner,
+            (value for value in itertools.islice(it, limit)),
+        )
+    else:
+        await yieldable_gather_results(
+            _concurrently_execute_inner,
+            (value for value in itertools.islice(it, limit)),
+        )
 
 
 P = ParamSpec("P")
@@ -292,6 +304,41 @@ async def yieldable_gather_results(
         raise dfe.subFailure.value from None
 
 
+async def yieldable_gather_results_delaying_cancellation(
+    func: Callable[Concatenate[T, P], Awaitable[R]],
+    iter: Iterable[T],
+    *args: P.args,
+    **kwargs: P.kwargs,
+) -> List[R]:
+    """Executes the function with each argument concurrently.
+    Cancellation is delayed until after all the results have been gathered.
+
+    See `yieldable_gather_results`.
+
+    Args:
+        func: Function to execute that returns a Deferred
+        iter: An iterable that yields items that get passed as the first
+            argument to the function
+        *args: Arguments to be passed to each call to func
+        **kwargs: Keyword arguments to be passed to each call to func
+
+    Returns
+        A list containing the results of the function
+    """
+    try:
+        return await make_deferred_yieldable(
+            delay_cancellation(
+                defer.gatherResults(
+                    [run_in_background(func, item, *args, **kwargs) for item in iter],  # type: ignore[arg-type]
+                    consumeErrors=True,
+                )
+            )
+        )
+    except defer.FirstError as dfe:
+        assert isinstance(dfe.subFailure.value, BaseException)
+        raise dfe.subFailure.value from None
+
+
 T1 = TypeVar("T1")
 T2 = TypeVar("T2")
 T3 = TypeVar("T3")