|
@@ -17,7 +17,7 @@ import logging
|
|
|
import re
|
|
|
import urllib.parse
|
|
|
from inspect import signature
|
|
|
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple
|
|
|
+from typing import TYPE_CHECKING, Any, Awaitable, Callable, ClassVar, Dict, List, Tuple
|
|
|
|
|
|
from prometheus_client import Counter, Gauge
|
|
|
|
|
@@ -27,6 +27,7 @@ from twisted.web.server import Request
|
|
|
from synapse.api.errors import HttpResponseException, SynapseError
|
|
|
from synapse.http import RequestTimedOutError
|
|
|
from synapse.http.server import HttpServer
|
|
|
+from synapse.http.servlet import parse_json_object_from_request
|
|
|
from synapse.http.site import SynapseRequest
|
|
|
from synapse.logging import opentracing
|
|
|
from synapse.logging.opentracing import trace_with_opname
|
|
@@ -53,6 +54,9 @@ _outgoing_request_counter = Counter(
|
|
|
)
|
|
|
|
|
|
|
|
|
+_STREAM_POSITION_KEY = "_INT_STREAM_POS"
|
|
|
+
|
|
|
+
|
|
|
class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
"""Helper base class for defining new replication HTTP endpoints.
|
|
|
|
|
@@ -94,6 +98,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
a connection error is received.
|
|
|
RETRY_ON_CONNECT_ERROR_ATTEMPTS (int): Number of attempts to retry when
|
|
|
receiving connection errors, each will backoff exponentially longer.
|
|
|
+ WAIT_FOR_STREAMS (bool): Whether to wait for replication streams to
|
|
|
+ catch up before processing the request and/or response. Defaults to
|
|
|
+ True.
|
|
|
"""
|
|
|
|
|
|
NAME: str = abc.abstractproperty() # type: ignore
|
|
@@ -104,6 +111,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
RETRY_ON_CONNECT_ERROR = True
|
|
|
RETRY_ON_CONNECT_ERROR_ATTEMPTS = 5 # =63s (2^6-1)
|
|
|
|
|
|
+ WAIT_FOR_STREAMS: ClassVar[bool] = True
|
|
|
+
|
|
|
def __init__(self, hs: "HomeServer"):
|
|
|
if self.CACHE:
|
|
|
self.response_cache: ResponseCache[str] = ResponseCache(
|
|
@@ -126,6 +135,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
if hs.config.worker.worker_replication_secret:
|
|
|
self._replication_secret = hs.config.worker.worker_replication_secret
|
|
|
|
|
|
+ self._streams = hs.get_replication_command_handler().get_streams_to_replicate()
|
|
|
+ self._replication = hs.get_replication_data_handler()
|
|
|
+ self._instance_name = hs.get_instance_name()
|
|
|
+
|
|
|
def _check_auth(self, request: Request) -> None:
|
|
|
# Get the authorization header.
|
|
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
|
@@ -160,7 +173,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
|
|
|
@abc.abstractmethod
|
|
|
async def _handle_request(
|
|
|
- self, request: Request, **kwargs: Any
|
|
|
+ self, request: Request, content: JsonDict, **kwargs: Any
|
|
|
) -> Tuple[int, JsonDict]:
|
|
|
"""Handle incoming request.
|
|
|
|
|
@@ -201,6 +214,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
|
|
|
@trace_with_opname("outgoing_replication_request")
|
|
|
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
|
|
|
+ # We have to pull these out here to avoid circular dependencies...
|
|
|
+ streams = hs.get_replication_command_handler().get_streams_to_replicate()
|
|
|
+ replication = hs.get_replication_data_handler()
|
|
|
+
|
|
|
with outgoing_gauge.track_inprogress():
|
|
|
if instance_name == local_instance_name:
|
|
|
raise Exception("Trying to send HTTP request to self")
|
|
@@ -219,6 +236,24 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
|
|
|
data = await cls._serialize_payload(**kwargs)
|
|
|
|
|
|
+ if cls.METHOD != "GET" and cls.WAIT_FOR_STREAMS:
|
|
|
+ # Include the current stream positions that we write to. We
|
|
|
+ # don't do this for GETs as they don't have a body, and we
|
|
|
+ # generally assume that a GET won't rely on data we have
|
|
|
+ # written.
|
|
|
+ if _STREAM_POSITION_KEY in data:
|
|
|
+ raise Exception(
|
|
|
+ "data to send contains %r key", _STREAM_POSITION_KEY
|
|
|
+ )
|
|
|
+
|
|
|
+ data[_STREAM_POSITION_KEY] = {
|
|
|
+ "streams": {
|
|
|
+ stream.NAME: stream.current_token(local_instance_name)
|
|
|
+ for stream in streams
|
|
|
+ },
|
|
|
+ "instance_name": local_instance_name,
|
|
|
+ }
|
|
|
+
|
|
|
url_args = [
|
|
|
urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
|
|
|
]
|
|
@@ -308,6 +343,18 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
) from e
|
|
|
|
|
|
_outgoing_request_counter.labels(cls.NAME, 200).inc()
|
|
|
+
|
|
|
+ # Wait on any streams that the remote may have written to.
|
|
|
+ for stream_name, position in result.get(
|
|
|
+ _STREAM_POSITION_KEY, {}
|
|
|
+ ).items():
|
|
|
+ await replication.wait_for_stream_position(
|
|
|
+ instance_name=instance_name,
|
|
|
+ stream_name=stream_name,
|
|
|
+ position=position,
|
|
|
+ raise_on_timeout=False,
|
|
|
+ )
|
|
|
+
|
|
|
return result
|
|
|
|
|
|
return send_request
|
|
@@ -353,6 +400,23 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
if self._replication_secret:
|
|
|
self._check_auth(request)
|
|
|
|
|
|
+ if self.METHOD == "GET":
|
|
|
+ # GET APIs always have an empty body.
|
|
|
+ content = {}
|
|
|
+ else:
|
|
|
+ content = parse_json_object_from_request(request)
|
|
|
+
|
|
|
+ # Wait on any streams that the remote may have written to.
|
|
|
+ for stream_name, position in content.get(_STREAM_POSITION_KEY, {"streams": {}})[
|
|
|
+ "streams"
|
|
|
+ ].items():
|
|
|
+ await self._replication.wait_for_stream_position(
|
|
|
+ instance_name=content[_STREAM_POSITION_KEY]["instance_name"],
|
|
|
+ stream_name=stream_name,
|
|
|
+ position=position,
|
|
|
+ raise_on_timeout=False,
|
|
|
+ )
|
|
|
+
|
|
|
if self.CACHE:
|
|
|
txn_id = kwargs.pop("txn_id")
|
|
|
|
|
@@ -361,13 +425,28 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
|
|
# correctly yet. In particular, there may be issues to do with logging
|
|
|
# context lifetimes.
|
|
|
|
|
|
- return await self.response_cache.wrap(
|
|
|
- txn_id, self._handle_request, request, **kwargs
|
|
|
+ code, response = await self.response_cache.wrap(
|
|
|
+ txn_id, self._handle_request, request, content, **kwargs
|
|
|
)
|
|
|
+ else:
|
|
|
+ # The `@cancellable` decorator may be applied to `_handle_request`. But we
|
|
|
+ # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
|
|
|
+ # so we have to set up the cancellable flag ourselves.
|
|
|
+ request.is_render_cancellable = is_function_cancellable(
|
|
|
+ self._handle_request
|
|
|
+ )
|
|
|
+
|
|
|
+ code, response = await self._handle_request(request, content, **kwargs)
|
|
|
+
|
|
|
+ # Return streams we may have written to in the course of processing this
|
|
|
+ # request.
|
|
|
+ if _STREAM_POSITION_KEY in response:
|
|
|
+ raise Exception("data to send contains %r key", _STREAM_POSITION_KEY)
|
|
|
|
|
|
- # The `@cancellable` decorator may be applied to `_handle_request`. But we
|
|
|
- # told `HttpServer.register_paths` that our handler is `_check_auth_and_handle`,
|
|
|
- # so we have to set up the cancellable flag ourselves.
|
|
|
- request.is_render_cancellable = is_function_cancellable(self._handle_request)
|
|
|
+ if self.WAIT_FOR_STREAMS:
|
|
|
+ response[_STREAM_POSITION_KEY] = {
|
|
|
+ stream.NAME: stream.current_token(self._instance_name)
|
|
|
+ for stream in self._streams
|
|
|
+ }
|
|
|
|
|
|
- return await self._handle_request(request, **kwargs)
|
|
|
+ return code, response
|