Browse Source

Track in memory events using weakrefs (#10533)

Erik Johnston 2 years ago
parent
commit
fcf951d5dc

+ 1 - 0
changelog.d/10533.misc

@@ -0,0 +1 @@
+Improve event caching mechanism to avoid having multiple copies of an event in memory at a time.

+ 33 - 2
synapse/storage/databases/main/events_worker.py

@@ -14,6 +14,7 @@
 
 import logging
 import threading
+import weakref
 from enum import Enum, auto
 from typing import (
     TYPE_CHECKING,
@@ -23,6 +24,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    MutableMapping,
     Optional,
     Set,
     Tuple,
@@ -248,6 +250,12 @@ class EventsWorkerStore(SQLBaseStore):
             str, ObservableDeferred[Dict[str, EventCacheEntry]]
         ] = {}
 
+        # We keep track of the events we have currently loaded in memory so that
+        # we can reuse them even if they've been evicted from the cache. We only
+        # track events that don't need redacting in here (as then we don't need
+        # to track redaction status).
+        self._event_ref: MutableMapping[str, EventBase] = weakref.WeakValueDictionary()
+
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list: List[
             Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"]
@@ -723,6 +731,8 @@ class EventsWorkerStore(SQLBaseStore):
 
     def _invalidate_get_event_cache(self, event_id: str) -> None:
         self._get_event_cache.invalidate((event_id,))
+        self._event_ref.pop(event_id, None)
+        self._current_event_fetches.pop(event_id, None)
 
     def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
@@ -738,13 +748,30 @@ class EventsWorkerStore(SQLBaseStore):
         event_map = {}
 
         for event_id in events:
+            # First check if it's in the event cache
             ret = self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
-            if not ret:
+            if ret:
+                event_map[event_id] = ret
                 continue
 
-            event_map[event_id] = ret
+            # Otherwise check if we still have the event in memory.
+            event = self._event_ref.get(event_id)
+            if event:
+                # Reconstruct an event cache entry
+
+                cache_entry = EventCacheEntry(
+                    event=event,
+                    # We don't cache weakrefs to redacted events, so we know
+                    # this is None.
+                    redacted_event=None,
+                )
+                event_map[event_id] = cache_entry
+
+                # We add the entry back into the cache as we want to keep
+                # recently queried events in the cache.
+                self._get_event_cache.set((event_id,), cache_entry)
 
         return event_map
 
@@ -1124,6 +1151,10 @@ class EventsWorkerStore(SQLBaseStore):
             self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
+            if not redacted_event:
+                # We only cache references to unredacted events.
+                self._event_ref[event_id] = original_ev
+
         return result_map
 
     async def _enqueue_events(self, events: Collection[str]) -> Dict[str, _EventRow]:

+ 1 - 0
tests/handlers/test_sync.py

@@ -160,6 +160,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         # Blow away caches (supported room versions can only change due to a restart).
         self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
         self.store._get_event_cache.clear()
+        self.store._event_ref.clear()
 
         # The rooms should be excluded from the sync response.
         # Get a new request key.

+ 25 - 0
tests/storage/databases/main/test_events_worker.py

@@ -154,6 +154,31 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
             # We should have fetched the event from the DB
             self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
 
+    def test_event_ref(self):
+        """Test that we reuse events that are still in memory but have fallen
+        out of the cache, rather than requesting them from the DB.
+        """
+
+        # Reset the event cache
+        self.store._get_event_cache.clear()
+
+        with LoggingContext("test") as ctx:
+            # We keep hold of the event event though we never use it.
+            event = self.get_success(self.store.get_event(self.event_id))  # noqa: F841
+
+            # We should have fetched the event from the DB
+            self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+        # Reset the event cache
+        self.store._get_event_cache.clear()
+
+        with LoggingContext("test") as ctx:
+            self.get_success(self.store.get_event(self.event_id))
+
+            # Since the event is still in memory we shouldn't have fetched it
+            # from the DB
+            self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)
+
     def test_dedupe(self):
         """Test that if we request the same event multiple times we only pull it
         out once.