test_keys.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2017 Vector Creations Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import signedjson.key
  16. from twisted.internet.defer import Deferred
  17. from synapse.storage.keys import FetchKeyResult
  18. import tests.unittest
  19. KEY_1 = signedjson.key.decode_verify_key_base64(
  20. "ed25519", "key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
  21. )
  22. KEY_2 = signedjson.key.decode_verify_key_base64(
  23. "ed25519", "key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
  24. )
  25. class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
  26. def test_get_server_verify_keys(self):
  27. store = self.hs.get_datastore()
  28. key_id_1 = "ed25519:key1"
  29. key_id_2 = "ed25519:KEY_ID_2"
  30. d = store.store_server_verify_keys(
  31. "from_server",
  32. 10,
  33. [
  34. ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
  35. ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
  36. ],
  37. )
  38. self.get_success(d)
  39. d = store.get_server_verify_keys(
  40. [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
  41. )
  42. res = self.get_success(d)
  43. self.assertEqual(len(res.keys()), 3)
  44. res1 = res[("server1", key_id_1)]
  45. self.assertEqual(res1.verify_key, KEY_1)
  46. self.assertEqual(res1.verify_key.version, "key1")
  47. self.assertEqual(res1.valid_until_ts, 100)
  48. res2 = res[("server1", key_id_2)]
  49. self.assertEqual(res2.verify_key, KEY_2)
  50. # version comes from the ID it was stored with
  51. self.assertEqual(res2.verify_key.version, "KEY_ID_2")
  52. self.assertEqual(res2.valid_until_ts, 200)
  53. # non-existent result gives None
  54. self.assertIsNone(res[("server1", "ed25519:key3")])
  55. def test_cache(self):
  56. """Check that updates correctly invalidate the cache."""
  57. store = self.hs.get_datastore()
  58. key_id_1 = "ed25519:key1"
  59. key_id_2 = "ed25519:key2"
  60. d = store.store_server_verify_keys(
  61. "from_server",
  62. 0,
  63. [
  64. ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
  65. ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
  66. ],
  67. )
  68. self.get_success(d)
  69. d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
  70. res = self.get_success(d)
  71. self.assertEqual(len(res.keys()), 2)
  72. res1 = res[("srv1", key_id_1)]
  73. self.assertEqual(res1.verify_key, KEY_1)
  74. self.assertEqual(res1.valid_until_ts, 100)
  75. res2 = res[("srv1", key_id_2)]
  76. self.assertEqual(res2.verify_key, KEY_2)
  77. self.assertEqual(res2.valid_until_ts, 200)
  78. # we should be able to look up the same thing again without a db hit
  79. res = store.get_server_verify_keys([("srv1", key_id_1)])
  80. if isinstance(res, Deferred):
  81. res = self.successResultOf(res)
  82. self.assertEqual(len(res.keys()), 1)
  83. self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
  84. new_key_2 = signedjson.key.get_verify_key(
  85. signedjson.key.generate_signing_key("key2")
  86. )
  87. d = store.store_server_verify_keys(
  88. "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
  89. )
  90. self.get_success(d)
  91. d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
  92. res = self.get_success(d)
  93. self.assertEqual(len(res.keys()), 2)
  94. res1 = res[("srv1", key_id_1)]
  95. self.assertEqual(res1.verify_key, KEY_1)
  96. self.assertEqual(res1.valid_until_ts, 100)
  97. res2 = res[("srv1", key_id_2)]
  98. self.assertEqual(res2.verify_key, new_key_2)
  99. self.assertEqual(res2.valid_until_ts, 300)