e2e_keys.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. # Copyright 2018 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. from six import iteritems
  18. from canonicaljson import encode_canonical_json, json
  19. from twisted.internet import defer
  20. from synapse.api.errors import CodeMessageException, SynapseError
  21. from synapse.logging.context import make_deferred_yieldable, run_in_background
  22. from synapse.types import UserID, get_domain_from_id
  23. from synapse.util.retryutils import NotRetryingDestination
  24. logger = logging.getLogger(__name__)
  25. class E2eKeysHandler(object):
  26. def __init__(self, hs):
  27. self.store = hs.get_datastore()
  28. self.federation = hs.get_federation_client()
  29. self.device_handler = hs.get_device_handler()
  30. self.is_mine = hs.is_mine
  31. self.clock = hs.get_clock()
  32. # doesn't really work as part of the generic query API, because the
  33. # query request requires an object POST, but we abuse the
  34. # "query handler" interface.
  35. hs.get_federation_registry().register_query_handler(
  36. "client_keys", self.on_federation_query_client_keys
  37. )
  38. @defer.inlineCallbacks
  39. def query_devices(self, query_body, timeout):
  40. """ Handle a device key query from a client
  41. {
  42. "device_keys": {
  43. "<user_id>": ["<device_id>"]
  44. }
  45. }
  46. ->
  47. {
  48. "device_keys": {
  49. "<user_id>": {
  50. "<device_id>": {
  51. ...
  52. }
  53. }
  54. }
  55. }
  56. """
  57. device_keys_query = query_body.get("device_keys", {})
  58. # separate users by domain.
  59. # make a map from domain to user_id to device_ids
  60. local_query = {}
  61. remote_queries = {}
  62. for user_id, device_ids in device_keys_query.items():
  63. # we use UserID.from_string to catch invalid user ids
  64. if self.is_mine(UserID.from_string(user_id)):
  65. local_query[user_id] = device_ids
  66. else:
  67. remote_queries[user_id] = device_ids
  68. # First get local devices.
  69. failures = {}
  70. results = {}
  71. if local_query:
  72. local_result = yield self.query_local_devices(local_query)
  73. for user_id, keys in local_result.items():
  74. if user_id in local_query:
  75. results[user_id] = keys
  76. # Now attempt to get any remote devices from our local cache.
  77. remote_queries_not_in_cache = {}
  78. if remote_queries:
  79. query_list = []
  80. for user_id, device_ids in iteritems(remote_queries):
  81. if device_ids:
  82. query_list.extend((user_id, device_id) for device_id in device_ids)
  83. else:
  84. query_list.append((user_id, None))
  85. user_ids_not_in_cache, remote_results = (
  86. yield self.store.get_user_devices_from_cache(query_list)
  87. )
  88. for user_id, devices in iteritems(remote_results):
  89. user_devices = results.setdefault(user_id, {})
  90. for device_id, device in iteritems(devices):
  91. keys = device.get("keys", None)
  92. device_display_name = device.get("device_display_name", None)
  93. if keys:
  94. result = dict(keys)
  95. unsigned = result.setdefault("unsigned", {})
  96. if device_display_name:
  97. unsigned["device_display_name"] = device_display_name
  98. user_devices[device_id] = result
  99. for user_id in user_ids_not_in_cache:
  100. domain = get_domain_from_id(user_id)
  101. r = remote_queries_not_in_cache.setdefault(domain, {})
  102. r[user_id] = remote_queries[user_id]
  103. # Now fetch any devices that we don't have in our cache
  104. @defer.inlineCallbacks
  105. def do_remote_query(destination):
  106. destination_query = remote_queries_not_in_cache[destination]
  107. try:
  108. remote_result = yield self.federation.query_client_keys(
  109. destination, {"device_keys": destination_query}, timeout=timeout
  110. )
  111. for user_id, keys in remote_result["device_keys"].items():
  112. if user_id in destination_query:
  113. results[user_id] = keys
  114. except Exception as e:
  115. failures[destination] = _exception_to_failure(e)
  116. yield make_deferred_yieldable(
  117. defer.gatherResults(
  118. [
  119. run_in_background(do_remote_query, destination)
  120. for destination in remote_queries_not_in_cache
  121. ],
  122. consumeErrors=True,
  123. )
  124. )
  125. return {"device_keys": results, "failures": failures}
  126. @defer.inlineCallbacks
  127. def query_local_devices(self, query):
  128. """Get E2E device keys for local users
  129. Args:
  130. query (dict[string, list[string]|None): map from user_id to a list
  131. of devices to query (None for all devices)
  132. Returns:
  133. defer.Deferred: (resolves to dict[string, dict[string, dict]]):
  134. map from user_id -> device_id -> device details
  135. """
  136. local_query = []
  137. result_dict = {}
  138. for user_id, device_ids in query.items():
  139. # we use UserID.from_string to catch invalid user ids
  140. if not self.is_mine(UserID.from_string(user_id)):
  141. logger.warning("Request for keys for non-local user %s", user_id)
  142. raise SynapseError(400, "Not a user here")
  143. if not device_ids:
  144. local_query.append((user_id, None))
  145. else:
  146. for device_id in device_ids:
  147. local_query.append((user_id, device_id))
  148. # make sure that each queried user appears in the result dict
  149. result_dict[user_id] = {}
  150. results = yield self.store.get_e2e_device_keys(local_query)
  151. # Build the result structure, un-jsonify the results, and add the
  152. # "unsigned" section
  153. for user_id, device_keys in results.items():
  154. for device_id, device_info in device_keys.items():
  155. r = dict(device_info["keys"])
  156. r["unsigned"] = {}
  157. display_name = device_info["device_display_name"]
  158. if display_name is not None:
  159. r["unsigned"]["device_display_name"] = display_name
  160. result_dict[user_id][device_id] = r
  161. return result_dict
  162. @defer.inlineCallbacks
  163. def on_federation_query_client_keys(self, query_body):
  164. """ Handle a device key query from a federated server
  165. """
  166. device_keys_query = query_body.get("device_keys", {})
  167. res = yield self.query_local_devices(device_keys_query)
  168. return {"device_keys": res}
  169. @defer.inlineCallbacks
  170. def claim_one_time_keys(self, query, timeout):
  171. local_query = []
  172. remote_queries = {}
  173. for user_id, device_keys in query.get("one_time_keys", {}).items():
  174. # we use UserID.from_string to catch invalid user ids
  175. if self.is_mine(UserID.from_string(user_id)):
  176. for device_id, algorithm in device_keys.items():
  177. local_query.append((user_id, device_id, algorithm))
  178. else:
  179. domain = get_domain_from_id(user_id)
  180. remote_queries.setdefault(domain, {})[user_id] = device_keys
  181. results = yield self.store.claim_e2e_one_time_keys(local_query)
  182. json_result = {}
  183. failures = {}
  184. for user_id, device_keys in results.items():
  185. for device_id, keys in device_keys.items():
  186. for key_id, json_bytes in keys.items():
  187. json_result.setdefault(user_id, {})[device_id] = {
  188. key_id: json.loads(json_bytes)
  189. }
  190. @defer.inlineCallbacks
  191. def claim_client_keys(destination):
  192. device_keys = remote_queries[destination]
  193. try:
  194. remote_result = yield self.federation.claim_client_keys(
  195. destination, {"one_time_keys": device_keys}, timeout=timeout
  196. )
  197. for user_id, keys in remote_result["one_time_keys"].items():
  198. if user_id in device_keys:
  199. json_result[user_id] = keys
  200. except Exception as e:
  201. failures[destination] = _exception_to_failure(e)
  202. yield make_deferred_yieldable(
  203. defer.gatherResults(
  204. [
  205. run_in_background(claim_client_keys, destination)
  206. for destination in remote_queries
  207. ],
  208. consumeErrors=True,
  209. )
  210. )
  211. logger.info(
  212. "Claimed one-time-keys: %s",
  213. ",".join(
  214. (
  215. "%s for %s:%s" % (key_id, user_id, device_id)
  216. for user_id, user_keys in iteritems(json_result)
  217. for device_id, device_keys in iteritems(user_keys)
  218. for key_id, _ in iteritems(device_keys)
  219. )
  220. ),
  221. )
  222. return {"one_time_keys": json_result, "failures": failures}
  223. @defer.inlineCallbacks
  224. def upload_keys_for_user(self, user_id, device_id, keys):
  225. time_now = self.clock.time_msec()
  226. # TODO: Validate the JSON to make sure it has the right keys.
  227. device_keys = keys.get("device_keys", None)
  228. if device_keys:
  229. logger.info(
  230. "Updating device_keys for device %r for user %s at %d",
  231. device_id,
  232. user_id,
  233. time_now,
  234. )
  235. # TODO: Sign the JSON with the server key
  236. changed = yield self.store.set_e2e_device_keys(
  237. user_id, device_id, time_now, device_keys
  238. )
  239. if changed:
  240. # Only notify about device updates *if* the keys actually changed
  241. yield self.device_handler.notify_device_update(user_id, [device_id])
  242. one_time_keys = keys.get("one_time_keys", None)
  243. if one_time_keys:
  244. yield self._upload_one_time_keys_for_user(
  245. user_id, device_id, time_now, one_time_keys
  246. )
  247. # the device should have been registered already, but it may have been
  248. # deleted due to a race with a DELETE request. Or we may be using an
  249. # old access_token without an associated device_id. Either way, we
  250. # need to double-check the device is registered to avoid ending up with
  251. # keys without a corresponding device.
  252. yield self.device_handler.check_device_registered(user_id, device_id)
  253. result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
  254. return {"one_time_key_counts": result}
  255. @defer.inlineCallbacks
  256. def _upload_one_time_keys_for_user(
  257. self, user_id, device_id, time_now, one_time_keys
  258. ):
  259. logger.info(
  260. "Adding one_time_keys %r for device %r for user %r at %d",
  261. one_time_keys.keys(),
  262. device_id,
  263. user_id,
  264. time_now,
  265. )
  266. # make a list of (alg, id, key) tuples
  267. key_list = []
  268. for key_id, key_obj in one_time_keys.items():
  269. algorithm, key_id = key_id.split(":")
  270. key_list.append((algorithm, key_id, key_obj))
  271. # First we check if we have already persisted any of the keys.
  272. existing_key_map = yield self.store.get_e2e_one_time_keys(
  273. user_id, device_id, [k_id for _, k_id, _ in key_list]
  274. )
  275. new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
  276. for algorithm, key_id, key in key_list:
  277. ex_json = existing_key_map.get((algorithm, key_id), None)
  278. if ex_json:
  279. if not _one_time_keys_match(ex_json, key):
  280. raise SynapseError(
  281. 400,
  282. (
  283. "One time key %s:%s already exists. "
  284. "Old key: %s; new key: %r"
  285. )
  286. % (algorithm, key_id, ex_json, key),
  287. )
  288. else:
  289. new_keys.append(
  290. (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
  291. )
  292. yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
  293. def _exception_to_failure(e):
  294. if isinstance(e, CodeMessageException):
  295. return {"status": e.code, "message": str(e)}
  296. if isinstance(e, NotRetryingDestination):
  297. return {"status": 503, "message": "Not ready for retry"}
  298. # include ConnectionRefused and other errors
  299. #
  300. # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
  301. # give a string for e.message, which json then fails to serialize.
  302. return {"status": 503, "message": str(e)}
  303. def _one_time_keys_match(old_key_json, new_key):
  304. old_key = json.loads(old_key_json)
  305. # if either is a string rather than an object, they must match exactly
  306. if not isinstance(old_key, dict) or not isinstance(new_key, dict):
  307. return old_key == new_key
  308. # otherwise, we strip off the 'signatures' if any, because it's legitimate
  309. # for different upload attempts to have different signatures.
  310. old_key.pop("signatures", None)
  311. new_key_copy = dict(new_key)
  312. new_key_copy.pop("signatures", None)
  313. return old_key == new_key_copy