acme.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import attr
  17. from zope.interface import implementer
  18. from twisted.internet import defer
  19. from twisted.internet.endpoints import serverFromString
  20. from twisted.python.filepath import FilePath
  21. from twisted.python.url import URL
  22. from twisted.web import server, static
  23. from twisted.web.resource import Resource
  24. logger = logging.getLogger(__name__)
  25. try:
  26. from txacme.interfaces import ICertificateStore
  27. @attr.s
  28. @implementer(ICertificateStore)
  29. class ErsatzStore(object):
  30. """
  31. A store that only stores in memory.
  32. """
  33. certs = attr.ib(default=attr.Factory(dict))
  34. def store(self, server_name, pem_objects):
  35. self.certs[server_name] = [o.as_bytes() for o in pem_objects]
  36. return defer.succeed(None)
  37. except ImportError:
  38. # txacme is missing
  39. pass
  40. class AcmeHandler(object):
  41. def __init__(self, hs):
  42. self.hs = hs
  43. self.reactor = hs.get_reactor()
  44. @defer.inlineCallbacks
  45. def start_listening(self):
  46. # Configure logging for txacme, if you need to debug
  47. # from eliot import add_destinations
  48. # from eliot.twisted import TwistedDestination
  49. #
  50. # add_destinations(TwistedDestination())
  51. from txacme.challenges import HTTP01Responder
  52. from txacme.service import AcmeIssuingService
  53. from txacme.endpoint import load_or_create_client_key
  54. from txacme.client import Client
  55. from josepy.jwa import RS256
  56. self._store = ErsatzStore()
  57. responder = HTTP01Responder()
  58. self._issuer = AcmeIssuingService(
  59. cert_store=self._store,
  60. client_creator=(
  61. lambda: Client.from_url(
  62. reactor=self.reactor,
  63. url=URL.from_text(self.hs.config.acme_url),
  64. key=load_or_create_client_key(
  65. FilePath(self.hs.config.config_dir_path)
  66. ),
  67. alg=RS256,
  68. )
  69. ),
  70. clock=self.reactor,
  71. responders=[responder],
  72. )
  73. well_known = Resource()
  74. well_known.putChild(b'acme-challenge', responder.resource)
  75. responder_resource = Resource()
  76. responder_resource.putChild(b'.well-known', well_known)
  77. responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
  78. srv = server.Site(responder_resource)
  79. listeners = []
  80. for host in self.hs.config.acme_bind_addresses:
  81. logger.info(
  82. "Listening for ACME requests on %s:%s", host, self.hs.config.acme_port
  83. )
  84. endpoint = serverFromString(
  85. self.reactor, "tcp:%s:interface=%s" % (self.hs.config.acme_port, host)
  86. )
  87. listeners.append(endpoint.listen(srv))
  88. # Make sure we are registered to the ACME server. There's no public API
  89. # for this, it is usually triggered by startService, but since we don't
  90. # want it to control where we save the certificates, we have to reach in
  91. # and trigger the registration machinery ourselves.
  92. self._issuer._registered = False
  93. yield self._issuer._ensure_registered()
  94. # Return a Deferred that will fire when all the servers have started up.
  95. yield defer.DeferredList(listeners, fireOnOneErrback=True, consumeErrors=True)
  96. @defer.inlineCallbacks
  97. def provision_certificate(self):
  98. logger.warning("Reprovisioning %s", self.hs.hostname)
  99. try:
  100. yield self._issuer.issue_cert(self.hs.hostname)
  101. except Exception:
  102. logger.exception("Fail!")
  103. raise
  104. logger.warning("Reprovisioned %s, saving.", self.hs.hostname)
  105. cert_chain = self._store.certs[self.hs.hostname]
  106. try:
  107. with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
  108. for x in cert_chain:
  109. if x.startswith(b"-----BEGIN RSA PRIVATE KEY-----"):
  110. private_key_file.write(x)
  111. with open(self.hs.config.tls_certificate_file, "wb") as certificate_file:
  112. for x in cert_chain:
  113. if x.startswith(b"-----BEGIN CERTIFICATE-----"):
  114. certificate_file.write(x)
  115. except Exception:
  116. logger.exception("Failed saving!")
  117. raise
  118. defer.returnValue(True)