Browse Source

Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens. (#12775)

Shay 2 years ago
parent
commit
19d79b6ebe

+ 1 - 0
changelog.d/12775.misc

@@ -0,0 +1 @@
+Refactor `resolve_state_groups_for_events` to not pull out full state when no state resolution happens.

+ 19 - 16
synapse/state/__init__.py

@@ -288,7 +288,6 @@ class StateHandler:
         #
         # first of all, figure out the state before the event
         #
-
         if old_state:
             # if we're given the state before the event, then we use that
             state_ids_before_event: StateMap[str] = {
@@ -419,33 +418,37 @@ class StateHandler:
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        # map from state group id to the state in that state group (where
-        # 'state' is a map from state key to event id)
-        # dict[int, dict[(str, str), str]]
-        state_groups_ids = await self.state_store.get_state_groups_ids(
-            room_id, event_ids
-        )
-
-        if len(state_groups_ids) == 0:
-            return _StateCacheEntry(state={}, state_group=None)
-        elif len(state_groups_ids) == 1:
-            name, state_list = list(state_groups_ids.items()).pop()
+        state_groups = await self.state_store.get_state_group_for_events(event_ids)
 
-            prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
+        state_group_ids = state_groups.values()
 
+        # check if each event has same state group id, if so there's no state to resolve
+        state_group_ids_set = set(state_group_ids)
+        if len(state_group_ids_set) == 1:
+            (state_group_id,) = state_group_ids_set
+            state = await self.state_store.get_state_for_groups(state_group_ids_set)
+            prev_group, delta_ids = await self.state_store.get_state_group_delta(
+                state_group_id
+            )
             return _StateCacheEntry(
-                state=state_list,
-                state_group=name,
+                state=state[state_group_id],
+                state_group=state_group_id,
                 prev_group=prev_group,
                 delta_ids=delta_ids,
             )
+        elif len(state_group_ids_set) == 0:
+            return _StateCacheEntry(state={}, state_group=None)
 
         room_version = await self.store.get_room_version_id(room_id)
 
+        state_to_resolve = await self.state_store.get_state_for_groups(
+            state_group_ids_set
+        )
+
         result = await self._state_resolution_handler.resolve_state_groups(
             room_id,
             room_version,
-            state_groups_ids,
+            state_to_resolve,
             None,
             state_res_store=StateResolutionStore(self.store),
         )

+ 1 - 1
synapse/storage/databases/state/store.py

@@ -189,7 +189,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         group: int,
         state_filter: StateFilter,
     ) -> Tuple[MutableStateMap[str], bool]:
-        """Checks if group is in cache. See `_get_state_for_groups`
+        """Checks if group is in cache. See `get_state_for_groups`
 
         Args:
             cache: the state group cache to use

+ 6 - 6
synapse/storage/state.py

@@ -586,7 +586,7 @@ class StateGroupStorage:
         if not event_ids:
             return {}
 
-        event_to_groups = await self._get_state_group_for_events(event_ids)
+        event_to_groups = await self.get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -602,7 +602,7 @@ class StateGroupStorage:
         Returns:
             Resolves to a map of (type, state_key) -> event_id
         """
-        group_to_state = await self._get_state_for_groups((state_group,))
+        group_to_state = await self.get_state_for_groups((state_group,))
 
         return group_to_state[state_group]
 
@@ -675,7 +675,7 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                (ie they are outliers or unknown)
         """
-        event_to_groups = await self._get_state_group_for_events(event_ids)
+        event_to_groups = await self.get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -716,7 +716,7 @@ class StateGroupStorage:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        event_to_groups = await self._get_state_group_for_events(event_ids)
+        event_to_groups = await self.get_state_group_for_events(event_ids)
 
         groups = set(event_to_groups.values())
         group_to_state = await self.stores.state._get_state_for_groups(
@@ -774,7 +774,7 @@ class StateGroupStorage:
         )
         return state_map[event_id]
 
-    def _get_state_for_groups(
+    def get_state_for_groups(
         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
         """Gets the state at each of a list of state groups, optionally
@@ -792,7 +792,7 @@ class StateGroupStorage:
             groups, state_filter or StateFilter.all()
         )
 
-    async def _get_state_group_for_events(
+    async def get_state_group_for_events(
         self,
         event_ids: Collection[str],
         await_full_state: bool = True,

+ 13 - 0
tests/test_state.py

@@ -129,6 +129,19 @@ class _DummyStore:
     async def get_room_version_id(self, room_id):
         return RoomVersions.V1.identifier
 
+    async def get_state_group_for_events(self, event_ids):
+        res = {}
+        for event in event_ids:
+            res[event] = self._event_to_state_group[event]
+        return res
+
+    async def get_state_for_groups(self, groups):
+        res = {}
+        for group in groups:
+            state = self._group_to_state[group]
+            res[group] = state
+        return res
+
 
 class DictObj(dict):
     def __init__(self, **kwargs):