e2e_keys.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. # Copyright 2018 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. from six import iteritems
  18. from canonicaljson import encode_canonical_json, json
  19. from twisted.internet import defer
  20. from synapse.api.errors import CodeMessageException, FederationDeniedError, SynapseError
  21. from synapse.types import UserID, get_domain_from_id
  22. from synapse.util.logcontext import make_deferred_yieldable, run_in_background
  23. from synapse.util.retryutils import NotRetryingDestination
  24. logger = logging.getLogger(__name__)
  25. class E2eKeysHandler(object):
  26. def __init__(self, hs):
  27. self.store = hs.get_datastore()
  28. self.federation = hs.get_federation_client()
  29. self.device_handler = hs.get_device_handler()
  30. self.is_mine = hs.is_mine
  31. self.clock = hs.get_clock()
  32. # doesn't really work as part of the generic query API, because the
  33. # query request requires an object POST, but we abuse the
  34. # "query handler" interface.
  35. hs.get_federation_registry().register_query_handler(
  36. "client_keys", self.on_federation_query_client_keys
  37. )
  38. @defer.inlineCallbacks
  39. def query_devices(self, query_body, timeout):
  40. """ Handle a device key query from a client
  41. {
  42. "device_keys": {
  43. "<user_id>": ["<device_id>"]
  44. }
  45. }
  46. ->
  47. {
  48. "device_keys": {
  49. "<user_id>": {
  50. "<device_id>": {
  51. ...
  52. }
  53. }
  54. }
  55. }
  56. """
  57. device_keys_query = query_body.get("device_keys", {})
  58. # separate users by domain.
  59. # make a map from domain to user_id to device_ids
  60. local_query = {}
  61. remote_queries = {}
  62. for user_id, device_ids in device_keys_query.items():
  63. # we use UserID.from_string to catch invalid user ids
  64. if self.is_mine(UserID.from_string(user_id)):
  65. local_query[user_id] = device_ids
  66. else:
  67. remote_queries[user_id] = device_ids
  68. # First get local devices.
  69. failures = {}
  70. results = {}
  71. if local_query:
  72. local_result = yield self.query_local_devices(local_query)
  73. for user_id, keys in local_result.items():
  74. if user_id in local_query:
  75. results[user_id] = keys
  76. # Now attempt to get any remote devices from our local cache.
  77. remote_queries_not_in_cache = {}
  78. if remote_queries:
  79. query_list = []
  80. for user_id, device_ids in iteritems(remote_queries):
  81. if device_ids:
  82. query_list.extend((user_id, device_id) for device_id in device_ids)
  83. else:
  84. query_list.append((user_id, None))
  85. user_ids_not_in_cache, remote_results = (
  86. yield self.store.get_user_devices_from_cache(
  87. query_list
  88. )
  89. )
  90. for user_id, devices in iteritems(remote_results):
  91. user_devices = results.setdefault(user_id, {})
  92. for device_id, device in iteritems(devices):
  93. keys = device.get("keys", None)
  94. device_display_name = device.get("device_display_name", None)
  95. if keys:
  96. result = dict(keys)
  97. unsigned = result.setdefault("unsigned", {})
  98. if device_display_name:
  99. unsigned["device_display_name"] = device_display_name
  100. user_devices[device_id] = result
  101. for user_id in user_ids_not_in_cache:
  102. domain = get_domain_from_id(user_id)
  103. r = remote_queries_not_in_cache.setdefault(domain, {})
  104. r[user_id] = remote_queries[user_id]
  105. # Now fetch any devices that we don't have in our cache
  106. @defer.inlineCallbacks
  107. def do_remote_query(destination):
  108. destination_query = remote_queries_not_in_cache[destination]
  109. try:
  110. remote_result = yield self.federation.query_client_keys(
  111. destination,
  112. {"device_keys": destination_query},
  113. timeout=timeout
  114. )
  115. for user_id, keys in remote_result["device_keys"].items():
  116. if user_id in destination_query:
  117. results[user_id] = keys
  118. except Exception as e:
  119. failures[destination] = _exception_to_failure(e)
  120. yield make_deferred_yieldable(defer.gatherResults([
  121. run_in_background(do_remote_query, destination)
  122. for destination in remote_queries_not_in_cache
  123. ], consumeErrors=True))
  124. defer.returnValue({
  125. "device_keys": results, "failures": failures,
  126. })
  127. @defer.inlineCallbacks
  128. def query_local_devices(self, query):
  129. """Get E2E device keys for local users
  130. Args:
  131. query (dict[string, list[string]|None): map from user_id to a list
  132. of devices to query (None for all devices)
  133. Returns:
  134. defer.Deferred: (resolves to dict[string, dict[string, dict]]):
  135. map from user_id -> device_id -> device details
  136. """
  137. local_query = []
  138. result_dict = {}
  139. for user_id, device_ids in query.items():
  140. # we use UserID.from_string to catch invalid user ids
  141. if not self.is_mine(UserID.from_string(user_id)):
  142. logger.warning("Request for keys for non-local user %s",
  143. user_id)
  144. raise SynapseError(400, "Not a user here")
  145. if not device_ids:
  146. local_query.append((user_id, None))
  147. else:
  148. for device_id in device_ids:
  149. local_query.append((user_id, device_id))
  150. # make sure that each queried user appears in the result dict
  151. result_dict[user_id] = {}
  152. results = yield self.store.get_e2e_device_keys(local_query)
  153. # Build the result structure, un-jsonify the results, and add the
  154. # "unsigned" section
  155. for user_id, device_keys in results.items():
  156. for device_id, device_info in device_keys.items():
  157. r = dict(device_info["keys"])
  158. r["unsigned"] = {}
  159. display_name = device_info["device_display_name"]
  160. if display_name is not None:
  161. r["unsigned"]["device_display_name"] = display_name
  162. result_dict[user_id][device_id] = r
  163. defer.returnValue(result_dict)
  164. @defer.inlineCallbacks
  165. def on_federation_query_client_keys(self, query_body):
  166. """ Handle a device key query from a federated server
  167. """
  168. device_keys_query = query_body.get("device_keys", {})
  169. res = yield self.query_local_devices(device_keys_query)
  170. defer.returnValue({"device_keys": res})
  171. @defer.inlineCallbacks
  172. def claim_one_time_keys(self, query, timeout):
  173. local_query = []
  174. remote_queries = {}
  175. for user_id, device_keys in query.get("one_time_keys", {}).items():
  176. # we use UserID.from_string to catch invalid user ids
  177. if self.is_mine(UserID.from_string(user_id)):
  178. for device_id, algorithm in device_keys.items():
  179. local_query.append((user_id, device_id, algorithm))
  180. else:
  181. domain = get_domain_from_id(user_id)
  182. remote_queries.setdefault(domain, {})[user_id] = device_keys
  183. results = yield self.store.claim_e2e_one_time_keys(local_query)
  184. json_result = {}
  185. failures = {}
  186. for user_id, device_keys in results.items():
  187. for device_id, keys in device_keys.items():
  188. for key_id, json_bytes in keys.items():
  189. json_result.setdefault(user_id, {})[device_id] = {
  190. key_id: json.loads(json_bytes)
  191. }
  192. @defer.inlineCallbacks
  193. def claim_client_keys(destination):
  194. device_keys = remote_queries[destination]
  195. try:
  196. remote_result = yield self.federation.claim_client_keys(
  197. destination,
  198. {"one_time_keys": device_keys},
  199. timeout=timeout
  200. )
  201. for user_id, keys in remote_result["one_time_keys"].items():
  202. if user_id in device_keys:
  203. json_result[user_id] = keys
  204. except Exception as e:
  205. failures[destination] = _exception_to_failure(e)
  206. yield make_deferred_yieldable(defer.gatherResults([
  207. run_in_background(claim_client_keys, destination)
  208. for destination in remote_queries
  209. ], consumeErrors=True))
  210. logger.info(
  211. "Claimed one-time-keys: %s",
  212. ",".join((
  213. "%s for %s:%s" % (key_id, user_id, device_id)
  214. for user_id, user_keys in iteritems(json_result)
  215. for device_id, device_keys in iteritems(user_keys)
  216. for key_id, _ in iteritems(device_keys)
  217. )),
  218. )
  219. defer.returnValue({
  220. "one_time_keys": json_result,
  221. "failures": failures
  222. })
  223. @defer.inlineCallbacks
  224. def upload_keys_for_user(self, user_id, device_id, keys):
  225. time_now = self.clock.time_msec()
  226. # TODO: Validate the JSON to make sure it has the right keys.
  227. device_keys = keys.get("device_keys", None)
  228. if device_keys:
  229. logger.info(
  230. "Updating device_keys for device %r for user %s at %d",
  231. device_id, user_id, time_now
  232. )
  233. # TODO: Sign the JSON with the server key
  234. changed = yield self.store.set_e2e_device_keys(
  235. user_id, device_id, time_now, device_keys,
  236. )
  237. if changed:
  238. # Only notify about device updates *if* the keys actually changed
  239. yield self.device_handler.notify_device_update(user_id, [device_id])
  240. one_time_keys = keys.get("one_time_keys", None)
  241. if one_time_keys:
  242. yield self._upload_one_time_keys_for_user(
  243. user_id, device_id, time_now, one_time_keys,
  244. )
  245. # the device should have been registered already, but it may have been
  246. # deleted due to a race with a DELETE request. Or we may be using an
  247. # old access_token without an associated device_id. Either way, we
  248. # need to double-check the device is registered to avoid ending up with
  249. # keys without a corresponding device.
  250. yield self.device_handler.check_device_registered(user_id, device_id)
  251. result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
  252. defer.returnValue({"one_time_key_counts": result})
  253. @defer.inlineCallbacks
  254. def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
  255. one_time_keys):
  256. logger.info(
  257. "Adding one_time_keys %r for device %r for user %r at %d",
  258. one_time_keys.keys(), device_id, user_id, time_now,
  259. )
  260. # make a list of (alg, id, key) tuples
  261. key_list = []
  262. for key_id, key_obj in one_time_keys.items():
  263. algorithm, key_id = key_id.split(":")
  264. key_list.append((
  265. algorithm, key_id, key_obj
  266. ))
  267. # First we check if we have already persisted any of the keys.
  268. existing_key_map = yield self.store.get_e2e_one_time_keys(
  269. user_id, device_id, [k_id for _, k_id, _ in key_list]
  270. )
  271. new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
  272. for algorithm, key_id, key in key_list:
  273. ex_json = existing_key_map.get((algorithm, key_id), None)
  274. if ex_json:
  275. if not _one_time_keys_match(ex_json, key):
  276. raise SynapseError(
  277. 400,
  278. ("One time key %s:%s already exists. "
  279. "Old key: %s; new key: %r") %
  280. (algorithm, key_id, ex_json, key)
  281. )
  282. else:
  283. new_keys.append((algorithm, key_id, encode_canonical_json(key)))
  284. yield self.store.add_e2e_one_time_keys(
  285. user_id, device_id, time_now, new_keys
  286. )
  287. def _exception_to_failure(e):
  288. if isinstance(e, CodeMessageException):
  289. return {
  290. "status": e.code, "message": e.message,
  291. }
  292. if isinstance(e, NotRetryingDestination):
  293. return {
  294. "status": 503, "message": "Not ready for retry",
  295. }
  296. if isinstance(e, FederationDeniedError):
  297. return {
  298. "status": 403, "message": "Federation Denied",
  299. }
  300. # include ConnectionRefused and other errors
  301. #
  302. # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
  303. # give a string for e.message, which json then fails to serialize.
  304. return {
  305. "status": 503, "message": str(e.message),
  306. }
  307. def _one_time_keys_match(old_key_json, new_key):
  308. old_key = json.loads(old_key_json)
  309. # if either is a string rather than an object, they must match exactly
  310. if not isinstance(old_key, dict) or not isinstance(new_key, dict):
  311. return old_key == new_key
  312. # otherwise, we strip off the 'signatures' if any, because it's legitimate
  313. # for different upload attempts to have different signatures.
  314. old_key.pop("signatures", None)
  315. new_key_copy = dict(new_key)
  316. new_key_copy.pop("signatures", None)
  317. return old_key == new_key_copy