Browse Source

Track ongoing event fetches correctly (again) (#11376)

The previous fix for the ongoing event fetches counter
(8eec25a1d9d656905db18a2c62a5552e63db2667) was both insufficient and
incorrect.

When the database is unreachable, `_do_fetch` never gets run and so
`_event_fetch_ongoing` is never decremented.

The previous fix also moved the `_event_fetch_ongoing` decrement outside
of the `_event_fetch_lock` which allowed race conditions to corrupt the
counter.
Sean Quah 2 years ago
parent
commit
c675a18071

+ 1 - 0
changelog.d/11376.bugfix

@@ -0,0 +1 @@
+Fix a long-standing bug where all requests that read events from the database could get stuck as a result of losing the database connection, for real this time. Also fix a race condition introduced in the previous insufficient fix in 1.47.0.

+ 112 - 42
synapse/storage/databases/main/events_worker.py

@@ -72,7 +72,7 @@ from synapse.util.metrics import Measure
 logger = logging.getLogger(__name__)
 
 
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# These values are used in the `enqueue_event` and `_fetch_loop` methods to
 # control how we batch/bulk fetch events from the database.
 # The values are plucked out of thing air to make initial sync run faster
 # on jki.re
@@ -602,7 +602,7 @@ class EventsWorkerStore(SQLBaseStore):
             # already due to `_get_events_from_db`).
             fetching_deferred: ObservableDeferred[
                 Dict[str, _EventCacheEntry]
-            ] = ObservableDeferred(defer.Deferred())
+            ] = ObservableDeferred(defer.Deferred(), consumeErrors=True)
             for event_id in missing_events_ids:
                 self._current_event_fetches[event_id] = fetching_deferred
 
@@ -736,35 +736,118 @@ class EventsWorkerStore(SQLBaseStore):
             for e in state_to_include.values()
         ]
 
-    def _do_fetch(self, conn: Connection) -> None:
+    def _maybe_start_fetch_thread(self) -> None:
+        """Starts an event fetch thread if we are not yet at the maximum number."""
+        with self._event_fetch_lock:
+            if (
+                self._event_fetch_list
+                and self._event_fetch_ongoing < EVENT_QUEUE_THREADS
+            ):
+                self._event_fetch_ongoing += 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+                # `_event_fetch_ongoing` is decremented in `_fetch_thread`.
+                should_start = True
+            else:
+                should_start = False
+
+        if should_start:
+            run_as_background_process("fetch_events", self._fetch_thread)
+
+    async def _fetch_thread(self) -> None:
+        """Services requests for events from `_event_fetch_list`."""
+        exc = None
+        try:
+            await self.db_pool.runWithConnection(self._fetch_loop)
+        except BaseException as e:
+            exc = e
+            raise
+        finally:
+            should_restart = False
+            event_fetches_to_fail = []
+            with self._event_fetch_lock:
+                self._event_fetch_ongoing -= 1
+                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+
+                # There may still be work remaining in `_event_fetch_list` if we
+                # failed, or it was added in between us deciding to exit and
+                # decrementing `_event_fetch_ongoing`.
+                if self._event_fetch_list:
+                    if exc is None:
+                        # We decided to exit, but then some more work was added
+                        # before `_event_fetch_ongoing` was decremented.
+                        # If a new event fetch thread was not started, we should
+                        # restart ourselves since the remaining event fetch threads
+                        # may take a while to get around to the new work.
+                        #
+                        # Unfortunately it is not possible to tell whether a new
+                        # event fetch thread was started, so we restart
+                        # unconditionally. If we are unlucky, we will end up with
+                        # an idle fetch thread, but it will time out after
+                        # `EVENT_QUEUE_ITERATIONS * EVENT_QUEUE_TIMEOUT_S` seconds
+                        # in any case.
+                        #
+                        # Note that multiple fetch threads may run down this path at
+                        # the same time.
+                        should_restart = True
+                    elif isinstance(exc, Exception):
+                        if self._event_fetch_ongoing == 0:
+                            # We were the last remaining fetcher and failed.
+                            # Fail any outstanding fetches since no one else will
+                            # handle them.
+                            event_fetches_to_fail = self._event_fetch_list
+                            self._event_fetch_list = []
+                        else:
+                            # We weren't the last remaining fetcher, so another
+                            # fetcher will pick up the work. This will either happen
+                            # after their existing work, however long that takes,
+                            # or after at most `EVENT_QUEUE_TIMEOUT_S` seconds if
+                            # they are idle.
+                            pass
+                    else:
+                        # The exception is a `SystemExit`, `KeyboardInterrupt` or
+                        # `GeneratorExit`. Don't try to do anything clever here.
+                        pass
+
+            if should_restart:
+                # We exited cleanly but noticed more work.
+                self._maybe_start_fetch_thread()
+
+            if event_fetches_to_fail:
+                # We were the last remaining fetcher and failed.
+                # Fail any outstanding fetches since no one else will handle them.
+                assert exc is not None
+                with PreserveLoggingContext():
+                    for _, deferred in event_fetches_to_fail:
+                        deferred.errback(exc)
+
+    def _fetch_loop(self, conn: Connection) -> None:
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
         """
-        try:
-            i = 0
-            while True:
-                with self._event_fetch_lock:
-                    event_list = self._event_fetch_list
-                    self._event_fetch_list = []
-
-                    if not event_list:
-                        single_threaded = self.database_engine.single_threaded
-                        if (
-                            not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
-                            or single_threaded
-                            or i > EVENT_QUEUE_ITERATIONS
-                        ):
-                            break
-                        else:
-                            self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
-                            i += 1
-                            continue
-                    i = 0
+        i = 0
+        while True:
+            with self._event_fetch_lock:
+                event_list = self._event_fetch_list
+                self._event_fetch_list = []
+
+                if not event_list:
+                    # There are no requests waiting. If we haven't yet reached the
+                    # maximum iteration limit, wait for some more requests to turn up.
+                    # Otherwise, bail out.
+                    single_threaded = self.database_engine.single_threaded
+                    if (
+                        not self.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING
+                        or single_threaded
+                        or i > EVENT_QUEUE_ITERATIONS
+                    ):
+                        return
+
+                    self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+                    i += 1
+                    continue
+                i = 0
 
-                self._fetch_event_list(conn, event_list)
-        finally:
-            self._event_fetch_ongoing -= 1
-            event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
+            self._fetch_event_list(conn, event_list)
 
     def _fetch_event_list(
         self, conn: Connection, event_list: List[Tuple[List[str], defer.Deferred]]
@@ -806,9 +889,7 @@ class EventsWorkerStore(SQLBaseStore):
                 # We only want to resolve deferreds from the main thread
                 def fire(evs, exc):
                     for _, d in evs:
-                        if not d.called:
-                            with PreserveLoggingContext():
-                                d.errback(exc)
+                        d.errback(exc)
 
                 with PreserveLoggingContext():
                     self.hs.get_reactor().callFromThread(fire, event_list, e)
@@ -983,20 +1064,9 @@ class EventsWorkerStore(SQLBaseStore):
         events_d = defer.Deferred()
         with self._event_fetch_lock:
             self._event_fetch_list.append((events, events_d))
-
             self._event_fetch_lock.notify()
 
-            if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
-                self._event_fetch_ongoing += 1
-                event_fetch_ongoing_gauge.set(self._event_fetch_ongoing)
-                should_start = True
-            else:
-                should_start = False
-
-        if should_start:
-            run_as_background_process(
-                "fetch_events", self.db_pool.runWithConnection, self._do_fetch
-            )
+        self._maybe_start_fetch_thread()
 
         logger.debug("Loading %d events: %s", len(events), events)
         with PreserveLoggingContext():

+ 138 - 1
tests/storage/databases/main/test_events_worker.py

@@ -12,11 +12,24 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+from contextlib import contextmanager
+from typing import Generator
 
+from twisted.enterprise.adbapi import ConnectionPool
+from twisted.internet.defer import ensureDeferred
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import EventFormatVersions, RoomVersions
 from synapse.logging.context import LoggingContext
 from synapse.rest import admin
 from synapse.rest.client import login, room
-from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import (
+    EVENT_QUEUE_THREADS,
+    EventsWorkerStore,
+)
+from synapse.storage.types import Connection
+from synapse.util import Clock
 from synapse.util.async_helpers import yieldable_gather_results
 
 from tests import unittest
@@ -144,3 +157,127 @@ class EventCacheTestCase(unittest.HomeserverTestCase):
 
             # We should have fetched the event from the DB
             self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)
+
+
+class DatabaseOutageTestCase(unittest.HomeserverTestCase):
+    """Test event fetching during a database outage."""
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+        self.store: EventsWorkerStore = hs.get_datastore()
+
+        self.room_id = f"!room:{hs.hostname}"
+        self.event_ids = [f"event{i}" for i in range(20)]
+
+        self._populate_events()
+
+    def _populate_events(self) -> None:
+        """Ensure that there are test events in the database.
+
+        When testing with the in-memory SQLite database, all the events are lost during
+        the simulated outage.
+
+        To ensure consistency between `room_id`s and `event_id`s before and after the
+        outage, rows are built and inserted manually.
+
+        Upserts are used to handle the non-SQLite case where events are not lost.
+        """
+        self.get_success(
+            self.store.db_pool.simple_upsert(
+                "rooms",
+                {"room_id": self.room_id},
+                {"room_version": RoomVersions.V4.identifier},
+            )
+        )
+
+        self.event_ids = [f"event{i}" for i in range(20)]
+        for idx, event_id in enumerate(self.event_ids):
+            self.get_success(
+                self.store.db_pool.simple_upsert(
+                    "events",
+                    {"event_id": event_id},
+                    {
+                        "event_id": event_id,
+                        "room_id": self.room_id,
+                        "topological_ordering": idx,
+                        "stream_ordering": idx,
+                        "type": "test",
+                        "processed": True,
+                        "outlier": False,
+                    },
+                )
+            )
+            self.get_success(
+                self.store.db_pool.simple_upsert(
+                    "event_json",
+                    {"event_id": event_id},
+                    {
+                        "room_id": self.room_id,
+                        "json": json.dumps({"type": "test", "room_id": self.room_id}),
+                        "internal_metadata": "{}",
+                        "format_version": EventFormatVersions.V3,
+                    },
+                )
+            )
+
+    @contextmanager
+    def _outage(self) -> Generator[None, None, None]:
+        """Simulate a database outage.
+
+        Returns:
+            A context manager. While the context is active, any attempts to connect to
+            the database will fail.
+        """
+        connection_pool = self.store.db_pool._db_pool
+
+        # Close all connections and shut down the database `ThreadPool`.
+        connection_pool.close()
+
+        # Restart the database `ThreadPool`.
+        connection_pool.start()
+
+        original_connection_factory = connection_pool.connectionFactory
+
+        def connection_factory(_pool: ConnectionPool) -> Connection:
+            raise Exception("Could not connect to the database.")
+
+        connection_pool.connectionFactory = connection_factory  # type: ignore[assignment]
+        try:
+            yield
+        finally:
+            connection_pool.connectionFactory = original_connection_factory
+
+            # If the in-memory SQLite database is being used, all the events are gone.
+            # Restore the test data.
+            self._populate_events()
+
+    def test_failure(self) -> None:
+        """Test that event fetches do not get stuck during a database outage."""
+        with self._outage():
+            failure = self.get_failure(
+                self.store.get_event(self.event_ids[0]), Exception
+            )
+            self.assertEqual(str(failure.value), "Could not connect to the database.")
+
+    def test_recovery(self) -> None:
+        """Test that event fetchers recover after a database outage."""
+        with self._outage():
+            # Kick off a bunch of event fetches but do not pump the reactor
+            event_deferreds = []
+            for event_id in self.event_ids:
+                event_deferreds.append(ensureDeferred(self.store.get_event(event_id)))
+
+            # We should have maxed out on event fetcher threads
+            self.assertEqual(self.store._event_fetch_ongoing, EVENT_QUEUE_THREADS)
+
+            # All the event fetchers will fail
+            self.pump()
+            self.assertEqual(self.store._event_fetch_ongoing, 0)
+
+            for event_deferred in event_deferreds:
+                failure = self.get_failure(event_deferred, Exception)
+                self.assertEqual(
+                    str(failure.value), "Could not connect to the database."
+                )
+
+        # This next event fetch should succeed
+        self.get_success(self.store.get_event(self.event_ids[0]))