keyring.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from synapse.crypto.keyclient import fetch_server_key
  16. from synapse.api.errors import SynapseError, Codes
  17. from synapse.util.retryutils import get_retry_limiter
  18. from synapse.util import unwrapFirstError
  19. from synapse.util.async import ObservableDeferred
  20. from synapse.util.logcontext import (
  21. preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
  22. preserve_fn
  23. )
  24. from twisted.internet import defer
  25. from signedjson.sign import (
  26. verify_signed_json, signature_ids, sign_json, encode_canonical_json
  27. )
  28. from signedjson.key import (
  29. is_signing_algorithm_supported, decode_verify_key_bytes
  30. )
  31. from unpaddedbase64 import decode_base64, encode_base64
  32. from OpenSSL import crypto
  33. from collections import namedtuple
  34. import urllib
  35. import hashlib
  36. import logging
  37. logger = logging.getLogger(__name__)
  38. KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
  39. class Keyring(object):
  40. def __init__(self, hs):
  41. self.store = hs.get_datastore()
  42. self.clock = hs.get_clock()
  43. self.client = hs.get_http_client()
  44. self.config = hs.get_config()
  45. self.perspective_servers = self.config.perspectives
  46. self.hs = hs
  47. self.key_downloads = {}
  48. def verify_json_for_server(self, server_name, json_object):
  49. return self.verify_json_objects_for_server(
  50. [(server_name, json_object)]
  51. )[0]
  52. def verify_json_objects_for_server(self, server_and_json):
  53. """Bulk verfies signatures of json objects, bulk fetching keys as
  54. necessary.
  55. Args:
  56. server_and_json (list): List of pairs of (server_name, json_object)
  57. Returns:
  58. list of deferreds indicating success or failure to verify each
  59. json object's signature for the given server_name.
  60. """
  61. group_id_to_json = {}
  62. group_id_to_group = {}
  63. group_ids = []
  64. next_group_id = 0
  65. deferreds = {}
  66. for server_name, json_object in server_and_json:
  67. logger.debug("Verifying for %s", server_name)
  68. group_id = next_group_id
  69. next_group_id += 1
  70. group_ids.append(group_id)
  71. key_ids = signature_ids(json_object, server_name)
  72. if not key_ids:
  73. deferreds[group_id] = defer.fail(SynapseError(
  74. 400,
  75. "Not signed with a supported algorithm",
  76. Codes.UNAUTHORIZED,
  77. ))
  78. else:
  79. deferreds[group_id] = defer.Deferred()
  80. group = KeyGroup(server_name, group_id, key_ids)
  81. group_id_to_group[group_id] = group
  82. group_id_to_json[group_id] = json_object
  83. @defer.inlineCallbacks
  84. def handle_key_deferred(group, deferred):
  85. server_name = group.server_name
  86. try:
  87. _, _, key_id, verify_key = yield deferred
  88. except IOError as e:
  89. logger.warn(
  90. "Got IOError when downloading keys for %s: %s %s",
  91. server_name, type(e).__name__, str(e.message),
  92. )
  93. raise SynapseError(
  94. 502,
  95. "Error downloading keys for %s" % (server_name,),
  96. Codes.UNAUTHORIZED,
  97. )
  98. except Exception as e:
  99. logger.exception(
  100. "Got Exception when downloading keys for %s: %s %s",
  101. server_name, type(e).__name__, str(e.message),
  102. )
  103. raise SynapseError(
  104. 401,
  105. "No key for %s with id %s" % (server_name, key_ids),
  106. Codes.UNAUTHORIZED,
  107. )
  108. json_object = group_id_to_json[group.group_id]
  109. try:
  110. verify_signed_json(json_object, server_name, verify_key)
  111. except:
  112. raise SynapseError(
  113. 401,
  114. "Invalid signature for server %s with key %s:%s" % (
  115. server_name, verify_key.alg, verify_key.version
  116. ),
  117. Codes.UNAUTHORIZED,
  118. )
  119. server_to_deferred = {
  120. server_name: defer.Deferred()
  121. for server_name, _ in server_and_json
  122. }
  123. with PreserveLoggingContext():
  124. # We want to wait for any previous lookups to complete before
  125. # proceeding.
  126. wait_on_deferred = self.wait_for_previous_lookups(
  127. [server_name for server_name, _ in server_and_json],
  128. server_to_deferred,
  129. )
  130. # Actually start fetching keys.
  131. wait_on_deferred.addBoth(
  132. lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
  133. )
  134. # When we've finished fetching all the keys for a given server_name,
  135. # resolve the deferred passed to `wait_for_previous_lookups` so that
  136. # any lookups waiting will proceed.
  137. server_to_gids = {}
  138. def remove_deferreds(res, server_name, group_id):
  139. server_to_gids[server_name].discard(group_id)
  140. if not server_to_gids[server_name]:
  141. d = server_to_deferred.pop(server_name, None)
  142. if d:
  143. d.callback(None)
  144. return res
  145. for g_id, deferred in deferreds.items():
  146. server_name = group_id_to_group[g_id].server_name
  147. server_to_gids.setdefault(server_name, set()).add(g_id)
  148. deferred.addBoth(remove_deferreds, server_name, g_id)
  149. # Pass those keys to handle_key_deferred so that the json object
  150. # signatures can be verified
  151. return [
  152. preserve_context_over_fn(
  153. handle_key_deferred,
  154. group_id_to_group[g_id],
  155. deferreds[g_id],
  156. )
  157. for g_id in group_ids
  158. ]
  159. @defer.inlineCallbacks
  160. def wait_for_previous_lookups(self, server_names, server_to_deferred):
  161. """Waits for any previous key lookups for the given servers to finish.
  162. Args:
  163. server_names (list): list of server_names we want to lookup
  164. server_to_deferred (dict): server_name to deferred which gets
  165. resolved once we've finished looking up keys for that server
  166. """
  167. while True:
  168. wait_on = [
  169. self.key_downloads[server_name]
  170. for server_name in server_names
  171. if server_name in self.key_downloads
  172. ]
  173. if wait_on:
  174. with PreserveLoggingContext():
  175. yield defer.DeferredList(wait_on)
  176. else:
  177. break
  178. for server_name, deferred in server_to_deferred.items():
  179. d = ObservableDeferred(preserve_context_over_deferred(deferred))
  180. self.key_downloads[server_name] = d
  181. def rm(r, server_name):
  182. self.key_downloads.pop(server_name, None)
  183. return r
  184. d.addBoth(rm, server_name)
  185. def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
  186. """Takes a dict of KeyGroups and tries to find at least one key for
  187. each group.
  188. """
  189. # These are functions that produce keys given a list of key ids
  190. key_fetch_fns = (
  191. self.get_keys_from_store, # First try the local store
  192. self.get_keys_from_perspectives, # Then try via perspectives
  193. self.get_keys_from_server, # Then try directly
  194. )
  195. @defer.inlineCallbacks
  196. def do_iterations():
  197. merged_results = {}
  198. missing_keys = {}
  199. for group in group_id_to_group.values():
  200. missing_keys.setdefault(group.server_name, set()).update(
  201. group.key_ids
  202. )
  203. for fn in key_fetch_fns:
  204. results = yield fn(missing_keys.items())
  205. merged_results.update(results)
  206. # We now need to figure out which groups we have keys for
  207. # and which we don't
  208. missing_groups = {}
  209. for group in group_id_to_group.values():
  210. for key_id in group.key_ids:
  211. if key_id in merged_results[group.server_name]:
  212. with PreserveLoggingContext():
  213. group_id_to_deferred[group.group_id].callback((
  214. group.group_id,
  215. group.server_name,
  216. key_id,
  217. merged_results[group.server_name][key_id],
  218. ))
  219. break
  220. else:
  221. missing_groups.setdefault(
  222. group.server_name, []
  223. ).append(group)
  224. if not missing_groups:
  225. break
  226. missing_keys = {
  227. server_name: set(
  228. key_id for group in groups for key_id in group.key_ids
  229. )
  230. for server_name, groups in missing_groups.items()
  231. }
  232. for group in missing_groups.values():
  233. group_id_to_deferred[group.group_id].errback(SynapseError(
  234. 401,
  235. "No key for %s with id %s" % (
  236. group.server_name, group.key_ids,
  237. ),
  238. Codes.UNAUTHORIZED,
  239. ))
  240. def on_err(err):
  241. for deferred in group_id_to_deferred.values():
  242. if not deferred.called:
  243. deferred.errback(err)
  244. do_iterations().addErrback(on_err)
  245. return group_id_to_deferred
  246. @defer.inlineCallbacks
  247. def get_keys_from_store(self, server_name_and_key_ids):
  248. res = yield defer.gatherResults(
  249. [
  250. self.store.get_server_verify_keys(
  251. server_name, key_ids
  252. ).addCallback(lambda ks, server: (server, ks), server_name)
  253. for server_name, key_ids in server_name_and_key_ids
  254. ],
  255. consumeErrors=True,
  256. ).addErrback(unwrapFirstError)
  257. defer.returnValue(dict(res))
  258. @defer.inlineCallbacks
  259. def get_keys_from_perspectives(self, server_name_and_key_ids):
  260. @defer.inlineCallbacks
  261. def get_key(perspective_name, perspective_keys):
  262. try:
  263. result = yield self.get_server_verify_key_v2_indirect(
  264. server_name_and_key_ids, perspective_name, perspective_keys
  265. )
  266. defer.returnValue(result)
  267. except Exception as e:
  268. logger.exception(
  269. "Unable to get key from %r: %s %s",
  270. perspective_name,
  271. type(e).__name__, str(e.message),
  272. )
  273. defer.returnValue({})
  274. results = yield defer.gatherResults(
  275. [
  276. get_key(p_name, p_keys)
  277. for p_name, p_keys in self.perspective_servers.items()
  278. ],
  279. consumeErrors=True,
  280. ).addErrback(unwrapFirstError)
  281. union_of_keys = {}
  282. for result in results:
  283. for server_name, keys in result.items():
  284. union_of_keys.setdefault(server_name, {}).update(keys)
  285. defer.returnValue(union_of_keys)
  286. @defer.inlineCallbacks
  287. def get_keys_from_server(self, server_name_and_key_ids):
  288. @defer.inlineCallbacks
  289. def get_key(server_name, key_ids):
  290. limiter = yield get_retry_limiter(
  291. server_name,
  292. self.clock,
  293. self.store,
  294. )
  295. with limiter:
  296. keys = None
  297. try:
  298. keys = yield self.get_server_verify_key_v2_direct(
  299. server_name, key_ids
  300. )
  301. except Exception as e:
  302. logger.info(
  303. "Unable to getting key %r for %r directly: %s %s",
  304. key_ids, server_name,
  305. type(e).__name__, str(e.message),
  306. )
  307. if not keys:
  308. keys = yield self.get_server_verify_key_v1_direct(
  309. server_name, key_ids
  310. )
  311. keys = {server_name: keys}
  312. defer.returnValue(keys)
  313. results = yield defer.gatherResults(
  314. [
  315. get_key(server_name, key_ids)
  316. for server_name, key_ids in server_name_and_key_ids
  317. ],
  318. consumeErrors=True,
  319. ).addErrback(unwrapFirstError)
  320. merged = {}
  321. for result in results:
  322. merged.update(result)
  323. defer.returnValue({
  324. server_name: keys
  325. for server_name, keys in merged.items()
  326. if keys
  327. })
  328. @defer.inlineCallbacks
  329. def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
  330. perspective_name,
  331. perspective_keys):
  332. # TODO(mark): Set the minimum_valid_until_ts to that needed by
  333. # the events being validated or the current time if validating
  334. # an incoming request.
  335. query_response = yield self.client.post_json(
  336. destination=perspective_name,
  337. path=b"/_matrix/key/v2/query",
  338. data={
  339. u"server_keys": {
  340. server_name: {
  341. key_id: {
  342. u"minimum_valid_until_ts": 0
  343. } for key_id in key_ids
  344. }
  345. for server_name, key_ids in server_names_and_key_ids
  346. }
  347. },
  348. long_retries=True,
  349. )
  350. keys = {}
  351. responses = query_response["server_keys"]
  352. for response in responses:
  353. if (u"signatures" not in response
  354. or perspective_name not in response[u"signatures"]):
  355. raise ValueError(
  356. "Key response not signed by perspective server"
  357. " %r" % (perspective_name,)
  358. )
  359. verified = False
  360. for key_id in response[u"signatures"][perspective_name]:
  361. if key_id in perspective_keys:
  362. verify_signed_json(
  363. response,
  364. perspective_name,
  365. perspective_keys[key_id]
  366. )
  367. verified = True
  368. if not verified:
  369. logging.info(
  370. "Response from perspective server %r not signed with a"
  371. " known key, signed with: %r, known keys: %r",
  372. perspective_name,
  373. list(response[u"signatures"][perspective_name]),
  374. list(perspective_keys)
  375. )
  376. raise ValueError(
  377. "Response not signed with a known key for perspective"
  378. " server %r" % (perspective_name,)
  379. )
  380. processed_response = yield self.process_v2_response(
  381. perspective_name, response
  382. )
  383. for server_name, response_keys in processed_response.items():
  384. keys.setdefault(server_name, {}).update(response_keys)
  385. yield defer.gatherResults(
  386. [
  387. self.store_keys(
  388. server_name=server_name,
  389. from_server=perspective_name,
  390. verify_keys=response_keys,
  391. )
  392. for server_name, response_keys in keys.items()
  393. ],
  394. consumeErrors=True
  395. ).addErrback(unwrapFirstError)
  396. defer.returnValue(keys)
  397. @defer.inlineCallbacks
  398. def get_server_verify_key_v2_direct(self, server_name, key_ids):
  399. keys = {}
  400. for requested_key_id in key_ids:
  401. if requested_key_id in keys:
  402. continue
  403. (response, tls_certificate) = yield fetch_server_key(
  404. server_name, self.hs.tls_server_context_factory,
  405. path=(b"/_matrix/key/v2/server/%s" % (
  406. urllib.quote(requested_key_id),
  407. )).encode("ascii"),
  408. )
  409. if (u"signatures" not in response
  410. or server_name not in response[u"signatures"]):
  411. raise ValueError("Key response not signed by remote server")
  412. if "tls_fingerprints" not in response:
  413. raise ValueError("Key response missing TLS fingerprints")
  414. certificate_bytes = crypto.dump_certificate(
  415. crypto.FILETYPE_ASN1, tls_certificate
  416. )
  417. sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
  418. sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
  419. response_sha256_fingerprints = set()
  420. for fingerprint in response[u"tls_fingerprints"]:
  421. if u"sha256" in fingerprint:
  422. response_sha256_fingerprints.add(fingerprint[u"sha256"])
  423. if sha256_fingerprint_b64 not in response_sha256_fingerprints:
  424. raise ValueError("TLS certificate not allowed by fingerprints")
  425. response_keys = yield self.process_v2_response(
  426. from_server=server_name,
  427. requested_ids=[requested_key_id],
  428. response_json=response,
  429. )
  430. keys.update(response_keys)
  431. yield defer.gatherResults(
  432. [
  433. preserve_fn(self.store_keys)(
  434. server_name=key_server_name,
  435. from_server=server_name,
  436. verify_keys=verify_keys,
  437. )
  438. for key_server_name, verify_keys in keys.items()
  439. ],
  440. consumeErrors=True
  441. ).addErrback(unwrapFirstError)
  442. defer.returnValue(keys)
  443. @defer.inlineCallbacks
  444. def process_v2_response(self, from_server, response_json,
  445. requested_ids=[]):
  446. time_now_ms = self.clock.time_msec()
  447. response_keys = {}
  448. verify_keys = {}
  449. for key_id, key_data in response_json["verify_keys"].items():
  450. if is_signing_algorithm_supported(key_id):
  451. key_base64 = key_data["key"]
  452. key_bytes = decode_base64(key_base64)
  453. verify_key = decode_verify_key_bytes(key_id, key_bytes)
  454. verify_key.time_added = time_now_ms
  455. verify_keys[key_id] = verify_key
  456. old_verify_keys = {}
  457. for key_id, key_data in response_json["old_verify_keys"].items():
  458. if is_signing_algorithm_supported(key_id):
  459. key_base64 = key_data["key"]
  460. key_bytes = decode_base64(key_base64)
  461. verify_key = decode_verify_key_bytes(key_id, key_bytes)
  462. verify_key.expired = key_data["expired_ts"]
  463. verify_key.time_added = time_now_ms
  464. old_verify_keys[key_id] = verify_key
  465. results = {}
  466. server_name = response_json["server_name"]
  467. for key_id in response_json["signatures"].get(server_name, {}):
  468. if key_id not in response_json["verify_keys"]:
  469. raise ValueError(
  470. "Key response must include verification keys for all"
  471. " signatures"
  472. )
  473. if key_id in verify_keys:
  474. verify_signed_json(
  475. response_json,
  476. server_name,
  477. verify_keys[key_id]
  478. )
  479. signed_key_json = sign_json(
  480. response_json,
  481. self.config.server_name,
  482. self.config.signing_key[0],
  483. )
  484. signed_key_json_bytes = encode_canonical_json(signed_key_json)
  485. ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
  486. updated_key_ids = set(requested_ids)
  487. updated_key_ids.update(verify_keys)
  488. updated_key_ids.update(old_verify_keys)
  489. response_keys.update(verify_keys)
  490. response_keys.update(old_verify_keys)
  491. yield defer.gatherResults(
  492. [
  493. preserve_fn(self.store.store_server_keys_json)(
  494. server_name=server_name,
  495. key_id=key_id,
  496. from_server=server_name,
  497. ts_now_ms=time_now_ms,
  498. ts_expires_ms=ts_valid_until_ms,
  499. key_json_bytes=signed_key_json_bytes,
  500. )
  501. for key_id in updated_key_ids
  502. ],
  503. consumeErrors=True,
  504. ).addErrback(unwrapFirstError)
  505. results[server_name] = response_keys
  506. defer.returnValue(results)
  507. @defer.inlineCallbacks
  508. def get_server_verify_key_v1_direct(self, server_name, key_ids):
  509. """Finds a verification key for the server with one of the key ids.
  510. Args:
  511. server_name (str): The name of the server to fetch a key for.
  512. keys_ids (list of str): The key_ids to check for.
  513. """
  514. # Try to fetch the key from the remote server.
  515. (response, tls_certificate) = yield fetch_server_key(
  516. server_name, self.hs.tls_server_context_factory
  517. )
  518. # Check the response.
  519. x509_certificate_bytes = crypto.dump_certificate(
  520. crypto.FILETYPE_ASN1, tls_certificate
  521. )
  522. if ("signatures" not in response
  523. or server_name not in response["signatures"]):
  524. raise ValueError("Key response not signed by remote server")
  525. if "tls_certificate" not in response:
  526. raise ValueError("Key response missing TLS certificate")
  527. tls_certificate_b64 = response["tls_certificate"]
  528. if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
  529. raise ValueError("TLS certificate doesn't match")
  530. # Cache the result in the datastore.
  531. time_now_ms = self.clock.time_msec()
  532. verify_keys = {}
  533. for key_id, key_base64 in response["verify_keys"].items():
  534. if is_signing_algorithm_supported(key_id):
  535. key_bytes = decode_base64(key_base64)
  536. verify_key = decode_verify_key_bytes(key_id, key_bytes)
  537. verify_key.time_added = time_now_ms
  538. verify_keys[key_id] = verify_key
  539. for key_id in response["signatures"][server_name]:
  540. if key_id not in response["verify_keys"]:
  541. raise ValueError(
  542. "Key response must include verification keys for all"
  543. " signatures"
  544. )
  545. if key_id in verify_keys:
  546. verify_signed_json(
  547. response,
  548. server_name,
  549. verify_keys[key_id]
  550. )
  551. yield self.store.store_server_certificate(
  552. server_name,
  553. server_name,
  554. time_now_ms,
  555. tls_certificate,
  556. )
  557. yield self.store_keys(
  558. server_name=server_name,
  559. from_server=server_name,
  560. verify_keys=verify_keys,
  561. )
  562. defer.returnValue(verify_keys)
  563. @defer.inlineCallbacks
  564. def store_keys(self, server_name, from_server, verify_keys):
  565. """Store a collection of verify keys for a given server
  566. Args:
  567. server_name(str): The name of the server the keys are for.
  568. from_server(str): The server the keys were downloaded from.
  569. verify_keys(dict): A mapping of key_id to VerifyKey.
  570. Returns:
  571. A deferred that completes when the keys are stored.
  572. """
  573. # TODO(markjh): Store whether the keys have expired.
  574. yield defer.gatherResults(
  575. [
  576. preserve_fn(self.store.store_server_verify_key)(
  577. server_name, server_name, key.time_added, key
  578. )
  579. for key_id, key in verify_keys.items()
  580. ],
  581. consumeErrors=True,
  582. ).addErrback(unwrapFirstError)