keyring.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from synapse.crypto.keyclient import fetch_server_key
  16. from synapse.api.errors import SynapseError, Codes
  17. from synapse.util.retryutils import get_retry_limiter
  18. from synapse.util import unwrapFirstError
  19. from synapse.util.async import ObservableDeferred
  20. from synapse.util.logcontext import (
  21. preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
  22. preserve_fn
  23. )
  24. from synapse.util.metrics import Measure
  25. from twisted.internet import defer
  26. from signedjson.sign import (
  27. verify_signed_json, signature_ids, sign_json, encode_canonical_json
  28. )
  29. from signedjson.key import (
  30. is_signing_algorithm_supported, decode_verify_key_bytes
  31. )
  32. from unpaddedbase64 import decode_base64, encode_base64
  33. from OpenSSL import crypto
  34. from collections import namedtuple
  35. import urllib
  36. import hashlib
  37. import logging
  38. logger = logging.getLogger(__name__)
  39. VerifyKeyRequest = namedtuple("VerifyRequest", (
  40. "server_name", "key_ids", "json_object", "deferred"
  41. ))
  42. """
  43. A request for a verify key to verify a JSON object.
  44. Attributes:
  45. server_name(str): The name of the server to verify against.
  46. key_ids(set(str)): The set of key_ids to that could be used to verify the
  47. JSON object
  48. json_object(dict): The JSON object to verify.
  49. deferred(twisted.internet.defer.Deferred):
  50. A deferred (server_name, key_id, verify_key) tuple that resolves when
  51. a verify key has been fetched
  52. """
  53. class KeyLookupError(ValueError):
  54. pass
  55. class Keyring(object):
  56. def __init__(self, hs):
  57. self.store = hs.get_datastore()
  58. self.clock = hs.get_clock()
  59. self.client = hs.get_http_client()
  60. self.config = hs.get_config()
  61. self.perspective_servers = self.config.perspectives
  62. self.hs = hs
  63. self.key_downloads = {}
  64. def verify_json_for_server(self, server_name, json_object):
  65. return self.verify_json_objects_for_server(
  66. [(server_name, json_object)]
  67. )[0]
  68. def verify_json_objects_for_server(self, server_and_json):
  69. """Bulk verfies signatures of json objects, bulk fetching keys as
  70. necessary.
  71. Args:
  72. server_and_json (list): List of pairs of (server_name, json_object)
  73. Returns:
  74. list of deferreds indicating success or failure to verify each
  75. json object's signature for the given server_name.
  76. """
  77. verify_requests = []
  78. for server_name, json_object in server_and_json:
  79. logger.debug("Verifying for %s", server_name)
  80. key_ids = signature_ids(json_object, server_name)
  81. if not key_ids:
  82. deferred = defer.fail(SynapseError(
  83. 400,
  84. "Not signed with a supported algorithm",
  85. Codes.UNAUTHORIZED,
  86. ))
  87. else:
  88. deferred = defer.Deferred()
  89. verify_request = VerifyKeyRequest(
  90. server_name, key_ids, json_object, deferred
  91. )
  92. verify_requests.append(verify_request)
  93. @defer.inlineCallbacks
  94. def handle_key_deferred(verify_request):
  95. server_name = verify_request.server_name
  96. try:
  97. _, key_id, verify_key = yield verify_request.deferred
  98. except IOError as e:
  99. logger.warn(
  100. "Got IOError when downloading keys for %s: %s %s",
  101. server_name, type(e).__name__, str(e.message),
  102. )
  103. raise SynapseError(
  104. 502,
  105. "Error downloading keys for %s" % (server_name,),
  106. Codes.UNAUTHORIZED,
  107. )
  108. except Exception as e:
  109. logger.exception(
  110. "Got Exception when downloading keys for %s: %s %s",
  111. server_name, type(e).__name__, str(e.message),
  112. )
  113. raise SynapseError(
  114. 401,
  115. "No key for %s with id %s" % (server_name, key_ids),
  116. Codes.UNAUTHORIZED,
  117. )
  118. json_object = verify_request.json_object
  119. try:
  120. verify_signed_json(json_object, server_name, verify_key)
  121. except:
  122. raise SynapseError(
  123. 401,
  124. "Invalid signature for server %s with key %s:%s" % (
  125. server_name, verify_key.alg, verify_key.version
  126. ),
  127. Codes.UNAUTHORIZED,
  128. )
  129. server_to_deferred = {
  130. server_name: defer.Deferred()
  131. for server_name, _ in server_and_json
  132. }
  133. with PreserveLoggingContext():
  134. # We want to wait for any previous lookups to complete before
  135. # proceeding.
  136. wait_on_deferred = self.wait_for_previous_lookups(
  137. [server_name for server_name, _ in server_and_json],
  138. server_to_deferred,
  139. )
  140. # Actually start fetching keys.
  141. wait_on_deferred.addBoth(
  142. lambda _: self.get_server_verify_keys(verify_requests)
  143. )
  144. # When we've finished fetching all the keys for a given server_name,
  145. # resolve the deferred passed to `wait_for_previous_lookups` so that
  146. # any lookups waiting will proceed.
  147. server_to_request_ids = {}
  148. def remove_deferreds(res, server_name, verify_request):
  149. request_id = id(verify_request)
  150. server_to_request_ids[server_name].discard(request_id)
  151. if not server_to_request_ids[server_name]:
  152. d = server_to_deferred.pop(server_name, None)
  153. if d:
  154. d.callback(None)
  155. return res
  156. for verify_request in verify_requests:
  157. server_name = verify_request.server_name
  158. request_id = id(verify_request)
  159. server_to_request_ids.setdefault(server_name, set()).add(request_id)
  160. deferred.addBoth(remove_deferreds, server_name, verify_request)
  161. # Pass those keys to handle_key_deferred so that the json object
  162. # signatures can be verified
  163. return [
  164. preserve_context_over_fn(handle_key_deferred, verify_request)
  165. for verify_request in verify_requests
  166. ]
  167. @defer.inlineCallbacks
  168. def wait_for_previous_lookups(self, server_names, server_to_deferred):
  169. """Waits for any previous key lookups for the given servers to finish.
  170. Args:
  171. server_names (list): list of server_names we want to lookup
  172. server_to_deferred (dict): server_name to deferred which gets
  173. resolved once we've finished looking up keys for that server
  174. """
  175. while True:
  176. wait_on = [
  177. self.key_downloads[server_name]
  178. for server_name in server_names
  179. if server_name in self.key_downloads
  180. ]
  181. if wait_on:
  182. with PreserveLoggingContext():
  183. yield defer.DeferredList(wait_on)
  184. else:
  185. break
  186. for server_name, deferred in server_to_deferred.items():
  187. d = ObservableDeferred(preserve_context_over_deferred(deferred))
  188. self.key_downloads[server_name] = d
  189. def rm(r, server_name):
  190. self.key_downloads.pop(server_name, None)
  191. return r
  192. d.addBoth(rm, server_name)
  193. def get_server_verify_keys(self, verify_requests):
  194. """Takes a dict of KeyGroups and tries to find at least one key for
  195. each group.
  196. """
  197. # These are functions that produce keys given a list of key ids
  198. key_fetch_fns = (
  199. self.get_keys_from_store, # First try the local store
  200. self.get_keys_from_perspectives, # Then try via perspectives
  201. self.get_keys_from_server, # Then try directly
  202. )
  203. @defer.inlineCallbacks
  204. def do_iterations():
  205. with Measure(self.clock, "get_server_verify_keys"):
  206. merged_results = {}
  207. missing_keys = {}
  208. for verify_request in verify_requests:
  209. missing_keys.setdefault(verify_request.server_name, set()).update(
  210. verify_request.key_ids
  211. )
  212. for fn in key_fetch_fns:
  213. results = yield fn(missing_keys.items())
  214. merged_results.update(results)
  215. # We now need to figure out which verify requests we have keys
  216. # for and which we don't
  217. missing_keys = {}
  218. requests_missing_keys = []
  219. for verify_request in verify_requests:
  220. server_name = verify_request.server_name
  221. result_keys = merged_results[server_name]
  222. if verify_request.deferred.called:
  223. # We've already called this deferred, which probably
  224. # means that we've already found a key for it.
  225. continue
  226. for key_id in verify_request.key_ids:
  227. if key_id in result_keys:
  228. with PreserveLoggingContext():
  229. verify_request.deferred.callback((
  230. server_name,
  231. key_id,
  232. result_keys[key_id],
  233. ))
  234. break
  235. else:
  236. # The else block is only reached if the loop above
  237. # doesn't break.
  238. missing_keys.setdefault(server_name, set()).update(
  239. verify_request.key_ids
  240. )
  241. requests_missing_keys.append(verify_request)
  242. if not missing_keys:
  243. break
  244. for verify_request in requests_missing_keys.values():
  245. verify_request.deferred.errback(SynapseError(
  246. 401,
  247. "No key for %s with id %s" % (
  248. verify_request.server_name, verify_request.key_ids,
  249. ),
  250. Codes.UNAUTHORIZED,
  251. ))
  252. def on_err(err):
  253. for verify_request in verify_requests:
  254. if not verify_request.deferred.called:
  255. verify_request.deferred.errback(err)
  256. do_iterations().addErrback(on_err)
  257. @defer.inlineCallbacks
  258. def get_keys_from_store(self, server_name_and_key_ids):
  259. res = yield preserve_context_over_deferred(defer.gatherResults(
  260. [
  261. preserve_fn(self.store.get_server_verify_keys)(
  262. server_name, key_ids
  263. ).addCallback(lambda ks, server: (server, ks), server_name)
  264. for server_name, key_ids in server_name_and_key_ids
  265. ],
  266. consumeErrors=True,
  267. )).addErrback(unwrapFirstError)
  268. defer.returnValue(dict(res))
  269. @defer.inlineCallbacks
  270. def get_keys_from_perspectives(self, server_name_and_key_ids):
  271. @defer.inlineCallbacks
  272. def get_key(perspective_name, perspective_keys):
  273. try:
  274. result = yield self.get_server_verify_key_v2_indirect(
  275. server_name_and_key_ids, perspective_name, perspective_keys
  276. )
  277. defer.returnValue(result)
  278. except Exception as e:
  279. logger.exception(
  280. "Unable to get key from %r: %s %s",
  281. perspective_name,
  282. type(e).__name__, str(e.message),
  283. )
  284. defer.returnValue({})
  285. results = yield preserve_context_over_deferred(defer.gatherResults(
  286. [
  287. preserve_fn(get_key)(p_name, p_keys)
  288. for p_name, p_keys in self.perspective_servers.items()
  289. ],
  290. consumeErrors=True,
  291. )).addErrback(unwrapFirstError)
  292. union_of_keys = {}
  293. for result in results:
  294. for server_name, keys in result.items():
  295. union_of_keys.setdefault(server_name, {}).update(keys)
  296. defer.returnValue(union_of_keys)
  297. @defer.inlineCallbacks
  298. def get_keys_from_server(self, server_name_and_key_ids):
  299. @defer.inlineCallbacks
  300. def get_key(server_name, key_ids):
  301. limiter = yield get_retry_limiter(
  302. server_name,
  303. self.clock,
  304. self.store,
  305. )
  306. with limiter:
  307. keys = None
  308. try:
  309. keys = yield self.get_server_verify_key_v2_direct(
  310. server_name, key_ids
  311. )
  312. except Exception as e:
  313. logger.info(
  314. "Unable to get key %r for %r directly: %s %s",
  315. key_ids, server_name,
  316. type(e).__name__, str(e.message),
  317. )
  318. if not keys:
  319. keys = yield self.get_server_verify_key_v1_direct(
  320. server_name, key_ids
  321. )
  322. keys = {server_name: keys}
  323. defer.returnValue(keys)
  324. results = yield preserve_context_over_deferred(defer.gatherResults(
  325. [
  326. preserve_fn(get_key)(server_name, key_ids)
  327. for server_name, key_ids in server_name_and_key_ids
  328. ],
  329. consumeErrors=True,
  330. )).addErrback(unwrapFirstError)
  331. merged = {}
  332. for result in results:
  333. merged.update(result)
  334. defer.returnValue({
  335. server_name: keys
  336. for server_name, keys in merged.items()
  337. if keys
  338. })
  339. @defer.inlineCallbacks
  340. def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
  341. perspective_name,
  342. perspective_keys):
  343. # TODO(mark): Set the minimum_valid_until_ts to that needed by
  344. # the events being validated or the current time if validating
  345. # an incoming request.
  346. query_response = yield self.client.post_json(
  347. destination=perspective_name,
  348. path=b"/_matrix/key/v2/query",
  349. data={
  350. u"server_keys": {
  351. server_name: {
  352. key_id: {
  353. u"minimum_valid_until_ts": 0
  354. } for key_id in key_ids
  355. }
  356. for server_name, key_ids in server_names_and_key_ids
  357. }
  358. },
  359. long_retries=True,
  360. )
  361. keys = {}
  362. responses = query_response["server_keys"]
  363. for response in responses:
  364. if (u"signatures" not in response
  365. or perspective_name not in response[u"signatures"]):
  366. raise KeyLookupError(
  367. "Key response not signed by perspective server"
  368. " %r" % (perspective_name,)
  369. )
  370. verified = False
  371. for key_id in response[u"signatures"][perspective_name]:
  372. if key_id in perspective_keys:
  373. verify_signed_json(
  374. response,
  375. perspective_name,
  376. perspective_keys[key_id]
  377. )
  378. verified = True
  379. if not verified:
  380. logging.info(
  381. "Response from perspective server %r not signed with a"
  382. " known key, signed with: %r, known keys: %r",
  383. perspective_name,
  384. list(response[u"signatures"][perspective_name]),
  385. list(perspective_keys)
  386. )
  387. raise KeyLookupError(
  388. "Response not signed with a known key for perspective"
  389. " server %r" % (perspective_name,)
  390. )
  391. processed_response = yield self.process_v2_response(
  392. perspective_name, response, only_from_server=False
  393. )
  394. for server_name, response_keys in processed_response.items():
  395. keys.setdefault(server_name, {}).update(response_keys)
  396. yield preserve_context_over_deferred(defer.gatherResults(
  397. [
  398. preserve_fn(self.store_keys)(
  399. server_name=server_name,
  400. from_server=perspective_name,
  401. verify_keys=response_keys,
  402. )
  403. for server_name, response_keys in keys.items()
  404. ],
  405. consumeErrors=True
  406. )).addErrback(unwrapFirstError)
  407. defer.returnValue(keys)
  408. @defer.inlineCallbacks
  409. def get_server_verify_key_v2_direct(self, server_name, key_ids):
  410. keys = {}
  411. for requested_key_id in key_ids:
  412. if requested_key_id in keys:
  413. continue
  414. (response, tls_certificate) = yield fetch_server_key(
  415. server_name, self.hs.tls_server_context_factory,
  416. path=(b"/_matrix/key/v2/server/%s" % (
  417. urllib.quote(requested_key_id),
  418. )).encode("ascii"),
  419. )
  420. if (u"signatures" not in response
  421. or server_name not in response[u"signatures"]):
  422. raise KeyLookupError("Key response not signed by remote server")
  423. if "tls_fingerprints" not in response:
  424. raise KeyLookupError("Key response missing TLS fingerprints")
  425. certificate_bytes = crypto.dump_certificate(
  426. crypto.FILETYPE_ASN1, tls_certificate
  427. )
  428. sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
  429. sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
  430. response_sha256_fingerprints = set()
  431. for fingerprint in response[u"tls_fingerprints"]:
  432. if u"sha256" in fingerprint:
  433. response_sha256_fingerprints.add(fingerprint[u"sha256"])
  434. if sha256_fingerprint_b64 not in response_sha256_fingerprints:
  435. raise KeyLookupError("TLS certificate not allowed by fingerprints")
  436. response_keys = yield self.process_v2_response(
  437. from_server=server_name,
  438. requested_ids=[requested_key_id],
  439. response_json=response,
  440. )
  441. keys.update(response_keys)
  442. yield preserve_context_over_deferred(defer.gatherResults(
  443. [
  444. preserve_fn(self.store_keys)(
  445. server_name=key_server_name,
  446. from_server=server_name,
  447. verify_keys=verify_keys,
  448. )
  449. for key_server_name, verify_keys in keys.items()
  450. ],
  451. consumeErrors=True
  452. )).addErrback(unwrapFirstError)
  453. defer.returnValue(keys)
  454. @defer.inlineCallbacks
  455. def process_v2_response(self, from_server, response_json,
  456. requested_ids=[], only_from_server=True):
  457. time_now_ms = self.clock.time_msec()
  458. response_keys = {}
  459. verify_keys = {}
  460. for key_id, key_data in response_json["verify_keys"].items():
  461. if is_signing_algorithm_supported(key_id):
  462. key_base64 = key_data["key"]
  463. key_bytes = decode_base64(key_base64)
  464. verify_key = decode_verify_key_bytes(key_id, key_bytes)
  465. verify_key.time_added = time_now_ms
  466. verify_keys[key_id] = verify_key
  467. old_verify_keys = {}
  468. for key_id, key_data in response_json["old_verify_keys"].items():
  469. if is_signing_algorithm_supported(key_id):
  470. key_base64 = key_data["key"]
  471. key_bytes = decode_base64(key_base64)
  472. verify_key = decode_verify_key_bytes(key_id, key_bytes)
  473. verify_key.expired = key_data["expired_ts"]
  474. verify_key.time_added = time_now_ms
  475. old_verify_keys[key_id] = verify_key
  476. results = {}
  477. server_name = response_json["server_name"]
  478. if only_from_server:
  479. if server_name != from_server:
  480. raise KeyLookupError(
  481. "Expected a response for server %r not %r" % (
  482. from_server, server_name
  483. )
  484. )
  485. for key_id in response_json["signatures"].get(server_name, {}):
  486. if key_id not in response_json["verify_keys"]:
  487. raise KeyLookupError(
  488. "Key response must include verification keys for all"
  489. " signatures"
  490. )
  491. if key_id in verify_keys:
  492. verify_signed_json(
  493. response_json,
  494. server_name,
  495. verify_keys[key_id]
  496. )
  497. signed_key_json = sign_json(
  498. response_json,
  499. self.config.server_name,
  500. self.config.signing_key[0],
  501. )
  502. signed_key_json_bytes = encode_canonical_json(signed_key_json)
  503. ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
  504. updated_key_ids = set(requested_ids)
  505. updated_key_ids.update(verify_keys)
  506. updated_key_ids.update(old_verify_keys)
  507. response_keys.update(verify_keys)
  508. response_keys.update(old_verify_keys)
  509. yield preserve_context_over_deferred(defer.gatherResults(
  510. [
  511. preserve_fn(self.store.store_server_keys_json)(
  512. server_name=server_name,
  513. key_id=key_id,
  514. from_server=server_name,
  515. ts_now_ms=time_now_ms,
  516. ts_expires_ms=ts_valid_until_ms,
  517. key_json_bytes=signed_key_json_bytes,
  518. )
  519. for key_id in updated_key_ids
  520. ],
  521. consumeErrors=True,
  522. )).addErrback(unwrapFirstError)
  523. results[server_name] = response_keys
  524. defer.returnValue(results)
  525. @defer.inlineCallbacks
  526. def get_server_verify_key_v1_direct(self, server_name, key_ids):
  527. """Finds a verification key for the server with one of the key ids.
  528. Args:
  529. server_name (str): The name of the server to fetch a key for.
  530. keys_ids (list of str): The key_ids to check for.
  531. """
  532. # Try to fetch the key from the remote server.
  533. (response, tls_certificate) = yield fetch_server_key(
  534. server_name, self.hs.tls_server_context_factory
  535. )
  536. # Check the response.
  537. x509_certificate_bytes = crypto.dump_certificate(
  538. crypto.FILETYPE_ASN1, tls_certificate
  539. )
  540. if ("signatures" not in response
  541. or server_name not in response["signatures"]):
  542. raise KeyLookupError("Key response not signed by remote server")
  543. if "tls_certificate" not in response:
  544. raise KeyLookupError("Key response missing TLS certificate")
  545. tls_certificate_b64 = response["tls_certificate"]
  546. if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
  547. raise KeyLookupError("TLS certificate doesn't match")
  548. # Cache the result in the datastore.
  549. time_now_ms = self.clock.time_msec()
  550. verify_keys = {}
  551. for key_id, key_base64 in response["verify_keys"].items():
  552. if is_signing_algorithm_supported(key_id):
  553. key_bytes = decode_base64(key_base64)
  554. verify_key = decode_verify_key_bytes(key_id, key_bytes)
  555. verify_key.time_added = time_now_ms
  556. verify_keys[key_id] = verify_key
  557. for key_id in response["signatures"][server_name]:
  558. if key_id not in response["verify_keys"]:
  559. raise KeyLookupError(
  560. "Key response must include verification keys for all"
  561. " signatures"
  562. )
  563. if key_id in verify_keys:
  564. verify_signed_json(
  565. response,
  566. server_name,
  567. verify_keys[key_id]
  568. )
  569. yield self.store.store_server_certificate(
  570. server_name,
  571. server_name,
  572. time_now_ms,
  573. tls_certificate,
  574. )
  575. yield self.store_keys(
  576. server_name=server_name,
  577. from_server=server_name,
  578. verify_keys=verify_keys,
  579. )
  580. defer.returnValue(verify_keys)
  581. @defer.inlineCallbacks
  582. def store_keys(self, server_name, from_server, verify_keys):
  583. """Store a collection of verify keys for a given server
  584. Args:
  585. server_name(str): The name of the server the keys are for.
  586. from_server(str): The server the keys were downloaded from.
  587. verify_keys(dict): A mapping of key_id to VerifyKey.
  588. Returns:
  589. A deferred that completes when the keys are stored.
  590. """
  591. # TODO(markjh): Store whether the keys have expired.
  592. yield preserve_context_over_deferred(defer.gatherResults(
  593. [
  594. preserve_fn(self.store.store_server_verify_key)(
  595. server_name, server_name, key.time_added, key
  596. )
  597. for key_id, key in verify_keys.items()
  598. ],
  599. consumeErrors=True,
  600. )).addErrback(unwrapFirstError)