Browse Source

Faster room joins: avoid blocking when pulling events with missing prevs (#13355)

Avoid blocking on full state in `_resolve_state_at_missing_prevs` and
return a new flag indicating whether the resolved state is partial.
Thread that flag around so that it makes it into the event context.

Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Sean Quah 1 year ago
parent
commit
335ebb21cc

+ 1 - 0
changelog.d/13355.misc

@@ -0,0 +1 @@
+Faster room joins: avoid blocking when pulling events with partially missing prev events.

+ 92 - 24
synapse/handlers/federation_event.py

@@ -278,7 +278,9 @@ class FederationEventHandler:
                 )
 
         try:
-            await self._process_received_pdu(origin, pdu, state_ids=None)
+            await self._process_received_pdu(
+                origin, pdu, state_ids=None, partial_state=None
+            )
         except PartialStateConflictError:
             # The room was un-partial stated while we were processing the PDU.
             # Try once more, with full state this time.
@@ -286,7 +288,9 @@ class FederationEventHandler:
                 "Room %s was un-partial stated while processing the PDU, trying again.",
                 room_id,
             )
-            await self._process_received_pdu(origin, pdu, state_ids=None)
+            await self._process_received_pdu(
+                origin, pdu, state_ids=None, partial_state=None
+            )
 
     async def on_send_membership_event(
         self, origin: str, event: EventBase
@@ -534,14 +538,36 @@ class FederationEventHandler:
             #
             # This is the same operation as we do when we receive a regular event
             # over federation.
-            state_ids = await self._resolve_state_at_missing_prevs(destination, event)
-
-            # build a new state group for it if need be
-            context = await self._state_handler.compute_event_context(
-                event,
-                state_ids_before_event=state_ids,
+            state_ids, partial_state = await self._resolve_state_at_missing_prevs(
+                destination, event
             )
-            if context.partial_state:
+
+            # There are three possible cases for (state_ids, partial_state):
+            #   * `state_ids` and `partial_state` are both `None` if we had all the
+            #     prev_events. The prev_events may or may not have partial state and
+            #     we won't know until we compute the event context.
+            #   * `state_ids` is not `None` and `partial_state` is `False` if we were
+            #     missing some prev_events (but we have full state for any we did
+            #     have). We calculated the full state after the prev_events.
+            #   * `state_ids` is not `None` and `partial_state` is `True` if we were
+            #     missing some, but not all, prev_events. At least one of the
+            #     prev_events we did have had partial state, so we calculated a partial
+            #     state after the prev_events.
+
+            context = None
+            if state_ids is not None and partial_state:
+                # the state after the prev events is still partial. We can't de-partial
+                # state the event, so don't bother building the event context.
+                pass
+            else:
+                # build a new state group for it if need be
+                context = await self._state_handler.compute_event_context(
+                    event,
+                    state_ids_before_event=state_ids,
+                    partial_state=partial_state,
+                )
+
+            if context is None or context.partial_state:
                 # this can happen if some or all of the event's prev_events still have
                 # partial state - ie, an event has an earlier stream_ordering than one
                 # or more of its prev_events, so we de-partial-state it before its
@@ -806,14 +832,39 @@ class FederationEventHandler:
             return
 
         try:
-            state_ids = await self._resolve_state_at_missing_prevs(origin, event)
-            # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
-            #   not return partial state
-            #   https://github.com/matrix-org/synapse/issues/13002
+            try:
+                state_ids, partial_state = await self._resolve_state_at_missing_prevs(
+                    origin, event
+                )
+                await self._process_received_pdu(
+                    origin,
+                    event,
+                    state_ids=state_ids,
+                    partial_state=partial_state,
+                    backfilled=backfilled,
+                )
+            except PartialStateConflictError:
+                # The room was un-partial stated while we were processing the event.
+                # Try once more, with full state this time.
+                state_ids, partial_state = await self._resolve_state_at_missing_prevs(
+                    origin, event
+                )
 
-            await self._process_received_pdu(
-                origin, event, state_ids=state_ids, backfilled=backfilled
-            )
+                # We ought to have full state now, barring some unlikely race where we left and
+                # rejoned the room in the background.
+                if state_ids is not None and partial_state:
+                    raise AssertionError(
+                        f"Event {event.event_id} still has a partial resolved state "
+                        f"after room {event.room_id} was un-partial stated"
+                    )
+
+                await self._process_received_pdu(
+                    origin,
+                    event,
+                    state_ids=state_ids,
+                    partial_state=partial_state,
+                    backfilled=backfilled,
+                )
         except FederationError as e:
             if e.code == 403:
                 logger.warning("Pulled event %s failed history check.", event_id)
@@ -822,7 +873,7 @@ class FederationEventHandler:
 
     async def _resolve_state_at_missing_prevs(
         self, dest: str, event: EventBase
-    ) -> Optional[StateMap[str]]:
+    ) -> Tuple[Optional[StateMap[str]], Optional[bool]]:
         """Calculate the state at an event with missing prev_events.
 
         This is used when we have pulled a batch of events from a remote server, and
@@ -849,8 +900,10 @@ class FederationEventHandler:
             event: an event to check for missing prevs.
 
         Returns:
-            if we already had all the prev events, `None`. Otherwise, returns
-            the event ids of the state at `event`.
+            if we already had all the prev events, `None, None`. Otherwise, returns a
+            tuple containing:
+             * the event ids of the state at `event`.
+             * a boolean indicating whether the state may be partial.
 
         Raises:
             FederationError if we fail to get the state from the remote server after any
@@ -864,7 +917,7 @@ class FederationEventHandler:
         missing_prevs = prevs - seen
 
         if not missing_prevs:
-            return None
+            return None, None
 
         logger.info(
             "Event %s is missing prev_events %s: calculating state for a "
@@ -876,9 +929,15 @@ class FederationEventHandler:
         # resolve them to find the correct state at the current event.
 
         try:
+            # Determine whether we may be about to retrieve partial state
+            # Events may be un-partial stated right after we compute the partial state
+            # flag, but that's okay, as long as the flag errs on the conservative side.
+            partial_state_flags = await self._store.get_partial_state_events(seen)
+            partial_state = any(partial_state_flags.values())
+
             # Get the state of the events we know about
             ours = await self._state_storage_controller.get_state_groups_ids(
-                room_id, seen
+                room_id, seen, await_full_state=False
             )
 
             # state_maps is a list of mappings from (type, state_key) to event_id
@@ -924,7 +983,7 @@ class FederationEventHandler:
                 "We can't get valid state history.",
                 affected=event_id,
             )
-        return state_map
+        return state_map, partial_state
 
     async def _get_state_ids_after_missing_prev_event(
         self,
@@ -1094,6 +1153,7 @@ class FederationEventHandler:
         origin: str,
         event: EventBase,
         state_ids: Optional[StateMap[str]],
+        partial_state: Optional[bool],
         backfilled: bool = False,
     ) -> None:
         """Called when we have a new non-outlier event.
@@ -1117,14 +1177,21 @@ class FederationEventHandler:
 
             state_ids: Normally None, but if we are handling a gap in the graph
                 (ie, we are missing one or more prev_events), the resolved state at the
-                event. Must not be partial state.
+                event
+
+            partial_state:
+                `True` if `state_ids` is partial and omits non-critical membership
+                events.
+                `False` if `state_ids` is the full state.
+                `None` if `state_ids` is not provided. In this case, the flag will be
+                calculated based on `event`'s prev events.
 
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
 
         PartialStateConflictError: if the room was un-partial stated in between
             computing the state at the event and persisting it. The caller should retry
-            exactly once in this case. Will never be raised if `state_ids` is provided.
+            exactly once in this case.
         """
         logger.debug("Processing event: %s", event)
         assert not event.internal_metadata.outlier
@@ -1132,6 +1199,7 @@ class FederationEventHandler:
         context = await self._state_handler.compute_event_context(
             event,
             state_ids_before_event=state_ids,
+            partial_state=partial_state,
         )
         try:
             await self._check_event_auth(origin, event, context)

+ 4 - 0
synapse/handlers/message.py

@@ -1135,6 +1135,10 @@ class EventCreationHandler:
             context = await self.state.compute_event_context(
                 event,
                 state_ids_before_event=state_map_for_event,
+                # TODO(faster_joins): check how MSC2716 works and whether we can have
+                #   partial state here
+                #   https://github.com/matrix-org/synapse/issues/13003
+                partial_state=False,
             )
         else:
             context = await self.state.compute_event_context(event)

+ 12 - 6
synapse/state/__init__.py

@@ -255,7 +255,7 @@ class StateHandler:
         self,
         event: EventBase,
         state_ids_before_event: Optional[StateMap[str]] = None,
-        partial_state: bool = False,
+        partial_state: Optional[bool] = None,
     ) -> EventContext:
         """Build an EventContext structure for a non-outlier event.
 
@@ -270,8 +270,12 @@ class StateHandler:
                 it can't be calculated from existing events. This is normally
                 only specified when receiving an event from federation where we
                 don't have the prev events, e.g. when backfilling.
-            partial_state: True if `state_ids_before_event` is partial and omits
-                non-critical membership events
+            partial_state:
+                `True` if `state_ids_before_event` is partial and omits non-critical
+                membership events.
+                `False` if `state_ids_before_event` is the full state.
+                `None` when `state_ids_before_event` is not provided. In this case, the
+                flag will be calculated based on `event`'s prev events.
         Returns:
             The event context.
         """
@@ -298,12 +302,14 @@ class StateHandler:
                 )
             )
 
+            # the partial_state flag must be provided
+            assert partial_state is not None
         else:
             # otherwise, we'll need to resolve the state across the prev_events.
 
             # partial_state should not be set explicitly in this case:
             # we work it out dynamically
-            assert not partial_state
+            assert partial_state is None
 
             # if any of the prev-events have partial state, so do we.
             # (This is slightly racy - the prev-events might get fixed up before we use
@@ -313,13 +319,13 @@ class StateHandler:
             incomplete_prev_events = await self.store.get_partial_state_events(
                 prev_event_ids
             )
-            if any(incomplete_prev_events.values()):
+            partial_state = any(incomplete_prev_events.values())
+            if partial_state:
                 logger.debug(
                     "New/incoming event %s refers to prev_events %s with partial state",
                     event.event_id,
                     [k for (k, v) in incomplete_prev_events.items() if v],
                 )
-                partial_state = True
 
             logger.debug("calling resolve_state_groups from compute_event_context")
             # we've already taken into account partial state, so no need to wait for

+ 6 - 2
synapse/storage/controllers/state.py

@@ -82,13 +82,15 @@ class StateStorageController:
         return state_group_delta.prev_group, state_group_delta.delta_ids
 
     async def get_state_groups_ids(
-        self, _room_id: str, event_ids: Collection[str]
+        self, _room_id: str, event_ids: Collection[str], await_full_state: bool = True
     ) -> Dict[int, MutableStateMap[str]]:
         """Get the event IDs of all the state for the state groups for the given events
 
         Args:
             _room_id: id of the room for these events
             event_ids: ids of the events
+            await_full_state: if `True`, will block if we do not yet have complete
+               state at these events.
 
         Returns:
             dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -100,7 +102,9 @@ class StateStorageController:
         if not event_ids:
             return {}
 
-        event_to_groups = await self.get_state_group_for_events(event_ids)
+        event_to_groups = await self.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(groups)

+ 1 - 0
tests/handlers/test_federation.py

@@ -287,6 +287,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
                     state_ids={
                         (e.type, e.state_key): e.event_id for e in current_state
                     },
+                    partial_state=False,
                 )
             )
 

+ 6 - 1
tests/storage/test_events.py

@@ -70,7 +70,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
     def persist_event(self, event, state=None):
         """Persist the event, with optional state"""
         context = self.get_success(
-            self.state.compute_event_context(event, state_ids_before_event=state)
+            self.state.compute_event_context(
+                event,
+                state_ids_before_event=state,
+                partial_state=None if state is None else False,
+            )
         )
         self.get_success(self._persistence.persist_event(event, context))
 
@@ -148,6 +152,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
             self.state.compute_event_context(
                 remote_event_2,
                 state_ids_before_event=state_before_gap,
+                partial_state=False,
             )
         )
 

+ 2 - 0
tests/test_state.py

@@ -462,6 +462,7 @@ class StateTestCase(unittest.TestCase):
                 state_ids_before_event={
                     (e.type, e.state_key): e.event_id for e in old_state
                 },
+                partial_state=False,
             )
         )
 
@@ -492,6 +493,7 @@ class StateTestCase(unittest.TestCase):
                 state_ids_before_event={
                     (e.type, e.state_key): e.event_id for e in old_state
                 },
+                partial_state=False,
             )
         )