|
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
|
|
|
> ERROR server stopping
|
|
|
* connection closed by server *
|
|
|
"""
|
|
|
-import abc
|
|
|
import fcntl
|
|
|
import logging
|
|
|
import struct
|
|
@@ -64,22 +63,15 @@ from synapse.metrics import LaterGauge
|
|
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
|
|
from synapse.replication.tcp.commands import (
|
|
|
COMMAND_MAP,
|
|
|
- VALID_CLIENT_COMMANDS,
|
|
|
- VALID_SERVER_COMMANDS,
|
|
|
Command,
|
|
|
ErrorCommand,
|
|
|
NameCommand,
|
|
|
PingCommand,
|
|
|
- PositionCommand,
|
|
|
- RdataCommand,
|
|
|
RemoteServerUpCommand,
|
|
|
ReplicateCommand,
|
|
|
ServerCommand,
|
|
|
- SyncCommand,
|
|
|
- UserSyncCommand,
|
|
|
)
|
|
|
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
|
|
-from synapse.types import Collection
|
|
|
from synapse.util import Clock
|
|
|
from synapse.util.stringutils import random_string
|
|
|
|
|
@@ -128,16 +120,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|
|
|
|
|
delimiter = b"\n"
|
|
|
|
|
|
- # Valid commands we expect to receive
|
|
|
- VALID_INBOUND_COMMANDS = [] # type: Collection[str]
|
|
|
-
|
|
|
- # Valid commands we can send
|
|
|
- VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
|
|
|
-
|
|
|
max_line_buffer = 10000
|
|
|
|
|
|
- def __init__(self, clock):
|
|
|
+ def __init__(self, clock, handler):
|
|
|
self.clock = clock
|
|
|
+ self.handler = handler
|
|
|
|
|
|
self.last_received_command = self.clock.time_msec()
|
|
|
self.last_sent_command = 0
|
|
@@ -177,6 +164,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|
|
# can time us out.
|
|
|
self.send_command(PingCommand(self.clock.time_msec()))
|
|
|
|
|
|
+ self.handler.new_connection(self)
|
|
|
+
|
|
|
def send_ping(self):
|
|
|
"""Periodically sends a ping and checks if we should close the connection
|
|
|
due to the other side timing out.
|
|
@@ -214,11 +203,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|
|
line = line.decode("utf-8")
|
|
|
cmd_name, rest_of_line = line.split(" ", 1)
|
|
|
|
|
|
- if cmd_name not in self.VALID_INBOUND_COMMANDS:
|
|
|
- logger.error("[%s] invalid command %s", self.id(), cmd_name)
|
|
|
- self.send_error("invalid command: %s", cmd_name)
|
|
|
- return
|
|
|
-
|
|
|
self.last_received_command = self.clock.time_msec()
|
|
|
|
|
|
self.inbound_commands_counter[cmd_name] = (
|
|
@@ -250,8 +234,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|
|
Args:
|
|
|
cmd: received command
|
|
|
"""
|
|
|
- handler = getattr(self, "on_%s" % (cmd.NAME,))
|
|
|
- await handler(cmd)
|
|
|
+ handled = False
|
|
|
+
|
|
|
+ # First call any command handlers on this instance. These are for TCP
|
|
|
+ # specific handling.
|
|
|
+ cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
|
|
|
+ if cmd_func:
|
|
|
+ await cmd_func(cmd)
|
|
|
+ handled = True
|
|
|
+
|
|
|
+ # Then call out to the handler.
|
|
|
+ cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
|
|
+ if cmd_func:
|
|
|
+ await cmd_func(cmd)
|
|
|
+ handled = True
|
|
|
+
|
|
|
+ if not handled:
|
|
|
+ logger.warning("Unhandled command: %r", cmd)
|
|
|
|
|
|
def close(self):
|
|
|
logger.warning("[%s] Closing connection", self.id())
|
|
@@ -259,6 +258,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|
|
self.transport.loseConnection()
|
|
|
self.on_connection_closed()
|
|
|
|
|
|
+ def send_remote_server_up(self, server: str):
|
|
|
+ self.send_command(RemoteServerUpCommand(server))
|
|
|
+
|
|
|
def send_error(self, error_string, *args):
|
|
|
"""Send an error to remote and close the connection.
|
|
|
"""
|
|
@@ -380,6 +382,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|
|
self.state = ConnectionStates.CLOSED
|
|
|
self.pending_commands = []
|
|
|
|
|
|
+ self.handler.lost_connection(self)
|
|
|
+
|
|
|
if self.transport:
|
|
|
self.transport.unregisterProducer()
|
|
|
|
|
@@ -403,162 +407,35 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
|
|
|
|
|
|
|
|
class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|
|
- VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
|
|
- VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
|
|
-
|
|
|
- def __init__(self, server_name, clock, streamer):
|
|
|
- BaseReplicationStreamProtocol.__init__(self, clock) # Old style class
|
|
|
+ def __init__(self, hs, server_name, clock, handler):
|
|
|
+ BaseReplicationStreamProtocol.__init__(self, clock, handler) # Old style class
|
|
|
|
|
|
self.server_name = server_name
|
|
|
- self.streamer = streamer
|
|
|
|
|
|
def connectionMade(self):
|
|
|
self.send_command(ServerCommand(self.server_name))
|
|
|
BaseReplicationStreamProtocol.connectionMade(self)
|
|
|
- self.streamer.new_connection(self)
|
|
|
|
|
|
async def on_NAME(self, cmd):
|
|
|
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
|
|
|
self.name = cmd.data
|
|
|
|
|
|
- async def on_USER_SYNC(self, cmd):
|
|
|
- await self.streamer.on_user_sync(
|
|
|
- cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
|
|
|
- )
|
|
|
-
|
|
|
- async def on_CLEAR_USER_SYNC(self, cmd):
|
|
|
- await self.streamer.on_clear_user_syncs(cmd.instance_id)
|
|
|
-
|
|
|
- async def on_REPLICATE(self, cmd):
|
|
|
- # Subscribe to all streams we're publishing to.
|
|
|
- for stream_name in self.streamer.streams_by_name:
|
|
|
- current_token = self.streamer.get_stream_token(stream_name)
|
|
|
- self.send_command(PositionCommand(stream_name, current_token))
|
|
|
-
|
|
|
- async def on_FEDERATION_ACK(self, cmd):
|
|
|
- self.streamer.federation_ack(cmd.token)
|
|
|
-
|
|
|
- async def on_REMOVE_PUSHER(self, cmd):
|
|
|
- await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
|
|
|
-
|
|
|
- async def on_INVALIDATE_CACHE(self, cmd):
|
|
|
- await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
|
|
|
-
|
|
|
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
|
|
- self.streamer.on_remote_server_up(cmd.data)
|
|
|
-
|
|
|
- async def on_USER_IP(self, cmd):
|
|
|
- self.streamer.on_user_ip(
|
|
|
- cmd.user_id,
|
|
|
- cmd.access_token,
|
|
|
- cmd.ip,
|
|
|
- cmd.user_agent,
|
|
|
- cmd.device_id,
|
|
|
- cmd.last_seen,
|
|
|
- )
|
|
|
-
|
|
|
- def stream_update(self, stream_name, token, data):
|
|
|
- """Called when a new update is available to stream to clients.
|
|
|
-
|
|
|
- We need to check if the client is interested in the stream or not
|
|
|
- """
|
|
|
- self.send_command(RdataCommand(stream_name, token, data))
|
|
|
-
|
|
|
- def send_sync(self, data):
|
|
|
- self.send_command(SyncCommand(data))
|
|
|
-
|
|
|
- def send_remote_server_up(self, server: str):
|
|
|
- self.send_command(RemoteServerUpCommand(server))
|
|
|
-
|
|
|
- def on_connection_closed(self):
|
|
|
- BaseReplicationStreamProtocol.on_connection_closed(self)
|
|
|
- self.streamer.lost_connection(self)
|
|
|
-
|
|
|
-
|
|
|
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
|
|
|
- """
|
|
|
- The interface for the handler that should be passed to
|
|
|
- ClientReplicationStreamProtocol
|
|
|
- """
|
|
|
-
|
|
|
- @abc.abstractmethod
|
|
|
- async def on_rdata(self, stream_name, token, rows):
|
|
|
- """Called to handle a batch of replication data with a given stream token.
|
|
|
-
|
|
|
- Args:
|
|
|
- stream_name (str): name of the replication stream for this batch of rows
|
|
|
- token (int): stream token for this batch of rows
|
|
|
- rows (list): a list of Stream.ROW_TYPE objects as returned by
|
|
|
- Stream.parse_row.
|
|
|
- """
|
|
|
- raise NotImplementedError()
|
|
|
-
|
|
|
- @abc.abstractmethod
|
|
|
- async def on_position(self, stream_name, token):
|
|
|
- """Called when we get new position data."""
|
|
|
- raise NotImplementedError()
|
|
|
-
|
|
|
- @abc.abstractmethod
|
|
|
- def on_sync(self, data):
|
|
|
- """Called when get a new SYNC command."""
|
|
|
- raise NotImplementedError()
|
|
|
-
|
|
|
- @abc.abstractmethod
|
|
|
- async def on_remote_server_up(self, server: str):
|
|
|
- """Called when get a new REMOTE_SERVER_UP command."""
|
|
|
- raise NotImplementedError()
|
|
|
-
|
|
|
- @abc.abstractmethod
|
|
|
- def get_streams_to_replicate(self):
|
|
|
- """Called when a new connection has been established and we need to
|
|
|
- subscribe to streams.
|
|
|
-
|
|
|
- Returns:
|
|
|
- map from stream name to the most recent update we have for
|
|
|
- that stream (ie, the point we want to start replicating from)
|
|
|
- """
|
|
|
- raise NotImplementedError()
|
|
|
-
|
|
|
- @abc.abstractmethod
|
|
|
- def get_currently_syncing_users(self):
|
|
|
- """Get the list of currently syncing users (if any). This is called
|
|
|
- when a connection has been established and we need to send the
|
|
|
- currently syncing users."""
|
|
|
- raise NotImplementedError()
|
|
|
-
|
|
|
- @abc.abstractmethod
|
|
|
- def update_connection(self, connection):
|
|
|
- """Called when a connection has been established (or lost with None).
|
|
|
- """
|
|
|
- raise NotImplementedError()
|
|
|
-
|
|
|
- @abc.abstractmethod
|
|
|
- def finished_connecting(self):
|
|
|
- """Called when we have successfully subscribed and caught up to all
|
|
|
- streams we're interested in.
|
|
|
- """
|
|
|
- raise NotImplementedError()
|
|
|
-
|
|
|
|
|
|
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|
|
- VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
|
|
|
- VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
|
|
|
-
|
|
|
def __init__(
|
|
|
self,
|
|
|
hs: "HomeServer",
|
|
|
client_name: str,
|
|
|
server_name: str,
|
|
|
clock: Clock,
|
|
|
- handler: AbstractReplicationClientHandler,
|
|
|
+ handler,
|
|
|
):
|
|
|
- BaseReplicationStreamProtocol.__init__(self, clock)
|
|
|
+ BaseReplicationStreamProtocol.__init__(self, clock, handler)
|
|
|
|
|
|
self.instance_id = hs.get_instance_id()
|
|
|
|
|
|
self.client_name = client_name
|
|
|
self.server_name = server_name
|
|
|
- self.handler = handler
|
|
|
|
|
|
self.streams = {
|
|
|
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
|
@@ -574,105 +451,16 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|
|
self.pending_batches = {} # type: Dict[str, List[Any]]
|
|
|
|
|
|
def connectionMade(self):
|
|
|
- self.send_command(NameCommand(self.client_name))
|
|
|
BaseReplicationStreamProtocol.connectionMade(self)
|
|
|
|
|
|
- # Once we've connected subscribe to the necessary streams
|
|
|
+ self.send_command(NameCommand(self.client_name))
|
|
|
self.replicate()
|
|
|
|
|
|
- # Tell the server if we have any users currently syncing (should only
|
|
|
- # happen on synchrotrons)
|
|
|
- currently_syncing = self.handler.get_currently_syncing_users()
|
|
|
- now = self.clock.time_msec()
|
|
|
- for user_id in currently_syncing:
|
|
|
- self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
|
|
|
-
|
|
|
- # We've now finished connecting to so inform the client handler
|
|
|
- self.handler.update_connection(self)
|
|
|
-
|
|
|
async def on_SERVER(self, cmd):
|
|
|
if cmd.data != self.server_name:
|
|
|
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
|
|
self.send_error("Wrong remote")
|
|
|
|
|
|
- async def on_RDATA(self, cmd):
|
|
|
- stream_name = cmd.stream_name
|
|
|
- inbound_rdata_count.labels(stream_name).inc()
|
|
|
-
|
|
|
- try:
|
|
|
- row = STREAMS_MAP[stream_name].parse_row(cmd.row)
|
|
|
- except Exception:
|
|
|
- logger.exception(
|
|
|
- "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
|
|
|
- )
|
|
|
- raise
|
|
|
-
|
|
|
- if cmd.token is None or stream_name in self.streams_connecting:
|
|
|
- # I.e. this is part of a batch of updates for this stream. Batch
|
|
|
- # until we get an update for the stream with a non None token
|
|
|
- self.pending_batches.setdefault(stream_name, []).append(row)
|
|
|
- else:
|
|
|
- # Check if this is the last of a batch of updates
|
|
|
- rows = self.pending_batches.pop(stream_name, [])
|
|
|
- rows.append(row)
|
|
|
- await self.handler.on_rdata(stream_name, cmd.token, rows)
|
|
|
-
|
|
|
- async def on_POSITION(self, cmd: PositionCommand):
|
|
|
- stream = self.streams.get(cmd.stream_name)
|
|
|
- if not stream:
|
|
|
- logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
|
|
- return
|
|
|
-
|
|
|
- # Find where we previously streamed up to.
|
|
|
- current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
|
|
|
- if current_token is None:
|
|
|
- logger.warning(
|
|
|
- "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
|
|
|
- )
|
|
|
- return
|
|
|
-
|
|
|
- # Fetch all updates between then and now.
|
|
|
- limited = True
|
|
|
- while limited:
|
|
|
- updates, current_token, limited = await stream.get_updates_since(
|
|
|
- current_token, cmd.token
|
|
|
- )
|
|
|
-
|
|
|
- # Check if the connection was closed underneath us, if so we bail
|
|
|
- # rather than risk having concurrent catch ups going on.
|
|
|
- if self.state == ConnectionStates.CLOSED:
|
|
|
- return
|
|
|
-
|
|
|
- if updates:
|
|
|
- await self.handler.on_rdata(
|
|
|
- cmd.stream_name,
|
|
|
- current_token,
|
|
|
- [stream.parse_row(update[1]) for update in updates],
|
|
|
- )
|
|
|
-
|
|
|
- # We've now caught up to position sent to us, notify handler.
|
|
|
- await self.handler.on_position(cmd.stream_name, cmd.token)
|
|
|
-
|
|
|
- self.streams_connecting.discard(cmd.stream_name)
|
|
|
- if not self.streams_connecting:
|
|
|
- self.handler.finished_connecting()
|
|
|
-
|
|
|
- # Check if the connection was closed underneath us, if so we bail
|
|
|
- # rather than risk having concurrent catch ups going on.
|
|
|
- if self.state == ConnectionStates.CLOSED:
|
|
|
- return
|
|
|
-
|
|
|
- # Handle any RDATA that came in while we were catching up.
|
|
|
- rows = self.pending_batches.pop(cmd.stream_name, [])
|
|
|
- if rows:
|
|
|
- await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
|
|
|
-
|
|
|
- async def on_SYNC(self, cmd):
|
|
|
- self.handler.on_sync(cmd.data)
|
|
|
-
|
|
|
- async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
|
|
|
- self.handler.on_remote_server_up(cmd.data)
|
|
|
-
|
|
|
def replicate(self):
|
|
|
"""Send the subscription request to the server
|
|
|
"""
|
|
@@ -680,10 +468,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
|
|
|
|
|
self.send_command(ReplicateCommand())
|
|
|
|
|
|
- def on_connection_closed(self):
|
|
|
- BaseReplicationStreamProtocol.on_connection_closed(self)
|
|
|
- self.handler.update_connection(None)
|
|
|
-
|
|
|
|
|
|
# The following simply registers metrics for the replication connections
|
|
|
|