test_keyring.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2017 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import time
  16. from mock import Mock
  17. import canonicaljson
  18. import signedjson.key
  19. import signedjson.sign
  20. from twisted.internet import defer
  21. from synapse.api.errors import SynapseError
  22. from synapse.crypto import keyring
  23. from synapse.crypto.keyring import KeyLookupError
  24. from synapse.util import logcontext
  25. from synapse.util.logcontext import LoggingContext
  26. from tests import unittest
  27. class MockPerspectiveServer(object):
  28. def __init__(self):
  29. self.server_name = "mock_server"
  30. self.key = signedjson.key.generate_signing_key(0)
  31. def get_verify_keys(self):
  32. vk = signedjson.key.get_verify_key(self.key)
  33. return {"%s:%s" % (vk.alg, vk.version): vk}
  34. def get_signed_key(self, server_name, verify_key):
  35. key_id = "%s:%s" % (verify_key.alg, verify_key.version)
  36. res = {
  37. "server_name": server_name,
  38. "old_verify_keys": {},
  39. "valid_until_ts": time.time() * 1000 + 3600,
  40. "verify_keys": {
  41. key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
  42. },
  43. }
  44. return self.get_signed_response(res)
  45. def get_signed_response(self, res):
  46. signedjson.sign.sign_json(res, self.server_name, self.key)
  47. return res
  48. class KeyringTestCase(unittest.HomeserverTestCase):
  49. def make_homeserver(self, reactor, clock):
  50. self.mock_perspective_server = MockPerspectiveServer()
  51. self.http_client = Mock()
  52. hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
  53. keys = self.mock_perspective_server.get_verify_keys()
  54. hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
  55. return hs
  56. def check_context(self, _, expected):
  57. self.assertEquals(
  58. getattr(LoggingContext.current_context(), "request", None), expected
  59. )
  60. def test_wait_for_previous_lookups(self):
  61. kr = keyring.Keyring(self.hs)
  62. lookup_1_deferred = defer.Deferred()
  63. lookup_2_deferred = defer.Deferred()
  64. # we run the lookup in a logcontext so that the patched inlineCallbacks can check
  65. # it is doing the right thing with logcontexts.
  66. wait_1_deferred = run_in_context(
  67. kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred}
  68. )
  69. # there were no previous lookups, so the deferred should be ready
  70. self.successResultOf(wait_1_deferred)
  71. # set off another wait. It should block because the first lookup
  72. # hasn't yet completed.
  73. wait_2_deferred = run_in_context(
  74. kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred}
  75. )
  76. self.assertFalse(wait_2_deferred.called)
  77. # let the first lookup complete (in the sentinel context)
  78. lookup_1_deferred.callback(None)
  79. # now the second wait should complete.
  80. self.successResultOf(wait_2_deferred)
  81. def test_verify_json_objects_for_server_awaits_previous_requests(self):
  82. key1 = signedjson.key.generate_signing_key(1)
  83. kr = keyring.Keyring(self.hs)
  84. json1 = {}
  85. signedjson.sign.sign_json(json1, "server10", key1)
  86. persp_resp = {
  87. "server_keys": [
  88. self.mock_perspective_server.get_signed_key(
  89. "server10", signedjson.key.get_verify_key(key1)
  90. )
  91. ]
  92. }
  93. persp_deferred = defer.Deferred()
  94. @defer.inlineCallbacks
  95. def get_perspectives(**kwargs):
  96. self.assertEquals(LoggingContext.current_context().request, "11")
  97. with logcontext.PreserveLoggingContext():
  98. yield persp_deferred
  99. defer.returnValue(persp_resp)
  100. self.http_client.post_json.side_effect = get_perspectives
  101. # start off a first set of lookups
  102. @defer.inlineCallbacks
  103. def first_lookup():
  104. with LoggingContext("11") as context_11:
  105. context_11.request = "11"
  106. res_deferreds = kr.verify_json_objects_for_server(
  107. [("server10", json1), ("server11", {})]
  108. )
  109. # the unsigned json should be rejected pretty quickly
  110. self.assertTrue(res_deferreds[1].called)
  111. try:
  112. yield res_deferreds[1]
  113. self.assertFalse("unsigned json didn't cause a failure")
  114. except SynapseError:
  115. pass
  116. self.assertFalse(res_deferreds[0].called)
  117. res_deferreds[0].addBoth(self.check_context, None)
  118. yield logcontext.make_deferred_yieldable(res_deferreds[0])
  119. # let verify_json_objects_for_server finish its work before we kill the
  120. # logcontext
  121. yield self.clock.sleep(0)
  122. d0 = first_lookup()
  123. # wait a tick for it to send the request to the perspectives server
  124. # (it first tries the datastore)
  125. self.pump()
  126. self.http_client.post_json.assert_called_once()
  127. # a second request for a server with outstanding requests
  128. # should block rather than start a second call
  129. @defer.inlineCallbacks
  130. def second_lookup():
  131. with LoggingContext("12") as context_12:
  132. context_12.request = "12"
  133. self.http_client.post_json.reset_mock()
  134. self.http_client.post_json.return_value = defer.Deferred()
  135. res_deferreds_2 = kr.verify_json_objects_for_server(
  136. [("server10", json1)]
  137. )
  138. res_deferreds_2[0].addBoth(self.check_context, None)
  139. yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
  140. # let verify_json_objects_for_server finish its work before we kill the
  141. # logcontext
  142. yield self.clock.sleep(0)
  143. d2 = second_lookup()
  144. self.pump()
  145. self.http_client.post_json.assert_not_called()
  146. # complete the first request
  147. persp_deferred.callback(persp_resp)
  148. self.get_success(d0)
  149. self.get_success(d2)
  150. def test_verify_json_for_server(self):
  151. kr = keyring.Keyring(self.hs)
  152. key1 = signedjson.key.generate_signing_key(1)
  153. r = self.hs.datastore.store_server_verify_key(
  154. "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
  155. )
  156. self.get_success(r)
  157. json1 = {}
  158. signedjson.sign.sign_json(json1, "server9", key1)
  159. # should fail immediately on an unsigned object
  160. d = _verify_json_for_server(kr, "server9", {})
  161. self.failureResultOf(d, SynapseError)
  162. d = _verify_json_for_server(kr, "server9", json1)
  163. self.assertFalse(d.called)
  164. self.get_success(d)
  165. def test_get_keys_from_server(self):
  166. # arbitrarily advance the clock a bit
  167. self.reactor.advance(100)
  168. SERVER_NAME = "server2"
  169. kr = keyring.Keyring(self.hs)
  170. testkey = signedjson.key.generate_signing_key("ver1")
  171. testverifykey = signedjson.key.get_verify_key(testkey)
  172. testverifykey_id = "ed25519:ver1"
  173. VALID_UNTIL_TS = 1000
  174. # valid response
  175. response = {
  176. "server_name": SERVER_NAME,
  177. "old_verify_keys": {},
  178. "valid_until_ts": VALID_UNTIL_TS,
  179. "verify_keys": {
  180. testverifykey_id: {
  181. "key": signedjson.key.encode_verify_key_base64(testverifykey)
  182. }
  183. },
  184. }
  185. signedjson.sign.sign_json(response, SERVER_NAME, testkey)
  186. def get_json(destination, path, **kwargs):
  187. self.assertEqual(destination, SERVER_NAME)
  188. self.assertEqual(path, "/_matrix/key/v2/server/key1")
  189. return response
  190. self.http_client.get_json.side_effect = get_json
  191. server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
  192. keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
  193. k = keys[SERVER_NAME][testverifykey_id]
  194. self.assertEqual(k, testverifykey)
  195. self.assertEqual(k.alg, "ed25519")
  196. self.assertEqual(k.version, "ver1")
  197. # check that the perspectives store is correctly updated
  198. lookup_triplet = (SERVER_NAME, testverifykey_id, None)
  199. key_json = self.get_success(
  200. self.hs.get_datastore().get_server_keys_json([lookup_triplet])
  201. )
  202. res = key_json[lookup_triplet]
  203. self.assertEqual(len(res), 1)
  204. res = res[0]
  205. self.assertEqual(res["key_id"], testverifykey_id)
  206. self.assertEqual(res["from_server"], SERVER_NAME)
  207. self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
  208. self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
  209. # we expect it to be encoded as canonical json *before* it hits the db
  210. self.assertEqual(
  211. bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
  212. )
  213. # change the server name: it should cause a rejection
  214. response["server_name"] = "OTHER_SERVER"
  215. self.get_failure(
  216. kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
  217. )
  218. def test_get_keys_from_perspectives(self):
  219. # arbitrarily advance the clock a bit
  220. self.reactor.advance(100)
  221. SERVER_NAME = "server2"
  222. kr = keyring.Keyring(self.hs)
  223. testkey = signedjson.key.generate_signing_key("ver1")
  224. testverifykey = signedjson.key.get_verify_key(testkey)
  225. testverifykey_id = "ed25519:ver1"
  226. VALID_UNTIL_TS = 200 * 1000
  227. # valid response
  228. response = {
  229. "server_name": SERVER_NAME,
  230. "old_verify_keys": {},
  231. "valid_until_ts": VALID_UNTIL_TS,
  232. "verify_keys": {
  233. testverifykey_id: {
  234. "key": signedjson.key.encode_verify_key_base64(testverifykey)
  235. }
  236. },
  237. }
  238. persp_resp = {
  239. "server_keys": [self.mock_perspective_server.get_signed_response(response)]
  240. }
  241. def post_json(destination, path, data, **kwargs):
  242. self.assertEqual(destination, self.mock_perspective_server.server_name)
  243. self.assertEqual(path, "/_matrix/key/v2/query")
  244. # check that the request is for the expected key
  245. q = data["server_keys"]
  246. self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
  247. return persp_resp
  248. self.http_client.post_json.side_effect = post_json
  249. server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
  250. keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
  251. self.assertIn(SERVER_NAME, keys)
  252. k = keys[SERVER_NAME][testverifykey_id]
  253. self.assertEqual(k, testverifykey)
  254. self.assertEqual(k.alg, "ed25519")
  255. self.assertEqual(k.version, "ver1")
  256. # check that the perspectives store is correctly updated
  257. lookup_triplet = (SERVER_NAME, testverifykey_id, None)
  258. key_json = self.get_success(
  259. self.hs.get_datastore().get_server_keys_json([lookup_triplet])
  260. )
  261. res = key_json[lookup_triplet]
  262. self.assertEqual(len(res), 1)
  263. res = res[0]
  264. self.assertEqual(res["key_id"], testverifykey_id)
  265. self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
  266. self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
  267. self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
  268. self.assertEqual(
  269. bytes(res["key_json"]),
  270. canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
  271. )
  272. @defer.inlineCallbacks
  273. def run_in_context(f, *args, **kwargs):
  274. with LoggingContext("testctx"):
  275. rv = yield f(*args, **kwargs)
  276. defer.returnValue(rv)
  277. def _verify_json_for_server(keyring, server_name, json_object):
  278. """thin wrapper around verify_json_for_server which makes sure it is wrapped
  279. with the patched defer.inlineCallbacks.
  280. """
  281. @defer.inlineCallbacks
  282. def v():
  283. rv1 = yield keyring.verify_json_for_server(server_name, json_object)
  284. defer.returnValue(rv1)
  285. return run_in_context(v)