Prechádzať zdrojové kódy

Async get event cache prep (#13242)

Some experimental prep work to enable external event caching based on #9379 & #12955. Doesn't actually move the cache at all, just lays the groundwork for async implemented caches.

Signed off by Nick @ Beeper (@Fizzadar)
Nick Mills-Barrett 1 rok pred
rodič
commit
cc21a431f3

+ 1 - 0
changelog.d/13242.misc

@@ -0,0 +1 @@
+Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar).

+ 6 - 4
synapse/storage/database.py

@@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
-from synapse.util.async_helpers import delay_cancellation
+from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -818,12 +818,14 @@ class DatabasePool:
                     )
 
                 for after_callback, after_args, after_kwargs in after_callbacks:
-                    after_callback(*after_args, **after_kwargs)
+                    await maybe_awaitable(after_callback(*after_args, **after_kwargs))
 
                 return cast(R, result)
             except Exception:
-                for after_callback, after_args, after_kwargs in exception_callbacks:
-                    after_callback(*after_args, **after_kwargs)
+                for exception_callback, after_args, after_kwargs in exception_callbacks:
+                    await maybe_awaitable(
+                        exception_callback(*after_args, **after_kwargs)
+                    )
                 raise
 
         # To handle cancellation, we ensure that `after_callback`s and

+ 5 - 2
synapse/storage/databases/main/cache.py

@@ -193,7 +193,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
         relates_to: Optional[str],
         backfilled: bool,
     ) -> None:
-        self._invalidate_get_event_cache(event_id)
+        # This invalidates any local in-memory cached event objects, the original
+        # process triggering the invalidation is responsible for clearing any external
+        # cached objects.
+        self._invalidate_local_get_event_cache(event_id)
         self.have_seen_event.invalidate((room_id, event_id))
 
         self.get_latest_event_ids_in_room.invalidate((room_id,))
@@ -208,7 +211,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
 
         if redacts:
-            self._invalidate_get_event_cache(redacts)
+            self._invalidate_local_get_event_cache(redacts)
             # Caches which might leak edits must be invalidated for the event being
             # redacted.
             self.get_relations_for_event.invalidate((redacts,))

+ 2 - 2
synapse/storage/databases/main/events.py

@@ -1669,9 +1669,9 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        def prefill() -> None:
+        async def prefill() -> None:
             for cache_entry in to_prefill:
-                self.store._get_event_cache.set(
+                await self.store._get_event_cache.set(
                     (cache_entry.event.event_id,), cache_entry
                 )
 

+ 24 - 10
synapse/storage/databases/main/events_worker.py

@@ -79,7 +79,7 @@ from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.lrucache import LruCache
+from synapse.util.caches.lrucache import AsyncLruCache
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -238,7 +238,9 @@ class EventsWorkerStore(SQLBaseStore):
                 5 * 60 * 1000,
             )
 
-        self._get_event_cache: LruCache[Tuple[str], EventCacheEntry] = LruCache(
+        self._get_event_cache: AsyncLruCache[
+            Tuple[str], EventCacheEntry
+        ] = AsyncLruCache(
             cache_name="*getEvent*",
             max_size=hs.config.caches.event_cache_size,
         )
@@ -598,7 +600,7 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             map from event id to result
         """
-        event_entry_map = self._get_events_from_cache(
+        event_entry_map = await self._get_events_from_cache(
             event_ids,
         )
 
@@ -710,12 +712,22 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_entry_map
 
-    def _invalidate_get_event_cache(self, event_id: str) -> None:
-        self._get_event_cache.invalidate((event_id,))
+    async def _invalidate_get_event_cache(self, event_id: str) -> None:
+        # First we invalidate the asynchronous cache instance. This may include
+        # out-of-process caches such as Redis/memcache. Once complete we can
+        # invalidate any in memory cache. The ordering is important here to
+        # ensure we don't pull in any remote invalid value after we invalidate
+        # the in-memory cache.
+        await self._get_event_cache.invalidate((event_id,))
         self._event_ref.pop(event_id, None)
         self._current_event_fetches.pop(event_id, None)
 
-    def _get_events_from_cache(
+    def _invalidate_local_get_event_cache(self, event_id: str) -> None:
+        self._get_event_cache.invalidate_local((event_id,))
+        self._event_ref.pop(event_id, None)
+        self._current_event_fetches.pop(event_id, None)
+
+    async def _get_events_from_cache(
         self, events: Iterable[str], update_metrics: bool = True
     ) -> Dict[str, EventCacheEntry]:
         """Fetch events from the caches.
@@ -730,7 +742,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         for event_id in events:
             # First check if it's in the event cache
-            ret = self._get_event_cache.get(
+            ret = await self._get_event_cache.get(
                 (event_id,), None, update_metrics=update_metrics
             )
             if ret:
@@ -752,7 +764,7 @@ class EventsWorkerStore(SQLBaseStore):
 
                 # We add the entry back into the cache as we want to keep
                 # recently queried events in the cache.
-                self._get_event_cache.set((event_id,), cache_entry)
+                await self._get_event_cache.set((event_id,), cache_entry)
 
         return event_map
 
@@ -1129,7 +1141,7 @@ class EventsWorkerStore(SQLBaseStore):
                 event=original_ev, redacted_event=redacted_event
             )
 
-            self._get_event_cache.set((event_id,), cache_entry)
+            await self._get_event_cache.set((event_id,), cache_entry)
             result_map[event_id] = cache_entry
 
             if not redacted_event:
@@ -1363,7 +1375,9 @@ class EventsWorkerStore(SQLBaseStore):
         # if the event cache contains the event, obviously we've seen it.
 
         cache_results = {
-            (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,))
+            (rid, eid)
+            for (rid, eid) in keys
+            if await self._get_event_cache.contains((eid,))
         }
         results = dict.fromkeys(cache_results, True)
         remaining = [k for k in keys if k not in cache_results]

+ 1 - 1
synapse/storage/databases/main/purge_events.py

@@ -302,7 +302,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
                 self._invalidate_cache_and_stream(
                     txn, self.have_seen_event, (room_id, event_id)
                 )
-                self._invalidate_get_event_cache(event_id)
+                txn.call_after(self._invalidate_get_event_cache, event_id)
 
         logger.info("[purge] done")
 

+ 3 - 1
synapse/storage/databases/main/roommember.py

@@ -843,7 +843,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         # We don't update the event cache hit ratio as it completely throws off
         # the hit ratio counts. After all, we don't populate the cache if we
         # miss it here
-        event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
+        event_map = await self._get_events_from_cache(
+            member_event_ids, update_metrics=False
+        )
 
         missing_member_event_ids = []
         for event_id in member_event_ids:

+ 38 - 0
synapse/util/caches/lrucache.py

@@ -730,3 +730,41 @@ class LruCache(Generic[KT, VT]):
         # This happens e.g. in the sync code where we have an expiring cache of
         # lru caches.
         self.clear()
+
+
+class AsyncLruCache(Generic[KT, VT]):
+    """
+    An asynchronous wrapper around a subset of the LruCache API.
+
+    On its own this doesn't change the behaviour but allows subclasses that
+    utilize external cache systems that require await behaviour to be created.
+    """
+
+    def __init__(self, *args, **kwargs):  # type: ignore
+        self._lru_cache: LruCache[KT, VT] = LruCache(*args, **kwargs)
+
+    async def get(
+        self, key: KT, default: Optional[T] = None, update_metrics: bool = True
+    ) -> Optional[VT]:
+        return self._lru_cache.get(key, update_metrics=update_metrics)
+
+    async def set(self, key: KT, value: VT) -> None:
+        self._lru_cache.set(key, value)
+
+    async def invalidate(self, key: KT) -> None:
+        # This method should invalidate any external cache and then invalidate the LruCache.
+        return self._lru_cache.invalidate(key)
+
+    def invalidate_local(self, key: KT) -> None:
+        """Remove an entry from the local cache
+
+        This variant of `invalidate` is useful if we know that the external
+        cache has already been invalidated.
+        """
+        return self._lru_cache.invalidate(key)
+
+    async def contains(self, key: KT) -> bool:
+        return self._lru_cache.contains(key)
+
+    async def clear(self) -> None:
+        self._lru_cache.clear()

+ 1 - 1
tests/handlers/test_sync.py

@@ -159,7 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
 
         # Blow away caches (supported room versions can only change due to a restart).
         self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
-        self.store._get_event_cache.clear()
+        self.get_success(self.store._get_event_cache.clear())
         self.store._event_ref.clear()
 
         # The rooms should be excluded from the sync response.

+ 4 - 4
tests/storage/databases/main/test_events_worker.py

@@ -143,7 +143,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
         self.event_id = res["event_id"]
 
         # Reset the event cache so the tests start with it empty
-        self.store._get_event_cache.clear()
+        self.get_success(self.store._get_event_cache.clear())
 
     def test_simple(self):
         """Test that we cache events that we pull from the DB."""
@@ -160,7 +160,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
         """
 
         # Reset the event cache
-        self.store._get_event_cache.clear()
+        self.get_success(self.store._get_event_cache.clear())
 
         with LoggingContext("test") as ctx:
             # We keep hold of the event event though we never use it.
@@ -170,7 +170,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
             self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
 
         # Reset the event cache
-        self.store._get_event_cache.clear()
+        self.get_success(self.store._get_event_cache.clear())
 
         with LoggingContext("test") as ctx:
             self.get_success(self.store.get_event(self.event_id))
@@ -345,7 +345,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase):
         self.event_id = res["event_id"]
 
         # Reset the event cache so the tests start with it empty
-        self.store._get_event_cache.clear()
+        self.get_success(self.store._get_event_cache.clear())
 
     @contextmanager
     def blocking_get_event_calls(

+ 1 - 1
tests/storage/test_purge.py

@@ -115,6 +115,6 @@ class PurgeTests(HomeserverTestCase):
         )
 
         # The events aren't found.
-        self.store._invalidate_get_event_cache(create_event.event_id)
+        self.store._invalidate_local_get_event_cache(create_event.event_id)
         self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
         self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)