devicemessage.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from canonicaljson import json
  17. from twisted.internet import defer
  18. from synapse.api.errors import SynapseError
  19. from synapse.logging.opentracing import (
  20. get_active_span_text_map,
  21. log_kv,
  22. set_tag,
  23. start_active_span,
  24. )
  25. from synapse.types import UserID, get_domain_from_id
  26. from synapse.util.stringutils import random_string
  27. logger = logging.getLogger(__name__)
  28. class DeviceMessageHandler(object):
  29. def __init__(self, hs):
  30. """
  31. Args:
  32. hs (synapse.server.HomeServer): server
  33. """
  34. self.store = hs.get_datastore()
  35. self.notifier = hs.get_notifier()
  36. self.is_mine = hs.is_mine
  37. self.federation = hs.get_federation_sender()
  38. hs.get_federation_registry().register_edu_handler(
  39. "m.direct_to_device", self.on_direct_to_device_edu
  40. )
  41. @defer.inlineCallbacks
  42. def on_direct_to_device_edu(self, origin, content):
  43. local_messages = {}
  44. sender_user_id = content["sender"]
  45. if origin != get_domain_from_id(sender_user_id):
  46. logger.warning(
  47. "Dropping device message from %r with spoofed sender %r",
  48. origin,
  49. sender_user_id,
  50. )
  51. message_type = content["type"]
  52. message_id = content["message_id"]
  53. for user_id, by_device in content["messages"].items():
  54. # we use UserID.from_string to catch invalid user ids
  55. if not self.is_mine(UserID.from_string(user_id)):
  56. logger.warning("Request for keys for non-local user %s", user_id)
  57. raise SynapseError(400, "Not a user here")
  58. messages_by_device = {
  59. device_id: {
  60. "content": message_content,
  61. "type": message_type,
  62. "sender": sender_user_id,
  63. }
  64. for device_id, message_content in by_device.items()
  65. }
  66. if messages_by_device:
  67. local_messages[user_id] = messages_by_device
  68. stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
  69. origin, message_id, local_messages
  70. )
  71. self.notifier.on_new_event(
  72. "to_device_key", stream_id, users=local_messages.keys()
  73. )
  74. @defer.inlineCallbacks
  75. def send_device_message(self, sender_user_id, message_type, messages):
  76. set_tag("number_of_messages", len(messages))
  77. set_tag("sender", sender_user_id)
  78. local_messages = {}
  79. remote_messages = {}
  80. for user_id, by_device in messages.items():
  81. # we use UserID.from_string to catch invalid user ids
  82. if self.is_mine(UserID.from_string(user_id)):
  83. messages_by_device = {
  84. device_id: {
  85. "content": message_content,
  86. "type": message_type,
  87. "sender": sender_user_id,
  88. }
  89. for device_id, message_content in by_device.items()
  90. }
  91. if messages_by_device:
  92. local_messages[user_id] = messages_by_device
  93. else:
  94. destination = get_domain_from_id(user_id)
  95. remote_messages.setdefault(destination, {})[user_id] = by_device
  96. message_id = random_string(16)
  97. context = get_active_span_text_map()
  98. remote_edu_contents = {}
  99. for destination, messages in remote_messages.items():
  100. with start_active_span("to_device_for_user"):
  101. set_tag("destination", destination)
  102. remote_edu_contents[destination] = {
  103. "messages": messages,
  104. "sender": sender_user_id,
  105. "type": message_type,
  106. "message_id": message_id,
  107. "org.matrix.opentracing_context": json.dumps(context),
  108. }
  109. log_kv({"local_messages": local_messages})
  110. stream_id = yield self.store.add_messages_to_device_inbox(
  111. local_messages, remote_edu_contents
  112. )
  113. self.notifier.on_new_event(
  114. "to_device_key", stream_id, users=local_messages.keys()
  115. )
  116. log_kv({"remote_messages": remote_messages})
  117. for destination in remote_messages.keys():
  118. # Enqueue a new federation transaction to send the new
  119. # device messages to each remote destination.
  120. self.federation.send_device_messages(destination)