test_keyring.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2017 New Vector Ltd.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import time
  16. from mock import Mock
  17. import signedjson.key
  18. import signedjson.sign
  19. from twisted.internet import defer, reactor
  20. from synapse.api.errors import SynapseError
  21. from synapse.crypto import keyring
  22. from synapse.util import Clock, logcontext
  23. from synapse.util.logcontext import LoggingContext
  24. from tests import unittest, utils
  25. class MockPerspectiveServer(object):
  26. def __init__(self):
  27. self.server_name = "mock_server"
  28. self.key = signedjson.key.generate_signing_key(0)
  29. def get_verify_keys(self):
  30. vk = signedjson.key.get_verify_key(self.key)
  31. return {"%s:%s" % (vk.alg, vk.version): vk}
  32. def get_signed_key(self, server_name, verify_key):
  33. key_id = "%s:%s" % (verify_key.alg, verify_key.version)
  34. res = {
  35. "server_name": server_name,
  36. "old_verify_keys": {},
  37. "valid_until_ts": time.time() * 1000 + 3600,
  38. "verify_keys": {
  39. key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
  40. },
  41. }
  42. signedjson.sign.sign_json(res, self.server_name, self.key)
  43. return res
  44. class KeyringTestCase(unittest.TestCase):
  45. @defer.inlineCallbacks
  46. def setUp(self):
  47. self.mock_perspective_server = MockPerspectiveServer()
  48. self.http_client = Mock()
  49. self.hs = yield utils.setup_test_homeserver(
  50. self.addCleanup, handlers=None, http_client=self.http_client
  51. )
  52. keys = self.mock_perspective_server.get_verify_keys()
  53. self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
  54. def check_context(self, _, expected):
  55. self.assertEquals(
  56. getattr(LoggingContext.current_context(), "request", None), expected
  57. )
  58. @defer.inlineCallbacks
  59. def test_wait_for_previous_lookups(self):
  60. sentinel_context = LoggingContext.current_context()
  61. kr = keyring.Keyring(self.hs)
  62. lookup_1_deferred = defer.Deferred()
  63. lookup_2_deferred = defer.Deferred()
  64. with LoggingContext("one") as context_one:
  65. context_one.request = "one"
  66. wait_1_deferred = kr.wait_for_previous_lookups(
  67. ["server1"], {"server1": lookup_1_deferred}
  68. )
  69. # there were no previous lookups, so the deferred should be ready
  70. self.assertTrue(wait_1_deferred.called)
  71. # ... so we should have preserved the LoggingContext.
  72. self.assertIs(LoggingContext.current_context(), context_one)
  73. wait_1_deferred.addBoth(self.check_context, "one")
  74. with LoggingContext("two") as context_two:
  75. context_two.request = "two"
  76. # set off another wait. It should block because the first lookup
  77. # hasn't yet completed.
  78. wait_2_deferred = kr.wait_for_previous_lookups(
  79. ["server1"], {"server1": lookup_2_deferred}
  80. )
  81. self.assertFalse(wait_2_deferred.called)
  82. # ... so we should have reset the LoggingContext.
  83. self.assertIs(LoggingContext.current_context(), sentinel_context)
  84. wait_2_deferred.addBoth(self.check_context, "two")
  85. # let the first lookup complete (in the sentinel context)
  86. lookup_1_deferred.callback(None)
  87. # now the second wait should complete and restore our
  88. # loggingcontext.
  89. yield wait_2_deferred
  90. @defer.inlineCallbacks
  91. def test_verify_json_objects_for_server_awaits_previous_requests(self):
  92. clock = Clock(reactor)
  93. key1 = signedjson.key.generate_signing_key(1)
  94. kr = keyring.Keyring(self.hs)
  95. json1 = {}
  96. signedjson.sign.sign_json(json1, "server10", key1)
  97. persp_resp = {
  98. "server_keys": [
  99. self.mock_perspective_server.get_signed_key(
  100. "server10", signedjson.key.get_verify_key(key1)
  101. )
  102. ]
  103. }
  104. persp_deferred = defer.Deferred()
  105. @defer.inlineCallbacks
  106. def get_perspectives(**kwargs):
  107. self.assertEquals(LoggingContext.current_context().request, "11")
  108. with logcontext.PreserveLoggingContext():
  109. yield persp_deferred
  110. defer.returnValue(persp_resp)
  111. self.http_client.post_json.side_effect = get_perspectives
  112. with LoggingContext("11") as context_11:
  113. context_11.request = "11"
  114. # start off a first set of lookups
  115. res_deferreds = kr.verify_json_objects_for_server(
  116. [("server10", json1), ("server11", {})]
  117. )
  118. # the unsigned json should be rejected pretty quickly
  119. self.assertTrue(res_deferreds[1].called)
  120. try:
  121. yield res_deferreds[1]
  122. self.assertFalse("unsigned json didn't cause a failure")
  123. except SynapseError:
  124. pass
  125. self.assertFalse(res_deferreds[0].called)
  126. res_deferreds[0].addBoth(self.check_context, None)
  127. # wait a tick for it to send the request to the perspectives server
  128. # (it first tries the datastore)
  129. yield clock.sleep(1) # XXX find out why this takes so long!
  130. self.http_client.post_json.assert_called_once()
  131. self.assertIs(LoggingContext.current_context(), context_11)
  132. context_12 = LoggingContext("12")
  133. context_12.request = "12"
  134. with logcontext.PreserveLoggingContext(context_12):
  135. # a second request for a server with outstanding requests
  136. # should block rather than start a second call
  137. self.http_client.post_json.reset_mock()
  138. self.http_client.post_json.return_value = defer.Deferred()
  139. res_deferreds_2 = kr.verify_json_objects_for_server(
  140. [("server10", json1)]
  141. )
  142. yield clock.sleep(1)
  143. self.http_client.post_json.assert_not_called()
  144. res_deferreds_2[0].addBoth(self.check_context, None)
  145. # complete the first request
  146. with logcontext.PreserveLoggingContext():
  147. persp_deferred.callback(persp_resp)
  148. self.assertIs(LoggingContext.current_context(), context_11)
  149. with logcontext.PreserveLoggingContext():
  150. yield res_deferreds[0]
  151. yield res_deferreds_2[0]
  152. @defer.inlineCallbacks
  153. def test_verify_json_for_server(self):
  154. kr = keyring.Keyring(self.hs)
  155. key1 = signedjson.key.generate_signing_key(1)
  156. yield self.hs.datastore.store_server_verify_key(
  157. "server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
  158. )
  159. json1 = {}
  160. signedjson.sign.sign_json(json1, "server9", key1)
  161. sentinel_context = LoggingContext.current_context()
  162. with LoggingContext("one") as context_one:
  163. context_one.request = "one"
  164. defer = kr.verify_json_for_server("server9", {})
  165. try:
  166. yield defer
  167. self.fail("should fail on unsigned json")
  168. except SynapseError:
  169. pass
  170. self.assertIs(LoggingContext.current_context(), context_one)
  171. defer = kr.verify_json_for_server("server9", json1)
  172. self.assertFalse(defer.called)
  173. self.assertIs(LoggingContext.current_context(), sentinel_context)
  174. yield defer
  175. self.assertIs(LoggingContext.current_context(), context_one)