test_keyring.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2017 New Vector Ltd.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import time
  16. from mock import Mock
  17. import signedjson.key
  18. import signedjson.sign
  19. from twisted.internet import defer, reactor
  20. from synapse.api.errors import SynapseError
  21. from synapse.crypto import keyring
  22. from synapse.util import Clock, logcontext
  23. from synapse.util.logcontext import LoggingContext
  24. from tests import unittest, utils
  25. class MockPerspectiveServer(object):
  26. def __init__(self):
  27. self.server_name = "mock_server"
  28. self.key = signedjson.key.generate_signing_key(0)
  29. def get_verify_keys(self):
  30. vk = signedjson.key.get_verify_key(self.key)
  31. return {"%s:%s" % (vk.alg, vk.version): vk}
  32. def get_signed_key(self, server_name, verify_key):
  33. key_id = "%s:%s" % (verify_key.alg, verify_key.version)
  34. res = {
  35. "server_name": server_name,
  36. "old_verify_keys": {},
  37. "valid_until_ts": time.time() * 1000 + 3600,
  38. "verify_keys": {
  39. key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
  40. },
  41. }
  42. signedjson.sign.sign_json(res, self.server_name, self.key)
  43. return res
  44. class KeyringTestCase(unittest.TestCase):
  45. @defer.inlineCallbacks
  46. def setUp(self):
  47. self.mock_perspective_server = MockPerspectiveServer()
  48. self.http_client = Mock()
  49. self.hs = yield utils.setup_test_homeserver(
  50. self.addCleanup, handlers=None, http_client=self.http_client
  51. )
  52. keys = self.mock_perspective_server.get_verify_keys()
  53. self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
  54. def assert_sentinel_context(self):
  55. if LoggingContext.current_context() != LoggingContext.sentinel:
  56. self.fail(
  57. "Expected sentinel context but got %s" % (
  58. LoggingContext.current_context(),
  59. )
  60. )
  61. def check_context(self, _, expected):
  62. self.assertEquals(
  63. getattr(LoggingContext.current_context(), "request", None), expected
  64. )
  65. @defer.inlineCallbacks
  66. def test_wait_for_previous_lookups(self):
  67. kr = keyring.Keyring(self.hs)
  68. lookup_1_deferred = defer.Deferred()
  69. lookup_2_deferred = defer.Deferred()
  70. with LoggingContext("one") as context_one:
  71. context_one.request = "one"
  72. wait_1_deferred = kr.wait_for_previous_lookups(
  73. ["server1"], {"server1": lookup_1_deferred}
  74. )
  75. # there were no previous lookups, so the deferred should be ready
  76. self.assertTrue(wait_1_deferred.called)
  77. # ... so we should have preserved the LoggingContext.
  78. self.assertIs(LoggingContext.current_context(), context_one)
  79. wait_1_deferred.addBoth(self.check_context, "one")
  80. with LoggingContext("two") as context_two:
  81. context_two.request = "two"
  82. # set off another wait. It should block because the first lookup
  83. # hasn't yet completed.
  84. wait_2_deferred = kr.wait_for_previous_lookups(
  85. ["server1"], {"server1": lookup_2_deferred}
  86. )
  87. self.assertFalse(wait_2_deferred.called)
  88. # ... so we should have reset the LoggingContext.
  89. self.assert_sentinel_context()
  90. wait_2_deferred.addBoth(self.check_context, "two")
  91. # let the first lookup complete (in the sentinel context)
  92. lookup_1_deferred.callback(None)
  93. # now the second wait should complete and restore our
  94. # loggingcontext.
  95. yield wait_2_deferred
  96. @defer.inlineCallbacks
  97. def test_verify_json_objects_for_server_awaits_previous_requests(self):
  98. clock = Clock(reactor)
  99. key1 = signedjson.key.generate_signing_key(1)
  100. kr = keyring.Keyring(self.hs)
  101. json1 = {}
  102. signedjson.sign.sign_json(json1, "server10", key1)
  103. persp_resp = {
  104. "server_keys": [
  105. self.mock_perspective_server.get_signed_key(
  106. "server10", signedjson.key.get_verify_key(key1)
  107. )
  108. ]
  109. }
  110. persp_deferred = defer.Deferred()
  111. @defer.inlineCallbacks
  112. def get_perspectives(**kwargs):
  113. self.assertEquals(LoggingContext.current_context().request, "11")
  114. with logcontext.PreserveLoggingContext():
  115. yield persp_deferred
  116. defer.returnValue(persp_resp)
  117. self.http_client.post_json.side_effect = get_perspectives
  118. with LoggingContext("11") as context_11:
  119. context_11.request = "11"
  120. # start off a first set of lookups
  121. res_deferreds = kr.verify_json_objects_for_server(
  122. [("server10", json1), ("server11", {})]
  123. )
  124. # the unsigned json should be rejected pretty quickly
  125. self.assertTrue(res_deferreds[1].called)
  126. try:
  127. yield res_deferreds[1]
  128. self.assertFalse("unsigned json didn't cause a failure")
  129. except SynapseError:
  130. pass
  131. self.assertFalse(res_deferreds[0].called)
  132. res_deferreds[0].addBoth(self.check_context, None)
  133. # wait a tick for it to send the request to the perspectives server
  134. # (it first tries the datastore)
  135. yield clock.sleep(1) # XXX find out why this takes so long!
  136. self.http_client.post_json.assert_called_once()
  137. self.assertIs(LoggingContext.current_context(), context_11)
  138. context_12 = LoggingContext("12")
  139. context_12.request = "12"
  140. with logcontext.PreserveLoggingContext(context_12):
  141. # a second request for a server with outstanding requests
  142. # should block rather than start a second call
  143. self.http_client.post_json.reset_mock()
  144. self.http_client.post_json.return_value = defer.Deferred()
  145. res_deferreds_2 = kr.verify_json_objects_for_server(
  146. [("server10", json1)]
  147. )
  148. yield clock.sleep(1)
  149. self.http_client.post_json.assert_not_called()
  150. res_deferreds_2[0].addBoth(self.check_context, None)
  151. # complete the first request
  152. with logcontext.PreserveLoggingContext():
  153. persp_deferred.callback(persp_resp)
  154. self.assertIs(LoggingContext.current_context(), context_11)
  155. with logcontext.PreserveLoggingContext():
  156. yield res_deferreds[0]
  157. yield res_deferreds_2[0]
  158. @defer.inlineCallbacks
  159. def test_verify_json_for_server(self):
  160. kr = keyring.Keyring(self.hs)
  161. key1 = signedjson.key.generate_signing_key(1)
  162. yield self.hs.datastore.store_server_verify_key(
  163. "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
  164. )
  165. json1 = {}
  166. signedjson.sign.sign_json(json1, "server9", key1)
  167. with LoggingContext("one") as context_one:
  168. context_one.request = "one"
  169. defer = kr.verify_json_for_server("server9", {})
  170. try:
  171. yield defer
  172. self.fail("should fail on unsigned json")
  173. except SynapseError:
  174. pass
  175. self.assertIs(LoggingContext.current_context(), context_one)
  176. defer = kr.verify_json_for_server("server9", json1)
  177. self.assertFalse(defer.called)
  178. self.assert_sentinel_context()
  179. yield defer
  180. self.assertIs(LoggingContext.current_context(), context_one)