keys.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from ._base import SQLBaseStore
  16. from synapse.util.caches.descriptors import cachedInlineCallbacks
  17. from twisted.internet import defer
  18. import OpenSSL
  19. from signedjson.key import decode_verify_key_bytes
  20. import hashlib
  21. import logging
  22. logger = logging.getLogger(__name__)
  23. class KeyStore(SQLBaseStore):
  24. """Persistence for signature verification keys and tls X.509 certificates
  25. """
  26. @defer.inlineCallbacks
  27. def get_server_certificate(self, server_name):
  28. """Retrieve the TLS X.509 certificate for the given server
  29. Args:
  30. server_name (bytes): The name of the server.
  31. Returns:
  32. (OpenSSL.crypto.X509): The tls certificate.
  33. """
  34. tls_certificate_bytes, = yield self._simple_select_one(
  35. table="server_tls_certificates",
  36. keyvalues={"server_name": server_name},
  37. retcols=("tls_certificate",),
  38. desc="get_server_certificate",
  39. )
  40. tls_certificate = OpenSSL.crypto.load_certificate(
  41. OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
  42. )
  43. defer.returnValue(tls_certificate)
  44. def store_server_certificate(self, server_name, from_server, time_now_ms,
  45. tls_certificate):
  46. """Stores the TLS X.509 certificate for the given server
  47. Args:
  48. server_name (str): The name of the server.
  49. from_server (str): Where the certificate was looked up
  50. time_now_ms (int): The time now in milliseconds
  51. tls_certificate (OpenSSL.crypto.X509): The X.509 certificate.
  52. """
  53. tls_certificate_bytes = OpenSSL.crypto.dump_certificate(
  54. OpenSSL.crypto.FILETYPE_ASN1, tls_certificate
  55. )
  56. fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
  57. return self._simple_upsert(
  58. table="server_tls_certificates",
  59. keyvalues={
  60. "server_name": server_name,
  61. "fingerprint": fingerprint,
  62. },
  63. values={
  64. "from_server": from_server,
  65. "ts_added_ms": time_now_ms,
  66. "tls_certificate": buffer(tls_certificate_bytes),
  67. },
  68. desc="store_server_certificate",
  69. )
  70. @cachedInlineCallbacks()
  71. def _get_server_verify_key(self, server_name, key_id):
  72. verify_key_bytes = yield self._simple_select_one_onecol(
  73. table="server_signature_keys",
  74. keyvalues={
  75. "server_name": server_name,
  76. "key_id": key_id,
  77. },
  78. retcol="verify_key",
  79. desc="_get_server_verify_key",
  80. allow_none=True,
  81. )
  82. if verify_key_bytes:
  83. defer.returnValue(decode_verify_key_bytes(
  84. key_id, str(verify_key_bytes)
  85. ))
  86. @defer.inlineCallbacks
  87. def get_server_verify_keys(self, server_name, key_ids):
  88. """Retrieve the NACL verification key for a given server for the given
  89. key_ids
  90. Args:
  91. server_name (str): The name of the server.
  92. key_ids (iterable[str]): key_ids to try and look up.
  93. Returns:
  94. Deferred: resolves to dict[str, VerifyKey]: map from
  95. key_id to verification key.
  96. """
  97. keys = {}
  98. for key_id in key_ids:
  99. key = yield self._get_server_verify_key(server_name, key_id)
  100. if key:
  101. keys[key_id] = key
  102. defer.returnValue(keys)
  103. def store_server_verify_key(self, server_name, from_server, time_now_ms,
  104. verify_key):
  105. """Stores a NACL verification key for the given server.
  106. Args:
  107. server_name (str): The name of the server.
  108. from_server (str): Where the verification key was looked up
  109. time_now_ms (int): The time now in milliseconds
  110. verify_key (nacl.signing.VerifyKey): The NACL verify key.
  111. """
  112. key_id = "%s:%s" % (verify_key.alg, verify_key.version)
  113. def _txn(txn):
  114. self._simple_upsert_txn(
  115. txn,
  116. table="server_signature_keys",
  117. keyvalues={
  118. "server_name": server_name,
  119. "key_id": key_id,
  120. },
  121. values={
  122. "from_server": from_server,
  123. "ts_added_ms": time_now_ms,
  124. "verify_key": buffer(verify_key.encode()),
  125. },
  126. )
  127. txn.call_after(
  128. self._get_server_verify_key.invalidate,
  129. (server_name, key_id)
  130. )
  131. return self.runInteraction("store_server_verify_key", _txn)
  132. def store_server_keys_json(self, server_name, key_id, from_server,
  133. ts_now_ms, ts_expires_ms, key_json_bytes):
  134. """Stores the JSON bytes for a set of keys from a server
  135. The JSON should be signed by the originating server, the intermediate
  136. server, and by this server. Updates the value for the
  137. (server_name, key_id, from_server) triplet if one already existed.
  138. Args:
  139. server_name (str): The name of the server.
  140. key_id (str): The identifer of the key this JSON is for.
  141. from_server (str): The server this JSON was fetched from.
  142. ts_now_ms (int): The time now in milliseconds.
  143. ts_valid_until_ms (int): The time when this json stops being valid.
  144. key_json (bytes): The encoded JSON.
  145. """
  146. return self._simple_upsert(
  147. table="server_keys_json",
  148. keyvalues={
  149. "server_name": server_name,
  150. "key_id": key_id,
  151. "from_server": from_server,
  152. },
  153. values={
  154. "server_name": server_name,
  155. "key_id": key_id,
  156. "from_server": from_server,
  157. "ts_added_ms": ts_now_ms,
  158. "ts_valid_until_ms": ts_expires_ms,
  159. "key_json": buffer(key_json_bytes),
  160. },
  161. desc="store_server_keys_json",
  162. )
  163. def get_server_keys_json(self, server_keys):
  164. """Retrive the key json for a list of server_keys and key ids.
  165. If no keys are found for a given server, key_id and source then
  166. that server, key_id, and source triplet entry will be an empty list.
  167. The JSON is returned as a byte array so that it can be efficiently
  168. used in an HTTP response.
  169. Args:
  170. server_keys (list): List of (server_name, key_id, source) triplets.
  171. Returns:
  172. Dict mapping (server_name, key_id, source) triplets to dicts with
  173. "ts_valid_until_ms" and "key_json" keys.
  174. """
  175. def _get_server_keys_json_txn(txn):
  176. results = {}
  177. for server_name, key_id, from_server in server_keys:
  178. keyvalues = {"server_name": server_name}
  179. if key_id is not None:
  180. keyvalues["key_id"] = key_id
  181. if from_server is not None:
  182. keyvalues["from_server"] = from_server
  183. rows = self._simple_select_list_txn(
  184. txn,
  185. "server_keys_json",
  186. keyvalues=keyvalues,
  187. retcols=(
  188. "key_id",
  189. "from_server",
  190. "ts_added_ms",
  191. "ts_valid_until_ms",
  192. "key_json",
  193. ),
  194. )
  195. results[(server_name, key_id, from_server)] = rows
  196. return results
  197. return self.runInteraction(
  198. "get_server_keys_json", _get_server_keys_json_txn
  199. )