test_keyring.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2017 New Vector Ltd.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import time
  16. import signedjson.key
  17. import signedjson.sign
  18. from mock import Mock
  19. from synapse.api.errors import SynapseError
  20. from synapse.crypto import keyring
  21. from synapse.util import async, logcontext
  22. from synapse.util.logcontext import LoggingContext
  23. from tests import unittest, utils
  24. from twisted.internet import defer
  25. class MockPerspectiveServer(object):
  26. def __init__(self):
  27. self.server_name = "mock_server"
  28. self.key = signedjson.key.generate_signing_key(0)
  29. def get_verify_keys(self):
  30. vk = signedjson.key.get_verify_key(self.key)
  31. return {
  32. "%s:%s" % (vk.alg, vk.version): vk,
  33. }
  34. def get_signed_key(self, server_name, verify_key):
  35. key_id = "%s:%s" % (verify_key.alg, verify_key.version)
  36. res = {
  37. "server_name": server_name,
  38. "old_verify_keys": {},
  39. "valid_until_ts": time.time() * 1000 + 3600,
  40. "verify_keys": {
  41. key_id: {
  42. "key": signedjson.key.encode_verify_key_base64(verify_key)
  43. }
  44. }
  45. }
  46. signedjson.sign.sign_json(res, self.server_name, self.key)
  47. return res
  48. class KeyringTestCase(unittest.TestCase):
  49. @defer.inlineCallbacks
  50. def setUp(self):
  51. self.mock_perspective_server = MockPerspectiveServer()
  52. self.http_client = Mock()
  53. self.hs = yield utils.setup_test_homeserver(
  54. handlers=None,
  55. http_client=self.http_client,
  56. )
  57. self.hs.config.perspectives = {
  58. self.mock_perspective_server.server_name:
  59. self.mock_perspective_server.get_verify_keys()
  60. }
  61. def check_context(self, _, expected):
  62. self.assertEquals(
  63. getattr(LoggingContext.current_context(), "request", None),
  64. expected
  65. )
  66. @defer.inlineCallbacks
  67. def test_wait_for_previous_lookups(self):
  68. sentinel_context = LoggingContext.current_context()
  69. kr = keyring.Keyring(self.hs)
  70. lookup_1_deferred = defer.Deferred()
  71. lookup_2_deferred = defer.Deferred()
  72. with LoggingContext("one") as context_one:
  73. context_one.request = "one"
  74. wait_1_deferred = kr.wait_for_previous_lookups(
  75. ["server1"],
  76. {"server1": lookup_1_deferred},
  77. )
  78. # there were no previous lookups, so the deferred should be ready
  79. self.assertTrue(wait_1_deferred.called)
  80. # ... so we should have preserved the LoggingContext.
  81. self.assertIs(LoggingContext.current_context(), context_one)
  82. wait_1_deferred.addBoth(self.check_context, "one")
  83. with LoggingContext("two") as context_two:
  84. context_two.request = "two"
  85. # set off another wait. It should block because the first lookup
  86. # hasn't yet completed.
  87. wait_2_deferred = kr.wait_for_previous_lookups(
  88. ["server1"],
  89. {"server1": lookup_2_deferred},
  90. )
  91. self.assertFalse(wait_2_deferred.called)
  92. # ... so we should have reset the LoggingContext.
  93. self.assertIs(LoggingContext.current_context(), sentinel_context)
  94. wait_2_deferred.addBoth(self.check_context, "two")
  95. # let the first lookup complete (in the sentinel context)
  96. lookup_1_deferred.callback(None)
  97. # now the second wait should complete and restore our
  98. # loggingcontext.
  99. yield wait_2_deferred
  100. @defer.inlineCallbacks
  101. def test_verify_json_objects_for_server_awaits_previous_requests(self):
  102. key1 = signedjson.key.generate_signing_key(1)
  103. kr = keyring.Keyring(self.hs)
  104. json1 = {}
  105. signedjson.sign.sign_json(json1, "server10", key1)
  106. persp_resp = {
  107. "server_keys": [
  108. self.mock_perspective_server.get_signed_key(
  109. "server10",
  110. signedjson.key.get_verify_key(key1)
  111. ),
  112. ]
  113. }
  114. persp_deferred = defer.Deferred()
  115. @defer.inlineCallbacks
  116. def get_perspectives(**kwargs):
  117. self.assertEquals(
  118. LoggingContext.current_context().request, "11",
  119. )
  120. with logcontext.PreserveLoggingContext():
  121. yield persp_deferred
  122. defer.returnValue(persp_resp)
  123. self.http_client.post_json.side_effect = get_perspectives
  124. with LoggingContext("11") as context_11:
  125. context_11.request = "11"
  126. # start off a first set of lookups
  127. res_deferreds = kr.verify_json_objects_for_server(
  128. [("server10", json1),
  129. ("server11", {})
  130. ]
  131. )
  132. # the unsigned json should be rejected pretty quickly
  133. self.assertTrue(res_deferreds[1].called)
  134. try:
  135. yield res_deferreds[1]
  136. self.assertFalse("unsigned json didn't cause a failure")
  137. except SynapseError:
  138. pass
  139. self.assertFalse(res_deferreds[0].called)
  140. res_deferreds[0].addBoth(self.check_context, None)
  141. # wait a tick for it to send the request to the perspectives server
  142. # (it first tries the datastore)
  143. yield async.sleep(1) # XXX find out why this takes so long!
  144. self.http_client.post_json.assert_called_once()
  145. self.assertIs(LoggingContext.current_context(), context_11)
  146. context_12 = LoggingContext("12")
  147. context_12.request = "12"
  148. with logcontext.PreserveLoggingContext(context_12):
  149. # a second request for a server with outstanding requests
  150. # should block rather than start a second call
  151. self.http_client.post_json.reset_mock()
  152. self.http_client.post_json.return_value = defer.Deferred()
  153. res_deferreds_2 = kr.verify_json_objects_for_server(
  154. [("server10", json1)],
  155. )
  156. yield async.sleep(01)
  157. self.http_client.post_json.assert_not_called()
  158. res_deferreds_2[0].addBoth(self.check_context, None)
  159. # complete the first request
  160. with logcontext.PreserveLoggingContext():
  161. persp_deferred.callback(persp_resp)
  162. self.assertIs(LoggingContext.current_context(), context_11)
  163. with logcontext.PreserveLoggingContext():
  164. yield res_deferreds[0]
  165. yield res_deferreds_2[0]
  166. @defer.inlineCallbacks
  167. def test_verify_json_for_server(self):
  168. kr = keyring.Keyring(self.hs)
  169. key1 = signedjson.key.generate_signing_key(1)
  170. yield self.hs.datastore.store_server_verify_key(
  171. "server9", "", time.time() * 1000,
  172. signedjson.key.get_verify_key(key1),
  173. )
  174. json1 = {}
  175. signedjson.sign.sign_json(json1, "server9", key1)
  176. sentinel_context = LoggingContext.current_context()
  177. with LoggingContext("one") as context_one:
  178. context_one.request = "one"
  179. defer = kr.verify_json_for_server("server9", {})
  180. try:
  181. yield defer
  182. self.fail("should fail on unsigned json")
  183. except SynapseError:
  184. pass
  185. self.assertIs(LoggingContext.current_context(), context_one)
  186. defer = kr.verify_json_for_server("server9", json1)
  187. self.assertFalse(defer.called)
  188. self.assertIs(LoggingContext.current_context(), sentinel_context)
  189. yield defer
  190. self.assertIs(LoggingContext.current_context(), context_one)