Browse Source

Faster joins: Avoid starting duplicate partial state syncs (#14844)

Currently, we will try to start a new partial state sync every time we
perform a remote join, which is undesirable if there is already one
running for a given room.

We intend to perform remote joins whenever additional local users wish
to join a partial state room, so let's ensure that we do not start more
than one concurrent partial state sync for any given room.

------------------------------------------------------------------------

There is a race condition where the homeserver leaves a room and later
rejoins while the partial state sync from the previous membership is
still running. There is no guarantee that the previous partial state
sync will process the latest join, so we restart it if needed.

Signed-off-by: Sean Quah <seanq@matrix.org>
Sean Quah 1 year ago
parent
commit
cdea7c11d0
3 changed files with 210 additions and 9 deletions
  1. 1 0
      changelog.d/14844.misc
  2. 98 8
      synapse/handlers/federation.py
  3. 111 1
      tests/handlers/test_federation.py

+ 1 - 0
changelog.d/14844.misc

@@ -0,0 +1 @@
+Add check to avoid starting duplicate partial state syncs.

+ 98 - 8
synapse/handlers/federation.py

@@ -27,6 +27,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Set,
     Tuple,
     Union,
 )
@@ -171,12 +172,23 @@ class FederationHandler:
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        # Tracks running partial state syncs by room ID.
+        # Partial state syncs currently only run on the main process, so it's okay to
+        # track them in-memory for now.
+        self._active_partial_state_syncs: Set[str] = set()
+        # Tracks partial state syncs we may want to restart.
+        # A dictionary mapping room IDs to (initial destination, other destinations)
+        # tuples.
+        self._partial_state_syncs_maybe_needing_restart: Dict[
+            str, Tuple[Optional[str], Collection[str]]
+        ] = {}
+
         # if this is the main process, fire off a background process to resume
         # any partial-state-resync operations which were in flight when we
         # were shut down.
         if not hs.config.worker.worker_app:
             run_as_background_process(
-                "resume_sync_partial_state_room", self._resume_sync_partial_state_room
+                "resume_sync_partial_state_room", self._resume_partial_state_room_sync
             )
 
     @trace
@@ -679,9 +691,7 @@ class FederationHandler:
                 if ret.partial_state:
                     # Kick off the process of asynchronously fetching the state for this
                     # room.
-                    run_as_background_process(
-                        desc="sync_partial_state_room",
-                        func=self._sync_partial_state_room,
+                    self._start_partial_state_room_sync(
                         initial_destination=origin,
                         other_destinations=ret.servers_in_room,
                         room_id=room_id,
@@ -1660,20 +1670,100 @@ class FederationHandler:
         # well.
         return None
 
-    async def _resume_sync_partial_state_room(self) -> None:
+    async def _resume_partial_state_room_sync(self) -> None:
         """Resumes resyncing of all partial-state rooms after a restart."""
         assert not self.config.worker.worker_app
 
         partial_state_rooms = await self.store.get_partial_state_room_resync_info()
         for room_id, resync_info in partial_state_rooms.items():
-            run_as_background_process(
-                desc="sync_partial_state_room",
-                func=self._sync_partial_state_room,
+            self._start_partial_state_room_sync(
                 initial_destination=resync_info.joined_via,
                 other_destinations=resync_info.servers_in_room,
                 room_id=room_id,
             )
 
+    def _start_partial_state_room_sync(
+        self,
+        initial_destination: Optional[str],
+        other_destinations: Collection[str],
+        room_id: str,
+    ) -> None:
+        """Starts the background process to resync the state of a partial state room,
+        if it is not already running.
+
+        Args:
+            initial_destination: the initial homeserver to pull the state from
+            other_destinations: other homeservers to try to pull the state from, if
+                `initial_destination` is unavailable
+            room_id: room to be resynced
+        """
+
+        async def _sync_partial_state_room_wrapper() -> None:
+            if room_id in self._active_partial_state_syncs:
+                # Another local user has joined the room while there is already a
+                # partial state sync running. This implies that there is a new join
+                # event to un-partial state. We might find ourselves in one of a few
+                # scenarios:
+                #  1. There is an existing partial state sync. The partial state sync
+                #     un-partial states the new join event before completing and all is
+                #     well.
+                #  2. Before the latest join, the homeserver was no longer in the room
+                #     and there is an existing partial state sync from our previous
+                #     membership of the room. The partial state sync may have:
+                #      a) succeeded, but not yet terminated. The room will not be
+                #         un-partial stated again unless we restart the partial state
+                #         sync.
+                #      b) failed, because we were no longer in the room and remote
+                #         homeservers were refusing our requests, but not yet
+                #         terminated. After the latest join, remote homeservers may
+                #         start answering our requests again, so we should restart the
+                #         partial state sync.
+                # In the cases where we would want to restart the partial state sync,
+                # the room would have the partial state flag when the partial state sync
+                # terminates.
+                self._partial_state_syncs_maybe_needing_restart[room_id] = (
+                    initial_destination,
+                    other_destinations,
+                )
+                return
+
+            self._active_partial_state_syncs.add(room_id)
+
+            try:
+                await self._sync_partial_state_room(
+                    initial_destination=initial_destination,
+                    other_destinations=other_destinations,
+                    room_id=room_id,
+                )
+            finally:
+                # Read the room's partial state flag while we still hold the claim to
+                # being the active partial state sync (so that another partial state
+                # sync can't come along and mess with it under us).
+                # Normally, the partial state flag will be gone. If it isn't, then we
+                # may find ourselves in scenario 2a or 2b as described in the comment
+                # above, where we want to restart the partial state sync.
+                is_still_partial_state_room = await self.store.is_partial_state_room(
+                    room_id
+                )
+                self._active_partial_state_syncs.remove(room_id)
+
+                if room_id in self._partial_state_syncs_maybe_needing_restart:
+                    (
+                        restart_initial_destination,
+                        restart_other_destinations,
+                    ) = self._partial_state_syncs_maybe_needing_restart.pop(room_id)
+
+                    if is_still_partial_state_room:
+                        self._start_partial_state_room_sync(
+                            initial_destination=restart_initial_destination,
+                            other_destinations=restart_other_destinations,
+                            room_id=room_id,
+                        )
+
+        run_as_background_process(
+            desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper
+        )
+
     async def _sync_partial_state_room(
         self,
         initial_destination: Optional[str],

+ 111 - 1
tests/handlers/test_federation.py

@@ -12,10 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import cast
+from typing import Collection, Optional, cast
 from unittest import TestCase
 from unittest.mock import Mock, patch
 
+from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes
@@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             f"Stale partial-stated room flag left over for {room_id} after a"
             f" failed do_invite_join!",
         )
+
+    def test_duplicate_partial_state_room_syncs(self) -> None:
+        """
+        Tests that concurrent partial state syncs are not started for the same room.
+        """
+        is_partial_state = True
+        end_sync: "Deferred[None]" = Deferred()
+
+        async def is_partial_state_room(room_id: str) -> bool:
+            return is_partial_state
+
+        async def sync_partial_state_room(
+            initial_destination: Optional[str],
+            other_destinations: Collection[str],
+            room_id: str,
+        ) -> None:
+            nonlocal end_sync
+            try:
+                await end_sync
+            finally:
+                end_sync = Deferred()
+
+        mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+        mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+        fed_handler = self.hs.get_federation_handler()
+        store = self.hs.get_datastores().main
+
+        with patch.object(
+            fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+        ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+            # Start the partial state sync.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Try to start another partial state sync.
+            # Nothing should happen.
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # End the partial state sync
+            is_partial_state = False
+            end_sync.callback(None)
+
+            # The partial state sync should not be restarted.
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # The next attempt to start the partial state sync should work.
+            is_partial_state = True
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+    def test_partial_state_room_sync_restart(self) -> None:
+        """
+        Tests that partial state syncs are restarted when a second partial state sync
+        was deduplicated and the first partial state sync fails.
+        """
+        is_partial_state = True
+        end_sync: "Deferred[None]" = Deferred()
+
+        async def is_partial_state_room(room_id: str) -> bool:
+            return is_partial_state
+
+        async def sync_partial_state_room(
+            initial_destination: Optional[str],
+            other_destinations: Collection[str],
+            room_id: str,
+        ) -> None:
+            nonlocal end_sync
+            try:
+                await end_sync
+            finally:
+                end_sync = Deferred()
+
+        mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+        mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+        fed_handler = self.hs.get_federation_handler()
+        store = self.hs.get_datastores().main
+
+        with patch.object(
+            fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+        ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+            # Start the partial state sync.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Fail the partial state sync.
+            # The partial state sync should not be restarted.
+            end_sync.errback(Exception("Failed to request /state_ids"))
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Start the partial state sync again.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+            # Deduplicate another partial state sync.
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+            # Fail the partial state sync.
+            # It should restart with the latest parameters.
+            end_sync.errback(Exception("Failed to request /state_ids"))
+            self.assertEqual(mock_sync_partial_state_room.call_count, 3)
+            mock_sync_partial_state_room.assert_called_with(
+                initial_destination="hs3",
+                other_destinations=["hs2"],
+                room_id="room_id",
+            )