keys.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from ._base import SQLBaseStore
  16. from synapse.util.caches.descriptors import cachedInlineCallbacks
  17. from twisted.internet import defer
  18. import six
  19. import OpenSSL
  20. from signedjson.key import decode_verify_key_bytes
  21. import hashlib
  22. import logging
  23. logger = logging.getLogger(__name__)
  24. # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
  25. # despite being deprecated and removed in favor of memoryview
  26. if six.PY2:
  27. db_binary_type = buffer
  28. else:
  29. db_binary_type = memoryview
  30. class KeyStore(SQLBaseStore):
  31. """Persistence for signature verification keys and tls X.509 certificates
  32. """
  33. @defer.inlineCallbacks
  34. def get_server_certificate(self, server_name):
  35. """Retrieve the TLS X.509 certificate for the given server
  36. Args:
  37. server_name (bytes): The name of the server.
  38. Returns:
  39. (OpenSSL.crypto.X509): The tls certificate.
  40. """
  41. tls_certificate_bytes, = yield self._simple_select_one(
  42. table="server_tls_certificates",
  43. keyvalues={"server_name": server_name},
  44. retcols=("tls_certificate",),
  45. desc="get_server_certificate",
  46. )
  47. tls_certificate = OpenSSL.crypto.load_certificate(
  48. OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
  49. )
  50. defer.returnValue(tls_certificate)
  51. def store_server_certificate(self, server_name, from_server, time_now_ms,
  52. tls_certificate):
  53. """Stores the TLS X.509 certificate for the given server
  54. Args:
  55. server_name (str): The name of the server.
  56. from_server (str): Where the certificate was looked up
  57. time_now_ms (int): The time now in milliseconds
  58. tls_certificate (OpenSSL.crypto.X509): The X.509 certificate.
  59. """
  60. tls_certificate_bytes = OpenSSL.crypto.dump_certificate(
  61. OpenSSL.crypto.FILETYPE_ASN1, tls_certificate
  62. )
  63. fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
  64. return self._simple_upsert(
  65. table="server_tls_certificates",
  66. keyvalues={
  67. "server_name": server_name,
  68. "fingerprint": fingerprint,
  69. },
  70. values={
  71. "from_server": from_server,
  72. "ts_added_ms": time_now_ms,
  73. "tls_certificate": db_binary_type(tls_certificate_bytes),
  74. },
  75. desc="store_server_certificate",
  76. )
  77. @cachedInlineCallbacks()
  78. def _get_server_verify_key(self, server_name, key_id):
  79. verify_key_bytes = yield self._simple_select_one_onecol(
  80. table="server_signature_keys",
  81. keyvalues={
  82. "server_name": server_name,
  83. "key_id": key_id,
  84. },
  85. retcol="verify_key",
  86. desc="_get_server_verify_key",
  87. allow_none=True,
  88. )
  89. if verify_key_bytes:
  90. defer.returnValue(decode_verify_key_bytes(
  91. key_id, bytes(verify_key_bytes)
  92. ))
  93. @defer.inlineCallbacks
  94. def get_server_verify_keys(self, server_name, key_ids):
  95. """Retrieve the NACL verification key for a given server for the given
  96. key_ids
  97. Args:
  98. server_name (str): The name of the server.
  99. key_ids (iterable[str]): key_ids to try and look up.
  100. Returns:
  101. Deferred: resolves to dict[str, VerifyKey]: map from
  102. key_id to verification key.
  103. """
  104. keys = {}
  105. for key_id in key_ids:
  106. key = yield self._get_server_verify_key(server_name, key_id)
  107. if key:
  108. keys[key_id] = key
  109. defer.returnValue(keys)
  110. def store_server_verify_key(self, server_name, from_server, time_now_ms,
  111. verify_key):
  112. """Stores a NACL verification key for the given server.
  113. Args:
  114. server_name (str): The name of the server.
  115. from_server (str): Where the verification key was looked up
  116. time_now_ms (int): The time now in milliseconds
  117. verify_key (nacl.signing.VerifyKey): The NACL verify key.
  118. """
  119. key_id = "%s:%s" % (verify_key.alg, verify_key.version)
  120. def _txn(txn):
  121. self._simple_upsert_txn(
  122. txn,
  123. table="server_signature_keys",
  124. keyvalues={
  125. "server_name": server_name,
  126. "key_id": key_id,
  127. },
  128. values={
  129. "from_server": from_server,
  130. "ts_added_ms": time_now_ms,
  131. "verify_key": db_binary_type(verify_key.encode()),
  132. },
  133. )
  134. txn.call_after(
  135. self._get_server_verify_key.invalidate,
  136. (server_name, key_id)
  137. )
  138. return self.runInteraction("store_server_verify_key", _txn)
  139. def store_server_keys_json(self, server_name, key_id, from_server,
  140. ts_now_ms, ts_expires_ms, key_json_bytes):
  141. """Stores the JSON bytes for a set of keys from a server
  142. The JSON should be signed by the originating server, the intermediate
  143. server, and by this server. Updates the value for the
  144. (server_name, key_id, from_server) triplet if one already existed.
  145. Args:
  146. server_name (str): The name of the server.
  147. key_id (str): The identifer of the key this JSON is for.
  148. from_server (str): The server this JSON was fetched from.
  149. ts_now_ms (int): The time now in milliseconds.
  150. ts_valid_until_ms (int): The time when this json stops being valid.
  151. key_json (bytes): The encoded JSON.
  152. """
  153. return self._simple_upsert(
  154. table="server_keys_json",
  155. keyvalues={
  156. "server_name": server_name,
  157. "key_id": key_id,
  158. "from_server": from_server,
  159. },
  160. values={
  161. "server_name": server_name,
  162. "key_id": key_id,
  163. "from_server": from_server,
  164. "ts_added_ms": ts_now_ms,
  165. "ts_valid_until_ms": ts_expires_ms,
  166. "key_json": db_binary_type(key_json_bytes),
  167. },
  168. desc="store_server_keys_json",
  169. )
  170. def get_server_keys_json(self, server_keys):
  171. """Retrive the key json for a list of server_keys and key ids.
  172. If no keys are found for a given server, key_id and source then
  173. that server, key_id, and source triplet entry will be an empty list.
  174. The JSON is returned as a byte array so that it can be efficiently
  175. used in an HTTP response.
  176. Args:
  177. server_keys (list): List of (server_name, key_id, source) triplets.
  178. Returns:
  179. Dict mapping (server_name, key_id, source) triplets to dicts with
  180. "ts_valid_until_ms" and "key_json" keys.
  181. """
  182. def _get_server_keys_json_txn(txn):
  183. results = {}
  184. for server_name, key_id, from_server in server_keys:
  185. keyvalues = {"server_name": server_name}
  186. if key_id is not None:
  187. keyvalues["key_id"] = key_id
  188. if from_server is not None:
  189. keyvalues["from_server"] = from_server
  190. rows = self._simple_select_list_txn(
  191. txn,
  192. "server_keys_json",
  193. keyvalues=keyvalues,
  194. retcols=(
  195. "key_id",
  196. "from_server",
  197. "ts_added_ms",
  198. "ts_valid_until_ms",
  199. "key_json",
  200. ),
  201. )
  202. results[(server_name, key_id, from_server)] = rows
  203. return results
  204. return self.runInteraction(
  205. "get_server_keys_json", _get_server_keys_json_txn
  206. )