瀏覽代碼

Merge pull request #4482 from matrix-org/erikj/event_auth_room_version

Pass through room version to event auth
Erik Johnston 5 年之前
父節點
當前提交
f1a04462eb

+ 1 - 0
changelog.d/4482.misc

@@ -0,0 +1 @@
+Add infrastructure to support different event formats

+ 10 - 4
synapse/api/auth.py

@@ -65,7 +65,7 @@ class Auth(object):
         register_cache("cache", "token_cache", self.token_cache)
         register_cache("cache", "token_cache", self.token_cache)
 
 
     @defer.inlineCallbacks
     @defer.inlineCallbacks
-    def check_from_context(self, event, context, do_sig_check=True):
+    def check_from_context(self, room_version, event, context, do_sig_check=True):
         prev_state_ids = yield context.get_prev_state_ids(self.store)
         prev_state_ids = yield context.get_prev_state_ids(self.store)
         auth_events_ids = yield self.compute_auth_events(
         auth_events_ids = yield self.compute_auth_events(
             event, prev_state_ids, for_verification=True,
             event, prev_state_ids, for_verification=True,
@@ -74,12 +74,16 @@ class Auth(object):
         auth_events = {
         auth_events = {
             (e.type, e.state_key): e for e in itervalues(auth_events)
             (e.type, e.state_key): e for e in itervalues(auth_events)
         }
         }
-        self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
+        self.check(
+            room_version, event,
+            auth_events=auth_events, do_sig_check=do_sig_check,
+        )
 
 
-    def check(self, event, auth_events, do_sig_check=True):
+    def check(self, room_version, event, auth_events, do_sig_check=True):
         """ Checks if this event is correctly authed.
         """ Checks if this event is correctly authed.
 
 
         Args:
         Args:
+            room_version (str): version of the room
             event: the event being checked.
             event: the event being checked.
             auth_events (dict: event-key -> event): the existing room state.
             auth_events (dict: event-key -> event): the existing room state.
 
 
@@ -88,7 +92,9 @@ class Auth(object):
             True if the auth checks pass.
             True if the auth checks pass.
         """
         """
         with Measure(self.clock, "auth.check"):
         with Measure(self.clock, "auth.check"):
-            event_auth.check(event, auth_events, do_sig_check=do_sig_check)
+            event_auth.check(
+                room_version, event, auth_events, do_sig_check=do_sig_check
+            )
 
 
     @defer.inlineCallbacks
     @defer.inlineCallbacks
     def check_joined_room(self, room_id, user_id, current_state=None):
     def check_joined_room(self, room_id, user_id, current_state=None):

+ 2 - 1
synapse/event_auth.py

@@ -27,10 +27,11 @@ from synapse.types import UserID, get_domain_from_id
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-def check(event, auth_events, do_sig_check=True, do_size_check=True):
+def check(room_version, event, auth_events, do_sig_check=True, do_size_check=True):
     """ Checks if this event is correctly authed.
     """ Checks if this event is correctly authed.
 
 
     Args:
     Args:
+        room_version (str): the version of the room
         event: the event being checked.
         event: the event being checked.
         auth_events (dict: event-key -> event): the existing room state.
         auth_events (dict: event-key -> event): the existing room state.
 
 

+ 12 - 8
synapse/handlers/federation.py

@@ -1189,7 +1189,9 @@ class FederationHandler(BaseHandler):
 
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
         # when we get the event back in `on_send_join_request`
-        yield self.auth.check_from_context(event, context, do_sig_check=False)
+        yield self.auth.check_from_context(
+            room_version, event, context, do_sig_check=False,
+        )
 
 
         defer.returnValue(event)
         defer.returnValue(event)
 
 
@@ -1388,7 +1390,9 @@ class FederationHandler(BaseHandler):
         try:
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
             # when we get the event back in `on_send_leave_request`
-            yield self.auth.check_from_context(event, context, do_sig_check=False)
+            yield self.auth.check_from_context(
+                room_version, event, context, do_sig_check=False,
+            )
         except AuthError as e:
         except AuthError as e:
             logger.warn("Failed to create new leave %r because %s", event, e)
             logger.warn("Failed to create new leave %r because %s", event, e)
             raise e
             raise e
@@ -1683,7 +1687,7 @@ class FederationHandler(BaseHandler):
                 auth_for_e[(EventTypes.Create, "")] = create_event
                 auth_for_e[(EventTypes.Create, "")] = create_event
 
 
             try:
             try:
-                self.auth.check(e, auth_events=auth_for_e)
+                self.auth.check(room_version, e, auth_events=auth_for_e)
             except SynapseError as err:
             except SynapseError as err:
                 # we may get SynapseErrors here as well as AuthErrors. For
                 # we may get SynapseErrors here as well as AuthErrors. For
                 # instance, there are a couple of (ancient) events in some
                 # instance, there are a couple of (ancient) events in some
@@ -1927,6 +1931,8 @@ class FederationHandler(BaseHandler):
         current_state = set(e.event_id for e in auth_events.values())
         current_state = set(e.event_id for e in auth_events.values())
         different_auth = event_auth_events - current_state
         different_auth = event_auth_events - current_state
 
 
+        room_version = yield self.store.get_room_version(event.room_id)
+
         if different_auth and not event.internal_metadata.is_outlier():
         if different_auth and not event.internal_metadata.is_outlier():
             # Do auth conflict res.
             # Do auth conflict res.
             logger.info("Different auth: %s", different_auth)
             logger.info("Different auth: %s", different_auth)
@@ -1951,8 +1957,6 @@ class FederationHandler(BaseHandler):
                     (d.type, d.state_key): d for d in different_events if d
                     (d.type, d.state_key): d for d in different_events if d
                 })
                 })
 
 
-                room_version = yield self.store.get_room_version(event.room_id)
-
                 new_state = yield self.state_handler.resolve_events(
                 new_state = yield self.state_handler.resolve_events(
                     room_version,
                     room_version,
                     [list(local_view.values()), list(remote_view.values())],
                     [list(local_view.values()), list(remote_view.values())],
@@ -2052,7 +2056,7 @@ class FederationHandler(BaseHandler):
                 )
                 )
 
 
         try:
         try:
-            self.auth.check(event, auth_events=auth_events)
+            self.auth.check(room_version, event, auth_events=auth_events)
         except AuthError as e:
         except AuthError as e:
             logger.warn("Failed auth resolution for %r because %s", event, e)
             logger.warn("Failed auth resolution for %r because %s", event, e)
             raise e
             raise e
@@ -2288,7 +2292,7 @@ class FederationHandler(BaseHandler):
             )
             )
 
 
             try:
             try:
-                yield self.auth.check_from_context(event, context)
+                yield self.auth.check_from_context(room_version, event, context)
             except AuthError as e:
             except AuthError as e:
                 logger.warn("Denying new third party invite %r because %s", event, e)
                 logger.warn("Denying new third party invite %r because %s", event, e)
                 raise e
                 raise e
@@ -2330,7 +2334,7 @@ class FederationHandler(BaseHandler):
         )
         )
 
 
         try:
         try:
-            self.auth.check_from_context(event, context)
+            self.auth.check_from_context(room_version, event, context)
         except AuthError as e:
         except AuthError as e:
             logger.warn("Denying third party invite %r because %s", event, e)
             logger.warn("Denying third party invite %r because %s", event, e)
             raise e
             raise e

+ 7 - 2
synapse/handlers/message.py

@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
 from twisted.internet import defer
 from twisted.internet import defer
 from twisted.internet.defer import succeed
 from twisted.internet.defer import succeed
 
 
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
 from synapse.api.errors import (
 from synapse.api.errors import (
     AuthError,
     AuthError,
     Codes,
     Codes,
@@ -611,8 +611,13 @@ class EventCreationHandler(object):
             extra_users (list(UserID)): Any extra users to notify about event
             extra_users (list(UserID)): Any extra users to notify about event
         """
         """
 
 
+        if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
+            room_version = event.content.get("room_version", RoomVersions.V1)
+        else:
+            room_version = yield self.store.get_room_version(event.room_id)
+
         try:
         try:
-            yield self.auth.check_from_context(event, context)
+            yield self.auth.check_from_context(room_version, event, context)
         except AuthError as err:
         except AuthError as err:
             logger.warn("Denying new event %r because %s", event, err)
             logger.warn("Denying new event %r because %s", event, err)
             raise err
             raise err

+ 4 - 1
synapse/handlers/room.py

@@ -123,7 +123,10 @@ class RoomCreationHandler(BaseHandler):
                     token_id=requester.access_token_id,
                     token_id=requester.access_token_id,
                 )
                 )
             )
             )
-            yield self.auth.check_from_context(tombstone_event, tombstone_context)
+            old_room_version = yield self.store.get_room_version(old_room_id)
+            yield self.auth.check_from_context(
+                old_room_version, tombstone_event, tombstone_context,
+            )
 
 
             yield self.clone_existing_room(
             yield self.clone_existing_room(
                 requester,
                 requester,

+ 1 - 1
synapse/state/__init__.py

@@ -611,7 +611,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
         RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2,
         RoomVersions.VDH_TEST, RoomVersions.STATE_V2_TEST, RoomVersions.V2,
     ):
     ):
         return v2.resolve_events_with_store(
         return v2.resolve_events_with_store(
-            state_sets, event_map, state_res_store,
+            room_version, state_sets, event_map, state_res_store,
         )
         )
     else:
     else:
         # This should only happen if we added a version but forgot to add it to
         # This should only happen if we added a version but forgot to add it to

+ 11 - 3
synapse/state/v1.py

@@ -21,7 +21,7 @@ from six import iteritems, iterkeys, itervalues
 from twisted.internet import defer
 from twisted.internet import defer
 
 
 from synapse import event_auth
 from synapse import event_auth
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, RoomVersions
 from synapse.api.errors import AuthError
 from synapse.api.errors import AuthError
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -274,7 +274,11 @@ def _resolve_auth_events(events, auth_events):
         auth_events[(prev_event.type, prev_event.state_key)] = prev_event
         auth_events[(prev_event.type, prev_event.state_key)] = prev_event
         try:
         try:
             # The signatures have already been checked at this point
             # The signatures have already been checked at this point
-            event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
+            event_auth.check(
+                RoomVersions.V1, event, auth_events,
+                do_sig_check=False,
+                do_size_check=False,
+            )
             prev_event = event
             prev_event = event
         except AuthError:
         except AuthError:
             return prev_event
             return prev_event
@@ -286,7 +290,11 @@ def _resolve_normal_events(events, auth_events):
     for event in _ordered_events(events):
     for event in _ordered_events(events):
         try:
         try:
             # The signatures have already been checked at this point
             # The signatures have already been checked at this point
-            event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
+            event_auth.check(
+                RoomVersions.V1, event, auth_events,
+                do_sig_check=False,
+                do_size_check=False,
+            )
             return event
             return event
         except AuthError:
         except AuthError:
             pass
             pass

+ 9 - 5
synapse/state/v2.py

@@ -29,10 +29,12 @@ logger = logging.getLogger(__name__)
 
 
 
 
 @defer.inlineCallbacks
 @defer.inlineCallbacks
-def resolve_events_with_store(state_sets, event_map, state_res_store):
+def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
     """Resolves the state using the v2 state resolution algorithm
     """Resolves the state using the v2 state resolution algorithm
 
 
     Args:
     Args:
+        room_version (str): The room version
+
         state_sets(list): List of dicts of (type, state_key) -> event_id,
         state_sets(list): List of dicts of (type, state_key) -> event_id,
             which are the different state groups to resolve.
             which are the different state groups to resolve.
 
 
@@ -104,7 +106,7 @@ def resolve_events_with_store(state_sets, event_map, state_res_store):
 
 
     # Now sequentially auth each one
     # Now sequentially auth each one
     resolved_state = yield _iterative_auth_checks(
     resolved_state = yield _iterative_auth_checks(
-        sorted_power_events, unconflicted_state, event_map,
+        room_version, sorted_power_events, unconflicted_state, event_map,
         state_res_store,
         state_res_store,
     )
     )
 
 
@@ -129,7 +131,7 @@ def resolve_events_with_store(state_sets, event_map, state_res_store):
     logger.debug("resolving remaining events")
     logger.debug("resolving remaining events")
 
 
     resolved_state = yield _iterative_auth_checks(
     resolved_state = yield _iterative_auth_checks(
-        leftover_events, resolved_state, event_map,
+        room_version, leftover_events, resolved_state, event_map,
         state_res_store,
         state_res_store,
     )
     )
 
 
@@ -350,11 +352,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
 
 
 
 
 @defer.inlineCallbacks
 @defer.inlineCallbacks
-def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
+def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
+                           state_res_store):
     """Sequentially apply auth checks to each event in given list, updating the
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
     state as it goes along.
 
 
     Args:
     Args:
+        room_version (str)
         event_ids (list[str]): Ordered list of events to apply auth checks to
         event_ids (list[str]): Ordered list of events to apply auth checks to
         base_state (dict[tuple[str, str], str]): The set of state to start with
         base_state (dict[tuple[str, str], str]): The set of state to start with
         event_map (dict[str,FrozenEvent])
         event_map (dict[str,FrozenEvent])
@@ -385,7 +389,7 @@ def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
 
 
         try:
         try:
             event_auth.check(
             event_auth.check(
-                event, auth_events,
+                room_version, event, auth_events,
                 do_sig_check=False,
                 do_sig_check=False,
                 do_size_check=False
                 do_size_check=False
             )
             )

+ 3 - 1
tests/state/test_v2.py

@@ -19,7 +19,7 @@ from six.moves import zip
 
 
 import attr
 import attr
 
 
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomVersions
 from synapse.event_auth import auth_types_for_event
 from synapse.event_auth import auth_types_for_event
 from synapse.events import FrozenEvent
 from synapse.events import FrozenEvent
 from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
 from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
@@ -539,6 +539,7 @@ class StateTestCase(unittest.TestCase):
                 state_before = dict(state_at_event[prev_events[0]])
                 state_before = dict(state_at_event[prev_events[0]])
             else:
             else:
                 state_d = resolve_events_with_store(
                 state_d = resolve_events_with_store(
+                    RoomVersions.V2,
                     [state_at_event[n] for n in prev_events],
                     [state_at_event[n] for n in prev_events],
                     event_map=event_map,
                     event_map=event_map,
                     state_res_store=TestStateResolutionStore(event_map),
                     state_res_store=TestStateResolutionStore(event_map),
@@ -685,6 +686,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
         # Test that we correctly handle passing `None` as the event_map
         # Test that we correctly handle passing `None` as the event_map
 
 
         state_d = resolve_events_with_store(
         state_d = resolve_events_with_store(
+            RoomVersions.V2,
             [self.state_at_bob, self.state_at_charlie],
             [self.state_at_bob, self.state_at_charlie],
             event_map=None,
             event_map=None,
             state_res_store=TestStateResolutionStore(self.event_map),
             state_res_store=TestStateResolutionStore(self.event_map),

+ 11 - 2
tests/test_event_auth.py

@@ -16,6 +16,7 @@
 import unittest
 import unittest
 
 
 from synapse import event_auth
 from synapse import event_auth
+from synapse.api.constants import RoomVersions
 from synapse.api.errors import AuthError
 from synapse.api.errors import AuthError
 from synapse.events import FrozenEvent
 from synapse.events import FrozenEvent
 
 
@@ -35,12 +36,16 @@ class EventAuthTestCase(unittest.TestCase):
         }
         }
 
 
         # creator should be able to send state
         # creator should be able to send state
-        event_auth.check(_random_state_event(creator), auth_events, do_sig_check=False)
+        event_auth.check(
+            RoomVersions.V1, _random_state_event(creator), auth_events,
+            do_sig_check=False,
+        )
 
 
         # joiner should not be able to send state
         # joiner should not be able to send state
         self.assertRaises(
         self.assertRaises(
             AuthError,
             AuthError,
             event_auth.check,
             event_auth.check,
+            RoomVersions.V1,
             _random_state_event(joiner),
             _random_state_event(joiner),
             auth_events,
             auth_events,
             do_sig_check=False,
             do_sig_check=False,
@@ -69,13 +74,17 @@ class EventAuthTestCase(unittest.TestCase):
         self.assertRaises(
         self.assertRaises(
             AuthError,
             AuthError,
             event_auth.check,
             event_auth.check,
+            RoomVersions.V1,
             _random_state_event(pleb),
             _random_state_event(pleb),
             auth_events,
             auth_events,
             do_sig_check=False,
             do_sig_check=False,
         ),
         ),
 
 
         # king should be able to send state
         # king should be able to send state
-        event_auth.check(_random_state_event(king), auth_events, do_sig_check=False)
+        event_auth.check(
+            RoomVersions.V1, _random_state_event(king), auth_events,
+            do_sig_check=False,
+        )
 
 
 
 
 # helpers for making events
 # helpers for making events