devicemessage.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import TYPE_CHECKING, Any, Dict
  16. from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
  17. from synapse.api.errors import SynapseError
  18. from synapse.api.ratelimiting import Ratelimiter
  19. from synapse.logging.context import run_in_background
  20. from synapse.logging.opentracing import (
  21. SynapseTags,
  22. get_active_span_text_map,
  23. log_kv,
  24. set_tag,
  25. )
  26. from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
  27. from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
  28. from synapse.util import json_encoder
  29. from synapse.util.stringutils import random_string
  30. if TYPE_CHECKING:
  31. from synapse.server import HomeServer
  32. logger = logging.getLogger(__name__)
  33. class DeviceMessageHandler:
  34. def __init__(self, hs: "HomeServer"):
  35. """
  36. Args:
  37. hs: server
  38. """
  39. self.store = hs.get_datastores().main
  40. self.notifier = hs.get_notifier()
  41. self.is_mine = hs.is_mine
  42. # We only need to poke the federation sender explicitly if its on the
  43. # same instance. Other federation sender instances will get notified by
  44. # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
  45. # in the to-device replication stream.
  46. self.federation_sender = None
  47. if hs.should_send_federation():
  48. self.federation_sender = hs.get_federation_sender()
  49. # If we can handle the to device EDUs we do so, otherwise we route them
  50. # to the appropriate worker.
  51. if hs.get_instance_name() in hs.config.worker.writers.to_device:
  52. hs.get_federation_registry().register_edu_handler(
  53. EduTypes.DIRECT_TO_DEVICE, self.on_direct_to_device_edu
  54. )
  55. else:
  56. hs.get_federation_registry().register_instances_for_edu(
  57. EduTypes.DIRECT_TO_DEVICE,
  58. hs.config.worker.writers.to_device,
  59. )
  60. # The handler to call when we think a user's device list might be out of
  61. # sync. We do all device list resyncing on the master instance, so if
  62. # we're on a worker we hit the device resync replication API.
  63. if hs.config.worker.worker_app is None:
  64. self._user_device_resync = (
  65. hs.get_device_handler().device_list_updater.user_device_resync
  66. )
  67. else:
  68. self._user_device_resync = (
  69. ReplicationUserDevicesResyncRestServlet.make_client(hs)
  70. )
  71. # a rate limiter for room key requests. The keys are
  72. # (sending_user_id, sending_device_id).
  73. self._ratelimiter = Ratelimiter(
  74. store=self.store,
  75. clock=hs.get_clock(),
  76. rate_hz=hs.config.ratelimiting.rc_key_requests.per_second,
  77. burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
  78. )
  79. async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
  80. """
  81. Handle receiving to-device messages from remote homeservers.
  82. Args:
  83. origin: The remote homeserver.
  84. content: The JSON dictionary containing the to-device messages.
  85. """
  86. local_messages = {}
  87. sender_user_id = content["sender"]
  88. if origin != get_domain_from_id(sender_user_id):
  89. logger.warning(
  90. "Dropping device message from %r with spoofed sender %r",
  91. origin,
  92. sender_user_id,
  93. )
  94. message_type = content["type"]
  95. message_id = content["message_id"]
  96. for user_id, by_device in content["messages"].items():
  97. # we use UserID.from_string to catch invalid user ids
  98. if not self.is_mine(UserID.from_string(user_id)):
  99. logger.warning("To-device message to non-local user %s", user_id)
  100. raise SynapseError(400, "Not a user here")
  101. if not by_device:
  102. continue
  103. # Ratelimit key requests by the sending user.
  104. if message_type == ToDeviceEventTypes.RoomKeyRequest:
  105. allowed, _ = await self._ratelimiter.can_do_action(
  106. None, (sender_user_id, None)
  107. )
  108. if not allowed:
  109. logger.info(
  110. "Dropping room_key_request from %s to %s due to rate limit",
  111. sender_user_id,
  112. user_id,
  113. )
  114. continue
  115. messages_by_device = {
  116. device_id: {
  117. "content": message_content,
  118. "type": message_type,
  119. "sender": sender_user_id,
  120. }
  121. for device_id, message_content in by_device.items()
  122. }
  123. local_messages[user_id] = messages_by_device
  124. await self._check_for_unknown_devices(
  125. message_type, sender_user_id, by_device
  126. )
  127. # Add messages to the database.
  128. # Retrieve the stream id of the last-processed to-device message.
  129. last_stream_id = await self.store.add_messages_from_remote_to_device_inbox(
  130. origin, message_id, local_messages
  131. )
  132. # Notify listeners that there are new to-device messages to process,
  133. # handing them the latest stream id.
  134. self.notifier.on_new_event(
  135. StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
  136. )
  137. async def _check_for_unknown_devices(
  138. self,
  139. message_type: str,
  140. sender_user_id: str,
  141. by_device: Dict[str, Dict[str, Any]],
  142. ) -> None:
  143. """Checks inbound device messages for unknown remote devices, and if
  144. found marks the remote cache for the user as stale.
  145. """
  146. if message_type != "m.room_key_request":
  147. return
  148. # Get the sending device IDs
  149. requesting_device_ids = set()
  150. for message_content in by_device.values():
  151. device_id = message_content.get("requesting_device_id")
  152. requesting_device_ids.add(device_id)
  153. # Check if we are tracking the devices of the remote user.
  154. room_ids = await self.store.get_rooms_for_user(sender_user_id)
  155. if not room_ids:
  156. logger.info(
  157. "Received device message from remote device we don't"
  158. " share a room with: %s %s",
  159. sender_user_id,
  160. requesting_device_ids,
  161. )
  162. return
  163. # If we are tracking check that we know about the sending
  164. # devices.
  165. cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
  166. unknown_devices = requesting_device_ids - set(cached_devices)
  167. if unknown_devices:
  168. logger.info(
  169. "Received device message from remote device not in our cache: %s %s",
  170. sender_user_id,
  171. unknown_devices,
  172. )
  173. await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
  174. # Immediately attempt a resync in the background
  175. run_in_background(self._user_device_resync, user_id=sender_user_id)
  176. async def send_device_message(
  177. self,
  178. requester: Requester,
  179. message_type: str,
  180. messages: Dict[str, Dict[str, JsonDict]],
  181. ) -> None:
  182. """
  183. Handle a request from a user to send to-device message(s).
  184. Args:
  185. requester: The user that is sending the to-device messages.
  186. message_type: The type of to-device messages that are being sent.
  187. messages: A dictionary containing recipients mapped to messages intended for them.
  188. """
  189. sender_user_id = requester.user.to_string()
  190. set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
  191. set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
  192. local_messages = {}
  193. remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
  194. for user_id, by_device in messages.items():
  195. # add an opentracing log entry for each message
  196. for device_id, message_content in by_device.items():
  197. log_kv(
  198. {
  199. "event": "send_to_device_message",
  200. "user_id": user_id,
  201. "device_id": device_id,
  202. EventContentFields.TO_DEVICE_MSGID: message_content.get(
  203. EventContentFields.TO_DEVICE_MSGID
  204. ),
  205. }
  206. )
  207. # Ratelimit local cross-user key requests by the sending device.
  208. if (
  209. message_type == ToDeviceEventTypes.RoomKeyRequest
  210. and user_id != sender_user_id
  211. ):
  212. allowed, _ = await self._ratelimiter.can_do_action(
  213. requester, (sender_user_id, requester.device_id)
  214. )
  215. if not allowed:
  216. log_kv({"message": f"dropping key requests to {user_id}"})
  217. logger.info(
  218. "Dropping room_key_request from %s to %s due to rate limit",
  219. sender_user_id,
  220. user_id,
  221. )
  222. continue
  223. # we use UserID.from_string to catch invalid user ids
  224. if self.is_mine(UserID.from_string(user_id)):
  225. messages_by_device = {
  226. device_id: {
  227. "content": message_content,
  228. "type": message_type,
  229. "sender": sender_user_id,
  230. }
  231. for device_id, message_content in by_device.items()
  232. }
  233. if messages_by_device:
  234. local_messages[user_id] = messages_by_device
  235. else:
  236. destination = get_domain_from_id(user_id)
  237. remote_messages.setdefault(destination, {})[user_id] = by_device
  238. context = get_active_span_text_map()
  239. remote_edu_contents = {}
  240. for destination, messages in remote_messages.items():
  241. # The EDU contains a "message_id" property which is used for
  242. # idempotence. Make up a random one.
  243. message_id = random_string(16)
  244. log_kv({"destination": destination, "message_id": message_id})
  245. remote_edu_contents[destination] = {
  246. "messages": messages,
  247. "sender": sender_user_id,
  248. "type": message_type,
  249. "message_id": message_id,
  250. "org.matrix.opentracing_context": json_encoder.encode(context),
  251. }
  252. # Add messages to the database.
  253. # Retrieve the stream id of the last-processed to-device message.
  254. last_stream_id = await self.store.add_messages_to_device_inbox(
  255. local_messages, remote_edu_contents
  256. )
  257. # Notify listeners that there are new to-device messages to process,
  258. # handing them the latest stream id.
  259. self.notifier.on_new_event(
  260. StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
  261. )
  262. if self.federation_sender:
  263. for destination in remote_messages.keys():
  264. # Enqueue a new federation transaction to send the new
  265. # device messages to each remote destination.
  266. self.federation_sender.send_device_messages(destination)