test_keyring.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2017 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import time
  16. from mock import Mock
  17. import canonicaljson
  18. import signedjson.key
  19. import signedjson.sign
  20. from signedjson.key import encode_verify_key_base64, get_verify_key
  21. from twisted.internet import defer
  22. from synapse.api.errors import SynapseError
  23. from synapse.crypto import keyring
  24. from synapse.crypto.keyring import (
  25. PerspectivesKeyFetcher,
  26. ServerKeyFetcher,
  27. StoreKeyFetcher,
  28. )
  29. from synapse.logging.context import (
  30. LoggingContext,
  31. PreserveLoggingContext,
  32. make_deferred_yieldable,
  33. )
  34. from synapse.storage.keys import FetchKeyResult
  35. from tests import unittest
  36. class MockPerspectiveServer(object):
  37. def __init__(self):
  38. self.server_name = "mock_server"
  39. self.key = signedjson.key.generate_signing_key(0)
  40. def get_verify_keys(self):
  41. vk = signedjson.key.get_verify_key(self.key)
  42. return {"%s:%s" % (vk.alg, vk.version): encode_verify_key_base64(vk)}
  43. def get_signed_key(self, server_name, verify_key):
  44. key_id = "%s:%s" % (verify_key.alg, verify_key.version)
  45. res = {
  46. "server_name": server_name,
  47. "old_verify_keys": {},
  48. "valid_until_ts": time.time() * 1000 + 3600,
  49. "verify_keys": {key_id: {"key": encode_verify_key_base64(verify_key)}},
  50. }
  51. self.sign_response(res)
  52. return res
  53. def sign_response(self, res):
  54. signedjson.sign.sign_json(res, self.server_name, self.key)
  55. class KeyringTestCase(unittest.HomeserverTestCase):
  56. def make_homeserver(self, reactor, clock):
  57. self.mock_perspective_server = MockPerspectiveServer()
  58. self.http_client = Mock()
  59. config = self.default_config()
  60. config["trusted_key_servers"] = [
  61. {
  62. "server_name": self.mock_perspective_server.server_name,
  63. "verify_keys": self.mock_perspective_server.get_verify_keys(),
  64. }
  65. ]
  66. return self.setup_test_homeserver(
  67. handlers=None, http_client=self.http_client, config=config
  68. )
  69. def check_context(self, _, expected):
  70. self.assertEquals(
  71. getattr(LoggingContext.current_context(), "request", None), expected
  72. )
  73. def test_wait_for_previous_lookups(self):
  74. kr = keyring.Keyring(self.hs)
  75. lookup_1_deferred = defer.Deferred()
  76. lookup_2_deferred = defer.Deferred()
  77. # we run the lookup in a logcontext so that the patched inlineCallbacks can check
  78. # it is doing the right thing with logcontexts.
  79. wait_1_deferred = run_in_context(
  80. kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
  81. )
  82. # there were no previous lookups, so the deferred should be ready
  83. self.successResultOf(wait_1_deferred)
  84. # set off another wait. It should block because the first lookup
  85. # hasn't yet completed.
  86. wait_2_deferred = run_in_context(
  87. kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
  88. )
  89. self.assertFalse(wait_2_deferred.called)
  90. # let the first lookup complete (in the sentinel context)
  91. lookup_1_deferred.callback(None)
  92. # now the second wait should complete.
  93. self.successResultOf(wait_2_deferred)
  94. def test_verify_json_objects_for_server_awaits_previous_requests(self):
  95. key1 = signedjson.key.generate_signing_key(1)
  96. kr = keyring.Keyring(self.hs)
  97. json1 = {}
  98. signedjson.sign.sign_json(json1, "server10", key1)
  99. persp_resp = {
  100. "server_keys": [
  101. self.mock_perspective_server.get_signed_key(
  102. "server10", signedjson.key.get_verify_key(key1)
  103. )
  104. ]
  105. }
  106. persp_deferred = defer.Deferred()
  107. @defer.inlineCallbacks
  108. def get_perspectives(**kwargs):
  109. self.assertEquals(LoggingContext.current_context().request, "11")
  110. with PreserveLoggingContext():
  111. yield persp_deferred
  112. defer.returnValue(persp_resp)
  113. self.http_client.post_json.side_effect = get_perspectives
  114. # start off a first set of lookups
  115. @defer.inlineCallbacks
  116. def first_lookup():
  117. with LoggingContext("11") as context_11:
  118. context_11.request = "11"
  119. res_deferreds = kr.verify_json_objects_for_server(
  120. [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
  121. )
  122. # the unsigned json should be rejected pretty quickly
  123. self.assertTrue(res_deferreds[1].called)
  124. try:
  125. yield res_deferreds[1]
  126. self.assertFalse("unsigned json didn't cause a failure")
  127. except SynapseError:
  128. pass
  129. self.assertFalse(res_deferreds[0].called)
  130. res_deferreds[0].addBoth(self.check_context, None)
  131. yield make_deferred_yieldable(res_deferreds[0])
  132. # let verify_json_objects_for_server finish its work before we kill the
  133. # logcontext
  134. yield self.clock.sleep(0)
  135. d0 = first_lookup()
  136. # wait a tick for it to send the request to the perspectives server
  137. # (it first tries the datastore)
  138. self.pump()
  139. self.http_client.post_json.assert_called_once()
  140. # a second request for a server with outstanding requests
  141. # should block rather than start a second call
  142. @defer.inlineCallbacks
  143. def second_lookup():
  144. with LoggingContext("12") as context_12:
  145. context_12.request = "12"
  146. self.http_client.post_json.reset_mock()
  147. self.http_client.post_json.return_value = defer.Deferred()
  148. res_deferreds_2 = kr.verify_json_objects_for_server(
  149. [("server10", json1, 0, "test")]
  150. )
  151. res_deferreds_2[0].addBoth(self.check_context, None)
  152. yield make_deferred_yieldable(res_deferreds_2[0])
  153. # let verify_json_objects_for_server finish its work before we kill the
  154. # logcontext
  155. yield self.clock.sleep(0)
  156. d2 = second_lookup()
  157. self.pump()
  158. self.http_client.post_json.assert_not_called()
  159. # complete the first request
  160. persp_deferred.callback(persp_resp)
  161. self.get_success(d0)
  162. self.get_success(d2)
  163. def test_verify_json_for_server(self):
  164. kr = keyring.Keyring(self.hs)
  165. key1 = signedjson.key.generate_signing_key(1)
  166. r = self.hs.datastore.store_server_verify_keys(
  167. "server9",
  168. time.time() * 1000,
  169. [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
  170. )
  171. self.get_success(r)
  172. json1 = {}
  173. signedjson.sign.sign_json(json1, "server9", key1)
  174. # should fail immediately on an unsigned object
  175. d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
  176. self.failureResultOf(d, SynapseError)
  177. # should suceed on a signed object
  178. d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
  179. # self.assertFalse(d.called)
  180. self.get_success(d)
  181. def test_verify_json_for_server_with_null_valid_until_ms(self):
  182. """Tests that we correctly handle key requests for keys we've stored
  183. with a null `ts_valid_until_ms`
  184. """
  185. mock_fetcher = keyring.KeyFetcher()
  186. mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
  187. kr = keyring.Keyring(
  188. self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
  189. )
  190. key1 = signedjson.key.generate_signing_key(1)
  191. r = self.hs.datastore.store_server_verify_keys(
  192. "server9",
  193. time.time() * 1000,
  194. [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],
  195. )
  196. self.get_success(r)
  197. json1 = {}
  198. signedjson.sign.sign_json(json1, "server9", key1)
  199. # should fail immediately on an unsigned object
  200. d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
  201. self.failureResultOf(d, SynapseError)
  202. # should fail on a signed object with a non-zero minimum_valid_until_ms,
  203. # as it tries to refetch the keys and fails.
  204. d = _verify_json_for_server(
  205. kr, "server9", json1, 500, "test signed non-zero min"
  206. )
  207. self.get_failure(d, SynapseError)
  208. # We expect the keyring tried to refetch the key once.
  209. mock_fetcher.get_keys.assert_called_once_with(
  210. {"server9": {get_key_id(key1): 500}}
  211. )
  212. # should succeed on a signed object with a 0 minimum_valid_until_ms
  213. d = _verify_json_for_server(
  214. kr, "server9", json1, 0, "test signed with zero min"
  215. )
  216. self.get_success(d)
  217. def test_verify_json_dedupes_key_requests(self):
  218. """Two requests for the same key should be deduped."""
  219. key1 = signedjson.key.generate_signing_key(1)
  220. def get_keys(keys_to_fetch):
  221. # there should only be one request object (with the max validity)
  222. self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
  223. return defer.succeed(
  224. {
  225. "server1": {
  226. get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
  227. }
  228. }
  229. )
  230. mock_fetcher = keyring.KeyFetcher()
  231. mock_fetcher.get_keys = Mock(side_effect=get_keys)
  232. kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
  233. json1 = {}
  234. signedjson.sign.sign_json(json1, "server1", key1)
  235. # the first request should succeed; the second should fail because the key
  236. # has expired
  237. results = kr.verify_json_objects_for_server(
  238. [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
  239. )
  240. self.assertEqual(len(results), 2)
  241. self.get_success(results[0])
  242. e = self.get_failure(results[1], SynapseError).value
  243. self.assertEqual(e.errcode, "M_UNAUTHORIZED")
  244. self.assertEqual(e.code, 401)
  245. # there should have been a single call to the fetcher
  246. mock_fetcher.get_keys.assert_called_once()
  247. def test_verify_json_falls_back_to_other_fetchers(self):
  248. """If the first fetcher cannot provide a recent enough key, we fall back"""
  249. key1 = signedjson.key.generate_signing_key(1)
  250. def get_keys1(keys_to_fetch):
  251. self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
  252. return defer.succeed(
  253. {
  254. "server1": {
  255. get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
  256. }
  257. }
  258. )
  259. def get_keys2(keys_to_fetch):
  260. self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
  261. return defer.succeed(
  262. {
  263. "server1": {
  264. get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
  265. }
  266. }
  267. )
  268. mock_fetcher1 = keyring.KeyFetcher()
  269. mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
  270. mock_fetcher2 = keyring.KeyFetcher()
  271. mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
  272. kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
  273. json1 = {}
  274. signedjson.sign.sign_json(json1, "server1", key1)
  275. results = kr.verify_json_objects_for_server(
  276. [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
  277. )
  278. self.assertEqual(len(results), 2)
  279. self.get_success(results[0])
  280. e = self.get_failure(results[1], SynapseError).value
  281. self.assertEqual(e.errcode, "M_UNAUTHORIZED")
  282. self.assertEqual(e.code, 401)
  283. # there should have been a single call to each fetcher
  284. mock_fetcher1.get_keys.assert_called_once()
  285. mock_fetcher2.get_keys.assert_called_once()
  286. class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
  287. def make_homeserver(self, reactor, clock):
  288. self.http_client = Mock()
  289. hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
  290. return hs
  291. def test_get_keys_from_server(self):
  292. # arbitrarily advance the clock a bit
  293. self.reactor.advance(100)
  294. SERVER_NAME = "server2"
  295. fetcher = ServerKeyFetcher(self.hs)
  296. testkey = signedjson.key.generate_signing_key("ver1")
  297. testverifykey = signedjson.key.get_verify_key(testkey)
  298. testverifykey_id = "ed25519:ver1"
  299. VALID_UNTIL_TS = 200 * 1000
  300. # valid response
  301. response = {
  302. "server_name": SERVER_NAME,
  303. "old_verify_keys": {},
  304. "valid_until_ts": VALID_UNTIL_TS,
  305. "verify_keys": {
  306. testverifykey_id: {
  307. "key": signedjson.key.encode_verify_key_base64(testverifykey)
  308. }
  309. },
  310. }
  311. signedjson.sign.sign_json(response, SERVER_NAME, testkey)
  312. def get_json(destination, path, **kwargs):
  313. self.assertEqual(destination, SERVER_NAME)
  314. self.assertEqual(path, "/_matrix/key/v2/server/key1")
  315. return response
  316. self.http_client.get_json.side_effect = get_json
  317. keys_to_fetch = {SERVER_NAME: {"key1": 0}}
  318. keys = self.get_success(fetcher.get_keys(keys_to_fetch))
  319. k = keys[SERVER_NAME][testverifykey_id]
  320. self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
  321. self.assertEqual(k.verify_key, testverifykey)
  322. self.assertEqual(k.verify_key.alg, "ed25519")
  323. self.assertEqual(k.verify_key.version, "ver1")
  324. # check that the perspectives store is correctly updated
  325. lookup_triplet = (SERVER_NAME, testverifykey_id, None)
  326. key_json = self.get_success(
  327. self.hs.get_datastore().get_server_keys_json([lookup_triplet])
  328. )
  329. res = key_json[lookup_triplet]
  330. self.assertEqual(len(res), 1)
  331. res = res[0]
  332. self.assertEqual(res["key_id"], testverifykey_id)
  333. self.assertEqual(res["from_server"], SERVER_NAME)
  334. self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
  335. self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
  336. # we expect it to be encoded as canonical json *before* it hits the db
  337. self.assertEqual(
  338. bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
  339. )
  340. # change the server name: the result should be ignored
  341. response["server_name"] = "OTHER_SERVER"
  342. keys = self.get_success(fetcher.get_keys(keys_to_fetch))
  343. self.assertEqual(keys, {})
  344. class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
  345. def make_homeserver(self, reactor, clock):
  346. self.mock_perspective_server = MockPerspectiveServer()
  347. self.http_client = Mock()
  348. config = self.default_config()
  349. config["trusted_key_servers"] = [
  350. {
  351. "server_name": self.mock_perspective_server.server_name,
  352. "verify_keys": self.mock_perspective_server.get_verify_keys(),
  353. }
  354. ]
  355. return self.setup_test_homeserver(
  356. handlers=None, http_client=self.http_client, config=config
  357. )
  358. def test_get_keys_from_perspectives(self):
  359. # arbitrarily advance the clock a bit
  360. self.reactor.advance(100)
  361. fetcher = PerspectivesKeyFetcher(self.hs)
  362. SERVER_NAME = "server2"
  363. testkey = signedjson.key.generate_signing_key("ver1")
  364. testverifykey = signedjson.key.get_verify_key(testkey)
  365. testverifykey_id = "ed25519:ver1"
  366. VALID_UNTIL_TS = 200 * 1000
  367. # valid response
  368. response = {
  369. "server_name": SERVER_NAME,
  370. "old_verify_keys": {},
  371. "valid_until_ts": VALID_UNTIL_TS,
  372. "verify_keys": {
  373. testverifykey_id: {
  374. "key": signedjson.key.encode_verify_key_base64(testverifykey)
  375. }
  376. },
  377. }
  378. # the response must be signed by both the origin server and the perspectives
  379. # server.
  380. signedjson.sign.sign_json(response, SERVER_NAME, testkey)
  381. self.mock_perspective_server.sign_response(response)
  382. def post_json(destination, path, data, **kwargs):
  383. self.assertEqual(destination, self.mock_perspective_server.server_name)
  384. self.assertEqual(path, "/_matrix/key/v2/query")
  385. # check that the request is for the expected key
  386. q = data["server_keys"]
  387. self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
  388. return {"server_keys": [response]}
  389. self.http_client.post_json.side_effect = post_json
  390. keys_to_fetch = {SERVER_NAME: {"key1": 0}}
  391. keys = self.get_success(fetcher.get_keys(keys_to_fetch))
  392. self.assertIn(SERVER_NAME, keys)
  393. k = keys[SERVER_NAME][testverifykey_id]
  394. self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
  395. self.assertEqual(k.verify_key, testverifykey)
  396. self.assertEqual(k.verify_key.alg, "ed25519")
  397. self.assertEqual(k.verify_key.version, "ver1")
  398. # check that the perspectives store is correctly updated
  399. lookup_triplet = (SERVER_NAME, testverifykey_id, None)
  400. key_json = self.get_success(
  401. self.hs.get_datastore().get_server_keys_json([lookup_triplet])
  402. )
  403. res = key_json[lookup_triplet]
  404. self.assertEqual(len(res), 1)
  405. res = res[0]
  406. self.assertEqual(res["key_id"], testverifykey_id)
  407. self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
  408. self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
  409. self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
  410. self.assertEqual(
  411. bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
  412. )
  413. def test_invalid_perspectives_responses(self):
  414. """Check that invalid responses from the perspectives server are rejected"""
  415. # arbitrarily advance the clock a bit
  416. self.reactor.advance(100)
  417. SERVER_NAME = "server2"
  418. testkey = signedjson.key.generate_signing_key("ver1")
  419. testverifykey = signedjson.key.get_verify_key(testkey)
  420. testverifykey_id = "ed25519:ver1"
  421. VALID_UNTIL_TS = 200 * 1000
  422. def build_response():
  423. # valid response
  424. response = {
  425. "server_name": SERVER_NAME,
  426. "old_verify_keys": {},
  427. "valid_until_ts": VALID_UNTIL_TS,
  428. "verify_keys": {
  429. testverifykey_id: {
  430. "key": signedjson.key.encode_verify_key_base64(testverifykey)
  431. }
  432. },
  433. }
  434. # the response must be signed by both the origin server and the perspectives
  435. # server.
  436. signedjson.sign.sign_json(response, SERVER_NAME, testkey)
  437. self.mock_perspective_server.sign_response(response)
  438. return response
  439. def get_key_from_perspectives(response):
  440. fetcher = PerspectivesKeyFetcher(self.hs)
  441. keys_to_fetch = {SERVER_NAME: {"key1": 0}}
  442. def post_json(destination, path, data, **kwargs):
  443. self.assertEqual(destination, self.mock_perspective_server.server_name)
  444. self.assertEqual(path, "/_matrix/key/v2/query")
  445. return {"server_keys": [response]}
  446. self.http_client.post_json.side_effect = post_json
  447. return self.get_success(fetcher.get_keys(keys_to_fetch))
  448. # start with a valid response so we can check we are testing the right thing
  449. response = build_response()
  450. keys = get_key_from_perspectives(response)
  451. k = keys[SERVER_NAME][testverifykey_id]
  452. self.assertEqual(k.verify_key, testverifykey)
  453. # remove the perspectives server's signature
  454. response = build_response()
  455. del response["signatures"][self.mock_perspective_server.server_name]
  456. self.http_client.post_json.return_value = {"server_keys": [response]}
  457. keys = get_key_from_perspectives(response)
  458. self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
  459. # remove the origin server's signature
  460. response = build_response()
  461. del response["signatures"][SERVER_NAME]
  462. self.http_client.post_json.return_value = {"server_keys": [response]}
  463. keys = get_key_from_perspectives(response)
  464. self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
  465. def get_key_id(key):
  466. """Get the matrix ID tag for a given SigningKey or VerifyKey"""
  467. return "%s:%s" % (key.alg, key.version)
  468. @defer.inlineCallbacks
  469. def run_in_context(f, *args, **kwargs):
  470. with LoggingContext("testctx") as ctx:
  471. # we set the "request" prop to make it easier to follow what's going on in the
  472. # logs.
  473. ctx.request = "testctx"
  474. rv = yield f(*args, **kwargs)
  475. defer.returnValue(rv)
  476. def _verify_json_for_server(kr, *args):
  477. """thin wrapper around verify_json_for_server which makes sure it is wrapped
  478. with the patched defer.inlineCallbacks.
  479. """
  480. @defer.inlineCallbacks
  481. def v():
  482. rv1 = yield kr.verify_json_for_server(*args)
  483. defer.returnValue(rv1)
  484. return run_in_context(v)