소스 검색

Wait for streams to catch up when processing HTTP replication. (#14820)

This should hopefully mitigate a class of races where data gets out of
sync due a HTTP replication request racing with the replication streams.
Erik Johnston 1 년 전
부모
커밋
9187fd940e

+ 1 - 0
changelog.d/14820.bugfix

@@ -0,0 +1 @@
+Fix rare races when using workers.

+ 4 - 0
synapse/handlers/federation_event.py

@@ -2259,6 +2259,10 @@ class FederationEventHandler:
                 event_and_contexts, backfilled=backfilled
             )
 
+            # After persistence we always need to notify replication there may
+            # be new data.
+            self._notifier.notify_replication()
+
             if self._ephemeral_messages_enabled:
                 for event in events:
                     # If there's an expiry timestamp on the event, schedule its expiry.

+ 88 - 9
synapse/replication/http/_base.py

@@ -17,7 +17,7 @@ import logging
 import re
 import urllib.parse
 from inspect import signature
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple
 
 from prometheus_client import Counter, Gauge
 
@@ -27,6 +27,7 @@ from twisted.web.server import Request
 from synapse.api.errors import HttpResponseException, SynapseError
 from synapse.http import RequestTimedOutError
 from synapse.http.server import HttpServer
+from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing
 from synapse.logging.opentracing import trace_with_opname
@@ -53,6 +54,9 @@ _outgoing_request_counter = Counter(
 )
 
 
+_STREAM_POSITION_KEY = "_INT_STREAM_POS"
+
+
 class ReplicationEndpoint(metaclass=abc.ABCMeta):
     """Helper base class for defining new replication HTTP endpoints.
 
@@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             a connection error is received.
         RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
             receiving connection errors, each will backoff exponentially longer.
+        WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
+            catch up before processing the request and/or response. Defaults to
+            True.
     """
 
     NAME: str = abc.abstractproperty()  # type: ignore
@@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
     RETRY_ON_CONNECT_ERROR = True
     RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5  # =63s (2^6-1)
 
+    WAIT_FOR_STREAMS: ClassVar[bool] = True
+
     def __init__(self, hs: "HomeServer"):
         if self.CACHE:
             self.response_cache: ResponseCache[str] = ResponseCache(
@@ -126,6 +135,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if hs.config.worker.worker_replication_secret:
             self._replication_secret = hs.config.worker.worker_replication_secret
 
+        self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
+        self._replication = hs.get_replication_data_handler()
+        self._instance_name = hs.get_instance_name()
+
     def _check_auth(self, request: Request) -> None:
         # Get the authorization header.
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@@ -160,7 +173,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     async def _handle_request(
-        self, request: Request, **kwargs: Any
+        self, request: Request, content: JsonDict, **kwargs: Any
     ) -> Tuple[int, JsonDict]:
         """Handle incoming request.
 
@@ -201,6 +214,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
         @trace_with_opname("outgoing_replication_request")
         async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
+            # We have to pull these out here to avoid circular dependencies...
+            streams = hs.get_replication_command_handler().get_streams_to_replicate()
+            replication = hs.get_replication_data_handler()
+
             with outgoing_gauge.track_inprogress():
                 if instance_name == local_instance_name:
                     raise Exception("Trying to send HTTP request to self")
@@ -219,6 +236,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
 
                 data = await cls._serialize_payload(**kwargs)
 
+                if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
+                    # Include the current stream positions that we write to. We
+                    # don't do this for GETs as they don't have a body, and we
+                    # generally assume that a GET won't rely on data we have
+                    # written.
+                    if _STREAM_POSITION_KEY in data:
+                        raise Exception(
+                            "data to send contains %r key", _STREAM_POSITION_KEY
+                        )
+
+                    data[_STREAM_POSITION_KEY] = {
+                        "streams": {
+                            stream.NAME: stream.current_token(local_instance_name)
+                            for stream in streams
+                        },
+                        "instance_name": local_instance_name,
+                    }
+
                 url_args = [
                     urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
                 ]
@@ -308,6 +343,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
                     ) from e
 
                 _outgoing_request_counter.labels(cls.NAME, 200).inc()
+
+                # Wait on any streams that the remote may have written to.
+                for stream_name, position in result.get(
+                    _STREAM_POSITION_KEY, {}
+                ).items():
+                    await replication.wait_for_stream_position(
+                        instance_name=instance_name,
+                        stream_name=stream_name,
+                        position=position,
+                        raise_on_timeout=False,
+                    )
+
                 return result
 
         return send_request
@@ -353,6 +400,23 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
         if self._replication_secret:
             self._check_auth(request)
 
+        if self.METHOD == "GET":
+            # GET APIs always have an empty body.
+            content = {}
+        else:
+            content = parse_json_object_from_request(request)
+
+        # Wait on any streams that the remote may have written to.
+        for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
+            "streams"
+        ].items():
+            await self._replication.wait_for_stream_position(
+                instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
+                stream_name=stream_name,
+                position=position,
+                raise_on_timeout=False,
+            )
+
         if self.CACHE:
             txn_id = kwargs.pop("txn_id")
 
@@ -361,13 +425,28 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
             # correctly yet. In particular, there may be issues to do with logging
             # context lifetimes.
 
-            return await self.response_cache.wrap(
-                txn_id, self._handle_request, request, **kwargs
+            code, response = await self.response_cache.wrap(
+                txn_id, self._handle_request, request, content, **kwargs
             )
+        else:
+            # The `@cancellable` decorator may be applied to `_handle_request`. But we
+            # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
+            # so we have to set up the cancellable flag ourselves.
+            request.is_render_cancellable = is_function_cancellable(
+                self._handle_request
+            )
+
+            code, response = await self._handle_request(request, content, **kwargs)
+
+        # Return streams we may have written to in the course of processing this
+        # request.
+        if _STREAM_POSITION_KEY in response:
+            raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)
 
-        # The `@cancellable` decorator may be applied to `_handle_request`. But we
-        # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
-        # so we have to set up the cancellable flag ourselves.
-        request.is_render_cancellable = is_function_cancellable(self._handle_request)
+        if self.WAIT_FOR_STREAMS:
+            response[_STREAM_POSITION_KEY] = {
+                stream.NAME: stream.current_token(self._instance_name)
+                for stream in self._streams
+            }
 
-        return await self._handle_request(request, **kwargs)
+        return code, response

+ 16 - 13
synapse/replication/http/account_data.py

@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
 from twisted.web.server import Request
 
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict
 
@@ -61,10 +60,8 @@ class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint):
         return payload
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str, account_data_type: str
+        self, request: Request, content: JsonDict, user_id: str, account_data_type: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         max_stream_id = await self.handler.add_account_data_for_user(
             user_id, account_data_type, content["content"]
         )
@@ -101,7 +98,7 @@ class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint):
         return {}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str, account_data_type: str
+        self, request: Request, content: JsonDict, user_id: str, account_data_type: str
     ) -> Tuple[int, JsonDict]:
         max_stream_id = await self.handler.remove_account_data_for_user(
             user_id, account_data_type
@@ -143,10 +140,13 @@ class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint):
         return payload
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str, room_id: str, account_data_type: str
+        self,
+        request: Request,
+        content: JsonDict,
+        user_id: str,
+        room_id: str,
+        account_data_type: str,
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         max_stream_id = await self.handler.add_account_data_to_room(
             user_id, room_id, account_data_type, content["content"]
         )
@@ -183,7 +183,12 @@ class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint):
         return {}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str, room_id: str, account_data_type: str
+        self,
+        request: Request,
+        content: JsonDict,
+        user_id: str,
+        room_id: str,
+        account_data_type: str,
     ) -> Tuple[int, JsonDict]:
         max_stream_id = await self.handler.remove_account_data_for_room(
             user_id, room_id, account_data_type
@@ -225,10 +230,8 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
         return payload
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str, room_id: str, tag: str
+        self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         max_stream_id = await self.handler.add_tag_to_room(
             user_id, room_id, tag, content["content"]
         )
@@ -266,7 +269,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
         return {}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str, room_id: str, tag: str
+        self, request: Request, content: JsonDict, user_id: str, room_id: str, tag: str
     ) -> Tuple[int, JsonDict]:
         max_stream_id = await self.handler.remove_tag_from_room(
             user_id,

+ 3 - 7
synapse/replication/http/devices.py

@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 from twisted.web.server import Request
 
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.logging.opentracing import active_span
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict
@@ -78,7 +77,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
         return {}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str
+        self, request: Request, content: JsonDict, user_id: str
     ) -> Tuple[int, Optional[JsonDict]]:
         user_devices = await self.device_list_updater.user_device_resync(user_id)
 
@@ -138,9 +137,8 @@ class ReplicationMultiUserDevicesResyncRestServlet(ReplicationEndpoint):
         return {"user_ids": user_ids}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request
+        self, request: Request, content: JsonDict
     ) -> Tuple[int, Dict[str, Optional[JsonDict]]]:
-        content = parse_json_object_from_request(request)
         user_ids: List[str] = content["user_ids"]
 
         logger.info("Resync for %r", user_ids)
@@ -205,10 +203,8 @@ class ReplicationUploadKeysForUserRestServlet(ReplicationEndpoint):
         }
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request
+        self, request: Request, content: JsonDict
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         user_id = content["user_id"]
         device_id = content["device_id"]
         keys = content["keys"]

+ 9 - 19
synapse/replication/http/federation.py

@@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict
 from synapse.util.metrics import Measure
@@ -114,10 +113,8 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
 
         return payload
 
-    async def _handle_request(self, request: Request) -> Tuple[int, JsonDict]:  # type: ignore[override]
+    async def _handle_request(self, request: Request, content: JsonDict) -> Tuple[int, JsonDict]:  # type: ignore[override]
         with Measure(self.clock, "repl_fed_send_events_parse"):
-            content = parse_json_object_from_request(request)
-
             room_id = content["room_id"]
             backfilled = content["backfilled"]
 
@@ -181,13 +178,10 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
         return {"origin": origin, "content": content}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, edu_type: str
+        self, request: Request, content: JsonDict, edu_type: str
     ) -> Tuple[int, JsonDict]:
-        with Measure(self.clock, "repl_fed_send_edu_parse"):
-            content = parse_json_object_from_request(request)
-
-            origin = content["origin"]
-            edu_content = content["content"]
+        origin = content["origin"]
+        edu_content = content["content"]
 
         logger.info("Got %r edu from %s", edu_type, origin)
 
@@ -231,13 +225,10 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
         return {"args": args}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, query_type: str
+        self, request: Request, content: JsonDict, query_type: str
     ) -> Tuple[int, JsonDict]:
-        with Measure(self.clock, "repl_fed_query_parse"):
-            content = parse_json_object_from_request(request)
-
-            args = content["args"]
-            args["origin"] = content["origin"]
+        args = content["args"]
+        args["origin"] = content["origin"]
 
         logger.info("Got %r query from %s", query_type, args["origin"])
 
@@ -274,7 +265,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
         return {}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, room_id: str
+        self, request: Request, content: JsonDict, room_id: str
     ) -> Tuple[int, JsonDict]:
         await self.store.clean_room_for_join(room_id)
 
@@ -307,9 +298,8 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
         return {"room_version": room_version.identifier}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, room_id: str
+        self, request: Request, content: JsonDict, room_id: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
         room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
         await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
         return 200, {}

+ 1 - 4
synapse/replication/http/login.py

@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple, cast
 from twisted.web.server import Request
 
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict
 
@@ -73,10 +72,8 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
         }
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str
+        self, request: Request, content: JsonDict, user_id: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         device_id = content["device_id"]
         initial_display_name = content["initial_display_name"]
         is_guest = content["is_guest"]

+ 10 - 12
synapse/replication/http/membership.py

@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
 from twisted.web.server import Request
 
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict, Requester, UserID
@@ -79,10 +78,8 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
         }
 
     async def _handle_request(  # type: ignore[override]
-        self, request: SynapseRequest, room_id: str, user_id: str
+        self, request: SynapseRequest, content: JsonDict, room_id: str, user_id: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         remote_room_hosts = content["remote_room_hosts"]
         event_content = content["content"]
 
@@ -147,11 +144,10 @@ class ReplicationRemoteKnockRestServlet(ReplicationEndpoint):
     async def _handle_request(  # type: ignore[override]
         self,
         request: SynapseRequest,
+        content: JsonDict,
         room_id: str,
         user_id: str,
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         remote_room_hosts = content["remote_room_hosts"]
         event_content = content["content"]
 
@@ -217,10 +213,8 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         }
 
     async def _handle_request(  # type: ignore[override]
-        self, request: SynapseRequest, invite_event_id: str
+        self, request: SynapseRequest, content: JsonDict, invite_event_id: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         txn_id = content["txn_id"]
         event_content = content["content"]
 
@@ -285,10 +279,9 @@ class ReplicationRemoteRescindKnockRestServlet(ReplicationEndpoint):
     async def _handle_request(  # type: ignore[override]
         self,
         request: SynapseRequest,
+        content: JsonDict,
         knock_event_id: str,
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         txn_id = content["txn_id"]
         event_content = content["content"]
 
@@ -347,7 +340,12 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
         return {}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, room_id: str, user_id: str, change: str
+        self,
+        request: Request,
+        content: JsonDict,
+        room_id: str,
+        user_id: str,
+        change: str,
     ) -> Tuple[int, JsonDict]:
         logger.info("user membership change: %s in %s", user_id, room_id)
 

+ 2 - 5
synapse/replication/http/presence.py

@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
 from twisted.web.server import Request
 
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict, UserID
 
@@ -56,7 +55,7 @@ class ReplicationBumpPresenceActiveTime(ReplicationEndpoint):
         return {}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str
+        self, request: Request, content: JsonDict, user_id: str
     ) -> Tuple[int, JsonDict]:
         await self._presence_handler.bump_presence_active_time(
             UserID.from_string(user_id)
@@ -107,10 +106,8 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
         }
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str
+        self, request: Request, content: JsonDict, user_id: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         await self._presence_handler.set_state(
             UserID.from_string(user_id),
             content["state"],

+ 1 - 4
synapse/replication/http/push.py

@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Tuple
 from twisted.web.server import Request
 
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict
 
@@ -61,10 +60,8 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
         return payload
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str
+        self, request: Request, content: JsonDict, user_id: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         app_id = content["app_id"]
         pushkey = content["pushkey"]
 

+ 2 - 7
synapse/replication/http/register.py

@@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
 from twisted.web.server import Request
 
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict
 
@@ -96,10 +95,8 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
         }
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str
+        self, request: Request, content: JsonDict, user_id: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         await self.registration_handler.check_registration_ratelimit(content["address"])
 
         # Always default admin users to approved (since it means they were created by
@@ -150,10 +147,8 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
         return {"auth_result": auth_result, "access_token": access_token}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, user_id: str
+        self, request: Request, content: JsonDict, user_id: str
     ) -> Tuple[int, JsonDict]:
-        content = parse_json_object_from_request(request)
-
         auth_result = content["auth_result"]
         access_token = content["access_token"]
 

+ 1 - 4
synapse/replication/http/send_event.py

@@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util.metrics import Measure
@@ -114,11 +113,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
         return payload
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, event_id: str
+        self, request: Request, content: JsonDict, event_id: str
     ) -> Tuple[int, JsonDict]:
         with Measure(self.clock, "repl_send_event_parse"):
-            content = parse_json_object_from_request(request)
-
             event_dict = content["event"]
             room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
             internal_metadata = content["internal_metadata"]

+ 1 - 3
synapse/replication/http/send_events.py

@@ -21,7 +21,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
 from synapse.http.server import HttpServer
-from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util.metrics import Measure
@@ -114,10 +113,9 @@ class ReplicationSendEventsRestServlet(ReplicationEndpoint):
         return payload
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request
+        self, request: Request, payload: JsonDict
     ) -> Tuple[int, JsonDict]:
         with Measure(self.clock, "repl_send_events_parse"):
-            payload = parse_json_object_from_request(request)
             events_and_context = []
             events = payload["events"]
 

+ 1 - 1
synapse/replication/http/state.py

@@ -57,7 +57,7 @@ class ReplicationUpdateCurrentStateRestServlet(ReplicationEndpoint):
         return {}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, room_id: str
+        self, request: Request, content: JsonDict, room_id: str
     ) -> Tuple[int, JsonDict]:
         writer_instance = self._events_shard_config.get_instance(room_id)
         if writer_instance != self._instance_name:

+ 5 - 1
synapse/replication/http/streams.py

@@ -54,6 +54,10 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
     PATH_ARGS = ("stream_name",)
     METHOD = "GET"
 
+    # We don't want to wait for replication streams to catch up, as this gets
+    # called in the process of catching replication streams up.
+    WAIT_FOR_STREAMS = False
+
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
@@ -67,7 +71,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
         return {"from_token": from_token, "upto_token": upto_token}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request, stream_name: str
+        self, request: Request, content: JsonDict, stream_name: str
     ) -> Tuple[int, JsonDict]:
         stream = self.streams.get(stream_name)
         if stream is None:

+ 23 - 2
synapse/replication/tcp/client.py

@@ -16,6 +16,7 @@
 import logging
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
+from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IAddress, IConnector
 from twisted.internet.protocol import ReconnectingClientFactory
@@ -314,10 +315,21 @@ class ReplicationDataHandler:
             self.send_handler.wake_destination(server)
 
     async def wait_for_stream_position(
-        self, instance_name: str, stream_name: str, position: int
+        self,
+        instance_name: str,
+        stream_name: str,
+        position: int,
+        raise_on_timeout: bool = True,
     ) -> None:
         """Wait until this instance has received updates up to and including
         the given stream position.
+
+        Args:
+            instance_name
+            stream_name
+            position
+            raise_on_timeout: Whether to raise an exception if we time out
+                waiting for the updates, or if we log an error and return.
         """
 
         if instance_name == self._instance_name:
@@ -345,7 +357,16 @@ class ReplicationDataHandler:
         # We measure here to get in flight counts and average waiting time.
         with Measure(self._clock, "repl.wait_for_stream_position"):
             logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
-            await make_deferred_yieldable(deferred)
+            try:
+                await make_deferred_yieldable(deferred)
+            except defer.TimeoutError:
+                logger.error("Timed out waiting for stream %s", stream_name)
+
+                if raise_on_timeout:
+                    raise
+
+                return
+
             logger.info(
                 "Finished waiting for repl stream %r to reach %s", stream_name, position
             )

+ 19 - 24
synapse/replication/tcp/resource.py

@@ -199,33 +199,28 @@ class ReplicationStreamer:
                             # The token has advanced but there is no data to
                             # send, so we send a `POSITION` to inform other
                             # workers of the updated position.
-                            if stream.NAME == EventsStream.NAME:
-                                # XXX: We only do this for the EventStream as it
-                                # turns out that e.g. account data streams share
-                                # their "current token" with each other, meaning
-                                # that it is *not* safe to send a POSITION.
-
-                                # Note: `last_token` may not *actually* be the
-                                # last token we sent out in a RDATA or POSITION.
-                                # This can happen if we sent out an RDATA for
-                                # position X when our current token was say X+1.
-                                # Other workers will see RDATA for X and then a
-                                # POSITION with last token of X+1, which will
-                                # cause them to check if there were any missing
-                                # updates between X and X+1.
-                                logger.info(
-                                    "Sending position: %s -> %s",
+
+                            # Note: `last_token` may not *actually* be the
+                            # last token we sent out in a RDATA or POSITION.
+                            # This can happen if we sent out an RDATA for
+                            # position X when our current token was say X+1.
+                            # Other workers will see RDATA for X and then a
+                            # POSITION with last token of X+1, which will
+                            # cause them to check if there were any missing
+                            # updates between X and X+1.
+                            logger.info(
+                                "Sending position: %s -> %s",
+                                stream.NAME,
+                                current_token,
+                            )
+                            self.command_handler.send_command(
+                                PositionCommand(
                                     stream.NAME,
+                                    self._instance_name,
+                                    last_token,
                                     current_token,
                                 )
-                                self.command_handler.send_command(
-                                    PositionCommand(
-                                        stream.NAME,
-                                        self._instance_name,
-                                        last_token,
-                                        current_token,
-                                    )
-                                )
+                            )
                             continue
 
                         # Some streams return multiple rows with the same stream IDs,

+ 19 - 15
synapse/storage/util/id_generators.py

@@ -378,6 +378,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             self._current_positions.values(), default=1
         )
 
+        if not writers:
+            # If there have been no explicit writers given then any instance can
+            # write to the stream. In which case, let's pre-seed our own
+            # position with the current minimum.
+            self._current_positions[self._instance_name] = self._persisted_upto_position
+
     def _load_current_ids(
         self,
         db_conn: LoggingDatabaseConnection,
@@ -695,24 +701,22 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         heapq.heappush(self._known_persisted_positions, new_id)
 
-        # If we're a writer and we don't have any active writes we update our
-        # current position to the latest position seen. This allows the instance
-        # to report a recent position when asked, rather than a potentially old
-        # one (if this instance hasn't written anything for a while).
-        our_current_position = self._current_positions.get(self._instance_name)
-        if (
-            our_current_position
-            and not self._unfinished_ids
-            and not self._in_flight_fetches
-        ):
-            self._current_positions[self._instance_name] = max(
-                our_current_position, new_id
-            )
-
         # We move the current min position up if the minimum current positions
         # of all instances is higher (since by definition all positions less
         # that that have been persisted).
-        min_curr = min(self._current_positions.values(), default=0)
+        our_current_position = self._current_positions.get(self._instance_name, 0)
+        min_curr = min(
+            (
+                token
+                for name, token in self._current_positions.items()
+                if name != self._instance_name
+            ),
+            default=our_current_position,
+        )
+
+        if our_current_position and (self._unfinished_ids or self._in_flight_fetches):
+            min_curr = min(min_curr, our_current_position)
+
         self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
 
         # We now iterate through the seen positions, discarding those that are

+ 6 - 0
synapse/types/__init__.py

@@ -604,6 +604,12 @@ class RoomStreamToken:
         elif self.instance_map:
             entries = []
             for name, pos in self.instance_map.items():
+                if pos <= self.stream:
+                    # Ignore instances who are below the minimum stream position
+                    # (we might know they've advanced without seeing a recent
+                    # write from them).
+                    continue
+
                 instance_id = await store.get_id_for_instance(name)
                 entries.append(f"{instance_id}.{pos}")
 

+ 5 - 4
tests/replication/http/test__base.py

@@ -44,7 +44,7 @@ class CancellableReplicationEndpoint(ReplicationEndpoint):
 
     @cancellable
     async def _handle_request(  # type: ignore[override]
-        self, request: Request
+        self, request: Request, content: JsonDict
     ) -> Tuple[int, JsonDict]:
         await self.clock.sleep(1.0)
         return HTTPStatus.OK, {"result": True}
@@ -54,6 +54,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
     NAME = "uncancellable_sleep"
     PATH_ARGS = ()
     CACHE = False
+    WAIT_FOR_STREAMS = False
 
     def __init__(self, hs: HomeServer):
         super().__init__(hs)
@@ -64,7 +65,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
         return {}
 
     async def _handle_request(  # type: ignore[override]
-        self, request: Request
+        self, request: Request, content: JsonDict
     ) -> Tuple[int, JsonDict]:
         await self.clock.sleep(1.0)
         return HTTPStatus.OK, {"result": True}
@@ -85,7 +86,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
     def test_cancellable_disconnect(self) -> None:
         """Test that handlers with the `@cancellable` flag can be cancelled."""
         path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
-        channel = self.make_request("POST", path, await_result=False)
+        channel = self.make_request("POST", path, await_result=False, content={})
         test_disconnect(
             self.reactor,
             channel,
@@ -96,7 +97,7 @@ class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
     def test_uncancellable_disconnect(self) -> None:
         """Test that handlers without the `@cancellable` flag cannot be cancelled."""
         path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
-        channel = self.make_request("POST", path, await_result=False)
+        channel = self.make_request("POST", path, await_result=False, content={})
         test_disconnect(
             self.reactor,
             channel,

+ 9 - 11
tests/storage/test_id_generators.py

@@ -349,8 +349,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         # The first ID gen will notice that it can advance its token to 7 as it
         # has no in progress writes...
-        self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 7})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
 
         # ... but the second ID gen doesn't know that.
@@ -366,8 +366,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
                 self.assertEqual(stream_id, 8)
 
                 self.assertEqual(
-                    first_id_gen.get_positions(), {"first": 7, "second": 7}
+                    first_id_gen.get_positions(), {"first": 3, "second": 7}
                 )
+                self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
 
         self.get_success(_get_next_async())
 
@@ -473,7 +474,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         id_gen = self._create_id_generator("first", writers=["first", "second"])
 
-        self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
+        self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
 
         self.assertEqual(id_gen.get_persisted_upto_position(), 5)
 
@@ -720,7 +721,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
 
         self.get_success(_get_next_async2())
 
-        self.assertEqual(id_gen_1.get_positions(), {"first": -2, "second": -2})
+        self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
         self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
         self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
@@ -816,15 +817,12 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
         first_id_gen = self._create_id_generator("first", writers=["first", "second"])
         second_id_gen = self._create_id_generator("second", writers=["first", "second"])
 
-        # The first ID gen will notice that it can advance its token to 7 as it
-        # has no in progress writes...
-        self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6})
-        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
+        self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
+        self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
         self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6)
         self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
 
-        # ... but the second ID gen doesn't know that.
         self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
         self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3)
         self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
-        self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
+        self.assertEqual(second_id_gen.get_persisted_upto_position(), 7)