Browse Source

Fix join ratelimiter breaking profile updates and idempotency (#8153)

Brendan Abolivier 3 years ago
parent
commit
393a811a41

+ 1 - 0
changelog.d/8153.bugfix

@@ -0,0 +1 @@
+Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests.

+ 25 - 21
synapse/handlers/room_member.py

@@ -210,24 +210,40 @@ class RoomMemberHandler(object):
             _, stream_id = await self.store.get_event_ordering(duplicate.event_id)
             return duplicate.event_id, stream_id
 
-        stream_id = await self.event_creation_handler.handle_new_client_event(
-            requester, event, context, extra_users=[target], ratelimit=ratelimit,
-        )
-
         prev_state_ids = await context.get_prev_state_ids()
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
 
+        newly_joined = False
         if event.membership == Membership.JOIN:
-            # Only fire user_joined_room if the user has actually joined the
-            # room. Don't bother if the user is just changing their profile
-            # info.
             newly_joined = True
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
                 newly_joined = prev_member_event.membership != Membership.JOIN
+
+            # Only rate-limit if the user actually joined the room, otherwise we'll end
+            # up blocking profile updates.
             if newly_joined:
-                await self._user_joined_room(target, room_id)
+                time_now_s = self.clock.time()
+                (
+                    allowed,
+                    time_allowed,
+                ) = self._join_rate_limiter_local.can_requester_do_action(requester)
+
+                if not allowed:
+                    raise LimitExceededError(
+                        retry_after_ms=int(1000 * (time_allowed - time_now_s))
+                    )
+
+        stream_id = await self.event_creation_handler.handle_new_client_event(
+            requester, event, context, extra_users=[target], ratelimit=ratelimit,
+        )
+
+        if event.membership == Membership.JOIN and newly_joined:
+            # Only fire user_joined_room if the user has actually joined the
+            # room. Don't bother if the user is just changing their profile
+            # info.
+            await self._user_joined_room(target, room_id)
         elif event.membership == Membership.LEAVE:
             if prev_member_event_id:
                 prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -457,19 +473,7 @@ class RoomMemberHandler(object):
                     # so don't really fit into the general auth process.
                     raise AuthError(403, "Guest access not allowed")
 
-            if is_host_in_room:
-                time_now_s = self.clock.time()
-                (
-                    allowed,
-                    time_allowed,
-                ) = self._join_rate_limiter_local.can_requester_do_action(requester,)
-
-                if not allowed:
-                    raise LimitExceededError(
-                        retry_after_ms=int(1000 * (time_allowed - time_now_s))
-                    )
-
-            else:
+            if not is_host_in_room:
                 time_now_s = self.clock.time()
                 (
                     allowed,

+ 86 - 1
tests/rest/client/v1/test_rooms.py

@@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.handlers.pagination import PurgeStatus
 from synapse.rest.client.v1 import directory, login, profile, room
 from synapse.rest.client.v2_alpha import account
-from synapse.types import JsonDict, RoomAlias
+from synapse.types import JsonDict, RoomAlias, UserID
 from synapse.util.stringutils import random_string
 
 from tests import unittest
@@ -675,6 +675,91 @@ class RoomMemberStateTestCase(RoomBase):
         self.assertEquals(json.loads(content), channel.json_body)
 
 
+class RoomJoinRatelimitTestCase(RoomBase):
+    user_id = "@sid1:red"
+
+    servlets = [
+        profile.register_servlets,
+        room.register_servlets,
+    ]
+
+    @unittest.override_config(
+        {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+    )
+    def test_join_local_ratelimit(self):
+        """Tests that local joins are actually rate-limited."""
+        for i in range(5):
+            self.helper.create_room_as(self.user_id)
+
+        self.helper.create_room_as(self.user_id, expect_code=429)
+
+    @unittest.override_config(
+        {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+    )
+    def test_join_local_ratelimit_profile_change(self):
+        """Tests that sending a profile update into all of the user's joined rooms isn't
+        rate-limited by the rate-limiter on joins."""
+
+        # Create and join more rooms than the rate-limiting config allows in a second.
+        room_ids = [
+            self.helper.create_room_as(self.user_id),
+            self.helper.create_room_as(self.user_id),
+            self.helper.create_room_as(self.user_id),
+        ]
+        self.reactor.advance(1)
+        room_ids = room_ids + [
+            self.helper.create_room_as(self.user_id),
+            self.helper.create_room_as(self.user_id),
+            self.helper.create_room_as(self.user_id),
+        ]
+
+        # Create a profile for the user, since it hasn't been done on registration.
+        store = self.hs.get_datastore()
+        store.create_profile(UserID.from_string(self.user_id).localpart)
+
+        # Update the display name for the user.
+        path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
+        request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
+        self.render(request)
+        self.assertEquals(channel.code, 200, channel.json_body)
+
+        # Check that all the rooms have been sent a profile update into.
+        for room_id in room_ids:
+            path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
+                room_id,
+                self.user_id,
+            )
+
+            request, channel = self.make_request("GET", path)
+            self.render(request)
+            self.assertEquals(channel.code, 200)
+
+            self.assertIn("displayname", channel.json_body)
+            self.assertEquals(channel.json_body["displayname"], "John Doe")
+
+    @unittest.override_config(
+        {"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
+    )
+    def test_join_local_ratelimit_idempotent(self):
+        """Tests that the room join endpoints remain idempotent despite rate-limiting
+        on room joins."""
+        room_id = self.helper.create_room_as(self.user_id)
+
+        # Let's test both paths to be sure.
+        paths_to_test = [
+            "/_matrix/client/r0/rooms/%s/join",
+            "/_matrix/client/r0/join/%s",
+        ]
+
+        for path in paths_to_test:
+            # Make sure we send more requests than the rate-limiting config would allow
+            # if all of these requests ended up joining the user to a room.
+            for i in range(6):
+                request, channel = self.make_request("POST", path % room_id, {})
+                self.render(request)
+                self.assertEquals(channel.code, 200)
+
+
 class RoomMessagesTestCase(RoomBase):
     """ Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
 

+ 7 - 3
tests/rest/client/v1/utils.py

@@ -39,7 +39,9 @@ class RestHelper(object):
     resource = attr.ib()
     auth_user_id = attr.ib()
 
-    def create_room_as(self, room_creator=None, is_public=True, tok=None):
+    def create_room_as(
+        self, room_creator=None, is_public=True, tok=None, expect_code=200,
+    ):
         temp_id = self.auth_user_id
         self.auth_user_id = room_creator
         path = "/_matrix/client/r0/createRoom"
@@ -54,9 +56,11 @@ class RestHelper(object):
         )
         render(request, self.resource, self.hs.get_reactor())
 
-        assert channel.result["code"] == b"200", channel.result
+        assert channel.result["code"] == b"%d" % expect_code, channel.result
         self.auth_user_id = temp_id
-        return channel.json_body["room_id"]
+
+        if expect_code == 200:
+            return channel.json_body["room_id"]
 
     def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
         self.change_membership(