keys.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. # Copyright 2019 New Vector Ltd.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import itertools
  17. import logging
  18. import six
  19. from signedjson.key import decode_verify_key_bytes
  20. from synapse.storage._base import SQLBaseStore
  21. from synapse.storage.keys import FetchKeyResult
  22. from synapse.util.caches.descriptors import cached, cachedList
  23. from synapse.util.iterutils import batch_iter
  24. logger = logging.getLogger(__name__)
  25. # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
  26. # despite being deprecated and removed in favor of memoryview
  27. if six.PY2:
  28. db_binary_type = six.moves.builtins.buffer
  29. else:
  30. db_binary_type = memoryview
  31. class KeyStore(SQLBaseStore):
  32. """Persistence for signature verification keys
  33. """
  34. @cached()
  35. def _get_server_verify_key(self, server_name_and_key_id):
  36. raise NotImplementedError()
  37. @cachedList(
  38. cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
  39. )
  40. def get_server_verify_keys(self, server_name_and_key_ids):
  41. """
  42. Args:
  43. server_name_and_key_ids (iterable[Tuple[str, str]]):
  44. iterable of (server_name, key-id) tuples to fetch keys for
  45. Returns:
  46. Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
  47. map from (server_name, key_id) -> FetchKeyResult, or None if the key is
  48. unknown
  49. """
  50. keys = {}
  51. def _get_keys(txn, batch):
  52. """Processes a batch of keys to fetch, and adds the result to `keys`."""
  53. # batch_iter always returns tuples so it's safe to do len(batch)
  54. sql = (
  55. "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
  56. "FROM server_signature_keys WHERE 1=0"
  57. ) + " OR (server_name=? AND key_id=?)" * len(batch)
  58. txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
  59. for row in txn:
  60. server_name, key_id, key_bytes, ts_valid_until_ms = row
  61. if ts_valid_until_ms is None:
  62. # Old keys may be stored with a ts_valid_until_ms of null,
  63. # in which case we treat this as if it was set to `0`, i.e.
  64. # it won't match key requests that define a minimum
  65. # `ts_valid_until_ms`.
  66. ts_valid_until_ms = 0
  67. res = FetchKeyResult(
  68. verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
  69. valid_until_ts=ts_valid_until_ms,
  70. )
  71. keys[(server_name, key_id)] = res
  72. def _txn(txn):
  73. for batch in batch_iter(server_name_and_key_ids, 50):
  74. _get_keys(txn, batch)
  75. return keys
  76. return self.db.runInteraction("get_server_verify_keys", _txn)
  77. def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
  78. """Stores NACL verification keys for remote servers.
  79. Args:
  80. from_server (str): Where the verification keys were looked up
  81. ts_added_ms (int): The time to record that the key was added
  82. verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
  83. keys to be stored. Each entry is a triplet of
  84. (server_name, key_id, key).
  85. """
  86. key_values = []
  87. value_values = []
  88. invalidations = []
  89. for server_name, key_id, fetch_result in verify_keys:
  90. key_values.append((server_name, key_id))
  91. value_values.append(
  92. (
  93. from_server,
  94. ts_added_ms,
  95. fetch_result.valid_until_ts,
  96. db_binary_type(fetch_result.verify_key.encode()),
  97. )
  98. )
  99. # invalidate takes a tuple corresponding to the params of
  100. # _get_server_verify_key. _get_server_verify_key only takes one
  101. # param, which is itself the 2-tuple (server_name, key_id).
  102. invalidations.append((server_name, key_id))
  103. def _invalidate(res):
  104. f = self._get_server_verify_key.invalidate
  105. for i in invalidations:
  106. f((i,))
  107. return res
  108. return self.db.runInteraction(
  109. "store_server_verify_keys",
  110. self.db.simple_upsert_many_txn,
  111. table="server_signature_keys",
  112. key_names=("server_name", "key_id"),
  113. key_values=key_values,
  114. value_names=(
  115. "from_server",
  116. "ts_added_ms",
  117. "ts_valid_until_ms",
  118. "verify_key",
  119. ),
  120. value_values=value_values,
  121. ).addCallback(_invalidate)
  122. def store_server_keys_json(
  123. self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
  124. ):
  125. """Stores the JSON bytes for a set of keys from a server
  126. The JSON should be signed by the originating server, the intermediate
  127. server, and by this server. Updates the value for the
  128. (server_name, key_id, from_server) triplet if one already existed.
  129. Args:
  130. server_name (str): The name of the server.
  131. key_id (str): The identifer of the key this JSON is for.
  132. from_server (str): The server this JSON was fetched from.
  133. ts_now_ms (int): The time now in milliseconds.
  134. ts_valid_until_ms (int): The time when this json stops being valid.
  135. key_json (bytes): The encoded JSON.
  136. """
  137. return self.db.simple_upsert(
  138. table="server_keys_json",
  139. keyvalues={
  140. "server_name": server_name,
  141. "key_id": key_id,
  142. "from_server": from_server,
  143. },
  144. values={
  145. "server_name": server_name,
  146. "key_id": key_id,
  147. "from_server": from_server,
  148. "ts_added_ms": ts_now_ms,
  149. "ts_valid_until_ms": ts_expires_ms,
  150. "key_json": db_binary_type(key_json_bytes),
  151. },
  152. desc="store_server_keys_json",
  153. )
  154. def get_server_keys_json(self, server_keys):
  155. """Retrive the key json for a list of server_keys and key ids.
  156. If no keys are found for a given server, key_id and source then
  157. that server, key_id, and source triplet entry will be an empty list.
  158. The JSON is returned as a byte array so that it can be efficiently
  159. used in an HTTP response.
  160. Args:
  161. server_keys (list): List of (server_name, key_id, source) triplets.
  162. Returns:
  163. Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
  164. Dict mapping (server_name, key_id, source) triplets to lists of dicts
  165. """
  166. def _get_server_keys_json_txn(txn):
  167. results = {}
  168. for server_name, key_id, from_server in server_keys:
  169. keyvalues = {"server_name": server_name}
  170. if key_id is not None:
  171. keyvalues["key_id"] = key_id
  172. if from_server is not None:
  173. keyvalues["from_server"] = from_server
  174. rows = self.db.simple_select_list_txn(
  175. txn,
  176. "server_keys_json",
  177. keyvalues=keyvalues,
  178. retcols=(
  179. "key_id",
  180. "from_server",
  181. "ts_added_ms",
  182. "ts_valid_until_ms",
  183. "key_json",
  184. ),
  185. )
  186. results[(server_name, key_id, from_server)] = rows
  187. return results
  188. return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)