Преглед на файлове

Move command processing out of transport

Erik Johnston преди 4 години
родител
ревизия
22fc68762e

+ 10 - 11
synapse/app/generic_worker.py

@@ -64,8 +64,9 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.registration import SlavedRegistrationStore
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.room import RoomStore
 from synapse.replication.slave.storage.transactions import SlavedTransactionStore
 from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.replication.tcp.client import ReplicationClientFactory
 from synapse.replication.tcp.commands import ClearUserSyncsCommand
 from synapse.replication.tcp.commands import ClearUserSyncsCommand
+from synapse.replication.tcp.handler import WorkerReplicationDataHandler
 from synapse.replication.tcp.streams import (
 from synapse.replication.tcp.streams import (
     AccountDataStream,
     AccountDataStream,
     DeviceListsStream,
     DeviceListsStream,
@@ -598,25 +599,26 @@ class GenericWorkerServer(HomeServer):
             else:
             else:
                 logger.warning("Unrecognized listener type: %s", listener["type"])
                 logger.warning("Unrecognized listener type: %s", listener["type"])
 
 
-        self.get_tcp_replication().start_replication(self)
+        factory = ReplicationClientFactory(self, self.config.worker_name)
+        host = self.config.worker_replication_host
+        port = self.config.worker_replication_port
+        self.get_reactor().connectTCP(host, port, factory)
 
 
     def remove_pusher(self, app_id, push_key, user_id):
     def remove_pusher(self, app_id, push_key, user_id):
         self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
         self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
 
 
-    def build_tcp_replication(self):
-        return GenericWorkerReplicationHandler(self)
-
     def build_presence_handler(self):
     def build_presence_handler(self):
         return GenericWorkerPresence(self)
         return GenericWorkerPresence(self)
 
 
     def build_typing_handler(self):
     def build_typing_handler(self):
         return GenericWorkerTyping(self)
         return GenericWorkerTyping(self)
 
 
+    def build_replication_data_handler(self):
+        return GenericWorkerReplicationHandler(self)
 
 
-class GenericWorkerReplicationHandler(ReplicationClientHandler):
-    def __init__(self, hs):
-        super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
 
 
+class GenericWorkerReplicationHandler(WorkerReplicationDataHandler):
+    def __init__(self, hs):
         self.store = hs.get_datastore()
         self.store = hs.get_datastore()
         self.typing_handler = hs.get_typing_handler()
         self.typing_handler = hs.get_typing_handler()
         # NB this is a SynchrotronPresence, not a normal PresenceHandler
         # NB this is a SynchrotronPresence, not a normal PresenceHandler
@@ -644,9 +646,6 @@ class GenericWorkerReplicationHandler(ReplicationClientHandler):
             args.update(self.send_handler.stream_positions())
             args.update(self.send_handler.stream_positions())
         return args
         return args
 
 
-    def get_currently_syncing_users(self):
-        return self.presence_handler.get_currently_syncing_users()
-
     async def process_and_notify(self, stream_name, token, rows):
     async def process_and_notify(self, stream_name, token, rows):
         try:
         try:
             if self.send_handler:
             if self.send_handler:

+ 3 - 188
synapse/replication/tcp/client.py

@@ -16,26 +16,10 @@
 """
 """
 
 
 import logging
 import logging
-from typing import Dict, List, Optional
 
 
-from twisted.internet import defer
 from twisted.internet.protocol import ReconnectingClientFactory
 from twisted.internet.protocol import ReconnectingClientFactory
 
 
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.protocol import (
-    AbstractReplicationClientHandler,
-    ClientReplicationStreamProtocol,
-)
-
-from .commands import (
-    Command,
-    FederationAckCommand,
-    InvalidateCacheCommand,
-    RemoteServerUpCommand,
-    RemovePusherCommand,
-    UserIpCommand,
-    UserSyncCommand,
-)
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -51,9 +35,9 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     initialDelay = 0.1
     initialDelay = 0.1
     maxDelay = 1  # Try at least once every N seconds
     maxDelay = 1  # Try at least once every N seconds
 
 
-    def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
+    def __init__(self, hs, client_name):
         self.client_name = client_name
         self.client_name = client_name
-        self.handler = handler
+        self.handler = hs.get_tcp_replication()
         self.server_name = hs.config.server_name
         self.server_name = hs.config.server_name
         self.hs = hs
         self.hs = hs
         self._clock = hs.get_clock()  # As self.clock is defined in super class
         self._clock = hs.get_clock()  # As self.clock is defined in super class
@@ -76,172 +60,3 @@ class ReplicationClientFactory(ReconnectingClientFactory):
     def clientConnectionFailed(self, connector, reason):
     def clientConnectionFailed(self, connector, reason):
         logger.error("Failed to connect to replication: %r", reason)
         logger.error("Failed to connect to replication: %r", reason)
         ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
         ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-
-
-class ReplicationClientHandler(AbstractReplicationClientHandler):
-    """A base handler that can be passed to the ReplicationClientFactory.
-
-    By default proxies incoming replication data to the SlaveStore.
-    """
-
-    def __init__(self, store: BaseSlavedStore):
-        self.store = store
-
-        # The current connection. None if we are currently (re)connecting
-        self.connection = None
-
-        # Any pending commands to be sent once a new connection has been
-        # established
-        self.pending_commands = []  # type: List[Command]
-
-        # Map from string -> deferred, to wake up when receiveing a SYNC with
-        # the given string.
-        # Used for tests.
-        self.awaiting_syncs = {}  # type: Dict[str, defer.Deferred]
-
-        # The factory used to create connections.
-        self.factory = None  # type: Optional[ReplicationClientFactory]
-
-    def start_replication(self, hs):
-        """Helper method to start a replication connection to the remote server
-        using TCP.
-        """
-        client_name = hs.config.worker_name
-        self.factory = ReplicationClientFactory(hs, client_name, self)
-        host = hs.config.worker_replication_host
-        port = hs.config.worker_replication_port
-        hs.get_reactor().connectTCP(host, port, self.factory)
-
-    async def on_rdata(self, stream_name, token, rows):
-        """Called to handle a batch of replication data with a given stream token.
-
-        By default this just pokes the slave store. Can be overridden in subclasses to
-        handle more.
-
-        Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
-        """
-        logger.debug("Received rdata %s -> %s", stream_name, token)
-        self.store.process_replication_rows(stream_name, token, rows)
-
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data. By default this just pokes
-        the slave store.
-
-        Can be overriden in subclasses to handle more.
-        """
-        self.store.process_replication_rows(stream_name, token, [])
-
-    def on_sync(self, data):
-        """When we received a SYNC we wake up any deferreds that were waiting
-        for the sync with the given data.
-
-        Used by tests.
-        """
-        d = self.awaiting_syncs.pop(data, None)
-        if d:
-            d.callback(data)
-
-    def on_remote_server_up(self, server: str):
-        """Called when get a new REMOTE_SERVER_UP command."""
-
-    def get_streams_to_replicate(self) -> Dict[str, int]:
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        args = self.store.stream_positions()
-        user_account_data = args.pop("user_account_data", None)
-        room_account_data = args.pop("room_account_data", None)
-        if user_account_data:
-            args["account_data"] = user_account_data
-        elif room_account_data:
-            args["account_data"] = room_account_data
-
-        return args
-
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users. (Overriden by the synchrotron's only)
-        """
-        return []
-
-    def send_command(self, cmd):
-        """Send a command to master (when we get establish a connection if we
-        don't have one already.)
-        """
-        if self.connection:
-            self.connection.send_command(cmd)
-        else:
-            logger.warning("Queuing command as not connected: %r", cmd.NAME)
-            self.pending_commands.append(cmd)
-
-    def send_federation_ack(self, token):
-        """Ack data for the federation stream. This allows the master to drop
-        data stored purely in memory.
-        """
-        self.send_command(FederationAckCommand(token))
-
-    def send_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
-        """Poke the master that a user has started/stopped syncing.
-        """
-        self.send_command(
-            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
-        )
-
-    def send_remove_pusher(self, app_id, push_key, user_id):
-        """Poke the master to remove a pusher for a user
-        """
-        cmd = RemovePusherCommand(app_id, push_key, user_id)
-        self.send_command(cmd)
-
-    def send_invalidate_cache(self, cache_func, keys):
-        """Poke the master to invalidate a cache.
-        """
-        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
-        self.send_command(cmd)
-
-    def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
-        """Tell the master that the user made a request.
-        """
-        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
-        self.send_command(cmd)
-
-    def send_remote_server_up(self, server: str):
-        self.send_command(RemoteServerUpCommand(server))
-
-    def await_sync(self, data):
-        """Returns a deferred that is resolved when we receive a SYNC command
-        with given data.
-
-        [Not currently] used by tests.
-        """
-        return self.awaiting_syncs.setdefault(data, defer.Deferred())
-
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        self.connection = connection
-        if connection:
-            for cmd in self.pending_commands:
-                connection.send_command(cmd)
-            self.pending_commands = []
-
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        logger.info("Finished connecting to server")
-
-        # We don't reset the delay any earlier as otherwise if there is a
-        # problem during start up we'll end up tight looping connecting to the
-        # server.
-        if self.factory:
-            self.factory.resetDelay()

+ 410 - 0
synapse/replication/tcp/handler.py

@@ -0,0 +1,410 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A replication client for use by synapse workers.
+"""
+
+import logging
+from typing import Any, Callable, Dict, List
+
+from prometheus_client import Counter
+
+from synapse.metrics import LaterGauge
+from synapse.replication.tcp.commands import (
+    ClearUserSyncsCommand,
+    Command,
+    FederationAckCommand,
+    InvalidateCacheCommand,
+    PositionCommand,
+    RdataCommand,
+    RemoteServerUpCommand,
+    RemovePusherCommand,
+    ReplicateCommand,
+    UserIpCommand,
+    UserSyncCommand,
+)
+from synapse.replication.tcp.streams import STREAMS_MAP, Stream
+from synapse.util.async_helpers import Linearizer
+
+logger = logging.getLogger(__name__)
+
+
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter(
+    "synapse_replication_tcp_resource_invalidate_cache", ""
+)
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
+
+
+class ReplicationClientHandler:
+    """Handles incoming commands from replication.
+
+    Proxies data to `HomeServer.get_replication_data_handler()`.
+    """
+
+    def __init__(self, hs):
+        self.replication_data_handler = hs.get_replication_data_handler()
+        self.store = hs.get_datastore()
+        self.notifier = hs.get_notifier()
+        self.clock = hs.get_clock()
+        self.presence_handler = hs.get_presence_handler()
+        self.instance_id = hs.get_instance_id()
+
+        self._position_linearizer = Linearizer("replication_position")
+
+        self.connections = []  # type: List[Any]
+
+        self.streams = {
+            stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
+        }  # type: Dict[str, Stream]
+
+        LaterGauge(
+            "synapse_replication_tcp_resource_total_connections",
+            "",
+            [],
+            lambda: len(self.connections),
+        )
+
+        LaterGauge(
+            "synapse_replication_tcp_resource_connections_per_stream",
+            "",
+            ["stream_name"],
+            lambda: {
+                (stream_name,): len(
+                    [
+                        conn
+                        for conn in self.connections
+                        if stream_name in conn.replication_streams
+                    ]
+                )
+                for stream_name in self.streams
+            },
+        )
+
+        # Map of stream to batched updates. See RdataCommand for info on how
+        # batching works.
+        self.pending_batches = {}  # type: Dict[str, List[Any]]
+
+        self.is_master = hs.config.worker_app is None
+
+        self.federation_sender = None
+        if self.is_master and not hs.config.send_federation:
+            self.federation_sender = hs.get_federation_sender()
+
+        self._server_notices_sender = None
+        if self.is_master:
+            self._server_notices_sender = hs.get_server_notices_sender()
+            self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
+
+    def new_connection(self, connection):
+        self.connections.append(connection)
+
+    def lost_connection(self, connection):
+        try:
+            self.connections.remove(connection)
+        except ValueError:
+            pass
+
+    def connected(self) -> bool:
+        """Do we have any replication connections open?
+
+        Used to no-op if nothing is connected.
+        """
+        return bool(self.connections)
+
+    async def on_REPLICATE(self, cmd: ReplicateCommand):
+        # We only want to announce positions by the writer of the streams.
+        # Currently this is just the master process.
+        if not self.is_master:
+            return
+
+        if not self.connections:
+            raise Exception("Not connected")
+
+        for stream_name, stream in self.streams.items():
+            current_token = stream.current_token()
+            self.send_command(PositionCommand(stream_name, current_token))
+
+    async def on_USER_SYNC(self, cmd: UserSyncCommand):
+        user_sync_counter.inc()
+
+        if self.is_master:
+            await self.presence_handler.update_external_syncs_row(
+                cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
+            )
+
+    async def on_CLEAR_USER_SYNC(self, cmd: ClearUserSyncsCommand):
+        if self.is_master:
+            await self.presence_handler.update_external_syncs_clear(cmd.instance_id)
+
+    async def on_FEDERATION_ACK(self, cmd: FederationAckCommand):
+        federation_ack_counter.inc()
+
+        if self.federation_sender:
+            self.federation_sender.federation_ack(cmd.token)
+
+    async def on_REMOVE_PUSHER(self, cmd: RemovePusherCommand):
+        remove_pusher_counter.inc()
+
+        if self.is_master:
+            await self.store.delete_pusher_by_app_id_pushkey_user_id(
+                app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
+            )
+
+            self.notifier.on_new_replication_data()
+
+    async def on_INVALIDATE_CACHE(self, cmd: InvalidateCacheCommand):
+        invalidate_cache_counter.inc()
+
+        if self.is_master:
+            # We invalidate the cache locally, but then also stream that to other
+            # workers.
+            await self.store.invalidate_cache_and_stream(
+                cmd.cache_func, tuple(cmd.keys)
+            )
+
+    async def on_USER_IP(self, cmd: UserIpCommand):
+        user_ip_cache_counter.inc()
+
+        if self.is_master:
+            await self.store.insert_client_ip(
+                cmd.user_id,
+                cmd.access_token,
+                cmd.ip,
+                cmd.user_agent,
+                cmd.device_id,
+                cmd.last_seen,
+            )
+
+        if self._server_notices_sender:
+            await self._server_notices_sender.on_user_ip(cmd.user_id)
+
+    async def on_RDATA(self, cmd: RdataCommand):
+        stream_name = cmd.stream_name
+
+        try:
+            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
+        except Exception:
+            logger.exception("[%s] Failed to parse RDATA: %r", stream_name, cmd.row)
+            raise
+
+        if cmd.token is None:
+            # I.e. this is part of a batch of updates for this stream. Batch
+            # until we get an update for the stream with a non None token
+            self.pending_batches.setdefault(stream_name, []).append(row)
+        else:
+            # Check if this is the last of a batch of updates
+            rows = self.pending_batches.pop(stream_name, [])
+            rows.append(row)
+            await self.on_rdata(stream_name, cmd.token, rows)
+
+    async def on_rdata(self, stream_name: str, token: int, rows: list):
+        """Called to handle a batch of replication data with a given stream token.
+
+        Args:
+            stream_name: name of the replication stream for this batch of rows
+            token: stream token for this batch of rows
+            rows: a list of Stream.ROW_TYPE objects as returned by
+                Stream.parse_row.
+        """
+        logger.debug("Received rdata %s -> %s", stream_name, token)
+        await self.replication_data_handler.on_rdata(stream_name, token, rows)
+
+    async def on_POSITION(self, cmd: PositionCommand):
+        stream = self.streams.get(cmd.stream_name)
+        if not stream:
+            logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
+            return
+
+        # We protect catching up with a linearizer in case the replicaiton
+        # connection reconnects under us.
+        with await self._position_linearizer.queue(cmd.stream_name):
+            # Find where we previously streamed up to.
+            current_token = self.replication_data_handler.get_streams_to_replicate().get(
+                cmd.stream_name
+            )
+            if current_token is None:
+                logger.debug(
+                    "Got POSITION for stream we're not subscribed to: %s",
+                    cmd.stream_name,
+                )
+                return
+
+            # Fetch all updates between then and now.
+            limited = cmd.token != current_token
+            while limited:
+                updates, current_token, limited = await stream.get_updates_since(
+                    current_token, cmd.token
+                )
+                if updates:
+                    await self.on_rdata(
+                        cmd.stream_name,
+                        current_token,
+                        [stream.parse_row(update[1]) for update in updates],
+                    )
+
+            # We've now caught up to position sent to us, notify handler.
+            await self.replication_data_handler.on_position(cmd.stream_name, cmd.token)
+
+        # Handle any RDATA that came in while we were catching up.
+        rows = self.pending_batches.pop(cmd.stream_name, [])
+        if rows:
+            await self.on_rdata(cmd.stream_name, rows[-1].token, rows)
+
+    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+        """Called when get a new REMOTE_SERVER_UP command."""
+        if self.is_master:
+            self.notifier.notify_remote_server_up(cmd.data)
+
+    def get_currently_syncing_users(self):
+        """Get the list of currently syncing users (if any). This is called
+        when a connection has been established and we need to send the
+        currently syncing users.
+        """
+        return self.presence_handler.get_currently_syncing_users()
+
+    def send_command(self, cmd: Command):
+        """Send a command to master (when we get establish a connection if we
+        don't have one already.)
+        """
+        for conn in self.connections:
+            conn.send_command(cmd)
+
+    def send_federation_ack(self, token: int):
+        """Ack data for the federation stream. This allows the master to drop
+        data stored purely in memory.
+        """
+        self.send_command(FederationAckCommand(token))
+
+    def send_user_sync(
+        self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
+    ):
+        """Poke the master that a user has started/stopped syncing.
+        """
+        self.send_command(
+            UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
+        )
+
+    def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
+        """Poke the master to remove a pusher for a user
+        """
+        cmd = RemovePusherCommand(app_id, push_key, user_id)
+        self.send_command(cmd)
+
+    def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
+        """Poke the master to invalidate a cache.
+        """
+        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
+        self.send_command(cmd)
+
+    def send_user_ip(
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: str,
+        last_seen: int,
+    ):
+        """Tell the master that the user made a request.
+        """
+        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+        self.send_command(cmd)
+
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
+    def stream_update(self, stream_name: str, token: str, data: Any):
+        """Called when a new update is available to stream to clients.
+
+        We need to check if the client is interested in the stream or not
+        """
+        self.send_command(RdataCommand(stream_name, token, data))
+
+
+class DummyReplicationDataHandler:
+    """A replication data handler that simply discards all data.
+    """
+
+    async def on_rdata(self, stream_name: str, token: int, rows: list):
+        """Called to handle a batch of replication data with a given stream token.
+
+        By default this just pokes the slave store. Can be overridden in subclasses to
+        handle more.
+
+        Args:
+            stream_name (str): name of the replication stream for this batch of rows
+            token (int): stream token for this batch of rows
+            rows (list): a list of Stream.ROW_TYPE objects as returned by
+                Stream.parse_row.
+        """
+        pass
+
+    def get_streams_to_replicate(self) -> Dict[str, int]:
+        """Called when a new connection has been established and we need to
+        subscribe to streams.
+
+        Returns:
+            map from stream name to the most recent update we have for
+            that stream (ie, the point we want to start replicating from)
+        """
+        return {}
+
+    async def on_position(self, stream_name: str, token: int):
+        pass
+
+
+class WorkerReplicationDataHandler:
+    """A replication data handler that calls slave data stores.
+    """
+
+    def __init__(self, store):
+        self.store = store
+
+    async def on_rdata(self, stream_name: str, token: int, rows: list):
+        """Called to handle a batch of replication data with a given stream token.
+
+        By default this just pokes the slave store. Can be overridden in subclasses to
+        handle more.
+
+        Args:
+            stream_name (str): name of the replication stream for this batch of rows
+            token (int): stream token for this batch of rows
+            rows (list): a list of Stream.ROW_TYPE objects as returned by
+                Stream.parse_row.
+        """
+        self.store.process_replication_rows(stream_name, token, rows)
+
+    def get_streams_to_replicate(self) -> Dict[str, int]:
+        """Called when a new connection has been established and we need to
+        subscribe to streams.
+
+        Returns:
+            map from stream name to the most recent update we have for
+            that stream (ie, the point we want to start replicating from)
+        """
+        args = self.store.stream_positions()
+        user_account_data = args.pop("user_account_data", None)
+        room_account_data = args.pop("room_account_data", None)
+        if user_account_data:
+            args["account_data"] = user_account_data
+        elif room_account_data:
+            args["account_data"] = room_account_data
+        return args
+
+    async def on_position(self, stream_name: str, token: int):
+        self.store.process_replication_rows(stream_name, token, [])

+ 31 - 247
synapse/replication/tcp/protocol.py

@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
     > ERROR server stopping
     > ERROR server stopping
     * connection closed by server *
     * connection closed by server *
 """
 """
-import abc
 import fcntl
 import fcntl
 import logging
 import logging
 import struct
 import struct
@@ -64,22 +63,15 @@ from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.commands import (
 from synapse.replication.tcp.commands import (
     COMMAND_MAP,
     COMMAND_MAP,
-    VALID_CLIENT_COMMANDS,
-    VALID_SERVER_COMMANDS,
     Command,
     Command,
     ErrorCommand,
     ErrorCommand,
     NameCommand,
     NameCommand,
     PingCommand,
     PingCommand,
-    PositionCommand,
-    RdataCommand,
     RemoteServerUpCommand,
     RemoteServerUpCommand,
     ReplicateCommand,
     ReplicateCommand,
     ServerCommand,
     ServerCommand,
-    SyncCommand,
-    UserSyncCommand,
 )
 )
 from synapse.replication.tcp.streams import STREAMS_MAP, Stream
 from synapse.replication.tcp.streams import STREAMS_MAP, Stream
-from synapse.types import Collection
 from synapse.util import Clock
 from synapse.util import Clock
 from synapse.util.stringutils import random_string
 from synapse.util.stringutils import random_string
 
 
@@ -128,16 +120,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
 
     delimiter = b"\n"
     delimiter = b"\n"
 
 
-    # Valid commands we expect to receive
-    VALID_INBOUND_COMMANDS = []  # type: Collection[str]
-
-    # Valid commands we can send
-    VALID_OUTBOUND_COMMANDS = []  # type: Collection[str]
-
     max_line_buffer = 10000
     max_line_buffer = 10000
 
 
-    def __init__(self, clock):
+    def __init__(self, clock, handler):
         self.clock = clock
         self.clock = clock
+        self.handler = handler
 
 
         self.last_received_command = self.clock.time_msec()
         self.last_received_command = self.clock.time_msec()
         self.last_sent_command = 0
         self.last_sent_command = 0
@@ -177,6 +164,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # can time us out.
         # can time us out.
         self.send_command(PingCommand(self.clock.time_msec()))
         self.send_command(PingCommand(self.clock.time_msec()))
 
 
+        self.handler.new_connection(self)
+
     def send_ping(self):
     def send_ping(self):
         """Periodically sends a ping and checks if we should close the connection
         """Periodically sends a ping and checks if we should close the connection
         due to the other side timing out.
         due to the other side timing out.
@@ -214,11 +203,6 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         line = line.decode("utf-8")
         line = line.decode("utf-8")
         cmd_name, rest_of_line = line.split(" ", 1)
         cmd_name, rest_of_line = line.split(" ", 1)
 
 
-        if cmd_name not in self.VALID_INBOUND_COMMANDS:
-            logger.error("[%s] invalid command %s", self.id(), cmd_name)
-            self.send_error("invalid command: %s", cmd_name)
-            return
-
         self.last_received_command = self.clock.time_msec()
         self.last_received_command = self.clock.time_msec()
 
 
         self.inbound_commands_counter[cmd_name] = (
         self.inbound_commands_counter[cmd_name] = (
@@ -250,8 +234,23 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         Args:
         Args:
             cmd: received command
             cmd: received command
         """
         """
-        handler = getattr(self, "on_%s" % (cmd.NAME,))
-        await handler(cmd)
+        handled = False
+
+        # First call any command handlers on this instance. These are for TCP
+        # specific handling.
+        cmd_func = getattr(self, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        # Then call out to the handler.
+        cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+        if cmd_func:
+            await cmd_func(cmd)
+            handled = True
+
+        if not handled:
+            logger.warning("Unhandled command: %r", cmd)
 
 
     def close(self):
     def close(self):
         logger.warning("[%s] Closing connection", self.id())
         logger.warning("[%s] Closing connection", self.id())
@@ -259,6 +258,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.transport.loseConnection()
         self.transport.loseConnection()
         self.on_connection_closed()
         self.on_connection_closed()
 
 
+    def send_remote_server_up(self, server: str):
+        self.send_command(RemoteServerUpCommand(server))
+
     def send_error(self, error_string, *args):
     def send_error(self, error_string, *args):
         """Send an error to remote and close the connection.
         """Send an error to remote and close the connection.
         """
         """
@@ -380,6 +382,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.state = ConnectionStates.CLOSED
         self.state = ConnectionStates.CLOSED
         self.pending_commands = []
         self.pending_commands = []
 
 
+        self.handler.lost_connection(self)
+
         if self.transport:
         if self.transport:
             self.transport.unregisterProducer()
             self.transport.unregisterProducer()
 
 
@@ -403,162 +407,35 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
 
 
 
 class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
 class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
-    VALID_INBOUND_COMMANDS = VALID_CLIENT_COMMANDS
-    VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS
-
-    def __init__(self, server_name, clock, streamer):
-        BaseReplicationStreamProtocol.__init__(self, clock)  # Old style class
+    def __init__(self, hs, server_name, clock, handler):
+        BaseReplicationStreamProtocol.__init__(self, clock, handler)  # Old style class
 
 
         self.server_name = server_name
         self.server_name = server_name
-        self.streamer = streamer
 
 
     def connectionMade(self):
     def connectionMade(self):
         self.send_command(ServerCommand(self.server_name))
         self.send_command(ServerCommand(self.server_name))
         BaseReplicationStreamProtocol.connectionMade(self)
         BaseReplicationStreamProtocol.connectionMade(self)
-        self.streamer.new_connection(self)
 
 
     async def on_NAME(self, cmd):
     async def on_NAME(self, cmd):
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         logger.info("[%s] Renamed to %r", self.id(), cmd.data)
         self.name = cmd.data
         self.name = cmd.data
 
 
-    async def on_USER_SYNC(self, cmd):
-        await self.streamer.on_user_sync(
-            cmd.instance_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
-        )
-
-    async def on_CLEAR_USER_SYNC(self, cmd):
-        await self.streamer.on_clear_user_syncs(cmd.instance_id)
-
-    async def on_REPLICATE(self, cmd):
-        # Subscribe to all streams we're publishing to.
-        for stream_name in self.streamer.streams_by_name:
-            current_token = self.streamer.get_stream_token(stream_name)
-            self.send_command(PositionCommand(stream_name, current_token))
-
-    async def on_FEDERATION_ACK(self, cmd):
-        self.streamer.federation_ack(cmd.token)
-
-    async def on_REMOVE_PUSHER(self, cmd):
-        await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
-
-    async def on_INVALIDATE_CACHE(self, cmd):
-        await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
-
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.streamer.on_remote_server_up(cmd.data)
-
-    async def on_USER_IP(self, cmd):
-        self.streamer.on_user_ip(
-            cmd.user_id,
-            cmd.access_token,
-            cmd.ip,
-            cmd.user_agent,
-            cmd.device_id,
-            cmd.last_seen,
-        )
-
-    def stream_update(self, stream_name, token, data):
-        """Called when a new update is available to stream to clients.
-
-        We need to check if the client is interested in the stream or not
-        """
-        self.send_command(RdataCommand(stream_name, token, data))
-
-    def send_sync(self, data):
-        self.send_command(SyncCommand(data))
-
-    def send_remote_server_up(self, server: str):
-        self.send_command(RemoteServerUpCommand(server))
-
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.streamer.lost_connection(self)
-
-
-class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
-    """
-    The interface for the handler that should be passed to
-    ClientReplicationStreamProtocol
-    """
-
-    @abc.abstractmethod
-    async def on_rdata(self, stream_name, token, rows):
-        """Called to handle a batch of replication data with a given stream token.
-
-        Args:
-            stream_name (str): name of the replication stream for this batch of rows
-            token (int): stream token for this batch of rows
-            rows (list): a list of Stream.ROW_TYPE objects as returned by
-                Stream.parse_row.
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def on_sync(self, data):
-        """Called when get a new SYNC command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    async def on_remote_server_up(self, server: str):
-        """Called when get a new REMOTE_SERVER_UP command."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_streams_to_replicate(self):
-        """Called when a new connection has been established and we need to
-        subscribe to streams.
-
-        Returns:
-            map from stream name to the most recent update we have for
-            that stream (ie, the point we want to start replicating from)
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def get_currently_syncing_users(self):
-        """Get the list of currently syncing users (if any). This is called
-        when a connection has been established and we need to send the
-        currently syncing users."""
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def update_connection(self, connection):
-        """Called when a connection has been established (or lost with None).
-        """
-        raise NotImplementedError()
-
-    @abc.abstractmethod
-    def finished_connecting(self):
-        """Called when we have successfully subscribed and caught up to all
-        streams we're interested in.
-        """
-        raise NotImplementedError()
-
 
 
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
-    VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
-    VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
-
     def __init__(
     def __init__(
         self,
         self,
         hs: "HomeServer",
         hs: "HomeServer",
         client_name: str,
         client_name: str,
         server_name: str,
         server_name: str,
         clock: Clock,
         clock: Clock,
-        handler: AbstractReplicationClientHandler,
+        handler,
     ):
     ):
-        BaseReplicationStreamProtocol.__init__(self, clock)
+        BaseReplicationStreamProtocol.__init__(self, clock, handler)
 
 
         self.instance_id = hs.get_instance_id()
         self.instance_id = hs.get_instance_id()
 
 
         self.client_name = client_name
         self.client_name = client_name
         self.server_name = server_name
         self.server_name = server_name
-        self.handler = handler
 
 
         self.streams = {
         self.streams = {
             stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
             stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
@@ -574,105 +451,16 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
         self.pending_batches = {}  # type: Dict[str, List[Any]]
         self.pending_batches = {}  # type: Dict[str, List[Any]]
 
 
     def connectionMade(self):
     def connectionMade(self):
-        self.send_command(NameCommand(self.client_name))
         BaseReplicationStreamProtocol.connectionMade(self)
         BaseReplicationStreamProtocol.connectionMade(self)
 
 
-        # Once we've connected subscribe to the necessary streams
+        self.send_command(NameCommand(self.client_name))
         self.replicate()
         self.replicate()
 
 
-        # Tell the server if we have any users currently syncing (should only
-        # happen on synchrotrons)
-        currently_syncing = self.handler.get_currently_syncing_users()
-        now = self.clock.time_msec()
-        for user_id in currently_syncing:
-            self.send_command(UserSyncCommand(self.instance_id, user_id, True, now))
-
-        # We've now finished connecting to so inform the client handler
-        self.handler.update_connection(self)
-
     async def on_SERVER(self, cmd):
     async def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
             self.send_error("Wrong remote")
             self.send_error("Wrong remote")
 
 
-    async def on_RDATA(self, cmd):
-        stream_name = cmd.stream_name
-        inbound_rdata_count.labels(stream_name).inc()
-
-        try:
-            row = STREAMS_MAP[stream_name].parse_row(cmd.row)
-        except Exception:
-            logger.exception(
-                "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
-            )
-            raise
-
-        if cmd.token is None or stream_name in self.streams_connecting:
-            # I.e. this is part of a batch of updates for this stream. Batch
-            # until we get an update for the stream with a non None token
-            self.pending_batches.setdefault(stream_name, []).append(row)
-        else:
-            # Check if this is the last of a batch of updates
-            rows = self.pending_batches.pop(stream_name, [])
-            rows.append(row)
-            await self.handler.on_rdata(stream_name, cmd.token, rows)
-
-    async def on_POSITION(self, cmd: PositionCommand):
-        stream = self.streams.get(cmd.stream_name)
-        if not stream:
-            logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
-            return
-
-        # Find where we previously streamed up to.
-        current_token = self.handler.get_streams_to_replicate().get(cmd.stream_name)
-        if current_token is None:
-            logger.warning(
-                "Got POSITION for stream we're not subscribed to: %s", cmd.stream_name
-            )
-            return
-
-        # Fetch all updates between then and now.
-        limited = True
-        while limited:
-            updates, current_token, limited = await stream.get_updates_since(
-                current_token, cmd.token
-            )
-
-            # Check if the connection was closed underneath us, if so we bail
-            # rather than risk having concurrent catch ups going on.
-            if self.state == ConnectionStates.CLOSED:
-                return
-
-            if updates:
-                await self.handler.on_rdata(
-                    cmd.stream_name,
-                    current_token,
-                    [stream.parse_row(update[1]) for update in updates],
-                )
-
-        # We've now caught up to position sent to us, notify handler.
-        await self.handler.on_position(cmd.stream_name, cmd.token)
-
-        self.streams_connecting.discard(cmd.stream_name)
-        if not self.streams_connecting:
-            self.handler.finished_connecting()
-
-        # Check if the connection was closed underneath us, if so we bail
-        # rather than risk having concurrent catch ups going on.
-        if self.state == ConnectionStates.CLOSED:
-            return
-
-        # Handle any RDATA that came in while we were catching up.
-        rows = self.pending_batches.pop(cmd.stream_name, [])
-        if rows:
-            await self.handler.on_rdata(cmd.stream_name, rows[-1].token, rows)
-
-    async def on_SYNC(self, cmd):
-        self.handler.on_sync(cmd.data)
-
-    async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
-        self.handler.on_remote_server_up(cmd.data)
-
     def replicate(self):
     def replicate(self):
         """Send the subscription request to the server
         """Send the subscription request to the server
         """
         """
@@ -680,10 +468,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
 
 
         self.send_command(ReplicateCommand())
         self.send_command(ReplicateCommand())
 
 
-    def on_connection_closed(self):
-        BaseReplicationStreamProtocol.on_connection_closed(self)
-        self.handler.update_connection(None)
-
 
 
 # The following simply registers metrics for the replication connections
 # The following simply registers metrics for the replication connections
 
 

+ 17 - 149
synapse/replication/tcp/resource.py

@@ -17,7 +17,7 @@
 
 
 import logging
 import logging
 import random
 import random
-from typing import Any, Dict, List
+from typing import Dict
 
 
 from six import itervalues
 from six import itervalues
 
 
@@ -25,9 +25,8 @@ from prometheus_client import Counter
 
 
 from twisted.internet.protocol import Factory
 from twisted.internet.protocol import Factory
 
 
-from synapse.metrics import LaterGauge
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.metrics import Measure, measure_func
+from synapse.util.metrics import Measure
 
 
 from .protocol import ServerReplicationStreamProtocol
 from .protocol import ServerReplicationStreamProtocol
 from .streams import STREAMS_MAP, Stream
 from .streams import STREAMS_MAP, Stream
@@ -36,13 +35,6 @@ from .streams.federation import FederationStream
 stream_updates_counter = Counter(
 stream_updates_counter = Counter(
     "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
     "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
 )
 )
-user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
-federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
-remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter(
-    "synapse_replication_tcp_resource_invalidate_cache", ""
-)
-user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -52,13 +44,18 @@ class ReplicationStreamProtocolFactory(Factory):
     """
     """
 
 
     def __init__(self, hs):
     def __init__(self, hs):
-        self.streamer = hs.get_replication_streamer()
+        self.handler = hs.get_tcp_replication()
         self.clock = hs.get_clock()
         self.clock = hs.get_clock()
         self.server_name = hs.config.server_name
         self.server_name = hs.config.server_name
+        self.hs = hs
+
+        # Ensure the replication streamer is started if we register a
+        # replication server endpoint.
+        hs.get_replication_streamer()
 
 
     def buildProtocol(self, addr):
     def buildProtocol(self, addr):
         return ServerReplicationStreamProtocol(
         return ServerReplicationStreamProtocol(
-            self.server_name, self.clock, self.streamer
+            self.hs, self.server_name, self.clock, self.handler
         )
         )
 
 
 
 
@@ -78,16 +75,6 @@ class ReplicationStreamer(object):
 
 
         self._replication_torture_level = hs.config.replication_torture_level
         self._replication_torture_level = hs.config.replication_torture_level
 
 
-        # Current connections.
-        self.connections = []  # type: List[ServerReplicationStreamProtocol]
-
-        LaterGauge(
-            "synapse_replication_tcp_resource_total_connections",
-            "",
-            [],
-            lambda: len(self.connections),
-        )
-
         # List of streams that clients can subscribe to.
         # List of streams that clients can subscribe to.
         # We only support federation stream if federation sending hase been
         # We only support federation stream if federation sending hase been
         # disabled on the master.
         # disabled on the master.
@@ -99,39 +86,17 @@ class ReplicationStreamer(object):
 
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
 
 
-        LaterGauge(
-            "synapse_replication_tcp_resource_connections_per_stream",
-            "",
-            ["stream_name"],
-            lambda: {
-                (stream_name,): len(
-                    [
-                        conn
-                        for conn in self.connections
-                        if stream_name in conn.replication_streams
-                    ]
-                )
-                for stream_name in self.streams_by_name
-            },
-        )
-
         self.federation_sender = None
         self.federation_sender = None
         if not hs.config.send_federation:
         if not hs.config.send_federation:
             self.federation_sender = hs.get_federation_sender()
             self.federation_sender = hs.get_federation_sender()
 
 
         self.notifier.add_replication_callback(self.on_notifier_poke)
         self.notifier.add_replication_callback(self.on_notifier_poke)
-        self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
 
 
         # Keeps track of whether we are currently checking for updates
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
         self.is_looping = False
         self.pending_updates = False
         self.pending_updates = False
 
 
-        hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.on_shutdown)
-
-    def on_shutdown(self):
-        # close all connections on shutdown
-        for conn in self.connections:
-            conn.send_error("server shutting down")
+        self.client = hs.get_tcp_replication()
 
 
     def get_streams(self) -> Dict[str, Stream]:
     def get_streams(self) -> Dict[str, Stream]:
         """Get a mapp from stream name to stream instance.
         """Get a mapp from stream name to stream instance.
@@ -145,7 +110,7 @@ class ReplicationStreamer(object):
         This should get called each time new data is available, even if it
         This should get called each time new data is available, even if it
         is currently being executed, so that nothing gets missed
         is currently being executed, so that nothing gets missed
         """
         """
-        if not self.connections:
+        if not self.client.connected():
             # Don't bother if nothing is listening. We still need to advance
             # Don't bother if nothing is listening. We still need to advance
             # the stream tokens otherwise they'll fall beihind forever
             # the stream tokens otherwise they'll fall beihind forever
             for stream in self.streams:
             for stream in self.streams:
@@ -202,9 +167,7 @@ class ReplicationStreamer(object):
                             raise
                             raise
 
 
                         logger.debug(
                         logger.debug(
-                            "Sending %d updates to %d connections",
-                            len(updates),
-                            len(self.connections),
+                            "Sending %d updates", len(updates),
                         )
                         )
 
 
                         if updates:
                         if updates:
@@ -220,112 +183,17 @@ class ReplicationStreamer(object):
                         # token. See RdataCommand for more details.
                         # token. See RdataCommand for more details.
                         batched_updates = _batch_updates(updates)
                         batched_updates = _batch_updates(updates)
 
 
-                        for conn in self.connections:
-                            for token, row in batched_updates:
-                                try:
-                                    conn.stream_update(stream.NAME, token, row)
-                                except Exception:
-                                    logger.exception("Failed to replicate")
+                        for token, row in batched_updates:
+                            try:
+                                self.client.stream_update(stream.NAME, token, row)
+                            except Exception:
+                                logger.exception("Failed to replicate")
 
 
             logger.debug("No more pending updates, breaking poke loop")
             logger.debug("No more pending updates, breaking poke loop")
         finally:
         finally:
             self.pending_updates = False
             self.pending_updates = False
             self.is_looping = False
             self.is_looping = False
 
 
-    def get_stream_token(self, stream_name):
-        """For a given stream get all updates since token. This is called when
-        a client first subscribes to a stream.
-        """
-        stream = self.streams_by_name.get(stream_name, None)
-        if not stream:
-            raise Exception("unknown stream %s", stream_name)
-
-        return stream.current_token()
-
-    @measure_func("repl.federation_ack")
-    def federation_ack(self, token):
-        """We've received an ack for federation stream from a client.
-        """
-        federation_ack_counter.inc()
-        if self.federation_sender:
-            self.federation_sender.federation_ack(token)
-
-    @measure_func("repl.on_user_sync")
-    async def on_user_sync(self, instance_id, user_id, is_syncing, last_sync_ms):
-        """A client has started/stopped syncing on a worker.
-        """
-        user_sync_counter.inc()
-        await self.presence_handler.update_external_syncs_row(
-            instance_id, user_id, is_syncing, last_sync_ms
-        )
-
-    async def on_clear_user_syncs(self, instance_id):
-        """A replication client wants us to drop all their UserSync data.
-        """
-        await self.presence_handler.update_external_syncs_clear(instance_id)
-
-    @measure_func("repl.on_remove_pusher")
-    async def on_remove_pusher(self, app_id, push_key, user_id):
-        """A client has asked us to remove a pusher
-        """
-        remove_pusher_counter.inc()
-        await self.store.delete_pusher_by_app_id_pushkey_user_id(
-            app_id=app_id, pushkey=push_key, user_id=user_id
-        )
-
-        self.notifier.on_new_replication_data()
-
-    @measure_func("repl.on_invalidate_cache")
-    async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
-        """The client has asked us to invalidate a cache
-        """
-        invalidate_cache_counter.inc()
-
-        # We invalidate the cache locally, but then also stream that to other
-        # workers.
-        await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
-
-    @measure_func("repl.on_user_ip")
-    async def on_user_ip(
-        self, user_id, access_token, ip, user_agent, device_id, last_seen
-    ):
-        """The client saw a user request
-        """
-        user_ip_cache_counter.inc()
-        await self.store.insert_client_ip(
-            user_id, access_token, ip, user_agent, device_id, last_seen
-        )
-        await self._server_notices_sender.on_user_ip(user_id)
-
-    @measure_func("repl.on_remote_server_up")
-    def on_remote_server_up(self, server: str):
-        self.notifier.notify_remote_server_up(server)
-
-    def send_remote_server_up(self, server: str):
-        for conn in self.connections:
-            conn.send_remote_server_up(server)
-
-    def send_sync_to_all_connections(self, data):
-        """Sends a SYNC command to all clients.
-
-        Used in tests.
-        """
-        for conn in self.connections:
-            conn.send_sync(data)
-
-    def new_connection(self, connection):
-        """A new client connection has been established
-        """
-        self.connections.append(connection)
-
-    def lost_connection(self, connection):
-        """A client connection has been lost
-        """
-        try:
-            self.connections.remove(connection)
-        except ValueError:
-            pass
-
 
 
 def _batch_updates(updates):
 def _batch_updates(updates):
     """Takes a list of updates of form [(token, row)] and sets the token to
     """Takes a list of updates of form [(token, row)] and sets the token to

+ 9 - 1
synapse/server.py

@@ -87,6 +87,10 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.notifier import Notifier
 from synapse.notifier import Notifier
 from synapse.push.action_generator import ActionGenerator
 from synapse.push.action_generator import ActionGenerator
 from synapse.push.pusherpool import PusherPool
 from synapse.push.pusherpool import PusherPool
+from synapse.replication.tcp.handler import (
+    DummyReplicationDataHandler,
+    ReplicationClientHandler,
+)
 from synapse.replication.tcp.resource import ReplicationStreamer
 from synapse.replication.tcp.resource import ReplicationStreamer
 from synapse.rest.media.v1.media_repository import (
 from synapse.rest.media.v1.media_repository import (
     MediaRepository,
     MediaRepository,
@@ -206,6 +210,7 @@ class HomeServer(object):
         "password_policy_handler",
         "password_policy_handler",
         "storage",
         "storage",
         "replication_streamer",
         "replication_streamer",
+        "replication_data_handler",
     ]
     ]
 
 
     REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
     REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -468,7 +473,7 @@ class HomeServer(object):
         return ReadMarkerHandler(self)
         return ReadMarkerHandler(self)
 
 
     def build_tcp_replication(self):
     def build_tcp_replication(self):
-        raise NotImplementedError()
+        return ReplicationClientHandler(self)
 
 
     def build_action_generator(self):
     def build_action_generator(self):
         return ActionGenerator(self)
         return ActionGenerator(self)
@@ -562,6 +567,9 @@ class HomeServer(object):
     def build_replication_streamer(self) -> ReplicationStreamer:
     def build_replication_streamer(self) -> ReplicationStreamer:
         return ReplicationStreamer(self)
         return ReplicationStreamer(self)
 
 
+    def build_replication_data_handler(self):
+        return DummyReplicationDataHandler()
+
     def remove_pusher(self, app_id, push_key, user_id):
     def remove_pusher(self, app_id, push_key, user_id):
         return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
         return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
 
 

+ 1 - 1
synapse/server.pyi

@@ -106,7 +106,7 @@ class HomeServer(object):
         pass
         pass
     def get_tcp_replication(
     def get_tcp_replication(
         self,
         self,
-    ) -> synapse.replication.tcp.client.ReplicationClientHandler:
+    ) -> synapse.replication.tcp.handler.ReplicationClientHandler:
         pass
         pass
     def get_federation_registry(
     def get_federation_registry(
         self,
         self,

+ 14 - 10
tests/replication/slave/storage/_base.py

@@ -15,9 +15,10 @@
 
 
 from mock import Mock, NonCallableMock
 from mock import Mock, NonCallableMock
 
 
-from synapse.replication.tcp.client import (
-    ReplicationClientFactory,
+from synapse.replication.tcp.client import ReplicationClientFactory
+from synapse.replication.tcp.handler import (
     ReplicationClientHandler,
     ReplicationClientHandler,
+    WorkerReplicationDataHandler,
 )
 )
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.storage.database import make_conn
 from synapse.storage.database import make_conn
@@ -51,16 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         self.event_id = 0
         self.event_id = 0
 
 
         server_factory = ReplicationStreamProtocolFactory(self.hs)
         server_factory = ReplicationStreamProtocolFactory(self.hs)
-        self.streamer = server_factory.streamer
-
-        handler_factory = Mock()
-        self.replication_handler = ReplicationClientHandler(self.slaved_store)
-        self.replication_handler.factory = handler_factory
-
-        client_factory = ReplicationClientFactory(
-            self.hs, "client_name", self.replication_handler
+        self.streamer = hs.get_replication_streamer()
+
+        # We now do some gut wrenching so that we have a client that is based
+        # off of the slave store rather than the main store.
+        self.replication_handler = ReplicationClientHandler(self.hs)
+        self.replication_handler.store = self.slaved_store
+        self.replication_handler.replication_data_handler = WorkerReplicationDataHandler(
+            self.slaved_store
         )
         )
 
 
+        client_factory = ReplicationClientFactory(self.hs, "client_name")
+        client_factory.handler = self.replication_handler
+
         server = server_factory.buildProtocol(None)
         server = server_factory.buildProtocol(None)
         client = client_factory.buildProtocol(None)
         client = client_factory.buildProtocol(None)
 
 

+ 14 - 24
tests/replication/tcp/streams/_base.py

@@ -15,7 +15,7 @@
 
 
 from mock import Mock
 from mock import Mock
 
 
-from synapse.replication.tcp.commands import ReplicateCommand
+from synapse.replication.tcp.handler import ReplicationClientHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 
 
@@ -26,15 +26,20 @@ from tests.server import FakeTransport
 class BaseStreamTestCase(unittest.HomeserverTestCase):
 class BaseStreamTestCase(unittest.HomeserverTestCase):
     """Base class for tests of the replication streams"""
     """Base class for tests of the replication streams"""
 
 
+    def make_homeserver(self, reactor, clock):
+        self.test_handler = Mock(wraps=TestReplicationClientHandler())
+        return self.setup_test_homeserver(replication_data_handler=self.test_handler)
+
     def prepare(self, reactor, clock, hs):
     def prepare(self, reactor, clock, hs):
         # build a replication server
         # build a replication server
-        server_factory = ReplicationStreamProtocolFactory(self.hs)
-        self.streamer = server_factory.streamer
+        server_factory = ReplicationStreamProtocolFactory(hs)
+        self.streamer = hs.get_replication_streamer()
         self.server = server_factory.buildProtocol(None)
         self.server = server_factory.buildProtocol(None)
 
 
-        self.test_handler = Mock(wraps=TestReplicationClientHandler())
+        repl_handler = ReplicationClientHandler(hs)
+        repl_handler.handler = self.test_handler
         self.client = ClientReplicationStreamProtocol(
         self.client = ClientReplicationStreamProtocol(
-            hs, "client", "test", clock, self.test_handler,
+            hs, "client", "test", clock, repl_handler,
         )
         )
 
 
         self._client_transport = None
         self._client_transport = None
@@ -69,14 +74,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.streamer.on_notifier_poke()
         self.streamer.on_notifier_poke()
         self.pump(0.1)
         self.pump(0.1)
 
 
-    def replicate_stream(self):
-        """Make the client end a REPLICATE command to set up a subscription to a stream"""
-        self.client.send_command(ReplicateCommand())
-
-
-class TestReplicationClientHandler(object):
-    """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
 
 
+class TestReplicationClientHandler:
     def __init__(self):
     def __init__(self):
         self.streams = set()
         self.streams = set()
         self._received_rdata_rows = []
         self._received_rdata_rows = []
@@ -88,18 +87,9 @@ class TestReplicationClientHandler(object):
                 positions[stream] = max(token, positions.get(stream, 0))
                 positions[stream] = max(token, positions.get(stream, 0))
         return positions
         return positions
 
 
-    def get_currently_syncing_users(self):
-        return []
-
-    def update_connection(self, connection):
-        pass
-
-    def finished_connecting(self):
-        pass
-
-    async def on_position(self, stream_name, token):
-        """Called when we get new position data."""
-
     async def on_rdata(self, stream_name, token, rows):
     async def on_rdata(self, stream_name, token, rows):
         for r in rows:
         for r in rows:
             self._received_rdata_rows.append((stream_name, token, r))
             self._received_rdata_rows.append((stream_name, token, r))
+
+    async def on_position(self, stream_name, token):
+        pass

+ 0 - 1
tests/replication/tcp/streams/test_receipts.py

@@ -24,7 +24,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.reconnect()
 
 
         # make the client subscribe to the receipts stream
         # make the client subscribe to the receipts stream
-        self.replicate_stream()
         self.test_handler.streams.add("receipts")
         self.test_handler.streams.add("receipts")
 
 
         # tell the master to send a new receipt
         # tell the master to send a new receipt