keys.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. # Copyright 2019 New Vector Ltd.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import itertools
  17. import logging
  18. import six
  19. import attr
  20. from signedjson.key import decode_verify_key_bytes
  21. from synapse.util import batch_iter
  22. from synapse.util.caches.descriptors import cached, cachedList
  23. from ._base import SQLBaseStore
  24. logger = logging.getLogger(__name__)
  25. # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
  26. # despite being deprecated and removed in favor of memoryview
  27. if six.PY2:
  28. db_binary_type = six.moves.builtins.buffer
  29. else:
  30. db_binary_type = memoryview
  31. @attr.s(slots=True, frozen=True)
  32. class FetchKeyResult(object):
  33. verify_key = attr.ib() # VerifyKey: the key itself
  34. valid_until_ts = attr.ib() # int: how long we can use this key for
  35. class KeyStore(SQLBaseStore):
  36. """Persistence for signature verification keys
  37. """
  38. @cached()
  39. def _get_server_verify_key(self, server_name_and_key_id):
  40. raise NotImplementedError()
  41. @cachedList(
  42. cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
  43. )
  44. def get_server_verify_keys(self, server_name_and_key_ids):
  45. """
  46. Args:
  47. server_name_and_key_ids (iterable[Tuple[str, str]]):
  48. iterable of (server_name, key-id) tuples to fetch keys for
  49. Returns:
  50. Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
  51. map from (server_name, key_id) -> FetchKeyResult, or None if the key is
  52. unknown
  53. """
  54. keys = {}
  55. def _get_keys(txn, batch):
  56. """Processes a batch of keys to fetch, and adds the result to `keys`."""
  57. # batch_iter always returns tuples so it's safe to do len(batch)
  58. sql = (
  59. "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
  60. "FROM server_signature_keys WHERE 1=0"
  61. ) + " OR (server_name=? AND key_id=?)" * len(batch)
  62. txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
  63. for row in txn:
  64. server_name, key_id, key_bytes, ts_valid_until_ms = row
  65. if ts_valid_until_ms is None:
  66. # Old keys may be stored with a ts_valid_until_ms of null,
  67. # in which case we treat this as if it was set to `0`, i.e.
  68. # it won't match key requests that define a minimum
  69. # `ts_valid_until_ms`.
  70. ts_valid_until_ms = 0
  71. res = FetchKeyResult(
  72. verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
  73. valid_until_ts=ts_valid_until_ms,
  74. )
  75. keys[(server_name, key_id)] = res
  76. def _txn(txn):
  77. for batch in batch_iter(server_name_and_key_ids, 50):
  78. _get_keys(txn, batch)
  79. return keys
  80. return self.runInteraction("get_server_verify_keys", _txn)
  81. def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
  82. """Stores NACL verification keys for remote servers.
  83. Args:
  84. from_server (str): Where the verification keys were looked up
  85. ts_added_ms (int): The time to record that the key was added
  86. verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
  87. keys to be stored. Each entry is a triplet of
  88. (server_name, key_id, key).
  89. """
  90. key_values = []
  91. value_values = []
  92. invalidations = []
  93. for server_name, key_id, fetch_result in verify_keys:
  94. key_values.append((server_name, key_id))
  95. value_values.append(
  96. (
  97. from_server,
  98. ts_added_ms,
  99. fetch_result.valid_until_ts,
  100. db_binary_type(fetch_result.verify_key.encode()),
  101. )
  102. )
  103. # invalidate takes a tuple corresponding to the params of
  104. # _get_server_verify_key. _get_server_verify_key only takes one
  105. # param, which is itself the 2-tuple (server_name, key_id).
  106. invalidations.append((server_name, key_id))
  107. def _invalidate(res):
  108. f = self._get_server_verify_key.invalidate
  109. for i in invalidations:
  110. f((i,))
  111. return res
  112. return self.runInteraction(
  113. "store_server_verify_keys",
  114. self._simple_upsert_many_txn,
  115. table="server_signature_keys",
  116. key_names=("server_name", "key_id"),
  117. key_values=key_values,
  118. value_names=(
  119. "from_server",
  120. "ts_added_ms",
  121. "ts_valid_until_ms",
  122. "verify_key",
  123. ),
  124. value_values=value_values,
  125. ).addCallback(_invalidate)
  126. def store_server_keys_json(
  127. self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
  128. ):
  129. """Stores the JSON bytes for a set of keys from a server
  130. The JSON should be signed by the originating server, the intermediate
  131. server, and by this server. Updates the value for the
  132. (server_name, key_id, from_server) triplet if one already existed.
  133. Args:
  134. server_name (str): The name of the server.
  135. key_id (str): The identifer of the key this JSON is for.
  136. from_server (str): The server this JSON was fetched from.
  137. ts_now_ms (int): The time now in milliseconds.
  138. ts_valid_until_ms (int): The time when this json stops being valid.
  139. key_json (bytes): The encoded JSON.
  140. """
  141. return self._simple_upsert(
  142. table="server_keys_json",
  143. keyvalues={
  144. "server_name": server_name,
  145. "key_id": key_id,
  146. "from_server": from_server,
  147. },
  148. values={
  149. "server_name": server_name,
  150. "key_id": key_id,
  151. "from_server": from_server,
  152. "ts_added_ms": ts_now_ms,
  153. "ts_valid_until_ms": ts_expires_ms,
  154. "key_json": db_binary_type(key_json_bytes),
  155. },
  156. desc="store_server_keys_json",
  157. )
  158. def get_server_keys_json(self, server_keys):
  159. """Retrive the key json for a list of server_keys and key ids.
  160. If no keys are found for a given server, key_id and source then
  161. that server, key_id, and source triplet entry will be an empty list.
  162. The JSON is returned as a byte array so that it can be efficiently
  163. used in an HTTP response.
  164. Args:
  165. server_keys (list): List of (server_name, key_id, source) triplets.
  166. Returns:
  167. Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
  168. Dict mapping (server_name, key_id, source) triplets to lists of dicts
  169. """
  170. def _get_server_keys_json_txn(txn):
  171. results = {}
  172. for server_name, key_id, from_server in server_keys:
  173. keyvalues = {"server_name": server_name}
  174. if key_id is not None:
  175. keyvalues["key_id"] = key_id
  176. if from_server is not None:
  177. keyvalues["from_server"] = from_server
  178. rows = self._simple_select_list_txn(
  179. txn,
  180. "server_keys_json",
  181. keyvalues=keyvalues,
  182. retcols=(
  183. "key_id",
  184. "from_server",
  185. "ts_added_ms",
  186. "ts_valid_until_ms",
  187. "key_json",
  188. ),
  189. )
  190. results[(server_name, key_id, from_server)] = rows
  191. return results
  192. return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)