Browse Source

Improve ServerNoticeServlet to avoid duplicate requests (#10679)

Fixes: #9544
Dirk Klimpel 2 years ago
parent
commit
e62cdbef1a

+ 1 - 0
changelog.d/10679.bugfix

@@ -0,0 +1 @@
+Improve ServerNoticeServlet to avoid duplicate requests and add unit tests.

+ 4 - 1
synapse/rest/admin/__init__.py

@@ -223,7 +223,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     RoomMembersRestServlet(hs).register(http_server)
     DeleteRoomRestServlet(hs).register(http_server)
     JoinRoomAliasServlet(hs).register(http_server)
-    SendServerNoticeServlet(hs).register(http_server)
     VersionServlet(hs).register(http_server)
     UserAdminServlet(hs).register(http_server)
     UserMembershipRestServlet(hs).register(http_server)
@@ -247,6 +246,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     NewRegistrationTokenRestServlet(hs).register(http_server)
     RegistrationTokenRestServlet(hs).register(http_server)
 
+    # Some servlets only get registered for the main process.
+    if hs.config.worker_app is None:
+        SendServerNoticeServlet(hs).register(http_server)
+
 
 def register_servlets_for_client_rest_resource(
     hs: "HomeServer", http_server: HttpServer

+ 12 - 7
synapse/rest/admin/server_notice_servlet.py

@@ -14,7 +14,7 @@
 from typing import TYPE_CHECKING, Optional, Tuple
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
@@ -53,6 +53,8 @@ class SendServerNoticeServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
+        self.server_notices_manager = hs.get_server_notices_manager()
+        self.admin_handler = hs.get_admin_handler()
         self.txns = HttpTransactionCache(hs)
 
     def register(self, json_resource: HttpServer):
@@ -79,19 +81,22 @@ class SendServerNoticeServlet(RestServlet):
         # We grab the server notices manager here as its initialisation has a check for worker processes,
         # but worker processes still need to initialise SendServerNoticeServlet (as it is part of the
         # admin api).
-        if not self.hs.get_server_notices_manager().is_enabled():
+        if not self.server_notices_manager.is_enabled():
             raise SynapseError(400, "Server notices are not enabled on this server")
 
-        user_id = body["user_id"]
-        UserID.from_string(user_id)
-        if not self.hs.is_mine_id(user_id):
+        target_user = UserID.from_string(body["user_id"])
+        if not self.hs.is_mine(target_user):
             raise SynapseError(400, "Server notices can only be sent to local users")
 
-        event = await self.hs.get_server_notices_manager().send_notice(
-            user_id=body["user_id"],
+        if not await self.admin_handler.get_user(target_user):
+            raise NotFoundError("User not found")
+
+        event = await self.server_notices_manager.send_notice(
+            user_id=target_user.to_string(),
             type=event_type,
             state_key=state_key,
             event_content=body["content"],
+            txn_id=txn_id,
         )
 
         return 200, {"event_id": event.event_id}

+ 8 - 9
synapse/server_notices/server_notices_manager.py

@@ -12,26 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
 from synapse.events import EventBase
 from synapse.types import UserID, create_requester
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 SERVER_NOTICE_ROOM_TAG = "m.server_notice"
 
 
 class ServerNoticesManager:
-    def __init__(self, hs):
-        """
-
-        Args:
-            hs (synapse.server.HomeServer):
-        """
-
+    def __init__(self, hs: "HomeServer"):
         self._store = hs.get_datastore()
         self._config = hs.config
         self._account_data_handler = hs.get_account_data_handler()
@@ -58,6 +55,7 @@ class ServerNoticesManager:
         event_content: dict,
         type: str = EventTypes.Message,
         state_key: Optional[str] = None,
+        txn_id: Optional[str] = None,
     ) -> EventBase:
         """Send a notice to the given user
 
@@ -68,6 +66,7 @@ class ServerNoticesManager:
             event_content: content of event to send
             type: type of event
             is_state_event: Is the event a state event
+            txn_id: The transaction ID.
         """
         room_id = await self.get_or_create_notice_room_for_user(user_id)
         await self.maybe_invite_user_to_room(user_id, room_id)
@@ -90,7 +89,7 @@ class ServerNoticesManager:
             event_dict["state_key"] = state_key
 
         event, _ = await self._event_creation_handler.create_and_send_nonmember_event(
-            requester, event_dict, ratelimit=False
+            requester, event_dict, ratelimit=False, txn_id=txn_id
         )
         return event
 

+ 450 - 0
tests/rest/admin/test_server_notice.py

@@ -0,0 +1,450 @@
+# Copyright 2021 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login, room, sync
+from synapse.storage.roommember import RoomsForUser
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class ServerNoticeTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+        sync.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.room_shutdown_handler = hs.get_room_shutdown_handler()
+        self.pagination_handler = hs.get_pagination_handler()
+        self.server_notices_manager = self.hs.get_server_notices_manager()
+
+        # Create user
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_token = self.login("user", "pass")
+
+        self.url = "/_synapse/admin/v1/send_server_notice"
+
+    def test_no_auth(self):
+        """Try to send a server notice without authentication."""
+        channel = self.make_request("POST", self.url)
+
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_requester_is_no_admin(self):
+        """If the user is not a server admin, an error is returned."""
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.other_user_token,
+        )
+
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+    def test_user_does_not_exist(self):
+        """Tests that a lookup for a user that does not exist returns a 404"""
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"user_id": "@unknown_person:test", "content": ""},
+        )
+
+        self.assertEqual(404, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+    @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+    def test_user_is_not_local(self):
+        """
+        Tests that a lookup for a user that is not a local returns a 400
+        """
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": "@unknown_person:unknown_domain",
+                "content": "",
+            },
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(
+            "Server notices can only be sent to local users", channel.json_body["error"]
+        )
+
+    @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+    def test_invalid_parameter(self):
+        """If parameters are invalid, an error is returned."""
+
+        # no content, no user
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+        # no content
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"user_id": self.other_user},
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+        # no body
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"user_id": self.other_user, "content": ""},
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual("'body' not in content", channel.json_body["error"])
+
+        # no msgtype
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={"user_id": self.other_user, "content": {"body": ""}},
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+
+    def test_server_notice_disabled(self):
+        """Tests that server returns error if server notice is disabled"""
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": self.other_user,
+                "content": "",
+            },
+        )
+
+        self.assertEqual(400, channel.code, msg=channel.json_body)
+        self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+        self.assertEqual(
+            "Server notices are not enabled on this server", channel.json_body["error"]
+        )
+
+    @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+    def test_send_server_notice(self):
+        """
+        Tests that sending two server notices is successfully,
+        the server uses the same room and do not send messages twice.
+        """
+        # user has no room memberships
+        self._check_invite_and_join_status(self.other_user, 0, 0)
+
+        # send first message
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": self.other_user,
+                "content": {"msgtype": "m.text", "body": "test msg one"},
+            },
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # user has one invite
+        invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+        room_id = invited_rooms[0].room_id
+
+        # user joins the room and is member now
+        self.helper.join(room=room_id, user=self.other_user, tok=self.other_user_token)
+        self._check_invite_and_join_status(self.other_user, 0, 1)
+
+        # get messages
+        messages = self._sync_and_get_messages(room_id, self.other_user_token)
+        self.assertEqual(len(messages), 1)
+        self.assertEqual(messages[0]["content"]["body"], "test msg one")
+        self.assertEqual(messages[0]["sender"], "@notices:test")
+
+        # invalidate cache of server notices room_ids
+        self.get_success(
+            self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+        )
+
+        # send second message
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": self.other_user,
+                "content": {"msgtype": "m.text", "body": "test msg two"},
+            },
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # user has no new invites or memberships
+        self._check_invite_and_join_status(self.other_user, 0, 1)
+
+        # get messages
+        messages = self._sync_and_get_messages(room_id, self.other_user_token)
+
+        self.assertEqual(len(messages), 2)
+        self.assertEqual(messages[0]["content"]["body"], "test msg one")
+        self.assertEqual(messages[0]["sender"], "@notices:test")
+        self.assertEqual(messages[1]["content"]["body"], "test msg two")
+        self.assertEqual(messages[1]["sender"], "@notices:test")
+
+    @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+    def test_send_server_notice_leave_room(self):
+        """
+        Tests that sending a server notices is successfully.
+        The user leaves the room and the second message appears
+        in a new room.
+        """
+        # user has no room memberships
+        self._check_invite_and_join_status(self.other_user, 0, 0)
+
+        # send first message
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": self.other_user,
+                "content": {"msgtype": "m.text", "body": "test msg one"},
+            },
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # user has one invite
+        invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+        first_room_id = invited_rooms[0].room_id
+
+        # user joins the room and is member now
+        self.helper.join(
+            room=first_room_id, user=self.other_user, tok=self.other_user_token
+        )
+        self._check_invite_and_join_status(self.other_user, 0, 1)
+
+        # get messages
+        messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+        self.assertEqual(len(messages), 1)
+        self.assertEqual(messages[0]["content"]["body"], "test msg one")
+        self.assertEqual(messages[0]["sender"], "@notices:test")
+
+        # user leaves the romm
+        self.helper.leave(
+            room=first_room_id, user=self.other_user, tok=self.other_user_token
+        )
+
+        # user is not member anymore
+        self._check_invite_and_join_status(self.other_user, 0, 0)
+
+        # invalidate cache of server notices room_ids
+        # if server tries to send to a cached room_id the user gets the message
+        # in old room
+        self.get_success(
+            self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+        )
+
+        # send second message
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": self.other_user,
+                "content": {"msgtype": "m.text", "body": "test msg two"},
+            },
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # user has one invite
+        invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+        second_room_id = invited_rooms[0].room_id
+
+        # user joins the room and is member now
+        self.helper.join(
+            room=second_room_id, user=self.other_user, tok=self.other_user_token
+        )
+        self._check_invite_and_join_status(self.other_user, 0, 1)
+
+        # get messages
+        messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+        self.assertEqual(len(messages), 1)
+        self.assertEqual(messages[0]["content"]["body"], "test msg two")
+        self.assertEqual(messages[0]["sender"], "@notices:test")
+        # room has the same id
+        self.assertNotEqual(first_room_id, second_room_id)
+
+    @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+    def test_send_server_notice_delete_room(self):
+        """
+        Tests that the user get server notice in a new room
+        after the first server notice room was deleted.
+        """
+        # user has no room memberships
+        self._check_invite_and_join_status(self.other_user, 0, 0)
+
+        # send first message
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": self.other_user,
+                "content": {"msgtype": "m.text", "body": "test msg one"},
+            },
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # user has one invite
+        invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+        first_room_id = invited_rooms[0].room_id
+
+        # user joins the room and is member now
+        self.helper.join(
+            room=first_room_id, user=self.other_user, tok=self.other_user_token
+        )
+        self._check_invite_and_join_status(self.other_user, 0, 1)
+
+        # get messages
+        messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+        self.assertEqual(len(messages), 1)
+        self.assertEqual(messages[0]["content"]["body"], "test msg one")
+        self.assertEqual(messages[0]["sender"], "@notices:test")
+
+        # shut down and purge room
+        self.get_success(
+            self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user)
+        )
+        self.get_success(self.pagination_handler.purge_room(first_room_id))
+
+        # user is not member anymore
+        self._check_invite_and_join_status(self.other_user, 0, 0)
+
+        # It doesn't really matter what API we use here, we just want to assert
+        # that the room doesn't exist.
+        summary = self.get_success(self.store.get_room_summary(first_room_id))
+        # The summary should be empty since the room doesn't exist.
+        self.assertEqual(summary, {})
+
+        # invalidate cache of server notices room_ids
+        # if server tries to send to a cached room_id it gives an error
+        self.get_success(
+            self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+        )
+
+        # send second message
+        channel = self.make_request(
+            "POST",
+            self.url,
+            access_token=self.admin_user_tok,
+            content={
+                "user_id": self.other_user,
+                "content": {"msgtype": "m.text", "body": "test msg two"},
+            },
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+
+        # user has one invite
+        invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+        second_room_id = invited_rooms[0].room_id
+
+        # user joins the room and is member now
+        self.helper.join(
+            room=second_room_id, user=self.other_user, tok=self.other_user_token
+        )
+        self._check_invite_and_join_status(self.other_user, 0, 1)
+
+        # get message
+        messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+        self.assertEqual(len(messages), 1)
+        self.assertEqual(messages[0]["content"]["body"], "test msg two")
+        self.assertEqual(messages[0]["sender"], "@notices:test")
+        # second room has new ID
+        self.assertNotEqual(first_room_id, second_room_id)
+
+    def _check_invite_and_join_status(
+        self, user_id: str, expected_invites: int, expected_memberships: int
+    ) -> RoomsForUser:
+        """Check invite and room membership status of a user.
+
+        Args
+            user_id: user to check
+            expected_invites: number of expected invites of this user
+            expected_memberships: number of expected room memberships of this user
+        Returns
+            room_ids from the rooms that the user is invited
+        """
+
+        invited_rooms = self.get_success(
+            self.store.get_invited_rooms_for_local_user(user_id)
+        )
+        self.assertEqual(expected_invites, len(invited_rooms))
+
+        room_ids = self.get_success(self.store.get_rooms_for_user(user_id))
+        self.assertEqual(expected_memberships, len(room_ids))
+
+        return invited_rooms
+
+    def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]:
+        """
+        Do a sync and get messages of a room.
+
+        Args
+            room_id: room that contains the messages
+            token: access token of user
+
+        Returns
+            list of messages contained in the room
+        """
+        channel = self.make_request(
+            "GET", "/_matrix/client/r0/sync", access_token=token
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Get the messages
+        room = channel.json_body["rooms"]["join"][room_id]
+        messages = [
+            x for x in room["timeline"]["events"] if x["type"] == "m.room.message"
+        ]
+        return messages