devicemessage.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from twisted.internet import defer
  17. from synapse.api.errors import SynapseError
  18. from synapse.types import UserID, get_domain_from_id
  19. from synapse.util.stringutils import random_string
  20. logger = logging.getLogger(__name__)
  21. class DeviceMessageHandler(object):
  22. def __init__(self, hs):
  23. """
  24. Args:
  25. hs (synapse.server.HomeServer): server
  26. """
  27. self.store = hs.get_datastore()
  28. self.notifier = hs.get_notifier()
  29. self.is_mine = hs.is_mine
  30. self.federation = hs.get_federation_sender()
  31. hs.get_federation_registry().register_edu_handler(
  32. "m.direct_to_device", self.on_direct_to_device_edu
  33. )
  34. @defer.inlineCallbacks
  35. def on_direct_to_device_edu(self, origin, content):
  36. local_messages = {}
  37. sender_user_id = content["sender"]
  38. if origin != get_domain_from_id(sender_user_id):
  39. logger.warn(
  40. "Dropping device message from %r with spoofed sender %r",
  41. origin, sender_user_id
  42. )
  43. message_type = content["type"]
  44. message_id = content["message_id"]
  45. for user_id, by_device in content["messages"].items():
  46. # we use UserID.from_string to catch invalid user ids
  47. if not self.is_mine(UserID.from_string(user_id)):
  48. logger.warning("Request for keys for non-local user %s",
  49. user_id)
  50. raise SynapseError(400, "Not a user here")
  51. messages_by_device = {
  52. device_id: {
  53. "content": message_content,
  54. "type": message_type,
  55. "sender": sender_user_id,
  56. }
  57. for device_id, message_content in by_device.items()
  58. }
  59. if messages_by_device:
  60. local_messages[user_id] = messages_by_device
  61. stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
  62. origin, message_id, local_messages
  63. )
  64. self.notifier.on_new_event(
  65. "to_device_key", stream_id, users=local_messages.keys()
  66. )
  67. @defer.inlineCallbacks
  68. def send_device_message(self, sender_user_id, message_type, messages):
  69. local_messages = {}
  70. remote_messages = {}
  71. for user_id, by_device in messages.items():
  72. # we use UserID.from_string to catch invalid user ids
  73. if self.is_mine(UserID.from_string(user_id)):
  74. messages_by_device = {
  75. device_id: {
  76. "content": message_content,
  77. "type": message_type,
  78. "sender": sender_user_id,
  79. }
  80. for device_id, message_content in by_device.items()
  81. }
  82. if messages_by_device:
  83. local_messages[user_id] = messages_by_device
  84. else:
  85. destination = get_domain_from_id(user_id)
  86. remote_messages.setdefault(destination, {})[user_id] = by_device
  87. message_id = random_string(16)
  88. remote_edu_contents = {}
  89. for destination, messages in remote_messages.items():
  90. remote_edu_contents[destination] = {
  91. "messages": messages,
  92. "sender": sender_user_id,
  93. "type": message_type,
  94. "message_id": message_id,
  95. }
  96. stream_id = yield self.store.add_messages_to_device_inbox(
  97. local_messages, remote_edu_contents
  98. )
  99. self.notifier.on_new_event(
  100. "to_device_key", stream_id, users=local_messages.keys()
  101. )
  102. for destination in remote_messages.keys():
  103. # Enqueue a new federation transaction to send the new
  104. # device messages to each remote destination.
  105. self.federation.send_device_messages(destination)