123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- # Copyright 2016 OpenMarket Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- from typing import TYPE_CHECKING, Any, Dict
- from synapse.api.constants import EduTypes, ToDeviceEventTypes
- from synapse.api.errors import SynapseError
- from synapse.api.ratelimiting import Ratelimiter
- from synapse.logging.context import run_in_background
- from synapse.logging.opentracing import (
- SynapseTags,
- get_active_span_text_map,
- log_kv,
- set_tag,
- )
- from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
- from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
- from synapse.util import json_encoder
- from synapse.util.stringutils import random_string
- if TYPE_CHECKING:
- from synapse.server import HomeServer
- logger = logging.getLogger(__name__)
- class DeviceMessageHandler:
- def __init__(self, hs: "HomeServer"):
- """
- Args:
- hs: server
- """
- self.store = hs.get_datastores().main
- self.notifier = hs.get_notifier()
- self.is_mine = hs.is_mine
- # We only need to poke the federation sender explicitly if its on the
- # same instance. Other federation sender instances will get notified by
- # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
- # in the to-device replication stream.
- self.federation_sender = None
- if hs.should_send_federation():
- self.federation_sender = hs.get_federation_sender()
- # If we can handle the to device EDUs we do so, otherwise we route them
- # to the appropriate worker.
- if hs.get_instance_name() in hs.config.worker.writers.to_device:
- hs.get_federation_registry().register_edu_handler(
- EduTypes.DIRECT_TO_DEVICE, self.on_direct_to_device_edu
- )
- else:
- hs.get_federation_registry().register_instances_for_edu(
- EduTypes.DIRECT_TO_DEVICE,
- hs.config.worker.writers.to_device,
- )
- # The handler to call when we think a user's device list might be out of
- # sync. We do all device list resyncing on the master instance, so if
- # we're on a worker we hit the device resync replication API.
- if hs.config.worker.worker_app is None:
- self._user_device_resync = (
- hs.get_device_handler().device_list_updater.user_device_resync
- )
- else:
- self._user_device_resync = (
- ReplicationUserDevicesResyncRestServlet.make_client(hs)
- )
- # a rate limiter for room key requests. The keys are
- # (sending_user_id, sending_device_id).
- self._ratelimiter = Ratelimiter(
- store=self.store,
- clock=hs.get_clock(),
- rate_hz=hs.config.ratelimiting.rc_key_requests.per_second,
- burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
- )
- async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
- """
- Handle receiving to-device messages from remote homeservers.
- Args:
- origin: The remote homeserver.
- content: The JSON dictionary containing the to-device messages.
- """
- local_messages = {}
- sender_user_id = content["sender"]
- if origin != get_domain_from_id(sender_user_id):
- logger.warning(
- "Dropping device message from %r with spoofed sender %r",
- origin,
- sender_user_id,
- )
- message_type = content["type"]
- message_id = content["message_id"]
- for user_id, by_device in content["messages"].items():
- # we use UserID.from_string to catch invalid user ids
- if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("To-device message to non-local user %s", user_id)
- raise SynapseError(400, "Not a user here")
- if not by_device:
- continue
- # Ratelimit key requests by the sending user.
- if message_type == ToDeviceEventTypes.RoomKeyRequest:
- allowed, _ = await self._ratelimiter.can_do_action(
- None, (sender_user_id, None)
- )
- if not allowed:
- logger.info(
- "Dropping room_key_request from %s to %s due to rate limit",
- sender_user_id,
- user_id,
- )
- continue
- messages_by_device = {
- device_id: {
- "content": message_content,
- "type": message_type,
- "sender": sender_user_id,
- }
- for device_id, message_content in by_device.items()
- }
- local_messages[user_id] = messages_by_device
- await self._check_for_unknown_devices(
- message_type, sender_user_id, by_device
- )
- # Add messages to the database.
- # Retrieve the stream id of the last-processed to-device message.
- last_stream_id = await self.store.add_messages_from_remote_to_device_inbox(
- origin, message_id, local_messages
- )
- # Notify listeners that there are new to-device messages to process,
- # handing them the latest stream id.
- self.notifier.on_new_event(
- StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
- )
- async def _check_for_unknown_devices(
- self,
- message_type: str,
- sender_user_id: str,
- by_device: Dict[str, Dict[str, Any]],
- ) -> None:
- """Checks inbound device messages for unknown remote devices, and if
- found marks the remote cache for the user as stale.
- """
- if message_type != "m.room_key_request":
- return
- # Get the sending device IDs
- requesting_device_ids = set()
- for message_content in by_device.values():
- device_id = message_content.get("requesting_device_id")
- requesting_device_ids.add(device_id)
- # Check if we are tracking the devices of the remote user.
- room_ids = await self.store.get_rooms_for_user(sender_user_id)
- if not room_ids:
- logger.info(
- "Received device message from remote device we don't"
- " share a room with: %s %s",
- sender_user_id,
- requesting_device_ids,
- )
- return
- # If we are tracking check that we know about the sending
- # devices.
- cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
- unknown_devices = requesting_device_ids - set(cached_devices)
- if unknown_devices:
- logger.info(
- "Received device message from remote device not in our cache: %s %s",
- sender_user_id,
- unknown_devices,
- )
- await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
- # Immediately attempt a resync in the background
- run_in_background(self._user_device_resync, user_id=sender_user_id)
- async def send_device_message(
- self,
- requester: Requester,
- message_type: str,
- messages: Dict[str, Dict[str, JsonDict]],
- ) -> None:
- """
- Handle a request from a user to send to-device message(s).
- Args:
- requester: The user that is sending the to-device messages.
- message_type: The type of to-device messages that are being sent.
- messages: A dictionary containing recipients mapped to messages intended for them.
- """
- sender_user_id = requester.user.to_string()
- message_id = random_string(16)
- set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
- log_kv({"number_of_to_device_messages": len(messages)})
- set_tag("sender", sender_user_id)
- local_messages = {}
- remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
- for user_id, by_device in messages.items():
- # Ratelimit local cross-user key requests by the sending device.
- if (
- message_type == ToDeviceEventTypes.RoomKeyRequest
- and user_id != sender_user_id
- ):
- allowed, _ = await self._ratelimiter.can_do_action(
- requester, (sender_user_id, requester.device_id)
- )
- if not allowed:
- logger.info(
- "Dropping room_key_request from %s to %s due to rate limit",
- sender_user_id,
- user_id,
- )
- continue
- # we use UserID.from_string to catch invalid user ids
- if self.is_mine(UserID.from_string(user_id)):
- messages_by_device = {
- device_id: {
- "content": message_content,
- "type": message_type,
- "sender": sender_user_id,
- "message_id": message_id,
- }
- for device_id, message_content in by_device.items()
- }
- if messages_by_device:
- local_messages[user_id] = messages_by_device
- log_kv(
- {
- "user_id": user_id,
- "device_id": list(messages_by_device),
- }
- )
- else:
- destination = get_domain_from_id(user_id)
- remote_messages.setdefault(destination, {})[user_id] = by_device
- context = get_active_span_text_map()
- remote_edu_contents = {}
- for destination, messages in remote_messages.items():
- log_kv({"destination": destination})
- remote_edu_contents[destination] = {
- "messages": messages,
- "sender": sender_user_id,
- "type": message_type,
- "message_id": message_id,
- "org.matrix.opentracing_context": json_encoder.encode(context),
- }
- # Add messages to the database.
- # Retrieve the stream id of the last-processed to-device message.
- last_stream_id = await self.store.add_messages_to_device_inbox(
- local_messages, remote_edu_contents
- )
- # Notify listeners that there are new to-device messages to process,
- # handing them the latest stream id.
- self.notifier.on_new_event(
- StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
- )
- if self.federation_sender:
- for destination in remote_messages.keys():
- # Enqueue a new federation transaction to send the new
- # device messages to each remote destination.
- self.federation_sender.send_device_messages(destination)
|