Browse Source

Add a configuration to exclude rooms from sync response (#12310)

Brendan Abolivier 2 years ago
parent
commit
437a8ed9ef

+ 1 - 0
changelog.d/12310.feature

@@ -0,0 +1 @@
+Add a configuration option to remove a specific set of rooms from sync responses.

+ 9 - 0
docs/sample_config.yaml

@@ -539,6 +539,15 @@ templates:
   #
   #custom_template_directory: /path/to/custom/templates/
 
+# List of rooms to exclude from sync responses. This is useful for server
+# administrators wishing to group users into a room without these users being able
+# to see it from their client.
+#
+# By default, no room is excluded.
+#
+#exclude_rooms_from_sync:
+#    - !foo:example.com
+
 
 # Message retention policy at the server level.
 #

+ 13 - 0
synapse/config/server.py

@@ -680,6 +680,10 @@ class ServerConfig(Config):
             config.get("use_account_validity_in_account_status") or False
         )
 
+        self.rooms_to_exclude_from_sync: List[str] = (
+            config.get("exclude_rooms_from_sync") or []
+        )
+
     def has_tls_listener(self) -> bool:
         return any(listener.tls for listener in self.listeners)
 
@@ -1234,6 +1238,15 @@ class ServerConfig(Config):
           # information about using custom templates.
           #
           #custom_template_directory: /path/to/custom/templates/
+
+        # List of rooms to exclude from sync responses. This is useful for server
+        # administrators wishing to group users into a room without these users being able
+        # to see it from their client.
+        #
+        # By default, no room is excluded.
+        #
+        #exclude_rooms_from_sync:
+        #    - !foo:example.com
         """
             % locals()
         )

+ 17 - 6
synapse/handlers/sync.py

@@ -298,6 +298,8 @@ class SyncHandler:
             expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
         )
 
+        self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
+
     async def wait_for_sync_for_user(
         self,
         requester: Requester,
@@ -1607,13 +1609,15 @@ class SyncHandler:
         ignored_users = await self.store.ignored_users(user_id)
         if since_token:
             room_changes = await self._get_rooms_changed(
-                sync_result_builder, ignored_users
+                sync_result_builder, ignored_users, self.rooms_to_exclude
             )
             tags_by_room = await self.store.get_updated_tags(
                 user_id, since_token.account_data_key
             )
         else:
-            room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
+            room_changes = await self._get_all_rooms(
+                sync_result_builder, ignored_users, self.rooms_to_exclude
+            )
             tags_by_room = await self.store.get_tags_for_user(user_id)
 
         log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1689,7 +1693,10 @@ class SyncHandler:
         return False
 
     async def _get_rooms_changed(
-        self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
+        self,
+        sync_result_builder: "SyncResultBuilder",
+        ignored_users: FrozenSet[str],
+        excluded_rooms: List[str],
     ) -> _RoomChanges:
         """Determine the changes in rooms to report to the user.
 
@@ -1721,7 +1728,7 @@ class SyncHandler:
         #       _have_rooms_changed. We could keep the results in memory to avoid a
         #       second query, at the cost of more complicated source code.
         membership_change_events = await self.store.get_membership_changes_for_user(
-            user_id, since_token.room_key, now_token.room_key
+            user_id, since_token.room_key, now_token.room_key, excluded_rooms
         )
 
         mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@@ -1922,7 +1929,10 @@ class SyncHandler:
         )
 
     async def _get_all_rooms(
-        self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
+        self,
+        sync_result_builder: "SyncResultBuilder",
+        ignored_users: FrozenSet[str],
+        ignored_rooms: List[str],
     ) -> _RoomChanges:
         """Returns entries for all rooms for the user.
 
@@ -1933,7 +1943,7 @@ class SyncHandler:
         Args:
             sync_result_builder
             ignored_users: Set of users ignored by user.
-
+            ignored_rooms: List of rooms to ignore.
         """
 
         user_id = sync_result_builder.sync_config.user.to_string()
@@ -1944,6 +1954,7 @@ class SyncHandler:
         room_list = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id=user_id,
             membership_list=Membership.LIST,
+            excluded_rooms=ignored_rooms,
         )
 
         room_entries = []

+ 16 - 5
synapse/storage/databases/main/roommember.py

@@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         return None
 
     async def get_rooms_for_local_user_where_membership_is(
-        self, user_id: str, membership_list: Collection[str]
+        self,
+        user_id: str,
+        membership_list: Collection[str],
+        excluded_rooms: Optional[List[str]] = None,
     ) -> List[RoomsForUser]:
         """Get all the rooms for this *local* user where the membership for this user
         matches one in the membership list.
@@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             user_id: The user ID.
             membership_list: A list of synapse.api.constants.Membership
                 values which the user must be in.
+            excluded_rooms: A list of rooms to ignore.
 
         Returns:
             The RoomsForUser that the user matches the membership types.
@@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             membership_list,
         )
 
-        # Now we filter out forgotten rooms
-        forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
-        return [room for room in rooms if room.room_id not in forgotten_rooms]
+        # Now we filter out forgotten and excluded rooms
+        rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
+
+        if excluded_rooms is not None:
+            rooms_to_exclude.update(set(excluded_rooms))
+
+        return [room for room in rooms if room.room_id not in rooms_to_exclude]
 
     def _get_rooms_for_local_user_where_membership_is_txn(
-        self, txn, user_id: str, membership_list: List[str]
+        self,
+        txn,
+        user_id: str,
+        membership_list: List[str],
     ) -> List[RoomsForUser]:
         # Paranoia check.
         if not self.hs.is_mine_id(user_id):

+ 20 - 10
synapse/storage/databases/main/stream.py

@@ -36,7 +36,7 @@ what sort order was used:
 """
 
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple
 
 import attr
 from frozendict import frozendict
@@ -585,7 +585,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         return ret, key
 
     async def get_membership_changes_for_user(
-        self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
+        self,
+        user_id: str,
+        from_key: RoomStreamToken,
+        to_key: RoomStreamToken,
+        excluded_rooms: Optional[List[str]] = None,
     ) -> List[EventBase]:
         """Fetch membership events for a given user.
 
@@ -610,23 +614,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             min_from_id = from_key.stream
             max_to_id = to_key.get_max_stream_pos()
 
+            args: List[Any] = [user_id, min_from_id, max_to_id]
+
+            ignore_room_clause = ""
+            if excluded_rooms is not None and len(excluded_rooms) > 0:
+                ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
+                    "?" for _ in excluded_rooms
+                )
+                args = args + excluded_rooms
+
             sql = """
                 SELECT m.event_id, instance_name, topological_ordering, stream_ordering
                 FROM events AS e, room_memberships AS m
                 WHERE e.event_id = m.event_id
                     AND m.user_id = ?
                     AND e.stream_ordering > ? AND e.stream_ordering <= ?
+                    %s
                 ORDER BY e.stream_ordering ASC
-            """
-            txn.execute(
-                sql,
-                (
-                    user_id,
-                    min_from_id,
-                    max_to_id,
-                ),
+            """ % (
+                ignore_room_clause,
             )
 
+            txn.execute(sql, args)
+
             rows = [
                 _EventDictReturn(event_id, None, stream_ordering)
                 for event_id, instance_name, topological_ordering, stream_ordering in txn

+ 62 - 0
tests/rest/client/test_sync.py

@@ -772,3 +772,65 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
         self.assertIn(
             self.user_id, device_list_changes, incremental_sync_channel.json_body
         )
+
+
+class ExcludeRoomTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        self.user_id = self.register_user("user", "password")
+        self.tok = self.login("user", "password")
+
+        self.excluded_room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+        self.included_room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
+
+        # We need to manually append the room ID, because we can't know the ID before
+        # creating the room, and we can't set the config after starting the homeserver.
+        self.hs.get_sync_handler().rooms_to_exclude.append(self.excluded_room_id)
+
+    def test_join_leave(self) -> None:
+        """Tests that rooms are correctly excluded from the 'join' and 'leave' sections of
+        sync responses.
+        """
+        channel = self.make_request("GET", "/sync", access_token=self.tok)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
+        self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
+
+        self.helper.leave(self.excluded_room_id, self.user_id, tok=self.tok)
+        self.helper.leave(self.included_room_id, self.user_id, tok=self.tok)
+
+        channel = self.make_request(
+            "GET",
+            "/sync?since=" + channel.json_body["next_batch"],
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["leave"])
+        self.assertIn(self.included_room_id, channel.json_body["rooms"]["leave"])
+
+    def test_invite(self) -> None:
+        """Tests that rooms are correctly excluded from the 'invite' section of sync
+        responses.
+        """
+        invitee = self.register_user("invitee", "password")
+        invitee_tok = self.login("invitee", "password")
+
+        self.helper.invite(self.excluded_room_id, self.user_id, invitee, tok=self.tok)
+        self.helper.invite(self.included_room_id, self.user_id, invitee, tok=self.tok)
+
+        channel = self.make_request("GET", "/sync", access_token=invitee_tok)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
+        self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])