1
0
Эх сурвалжийг харах

Implements part of MSC 3944 by dropping cancelled&duplicated `m.room_key_request`

Mathieu Velten 1 жил өмнө
parent
commit
e25c15ea0f

+ 1 - 0
changelog.d/15842.feature

@@ -0,0 +1 @@
+Implements bullets 1 and 2 of [MSC 3944](https://github.com/matrix-org/matrix-spec-proposals/pull/3944) related to dropping cancelled and duplicated `m.room_key_request`.

+ 3 - 0
synapse/config/experimental.py

@@ -389,3 +389,6 @@ class ExperimentalConfig(Config):
         self.msc4010_push_rules_account_data = experimental.get(
         self.msc4010_push_rules_account_data = experimental.get(
             "msc4010_push_rules_account_data", False
             "msc4010_push_rules_account_data", False
         )
         )
+
+        # MSC3944: Dropping stale send-to-device messages
+        self.msc3944_enabled: bool = experimental.get("msc3944_enabled", False)

+ 50 - 7
synapse/handlers/devicemessage.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
+import json
 import logging
 import logging
 from typing import TYPE_CHECKING, Any, Dict
 from typing import TYPE_CHECKING, Any, Dict
 
 
@@ -90,6 +91,8 @@ class DeviceMessageHandler:
             burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
             burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
         )
         )
 
 
+        self._msc3944_enabled = hs.config.experimental.msc3944_enabled
+
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
         """
         """
         Handle receiving to-device messages from remote homeservers.
         Handle receiving to-device messages from remote homeservers.
@@ -220,7 +223,7 @@ class DeviceMessageHandler:
 
 
         set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
         set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
         set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
         set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
-        local_messages = {}
+        local_messages: Dict[str, Dict[str, JsonDict]] = {}
         remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for user_id, by_device in messages.items():
         for user_id, by_device in messages.items():
             # add an opentracing log entry for each message
             # add an opentracing log entry for each message
@@ -255,16 +258,56 @@ class DeviceMessageHandler:
 
 
             # we use UserID.from_string to catch invalid user ids
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
             if self.is_mine(UserID.from_string(user_id)):
-                messages_by_device = {
-                    device_id: {
+                for device_id, message_content in by_device.items():
+                    # Drop any previous identical (same request_id and requesting_device_id)
+                    # room_key_request, ignoring the action property when comparing.
+                    # This handles dropping previous identical and cancelled requests.
+                    if (
+                        self._msc3944_enabled
+                        and message_type == ToDeviceEventTypes.RoomKeyRequest
+                        and user_id == sender_user_id
+                    ):
+                        req_id = message_content.get("request_id")
+                        requesting_device_id = message_content.get(
+                            "requesting_device_id"
+                        )
+                        if req_id and requesting_device_id:
+                            previous_request_deleted = False
+                            for (
+                                stream_id,
+                                message_json,
+                            ) in await self.store.get_all_device_messages(
+                                user_id, device_id
+                            ):
+                                orig_message = json.loads(message_json)
+                                if (
+                                    orig_message["type"]
+                                    == ToDeviceEventTypes.RoomKeyRequest
+                                ):
+                                    content = orig_message.get("content", {})
+                                    if (
+                                        content.get("request_id") == req_id
+                                        and content.get("requesting_device_id")
+                                        == requesting_device_id
+                                    ):
+                                        if await self.store.delete_device_message(
+                                            stream_id
+                                        ):
+                                            previous_request_deleted = True
+
+                            if (
+                                message_content.get("action") == "request_cancellation"
+                                and previous_request_deleted
+                            ):
+                                # Do not store the cancellation since we deleted the matching
+                                # request(s) before it reaches the device.
+                                continue
+                    message = {
                         "content": message_content,
                         "content": message_content,
                         "type": message_type,
                         "type": message_type,
                         "sender": sender_user_id,
                         "sender": sender_user_id,
                     }
                     }
-                    for device_id, message_content in by_device.items()
-                }
-                if messages_by_device:
-                    local_messages[user_id] = messages_by_device
+                    local_messages.setdefault(user_id, {})[device_id] = message
             else:
             else:
                 destination = get_domain_from_id(user_id)
                 destination = get_domain_from_id(user_id)
                 remote_messages.setdefault(destination, {})[user_id] = by_device
                 remote_messages.setdefault(destination, {})[user_id] = by_device

+ 41 - 0
synapse/storage/databases/main/deviceinbox.py

@@ -27,6 +27,7 @@ from typing import (
 )
 )
 
 
 from synapse.api.constants import EventContentFields
 from synapse.api.constants import EventContentFields
+from synapse.api.errors import StoreError
 from synapse.logging import issue9533_logger
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import (
 from synapse.logging.opentracing import (
     SynapseTags,
     SynapseTags,
@@ -891,6 +892,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 ],
                 ],
             )
             )
 
 
+    async def delete_device_message(self, stream_id: int) -> bool:
+        """Delete a specific device message from the message inbox.
+
+        Args:
+            stream_id: the stream ID identifying the message.
+        Returns:
+            True if the message has been deleted, False if it didn't exist.
+        """
+        try:
+            await self.db_pool.simple_delete_one(
+                "device_inbox",
+                keyvalues={"stream_id": stream_id},
+                desc="delete_device_message",
+            )
+        except StoreError:
+            # Deletion failed because device message does not exist
+            return False
+        return True
+
+    async def get_all_device_messages(
+        self,
+        user_id: str,
+        device_id: str,
+    ) -> List[Tuple[int, str]]:
+        """Get all device messages in the inbox from a specific device.
+
+        Args:
+            user_id: the user ID of the device we want to query.
+            device_id: the device ID of the device we want to query.
+        Returns:
+            A list of (stream ID, message content) tuples.
+        """
+        rows = await self.db_pool.simple_select_list(
+            table="device_inbox",
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            retcols=("stream_id", "message_json"),
+            desc="get_all_device_messages",
+        )
+        return [(r["stream_id"], r["message_json"]) for r in rows]
+
 
 
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"

+ 120 - 1
tests/handlers/test_device.py

@@ -19,13 +19,14 @@ from unittest import mock
 
 
 from twisted.test.proto_helpers import MemoryReactor
 from twisted.test.proto_helpers import MemoryReactor
 
 
+import synapse
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.appservice import ApplicationService
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
 from synapse.server import HomeServer
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 from synapse.util import Clock
 
 
 from tests import unittest
 from tests import unittest
@@ -37,6 +38,11 @@ user2 = "@theresa:bbb"
 
 
 
 
 class DeviceTestCase(unittest.HomeserverTestCase):
 class DeviceTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        synapse.rest.client.login.register_servlets,
+    ]
+
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.appservice_api = mock.Mock()
         self.appservice_api = mock.Mock()
         hs = self.setup_test_homeserver(
         hs = self.setup_test_homeserver(
@@ -47,6 +53,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         handler = hs.get_device_handler()
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
         self.handler = handler
+        self.msg_handler = hs.get_device_message_handler()
+        self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastores().main
         self.store = hs.get_datastores().main
         return hs
         return hs
 
 
@@ -398,6 +406,117 @@ class DeviceTestCase(unittest.HomeserverTestCase):
             ],
             ],
         )
         )
 
 
+    @override_config({"experimental_features": {"msc3944_enabled": True}})
+    def test_duplicated_and_cancelled_room_key_request(self) -> None:
+        myuser = self.register_user("myuser", "pass")
+        self.login("myuser", "pass", "device")
+        self.login("myuser", "pass", "device2")
+        self.login("myuser", "pass", "device3")
+
+        requester = requester = create_requester(myuser)
+
+        from_token = self.event_sources.get_current_token()
+
+        # This room_key_request is for device3 and should not be deleted.
+        self.get_success(
+            self.msg_handler.send_device_message(
+                requester,
+                "m.room_key_request",
+                {
+                    myuser: {
+                        "device3": {
+                            "action": "request",
+                            "request_id": "request_id",
+                            "requesting_device_id": "device",
+                        }
+                    }
+                },
+            )
+        )
+
+        for _ in range(0, 2):
+            self.get_success(
+                self.msg_handler.send_device_message(
+                    requester,
+                    "m.room_key_request",
+                    {
+                        myuser: {
+                            "device2": {
+                                "action": "request",
+                                "request_id": "request_id",
+                                "requesting_device_id": "device",
+                            }
+                        }
+                    },
+                )
+            )
+
+            to_token = self.event_sources.get_current_token()
+
+            # Test that if we queue 2 identical room_key_request,
+            # only one is delivered to the device.
+            res = self.get_success(
+                self.store.get_messages_for_device(
+                    myuser,
+                    "device2",
+                    from_token.to_device_key,
+                    to_token.to_device_key,
+                )
+            )
+            self.assertEqual(len(res[0]), 1)
+
+        # room_key_request for device3 should still be around.
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device3",
+                from_token.to_device_key,
+                to_token.to_device_key,
+            )
+        )
+        self.assertEqual(len(res[0]), 1)
+
+        self.get_success(
+            self.msg_handler.send_device_message(
+                requester,
+                "m.room_key_request",
+                {
+                    myuser: {
+                        "device2": {
+                            "action": "request_cancellation",
+                            "request_id": "request_id",
+                            "requesting_device_id": "device",
+                        }
+                    }
+                },
+            )
+        )
+
+        to_token = self.event_sources.get_current_token()
+
+        # Test that if we cancel a room_key_request, both previous matching
+        # requests and the cancelled request are not delivered to the device.
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device2",
+                from_token.to_device_key,
+                to_token.to_device_key,
+            )
+        )
+        self.assertEqual(len(res[0]), 0)
+
+        # room_key_request for device3 should still be around.
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device3",
+                from_token.to_device_key,
+                to_token.to_device_key,
+            )
+        )
+        self.assertEqual(len(res[0]), 1)
+
 
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: