Browse Source

Faster joins: filter out non local events when a room doesn't have its full state (#14404)

Signed-off-by: Mathieu Velten <mathieuv@matrix.org>
Mathieu Velten 1 year ago
parent
commit
1526ff389f

+ 1 - 0
changelog.d/14404.misc

@@ -0,0 +1 @@
+Faster joins: filter out non local events when a room doesn't have its full state.

+ 1 - 0
synapse/federation/sender/per_destination_queue.py

@@ -505,6 +505,7 @@ class PerDestinationQueue:
                     new_pdus = await filter_events_for_server(
                         self._storage_controllers,
                         self._destination,
+                        self._server_name,
                         new_pdus,
                         redact=False,
                     )

+ 10 - 5
synapse/handlers/federation.py

@@ -379,6 +379,7 @@ class FederationHandler:
             filtered_extremities = await filter_events_for_server(
                 self._storage_controllers,
                 self.server_name,
+                self.server_name,
                 events_to_check,
                 redact=False,
                 check_history_visibility_only=True,
@@ -1231,7 +1232,9 @@ class FederationHandler:
     async def on_backfill_request(
         self, origin: str, room_id: str, pdu_list: List[str], limit: int
     ) -> List[EventBase]:
-        await self._event_auth_handler.assert_host_in_room(room_id, origin)
+        # We allow partially joined rooms since in this case we are filtering out
+        # non-local events in `filter_events_for_server`.
+        await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
 
         # Synapse asks for 100 events per backfill request. Do not allow more.
         limit = min(limit, 100)
@@ -1252,7 +1255,7 @@ class FederationHandler:
         )
 
         events = await filter_events_for_server(
-            self._storage_controllers, origin, events
+            self._storage_controllers, origin, self.server_name, events
         )
 
         return events
@@ -1283,7 +1286,7 @@ class FederationHandler:
         await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
 
         events = await filter_events_for_server(
-            self._storage_controllers, origin, [event]
+            self._storage_controllers, origin, self.server_name, [event]
         )
         event = events[0]
         return event
@@ -1296,7 +1299,9 @@ class FederationHandler:
         latest_events: List[str],
         limit: int,
     ) -> List[EventBase]:
-        await self._event_auth_handler.assert_host_in_room(room_id, origin)
+        # We allow partially joined rooms since in this case we are filtering out
+        # non-local events in `filter_events_for_server`.
+        await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
 
         # Only allow up to 20 events to be retrieved per request.
         limit = min(limit, 20)
@@ -1309,7 +1314,7 @@ class FederationHandler:
         )
 
         missing_events = await filter_events_for_server(
-            self._storage_controllers, origin, missing_events
+            self._storage_controllers, origin, self.server_name, missing_events
         )
 
         return missing_events

+ 26 - 3
synapse/visibility.py

@@ -563,7 +563,8 @@ def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str:
 
 async def filter_events_for_server(
     storage: StorageControllers,
-    server_name: str,
+    target_server_name: str,
+    local_server_name: str,
     events: List[EventBase],
     redact: bool = True,
     check_history_visibility_only: bool = False,
@@ -603,7 +604,7 @@ async def filter_events_for_server(
         # if the server is either in the room or has been invited
         # into the room.
         for ev in memberships.values():
-            assert get_domain_from_id(ev.state_key) == server_name
+            assert get_domain_from_id(ev.state_key) == target_server_name
 
             memtype = ev.membership
             if memtype == Membership.JOIN:
@@ -622,6 +623,24 @@ async def filter_events_for_server(
         # to no users having been erased.
         erased_senders = {}
 
+    # Filter out non-local events when we are in the middle of a partial join, since our servers
+    # list can be out of date and we could leak events to servers not in the room anymore.
+    # This can also be true for local events but we consider it to be an acceptable risk.
+
+    # We do this check as a first step and before retrieving membership events because
+    # otherwise a room could be fully joined after we retrieve those, which would then bypass
+    # this check but would base the filtering on an outdated view of the membership events.
+
+    partial_state_invisible_events = set()
+    if not check_history_visibility_only:
+        for e in events:
+            sender_domain = get_domain_from_id(e.sender)
+            if (
+                sender_domain != local_server_name
+                and await storage.main.is_partial_state_room(e.room_id)
+            ):
+                partial_state_invisible_events.add(e)
+
     # Let's check to see if all the events have a history visibility
     # of "shared" or "world_readable". If that's the case then we don't
     # need to check membership (as we know the server is in the room).
@@ -636,7 +655,7 @@ async def filter_events_for_server(
             if event_to_history_vis[e.event_id]
             not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
         ],
-        server_name,
+        target_server_name,
     )
 
     to_return = []
@@ -645,6 +664,10 @@ async def filter_events_for_server(
         visible = check_event_is_visible(
             event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
         )
+
+        if e in partial_state_invisible_events:
+            visible = False
+
         if visible and not erased:
             to_return.append(e)
         elif redact:

+ 5 - 5
tests/test_visibility.py

@@ -61,7 +61,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "test_server", events_to_filter
+                self._storage_controllers, "test_server", "hs", events_to_filter
             )
         )
 
@@ -83,7 +83,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(
             self.get_success(
                 filter_events_for_server(
-                    self._storage_controllers, "remote_hs", [outlier]
+                    self._storage_controllers, "remote_hs", "hs", [outlier]
                 )
             ),
             [outlier],
@@ -94,7 +94,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "remote_hs", [outlier, evt]
+                self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
             )
         )
         self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
@@ -106,7 +106,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # be redacted)
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "other_server", [outlier, evt]
+                self._storage_controllers, "other_server", "local_hs", [outlier, evt]
             )
         )
         self.assertEqual(filtered[0], outlier)
@@ -141,7 +141,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # ... and the filtering happens.
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "test_server", events_to_filter
+                self._storage_controllers, "test_server", "local_hs", events_to_filter
             )
         )