keyring.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902
  1. # Copyright 2014-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import abc
  15. import logging
  16. import urllib
  17. from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple
  18. import attr
  19. from signedjson.key import (
  20. decode_verify_key_bytes,
  21. encode_verify_key_base64,
  22. get_verify_key,
  23. is_signing_algorithm_supported,
  24. )
  25. from signedjson.sign import (
  26. SignatureVerifyException,
  27. encode_canonical_json,
  28. signature_ids,
  29. verify_signed_json,
  30. )
  31. from signedjson.types import VerifyKey
  32. from unpaddedbase64 import decode_base64
  33. from twisted.internet import defer
  34. from synapse.api.errors import (
  35. Codes,
  36. HttpResponseException,
  37. RequestSendFailed,
  38. SynapseError,
  39. )
  40. from synapse.config.key import TrustedKeyServer
  41. from synapse.events import EventBase
  42. from synapse.events.utils import prune_event_dict
  43. from synapse.logging.context import make_deferred_yieldable, run_in_background
  44. from synapse.storage.keys import FetchKeyResult
  45. from synapse.types import JsonDict
  46. from synapse.util import unwrapFirstError
  47. from synapse.util.async_helpers import yieldable_gather_results
  48. from synapse.util.batching_queue import BatchingQueue
  49. from synapse.util.retryutils import NotRetryingDestination
  50. if TYPE_CHECKING:
  51. from synapse.server import HomeServer
  52. logger = logging.getLogger(__name__)
  53. @attr.s(slots=True, frozen=True, cmp=False, auto_attribs=True)
  54. class VerifyJsonRequest:
  55. """
  56. A request to verify a JSON object.
  57. Attributes:
  58. server_name: The name of the server to verify against.
  59. get_json_object: A callback to fetch the JSON object to verify.
  60. A callback is used to allow deferring the creation of the JSON
  61. object to verify until needed, e.g. for events we can defer
  62. creating the redacted copy. This reduces the memory usage when
  63. there are large numbers of in flight requests.
  64. minimum_valid_until_ts: time at which we require the signing key to
  65. be valid. (0 implies we don't care)
  66. key_ids: The set of key_ids to that could be used to verify the JSON object
  67. """
  68. server_name: str
  69. get_json_object: Callable[[], JsonDict]
  70. minimum_valid_until_ts: int
  71. key_ids: List[str]
  72. @staticmethod
  73. def from_json_object(
  74. server_name: str,
  75. json_object: JsonDict,
  76. minimum_valid_until_ms: int,
  77. ) -> "VerifyJsonRequest":
  78. """Create a VerifyJsonRequest to verify all signatures on a signed JSON
  79. object for the given server.
  80. """
  81. key_ids = signature_ids(json_object, server_name)
  82. return VerifyJsonRequest(
  83. server_name,
  84. lambda: json_object,
  85. minimum_valid_until_ms,
  86. key_ids=key_ids,
  87. )
  88. @staticmethod
  89. def from_event(
  90. server_name: str,
  91. event: EventBase,
  92. minimum_valid_until_ms: int,
  93. ) -> "VerifyJsonRequest":
  94. """Create a VerifyJsonRequest to verify all signatures on an event
  95. object for the given server.
  96. """
  97. key_ids = list(event.signatures.get(server_name, []))
  98. return VerifyJsonRequest(
  99. server_name,
  100. # We defer creating the redacted json object, as it uses a lot more
  101. # memory than the Event object itself.
  102. lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
  103. minimum_valid_until_ms,
  104. key_ids=key_ids,
  105. )
  106. class KeyLookupError(ValueError):
  107. pass
  108. @attr.s(slots=True, frozen=True, auto_attribs=True)
  109. class _FetchKeyRequest:
  110. """A request for keys for a given server.
  111. We will continue to try and fetch until we have all the keys listed under
  112. `key_ids` (with an appropriate `valid_until_ts` property) or we run out of
  113. places to fetch keys from.
  114. Attributes:
  115. server_name: The name of the server that owns the keys.
  116. minimum_valid_until_ts: The timestamp which the keys must be valid until.
  117. key_ids: The IDs of the keys to attempt to fetch
  118. """
  119. server_name: str
  120. minimum_valid_until_ts: int
  121. key_ids: List[str]
  122. class Keyring:
  123. """Handles verifying signed JSON objects and fetching the keys needed to do
  124. so.
  125. """
  126. def __init__(
  127. self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
  128. ):
  129. self.clock = hs.get_clock()
  130. if key_fetchers is None:
  131. key_fetchers = (
  132. StoreKeyFetcher(hs),
  133. PerspectivesKeyFetcher(hs),
  134. ServerKeyFetcher(hs),
  135. )
  136. self._key_fetchers = key_fetchers
  137. self._server_queue: BatchingQueue[
  138. _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
  139. ] = BatchingQueue(
  140. "keyring_server",
  141. clock=hs.get_clock(),
  142. process_batch_callback=self._inner_fetch_key_requests,
  143. )
  144. self._hostname = hs.hostname
  145. # build a FetchKeyResult for each of our own keys, to shortcircuit the
  146. # fetcher.
  147. self._local_verify_keys: Dict[str, FetchKeyResult] = {}
  148. for key_id, key in hs.config.key.old_signing_keys.items():
  149. self._local_verify_keys[key_id] = FetchKeyResult(
  150. verify_key=key, valid_until_ts=key.expired
  151. )
  152. vk = get_verify_key(hs.signing_key)
  153. self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
  154. verify_key=vk,
  155. valid_until_ts=2**63, # fake future timestamp
  156. )
  157. async def verify_json_for_server(
  158. self,
  159. server_name: str,
  160. json_object: JsonDict,
  161. validity_time: int,
  162. ) -> None:
  163. """Verify that a JSON object has been signed by a given server
  164. Completes if the the object was correctly signed, otherwise raises.
  165. Args:
  166. server_name: name of the server which must have signed this object
  167. json_object: object to be checked
  168. validity_time: timestamp at which we require the signing key to
  169. be valid. (0 implies we don't care)
  170. """
  171. request = VerifyJsonRequest.from_json_object(
  172. server_name,
  173. json_object,
  174. validity_time,
  175. )
  176. return await self.process_request(request)
  177. def verify_json_objects_for_server(
  178. self, server_and_json: Iterable[Tuple[str, dict, int]]
  179. ) -> List[defer.Deferred]:
  180. """Bulk verifies signatures of json objects, bulk fetching keys as
  181. necessary.
  182. Args:
  183. server_and_json:
  184. Iterable of (server_name, json_object, validity_time)
  185. tuples.
  186. validity_time is a timestamp at which the signing key must be
  187. valid.
  188. Returns:
  189. List<Deferred[None]>: for each input triplet, a deferred indicating success
  190. or failure to verify each json object's signature for the given
  191. server_name. The deferreds run their callbacks in the sentinel
  192. logcontext.
  193. """
  194. return [
  195. run_in_background(
  196. self.process_request,
  197. VerifyJsonRequest.from_json_object(
  198. server_name,
  199. json_object,
  200. validity_time,
  201. ),
  202. )
  203. for server_name, json_object, validity_time in server_and_json
  204. ]
  205. async def verify_event_for_server(
  206. self,
  207. server_name: str,
  208. event: EventBase,
  209. validity_time: int,
  210. ) -> None:
  211. await self.process_request(
  212. VerifyJsonRequest.from_event(
  213. server_name,
  214. event,
  215. validity_time,
  216. )
  217. )
  218. async def process_request(self, verify_request: VerifyJsonRequest) -> None:
  219. """Processes the `VerifyJsonRequest`. Raises if the object is not signed
  220. by the server, the signatures don't match or we failed to fetch the
  221. necessary keys.
  222. """
  223. if not verify_request.key_ids:
  224. raise SynapseError(
  225. 400,
  226. f"Not signed by {verify_request.server_name}",
  227. Codes.UNAUTHORIZED,
  228. )
  229. found_keys: Dict[str, FetchKeyResult] = {}
  230. # If we are the originating server, short-circuit the key-fetch for any keys
  231. # we already have
  232. if verify_request.server_name == self._hostname:
  233. for key_id in verify_request.key_ids:
  234. if key_id in self._local_verify_keys:
  235. found_keys[key_id] = self._local_verify_keys[key_id]
  236. key_ids_to_find = set(verify_request.key_ids) - found_keys.keys()
  237. if key_ids_to_find:
  238. # Add the keys we need to verify to the queue for retrieval. We queue
  239. # up requests for the same server so we don't end up with many in flight
  240. # requests for the same keys.
  241. key_request = _FetchKeyRequest(
  242. server_name=verify_request.server_name,
  243. minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
  244. key_ids=list(key_ids_to_find),
  245. )
  246. found_keys_by_server = await self._server_queue.add_to_queue(
  247. key_request, key=verify_request.server_name
  248. )
  249. # Since we batch up requests the returned set of keys may contain keys
  250. # from other servers, so we pull out only the ones we care about.
  251. found_keys.update(found_keys_by_server.get(verify_request.server_name, {}))
  252. # Verify each signature we got valid keys for, raising if we can't
  253. # verify any of them.
  254. verified = False
  255. for key_id in verify_request.key_ids:
  256. key_result = found_keys.get(key_id)
  257. if not key_result:
  258. continue
  259. if key_result.valid_until_ts < verify_request.minimum_valid_until_ts:
  260. continue
  261. await self._process_json(key_result.verify_key, verify_request)
  262. verified = True
  263. if not verified:
  264. raise SynapseError(
  265. 401,
  266. f"Failed to find any key to satisfy: {key_request}",
  267. Codes.UNAUTHORIZED,
  268. )
  269. async def _process_json(
  270. self, verify_key: VerifyKey, verify_request: VerifyJsonRequest
  271. ) -> None:
  272. """Processes the `VerifyJsonRequest`. Raises if the signature can't be
  273. verified.
  274. """
  275. try:
  276. verify_signed_json(
  277. verify_request.get_json_object(),
  278. verify_request.server_name,
  279. verify_key,
  280. )
  281. except SignatureVerifyException as e:
  282. logger.debug(
  283. "Error verifying signature for %s:%s:%s with key %s: %s",
  284. verify_request.server_name,
  285. verify_key.alg,
  286. verify_key.version,
  287. encode_verify_key_base64(verify_key),
  288. str(e),
  289. )
  290. raise SynapseError(
  291. 401,
  292. "Invalid signature for server %s with key %s:%s: %s"
  293. % (
  294. verify_request.server_name,
  295. verify_key.alg,
  296. verify_key.version,
  297. str(e),
  298. ),
  299. Codes.UNAUTHORIZED,
  300. )
  301. async def _inner_fetch_key_requests(
  302. self, requests: List[_FetchKeyRequest]
  303. ) -> Dict[str, Dict[str, FetchKeyResult]]:
  304. """Processing function for the queue of `_FetchKeyRequest`."""
  305. logger.debug("Starting fetch for %s", requests)
  306. # First we need to deduplicate requests for the same key. We do this by
  307. # taking the *maximum* requested `minimum_valid_until_ts` for each pair
  308. # of server name/key ID.
  309. server_to_key_to_ts: Dict[str, Dict[str, int]] = {}
  310. for request in requests:
  311. by_server = server_to_key_to_ts.setdefault(request.server_name, {})
  312. for key_id in request.key_ids:
  313. existing_ts = by_server.get(key_id, 0)
  314. by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts)
  315. deduped_requests = [
  316. _FetchKeyRequest(server_name, minimum_valid_ts, [key_id])
  317. for server_name, by_server in server_to_key_to_ts.items()
  318. for key_id, minimum_valid_ts in by_server.items()
  319. ]
  320. logger.debug("Deduplicated key requests to %s", deduped_requests)
  321. # For each key we call `_inner_verify_request` which will handle
  322. # fetching each key. Note these shouldn't throw if we fail to contact
  323. # other servers etc.
  324. results_per_request = await yieldable_gather_results(
  325. self._inner_fetch_key_request,
  326. deduped_requests,
  327. )
  328. # We now convert the returned list of results into a map from server
  329. # name to key ID to FetchKeyResult, to return.
  330. to_return: Dict[str, Dict[str, FetchKeyResult]] = {}
  331. for (request, results) in zip(deduped_requests, results_per_request):
  332. to_return_by_server = to_return.setdefault(request.server_name, {})
  333. for key_id, key_result in results.items():
  334. existing = to_return_by_server.get(key_id)
  335. if not existing or existing.valid_until_ts < key_result.valid_until_ts:
  336. to_return_by_server[key_id] = key_result
  337. return to_return
  338. async def _inner_fetch_key_request(
  339. self, verify_request: _FetchKeyRequest
  340. ) -> Dict[str, FetchKeyResult]:
  341. """Attempt to fetch the given key by calling each key fetcher one by
  342. one.
  343. """
  344. logger.debug("Starting fetch for %s", verify_request)
  345. found_keys: Dict[str, FetchKeyResult] = {}
  346. missing_key_ids = set(verify_request.key_ids)
  347. for fetcher in self._key_fetchers:
  348. if not missing_key_ids:
  349. break
  350. logger.debug("Getting keys from %s for %s", fetcher, verify_request)
  351. keys = await fetcher.get_keys(
  352. verify_request.server_name,
  353. list(missing_key_ids),
  354. verify_request.minimum_valid_until_ts,
  355. )
  356. for key_id, key in keys.items():
  357. if not key:
  358. continue
  359. # If we already have a result for the given key ID we keep the
  360. # one with the highest `valid_until_ts`.
  361. existing_key = found_keys.get(key_id)
  362. if existing_key:
  363. if key.valid_until_ts <= existing_key.valid_until_ts:
  364. continue
  365. # We always store the returned key even if it doesn't the
  366. # `minimum_valid_until_ts` requirement, as some verification
  367. # requests may still be able to be satisfied by it.
  368. #
  369. # We still keep looking for the key from other fetchers in that
  370. # case though.
  371. found_keys[key_id] = key
  372. if key.valid_until_ts < verify_request.minimum_valid_until_ts:
  373. continue
  374. missing_key_ids.discard(key_id)
  375. return found_keys
  376. class KeyFetcher(metaclass=abc.ABCMeta):
  377. def __init__(self, hs: "HomeServer"):
  378. self._queue = BatchingQueue(
  379. self.__class__.__name__, hs.get_clock(), self._fetch_keys
  380. )
  381. async def get_keys(
  382. self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
  383. ) -> Dict[str, FetchKeyResult]:
  384. results = await self._queue.add_to_queue(
  385. _FetchKeyRequest(
  386. server_name=server_name,
  387. key_ids=key_ids,
  388. minimum_valid_until_ts=minimum_valid_until_ts,
  389. )
  390. )
  391. return results.get(server_name, {})
  392. @abc.abstractmethod
  393. async def _fetch_keys(
  394. self, keys_to_fetch: List[_FetchKeyRequest]
  395. ) -> Dict[str, Dict[str, FetchKeyResult]]:
  396. pass
  397. class StoreKeyFetcher(KeyFetcher):
  398. """KeyFetcher impl which fetches keys from our data store"""
  399. def __init__(self, hs: "HomeServer"):
  400. super().__init__(hs)
  401. self.store = hs.get_datastores().main
  402. async def _fetch_keys(
  403. self, keys_to_fetch: List[_FetchKeyRequest]
  404. ) -> Dict[str, Dict[str, FetchKeyResult]]:
  405. key_ids_to_fetch = (
  406. (queue_value.server_name, key_id)
  407. for queue_value in keys_to_fetch
  408. for key_id in queue_value.key_ids
  409. )
  410. res = await self.store.get_server_verify_keys(key_ids_to_fetch)
  411. keys: Dict[str, Dict[str, FetchKeyResult]] = {}
  412. for (server_name, key_id), key in res.items():
  413. keys.setdefault(server_name, {})[key_id] = key
  414. return keys
  415. class BaseV2KeyFetcher(KeyFetcher):
  416. def __init__(self, hs: "HomeServer"):
  417. super().__init__(hs)
  418. self.store = hs.get_datastores().main
  419. self.config = hs.config
  420. async def process_v2_response(
  421. self, from_server: str, response_json: JsonDict, time_added_ms: int
  422. ) -> Dict[str, FetchKeyResult]:
  423. """Parse a 'Server Keys' structure from the result of a /key request
  424. This is used to parse either the entirety of the response from
  425. GET /_matrix/key/v2/server, or a single entry from the list returned by
  426. POST /_matrix/key/v2/query.
  427. Checks that each signature in the response that claims to come from the origin
  428. server is valid, and that there is at least one such signature.
  429. Stores the json in server_keys_json so that it can be used for future responses
  430. to /_matrix/key/v2/query.
  431. Args:
  432. from_server: the name of the server producing this result: either
  433. the origin server for a /_matrix/key/v2/server request, or the notary
  434. for a /_matrix/key/v2/query.
  435. response_json: the json-decoded Server Keys response object
  436. time_added_ms: the timestamp to record in server_keys_json
  437. Returns:
  438. Map from key_id to result object
  439. """
  440. ts_valid_until_ms = response_json["valid_until_ts"]
  441. # start by extracting the keys from the response, since they may be required
  442. # to validate the signature on the response.
  443. verify_keys = {}
  444. for key_id, key_data in response_json["verify_keys"].items():
  445. if is_signing_algorithm_supported(key_id):
  446. key_base64 = key_data["key"]
  447. key_bytes = decode_base64(key_base64)
  448. verify_key = decode_verify_key_bytes(key_id, key_bytes)
  449. verify_keys[key_id] = FetchKeyResult(
  450. verify_key=verify_key, valid_until_ts=ts_valid_until_ms
  451. )
  452. server_name = response_json["server_name"]
  453. verified = False
  454. for key_id in response_json["signatures"].get(server_name, {}):
  455. key = verify_keys.get(key_id)
  456. if not key:
  457. # the key may not be present in verify_keys if:
  458. # * we got the key from the notary server, and:
  459. # * the key belongs to the notary server, and:
  460. # * the notary server is using a different key to sign notary
  461. # responses.
  462. continue
  463. verify_signed_json(response_json, server_name, key.verify_key)
  464. verified = True
  465. break
  466. if not verified:
  467. raise KeyLookupError(
  468. "Key response for %s is not signed by the origin server"
  469. % (server_name,)
  470. )
  471. for key_id, key_data in response_json["old_verify_keys"].items():
  472. if is_signing_algorithm_supported(key_id):
  473. key_base64 = key_data["key"]
  474. key_bytes = decode_base64(key_base64)
  475. verify_key = decode_verify_key_bytes(key_id, key_bytes)
  476. verify_keys[key_id] = FetchKeyResult(
  477. verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
  478. )
  479. key_json_bytes = encode_canonical_json(response_json)
  480. await make_deferred_yieldable(
  481. defer.gatherResults(
  482. [
  483. run_in_background(
  484. self.store.store_server_keys_json,
  485. server_name=server_name,
  486. key_id=key_id,
  487. from_server=from_server,
  488. ts_now_ms=time_added_ms,
  489. ts_expires_ms=ts_valid_until_ms,
  490. key_json_bytes=key_json_bytes,
  491. )
  492. for key_id in verify_keys
  493. ],
  494. consumeErrors=True,
  495. ).addErrback(unwrapFirstError)
  496. )
  497. return verify_keys
  498. class PerspectivesKeyFetcher(BaseV2KeyFetcher):
  499. """KeyFetcher impl which fetches keys from the "perspectives" servers"""
  500. def __init__(self, hs: "HomeServer"):
  501. super().__init__(hs)
  502. self.clock = hs.get_clock()
  503. self.client = hs.get_federation_http_client()
  504. self.key_servers = self.config.key.key_servers
  505. async def _fetch_keys(
  506. self, keys_to_fetch: List[_FetchKeyRequest]
  507. ) -> Dict[str, Dict[str, FetchKeyResult]]:
  508. """see KeyFetcher._fetch_keys"""
  509. async def get_key(key_server: TrustedKeyServer) -> Dict:
  510. try:
  511. return await self.get_server_verify_key_v2_indirect(
  512. keys_to_fetch, key_server
  513. )
  514. except KeyLookupError as e:
  515. logger.warning(
  516. "Key lookup failed from %r: %s", key_server.server_name, e
  517. )
  518. except Exception as e:
  519. logger.exception(
  520. "Unable to get key from %r: %s %s",
  521. key_server.server_name,
  522. type(e).__name__,
  523. str(e),
  524. )
  525. return {}
  526. results = await make_deferred_yieldable(
  527. defer.gatherResults(
  528. [run_in_background(get_key, server) for server in self.key_servers],
  529. consumeErrors=True,
  530. ).addErrback(unwrapFirstError)
  531. )
  532. union_of_keys: Dict[str, Dict[str, FetchKeyResult]] = {}
  533. for result in results:
  534. for server_name, keys in result.items():
  535. union_of_keys.setdefault(server_name, {}).update(keys)
  536. return union_of_keys
  537. async def get_server_verify_key_v2_indirect(
  538. self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer
  539. ) -> Dict[str, Dict[str, FetchKeyResult]]:
  540. """
  541. Args:
  542. keys_to_fetch:
  543. the keys to be fetched.
  544. key_server: notary server to query for the keys
  545. Returns:
  546. Map from server_name -> key_id -> FetchKeyResult
  547. Raises:
  548. KeyLookupError if there was an error processing the entire response from
  549. the server
  550. """
  551. perspective_name = key_server.server_name
  552. logger.info(
  553. "Requesting keys %s from notary server %s",
  554. keys_to_fetch,
  555. perspective_name,
  556. )
  557. request: JsonDict = {}
  558. for queue_value in keys_to_fetch:
  559. # there may be multiple requests for each server, so we have to merge
  560. # them intelligently.
  561. request_for_server = {
  562. key_id: {
  563. "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
  564. }
  565. for key_id in queue_value.key_ids
  566. }
  567. request.setdefault(queue_value.server_name, {}).update(request_for_server)
  568. logger.debug("Request to notary server %s: %s", perspective_name, request)
  569. try:
  570. query_response = await self.client.post_json(
  571. destination=perspective_name,
  572. path="/_matrix/key/v2/query",
  573. data={"server_keys": request},
  574. )
  575. except (NotRetryingDestination, RequestSendFailed) as e:
  576. # these both have str() representations which we can't really improve upon
  577. raise KeyLookupError(str(e))
  578. except HttpResponseException as e:
  579. raise KeyLookupError("Remote server returned an error: %s" % (e,))
  580. logger.debug(
  581. "Response from notary server %s: %s", perspective_name, query_response
  582. )
  583. keys: Dict[str, Dict[str, FetchKeyResult]] = {}
  584. added_keys: List[Tuple[str, str, FetchKeyResult]] = []
  585. time_now_ms = self.clock.time_msec()
  586. assert isinstance(query_response, dict)
  587. for response in query_response["server_keys"]:
  588. # do this first, so that we can give useful errors thereafter
  589. server_name = response.get("server_name")
  590. if not isinstance(server_name, str):
  591. raise KeyLookupError(
  592. "Malformed response from key notary server %s: invalid server_name"
  593. % (perspective_name,)
  594. )
  595. try:
  596. self._validate_perspectives_response(key_server, response)
  597. processed_response = await self.process_v2_response(
  598. perspective_name, response, time_added_ms=time_now_ms
  599. )
  600. except KeyLookupError as e:
  601. logger.warning(
  602. "Error processing response from key notary server %s for origin "
  603. "server %s: %s",
  604. perspective_name,
  605. server_name,
  606. e,
  607. )
  608. # we continue to process the rest of the response
  609. continue
  610. added_keys.extend(
  611. (server_name, key_id, key) for key_id, key in processed_response.items()
  612. )
  613. keys.setdefault(server_name, {}).update(processed_response)
  614. await self.store.store_server_verify_keys(
  615. perspective_name, time_now_ms, added_keys
  616. )
  617. return keys
  618. def _validate_perspectives_response(
  619. self, key_server: TrustedKeyServer, response: JsonDict
  620. ) -> None:
  621. """Optionally check the signature on the result of a /key/query request
  622. Args:
  623. key_server: the notary server that produced this result
  624. response: the json-decoded Server Keys response object
  625. """
  626. perspective_name = key_server.server_name
  627. perspective_keys = key_server.verify_keys
  628. if perspective_keys is None:
  629. # signature checking is disabled on this server
  630. return
  631. if (
  632. "signatures" not in response
  633. or perspective_name not in response["signatures"]
  634. ):
  635. raise KeyLookupError("Response not signed by the notary server")
  636. verified = False
  637. for key_id in response["signatures"][perspective_name]:
  638. if key_id in perspective_keys:
  639. verify_signed_json(response, perspective_name, perspective_keys[key_id])
  640. verified = True
  641. if not verified:
  642. raise KeyLookupError(
  643. "Response not signed with a known key: signed with: %r, known keys: %r"
  644. % (
  645. list(response["signatures"][perspective_name].keys()),
  646. list(perspective_keys.keys()),
  647. )
  648. )
  649. class ServerKeyFetcher(BaseV2KeyFetcher):
  650. """KeyFetcher impl which fetches keys from the origin servers"""
  651. def __init__(self, hs: "HomeServer"):
  652. super().__init__(hs)
  653. self.clock = hs.get_clock()
  654. self.client = hs.get_federation_http_client()
  655. async def get_keys(
  656. self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
  657. ) -> Dict[str, FetchKeyResult]:
  658. results = await self._queue.add_to_queue(
  659. _FetchKeyRequest(
  660. server_name=server_name,
  661. key_ids=key_ids,
  662. minimum_valid_until_ts=minimum_valid_until_ts,
  663. ),
  664. key=server_name,
  665. )
  666. return results.get(server_name, {})
  667. async def _fetch_keys(
  668. self, keys_to_fetch: List[_FetchKeyRequest]
  669. ) -> Dict[str, Dict[str, FetchKeyResult]]:
  670. """
  671. Args:
  672. keys_to_fetch:
  673. the keys to be fetched. server_name -> key_ids
  674. Returns:
  675. Map from server_name -> key_id -> FetchKeyResult
  676. """
  677. results = {}
  678. async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None:
  679. server_name = key_to_fetch_item.server_name
  680. key_ids = key_to_fetch_item.key_ids
  681. try:
  682. keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
  683. results[server_name] = keys
  684. except KeyLookupError as e:
  685. logger.warning(
  686. "Error looking up keys %s from %s: %s", key_ids, server_name, e
  687. )
  688. except Exception:
  689. logger.exception("Error getting keys %s from %s", key_ids, server_name)
  690. await yieldable_gather_results(get_key, keys_to_fetch)
  691. return results
  692. async def get_server_verify_key_v2_direct(
  693. self, server_name: str, key_ids: Iterable[str]
  694. ) -> Dict[str, FetchKeyResult]:
  695. """
  696. Args:
  697. server_name:
  698. key_ids:
  699. Returns:
  700. Map from key ID to lookup result
  701. Raises:
  702. KeyLookupError if there was a problem making the lookup
  703. """
  704. keys: Dict[str, FetchKeyResult] = {}
  705. for requested_key_id in key_ids:
  706. # we may have found this key as a side-effect of asking for another.
  707. if requested_key_id in keys:
  708. continue
  709. time_now_ms = self.clock.time_msec()
  710. try:
  711. response = await self.client.get_json(
  712. destination=server_name,
  713. path="/_matrix/key/v2/server/"
  714. + urllib.parse.quote(requested_key_id),
  715. ignore_backoff=True,
  716. # we only give the remote server 10s to respond. It should be an
  717. # easy request to handle, so if it doesn't reply within 10s, it's
  718. # probably not going to.
  719. #
  720. # Furthermore, when we are acting as a notary server, we cannot
  721. # wait all day for all of the origin servers, as the requesting
  722. # server will otherwise time out before we can respond.
  723. #
  724. # (Note that get_json may make 4 attempts, so this can still take
  725. # almost 45 seconds to fetch the headers, plus up to another 60s to
  726. # read the response).
  727. timeout=10000,
  728. )
  729. except (NotRetryingDestination, RequestSendFailed) as e:
  730. # these both have str() representations which we can't really improve
  731. # upon
  732. raise KeyLookupError(str(e))
  733. except HttpResponseException as e:
  734. raise KeyLookupError("Remote server returned an error: %s" % (e,))
  735. assert isinstance(response, dict)
  736. if response["server_name"] != server_name:
  737. raise KeyLookupError(
  738. "Expected a response for server %r not %r"
  739. % (server_name, response["server_name"])
  740. )
  741. response_keys = await self.process_v2_response(
  742. from_server=server_name,
  743. response_json=response,
  744. time_added_ms=time_now_ms,
  745. )
  746. await self.store.store_server_verify_keys(
  747. server_name,
  748. time_now_ms,
  749. ((server_name, key_id, key) for key_id, key in response_keys.items()),
  750. )
  751. keys.update(response_keys)
  752. return keys