Browse Source

Prefill more stream change caches. (#12372)

Erik Johnston 2 years ago
parent
commit
66053b6bfb

+ 1 - 0
changelog.d/12372.feature

@@ -0,0 +1 @@
+Reduce overhead of restarting synchrotrons.

+ 2 - 23
synapse/replication/slave/storage/devices.py

@@ -20,7 +20,6 @@ from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatu
 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.devices import DeviceWorkerStore
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -33,8 +32,6 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
     ):
-        super().__init__(database, db_conn, hs)
-
         self.hs = hs
 
         self._device_list_id_gen = SlavedIdTracker(
@@ -47,26 +44,8 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
                 ("device_lists_changes_in_room", "stream_id"),
             ],
         )
-        device_list_max = self._device_list_id_gen.get_current_token()
-        device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_lists_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=device_list_max,
-            limit=1000,
-        )
-        self._device_list_stream_cache = StreamChangeCache(
-            "DeviceListStreamChangeCache",
-            min_device_list_id,
-            prefilled_cache=device_list_prefill,
-        )
-        self._user_signature_stream_cache = StreamChangeCache(
-            "UserSignatureStreamChangeCache", device_list_max
-        )
-        self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache", device_list_max
-        )
+
+        super().__init__(database, db_conn, hs)
 
     def get_device_stream_token(self) -> int:
         return self._device_list_id_gen.get_current_token()

+ 27 - 16
synapse/storage/database.py

@@ -2030,29 +2030,40 @@ class DatabasePool:
         max_value: int,
         limit: int = 100000,
     ) -> Tuple[Dict[Any, int], int]:
-        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
-        # It doesn't really matter how many we get, the StreamChangeCache will
-        # do the right thing to ensure it respects the max size of cache.
-        sql = (
-            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
-            " WHERE %(stream)s > ? - %(limit)s"
-            " GROUP BY %(entity)s"
-        ) % {
-            "table": table,
-            "entity": entity_column,
-            "stream": stream_column,
-            "limit": limit,
-        }
+        """Gets roughly the last N changes in the given stream table as a
+        map from entity to the stream ID of the most recent change.
+
+        Also returns the minimum stream ID.
+        """
+
+        # This may return many rows for the same entity, but the `limit` is only
+        # a suggestion so we don't care that much.
+        #
+        # Note: Some stream tables can have multiple rows with the same stream
+        # ID. Instead of handling this with complicated SQL, we instead simply
+        # add one to the returned minimum stream ID to ensure correctness.
+        sql = f"""
+            SELECT {entity_column}, {stream_column}
+            FROM {table}
+            ORDER BY {stream_column} DESC
+            LIMIT ?
+        """
 
         txn = db_conn.cursor(txn_name="get_cache_dict")
-        txn.execute(sql, (int(max_value),))
+        txn.execute(sql, (limit,))
 
-        cache = {row[0]: int(row[1]) for row in txn}
+        # The rows come out in reverse stream ID order, so we want to keep the
+        # stream ID of the first row for each entity.
+        cache: Dict[Any, int] = {}
+        for row in txn:
+            cache.setdefault(row[0], int(row[1]))
 
         txn.close()
 
         if cache:
-            min_val = min(cache.values())
+            # We add one here as we don't know if we have all rows for the
+            # minimum stream ID.
+            min_val = min(cache.values()) + 1
         else:
             min_val = max_value
 

+ 0 - 21
synapse/storage/databases/main/__init__.py

@@ -183,27 +183,6 @@ class DataStore(
 
         super().__init__(database, db_conn, hs)
 
-        device_list_max = self._device_list_id_gen.get_current_token()
-        device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_lists_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=device_list_max,
-            limit=1000,
-        )
-        self._device_list_stream_cache = StreamChangeCache(
-            "DeviceListStreamChangeCache",
-            min_device_list_id,
-            prefilled_cache=device_list_prefill,
-        )
-        self._user_signature_stream_cache = StreamChangeCache(
-            "UserSignatureStreamChangeCache", device_list_max
-        )
-        self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache", device_list_max
-        )
-
         events_max = self._stream_id_gen.get_current_token()
         curr_state_delta_prefill, min_curr_state_delta_id = self.db_pool.get_cache_dict(
             db_conn,

+ 50 - 0
synapse/storage/databases/main/devices.py

@@ -46,6 +46,7 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
 
@@ -71,6 +72,55 @@ class DeviceWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
+        device_list_max = self._device_list_id_gen.get_current_token()
+        device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_lists_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=device_list_max,
+            limit=10000,
+        )
+        self._device_list_stream_cache = StreamChangeCache(
+            "DeviceListStreamChangeCache",
+            min_device_list_id,
+            prefilled_cache=device_list_prefill,
+        )
+
+        (
+            user_signature_stream_prefill,
+            user_signature_stream_list_id,
+        ) = self.db_pool.get_cache_dict(
+            db_conn,
+            "user_signature_stream",
+            entity_column="from_user_id",
+            stream_column="stream_id",
+            max_value=device_list_max,
+            limit=1000,
+        )
+        self._user_signature_stream_cache = StreamChangeCache(
+            "UserSignatureStreamChangeCache",
+            user_signature_stream_list_id,
+            prefilled_cache=user_signature_stream_prefill,
+        )
+
+        (
+            device_list_federation_prefill,
+            device_list_federation_list_id,
+        ) = self.db_pool.get_cache_dict(
+            db_conn,
+            "device_lists_outbound_pokes",
+            entity_column="destination",
+            stream_column="stream_id",
+            max_value=device_list_max,
+            limit=10000,
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache",
+            device_list_federation_list_id,
+            prefilled_cache=device_list_federation_prefill,
+        )
+
         if hs.config.worker.run_background_tasks:
             self._clock.looping_call(
                 self._prune_old_outbound_device_pokes, 60 * 60 * 1000

+ 12 - 1
synapse/storage/databases/main/receipts.py

@@ -98,8 +98,19 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         super().__init__(database, db_conn, hs)
 
+        max_receipts_stream_id = self.get_max_receipt_stream_id()
+        receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
+            db_conn,
+            "receipts_linearized",
+            entity_column="room_id",
+            stream_column="stream_id",
+            max_value=max_receipts_stream_id,
+            limit=10000,
+        )
         self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
+            "ReceiptsRoomChangeCache",
+            min_receipts_stream_id,
+            prefilled_cache=receipts_stream_prefill,
         )
 
     def get_max_receipt_stream_id(self) -> int: