Browse Source

Choose state algorithm based on room version

Erik Johnston 5 years ago
parent
commit
ce6db0e547

+ 6 - 2
synapse/handlers/federation.py

@@ -274,8 +274,9 @@ class FederationHandler(BaseHandler):
                             ev_ids, get_prev_content=False, check_redacted=False
                         )
 
+                    room_version = yield self.store.get_room_version(pdu.room_id)
                     state_map = yield resolve_events_with_factory(
-                        state_groups, {pdu.event_id: pdu}, fetch
+                        room_version, state_groups, {pdu.event_id: pdu}, fetch
                     )
 
                     state = (yield self.store.get_events(state_map.values())).values()
@@ -1811,7 +1812,10 @@ class FederationHandler(BaseHandler):
                     (d.type, d.state_key): d for d in different_events if d
                 })
 
-                new_state = self.state_handler.resolve_events(
+                room_version = yield self.store.get_room_version(event.room_id)
+
+                new_state = yield self.state_handler.resolve_events(
+                    room_version,
                     [list(local_view.values()), list(remote_view.values())],
                     event
                 )

+ 3 - 2
synapse/handlers/room_member.py

@@ -341,9 +341,10 @@ class RoomMemberHandler(object):
         prev_events_and_hashes = yield self.store.get_prev_events_for_room(
             room_id,
         )
-        latest_event_ids = (
+        latest_event_ids = [
             event_id for (event_id, _, _) in prev_events_and_hashes
-        )
+        ]
+
         current_state_ids = yield self.state_handler.get_current_state_ids(
             room_id, latest_event_ids=latest_event_ids,
         )

+ 93 - 11
synapse/state/__init__.py

@@ -23,16 +23,15 @@ from frozendict import frozendict
 
 from twisted.internet import defer
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, RoomVersions
 from synapse.events.snapshot import EventContext
+from synapse.state import v1
 from synapse.util.async import Linearizer
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 
-from .v1 import resolve_events_with_factory, resolve_events_with_state_map
-
 logger = logging.getLogger(__name__)
 
 
@@ -263,8 +262,14 @@ class StateHandler(object):
             defer.returnValue(context)
 
         logger.debug("calling resolve_state_groups from compute_event_context")
+        if event.type == EventTypes.Create:
+            room_version = event.content.get("room_version", RoomVersions.V1)
+        else:
+            room_version = None
+
         entry = yield self.resolve_state_groups_for_events(
             event.room_id, [e for e, _ in event.prev_events],
+            explicit_room_version=room_version,
         )
 
         prev_state_ids = entry.state
@@ -332,13 +337,17 @@ class StateHandler(object):
         defer.returnValue(context)
 
     @defer.inlineCallbacks
-    def resolve_state_groups_for_events(self, room_id, event_ids):
+    def resolve_state_groups_for_events(self, room_id, event_ids,
+                                        explicit_room_version=None):
         """ Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
 
         Args:
-            room_id (str):
-            event_ids (list[str]):
+            room_id (str)
+            event_ids (list[str])
+            explicit_room_version (str|None): If set uses the the given room
+                version to choose the resolution algorithm. If None, then
+                checks the database for room version.
 
         Returns:
             Deferred[_StateCacheEntry]: resolved state
@@ -364,8 +373,13 @@ class StateHandler(object):
                 delta_ids=delta_ids,
             ))
 
+        room_version = explicit_room_version
+        if not room_version:
+            room_version = yield self.store.get_room_version(room_id)
+
         result = yield self._state_resolution_handler.resolve_state_groups(
-            room_id, state_groups_ids, None, self._state_map_factory,
+            room_id, room_version, state_groups_ids, None,
+            self._state_map_factory,
         )
         defer.returnValue(result)
 
@@ -374,7 +388,8 @@ class StateHandler(object):
             ev_ids, get_prev_content=False, check_redacted=False,
         )
 
-    def resolve_events(self, state_sets, event):
+    @defer.inlineCallbacks
+    def resolve_events(self, room_version, state_sets, event):
         logger.info(
             "Resolving state for %s with %d groups", event.room_id, len(state_sets)
         )
@@ -389,14 +404,18 @@ class StateHandler(object):
             for ev in st
         }
 
+        room_version = yield self.store.get_room_version(event.room_id)
+
         with Measure(self.clock, "state._resolve_events"):
-            new_state = resolve_events_with_state_map(state_set_ids, state_map)
+            new_state = resolve_events_with_state_map(
+                room_version, state_set_ids, state_map,
+            )
 
         new_state = {
             key: state_map[ev_id] for key, ev_id in iteritems(new_state)
         }
 
-        return new_state
+        defer.returnValue(new_state)
 
 
 class StateResolutionHandler(object):
@@ -429,7 +448,7 @@ class StateResolutionHandler(object):
     @defer.inlineCallbacks
     @log_function
     def resolve_state_groups(
-        self, room_id, state_groups_ids, event_map, state_map_factory,
+        self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
     ):
         """Resolves conflicts between a set of state groups
 
@@ -438,6 +457,7 @@ class StateResolutionHandler(object):
 
         Args:
             room_id (str): room we are resolving for (used for logging)
+            room_version (str): version of the room
             state_groups_ids (dict[int, dict[(str, str), str]]):
                  map from state group id to the state in that state group
                 (where 'state' is a map from state key to event id)
@@ -491,6 +511,7 @@ class StateResolutionHandler(object):
                 logger.info("Resolving conflicted state for %r", room_id)
                 with Measure(self.clock, "state._resolve_events"):
                     new_state = yield resolve_events_with_factory(
+                        room_version,
                         list(itervalues(state_groups_ids)),
                         event_map=event_map,
                         state_map_factory=state_map_factory,
@@ -572,3 +593,64 @@ def _make_state_cache_entry(
         prev_group=prev_group,
         delta_ids=delta_ids,
     )
+
+
+def resolve_events_with_state_map(room_version, state_sets, state_map):
+    """
+    Args:
+        room_version(str): Version of the room
+        state_sets(list): List of dicts of (type, state_key) -> event_id,
+            which are the different state groups to resolve.
+        state_map(dict): a dict from event_id to event, for all events in
+            state_sets.
+
+    Returns
+        dict[(str, str), str]:
+            a map from (type, state_key) to event_id.
+    """
+    if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
+        return v1.resolve_events_with_state_map(
+            state_sets, state_map,
+        )
+    else:
+        # This should only happen if we added a version but forgot to add it to
+        # the list above.
+        raise Exception(
+            "No state resolution algorithm defined for version %r" % (room_version,)
+        )
+
+
+def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
+    """
+    Args:
+        room_version(str): Version of the room
+
+        state_sets(list): List of dicts of (type, state_key) -> event_id,
+            which are the different state groups to resolve.
+
+        event_map(dict[str,FrozenEvent]|None):
+            a dict from event_id to event, for any events that we happen to
+            have in flight (eg, those currently being persisted). This will be
+            used as a starting point fof finding the state we need; any missing
+            events will be requested via state_map_factory.
+
+            If None, all events will be fetched via state_map_factory.
+
+        state_map_factory(func): will be called
+            with a list of event_ids that are needed, and should return with
+            a Deferred of dict of event_id to event.
+
+    Returns
+        Deferred[dict[(str, str), str]]:
+            a map from (type, state_key) to event_id.
+    """
+    if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
+        return v1.resolve_events_with_factory(
+            state_sets, event_map, state_map_factory,
+        )
+    else:
+        # This should only happen if we added a version but forgot to add it to
+        # the list above.
+        raise Exception(
+            "No state resolution algorithm defined for version %r" % (room_version,)
+        )

+ 3 - 1
synapse/storage/events.py

@@ -705,9 +705,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
         }
 
         events_map = {ev.event_id: ev for ev, _ in events_context}
+        room_version = yield self.get_room_version(room_id)
+
         logger.debug("calling resolve_state_groups from preserve_events")
         res = yield self._state_resolution_handler.resolve_state_groups(
-            room_id, state_groups, events_map, get_events
+            room_id, room_version, state_groups, events_map, get_events
         )
 
         defer.returnValue((res.state, None))