Browse Source

SIGHUP for TLS cert reloading (#4495)

Amber Brown 5 years ago
parent
commit
f6813919e8
5 changed files with 81 additions and 20 deletions
  1. 1 0
      .gitignore
  2. 1 0
      changelog.d/4495.feature
  3. 22 7
      synapse/app/_base.py
  4. 46 5
      synapse/app/homeserver.py
  5. 11 8
      synapse/config/logger.py

+ 1 - 0
.gitignore

@@ -12,6 +12,7 @@ dbs/
 dist/
 docs/build/
 *.egg-info
+pip-wheel-metadata/
 
 cmdclient_config.json
 homeserver*.db

+ 1 - 0
changelog.d/4495.feature

@@ -0,0 +1 @@
+Synapse will now reload TLS certificates from disk upon SIGHUP.

+ 22 - 7
synapse/app/_base.py

@@ -143,6 +143,9 @@ def listen_metrics(bind_addresses, port):
 def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
     """
     Create a TCP socket for a port and several addresses
+
+    Returns:
+        list (empty)
     """
     for address in bind_addresses:
         try:
@@ -155,25 +158,37 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
         except error.CannotListenError as e:
             check_bind_error(e, address, bind_addresses)
 
+    logger.info("Synapse now listening on TCP port %d", port)
+    return []
+
 
 def listen_ssl(
     bind_addresses, port, factory, context_factory, reactor=reactor, backlog=50
 ):
     """
-    Create an SSL socket for a port and several addresses
+    Create an TLS-over-TCP socket for a port and several addresses
+
+    Returns:
+        list of twisted.internet.tcp.Port listening for TLS connections
     """
+    r = []
     for address in bind_addresses:
         try:
-            reactor.listenSSL(
-                port,
-                factory,
-                context_factory,
-                backlog,
-                address
+            r.append(
+                reactor.listenSSL(
+                    port,
+                    factory,
+                    context_factory,
+                    backlog,
+                    address
+                )
             )
         except error.CannotListenError as e:
             check_bind_error(e, address, bind_addresses)
 
+    logger.info("Synapse now listening on port %d (TLS)", port)
+    return r
+
 
 def check_bind_error(e, address, bind_addresses):
     """

+ 46 - 5
synapse/app/homeserver.py

@@ -17,6 +17,7 @@
 import gc
 import logging
 import os
+import signal
 import sys
 import traceback
 
@@ -27,6 +28,7 @@ from prometheus_client import Gauge
 
 from twisted.application import service
 from twisted.internet import defer, reactor
+from twisted.protocols.tls import TLSMemoryBIOFactory
 from twisted.web.resource import EncodingResourceWrapper, NoResource
 from twisted.web.server import GzipEncoderFactory
 from twisted.web.static import File
@@ -84,6 +86,7 @@ def gz_wrap(r):
 
 class SynapseHomeServer(HomeServer):
     DATASTORE_CLASS = DataStore
+    _listening_services = []
 
     def _listener_http(self, config, listener_config):
         port = listener_config["port"]
@@ -121,7 +124,7 @@ class SynapseHomeServer(HomeServer):
         root_resource = create_resource_tree(resources, root_resource)
 
         if tls:
-            listen_ssl(
+            return listen_ssl(
                 bind_addresses,
                 port,
                 SynapseSite(
@@ -135,7 +138,7 @@ class SynapseHomeServer(HomeServer):
             )
 
         else:
-            listen_tcp(
+            return listen_tcp(
                 bind_addresses,
                 port,
                 SynapseSite(
@@ -146,7 +149,6 @@ class SynapseHomeServer(HomeServer):
                     self.version_string,
                 )
             )
-        logger.info("Synapse now listening on port %d", port)
 
     def _configure_named_resource(self, name, compress=False):
         """Build a resource map for a named resource
@@ -242,7 +244,9 @@ class SynapseHomeServer(HomeServer):
 
         for listener in config.listeners:
             if listener["type"] == "http":
-                self._listener_http(config, listener)
+                self._listening_services.extend(
+                    self._listener_http(config, listener)
+                )
             elif listener["type"] == "manhole":
                 listen_tcp(
                     listener["bind_addresses"],
@@ -322,7 +326,19 @@ def setup(config_options):
         # generating config files and shouldn't try to continue.
         sys.exit(0)
 
-    synapse.config.logger.setup_logging(config, use_worker_options=False)
+    sighup_callbacks = []
+    synapse.config.logger.setup_logging(
+        config,
+        use_worker_options=False,
+        register_sighup=sighup_callbacks.append
+    )
+
+    def handle_sighup(*args, **kwargs):
+        for i in sighup_callbacks:
+            i(*args, **kwargs)
+
+    if hasattr(signal, "SIGHUP"):
+        signal.signal(signal.SIGHUP, handle_sighup)
 
     events.USE_FROZEN_DICTS = config.use_frozen_dicts
 
@@ -359,6 +375,31 @@ def setup(config_options):
 
     hs.setup()
 
+    def refresh_certificate(*args):
+        """
+        Refresh the TLS certificates that Synapse is using by re-reading them
+        from disk and updating the TLS context factories to use them.
+        """
+        logging.info("Reloading certificate from disk...")
+        hs.config.read_certificate_from_disk()
+        hs.tls_server_context_factory = context_factory.ServerContextFactory(config)
+        hs.tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
+            config
+        )
+        logging.info("Certificate reloaded.")
+
+        logging.info("Updating context factories...")
+        for i in hs._listening_services:
+            if isinstance(i.factory, TLSMemoryBIOFactory):
+                i.factory = TLSMemoryBIOFactory(
+                    hs.tls_server_context_factory,
+                    False,
+                    i.factory.wrappedFactory
+                )
+        logging.info("Context factories updated.")
+
+    sighup_callbacks.append(refresh_certificate)
+
     @defer.inlineCallbacks
     def start():
         try:

+ 11 - 8
synapse/config/logger.py

@@ -127,7 +127,7 @@ class LoggingConfig(Config):
                 )
 
 
-def setup_logging(config, use_worker_options=False):
+def setup_logging(config, use_worker_options=False, register_sighup=None):
     """ Set up python logging
 
     Args:
@@ -136,7 +136,16 @@ def setup_logging(config, use_worker_options=False):
 
         use_worker_options (bool): True to use 'worker_log_config' and
             'worker_log_file' options instead of 'log_config' and 'log_file'.
+
+        register_sighup (func | None): Function to call to register a
+            sighup handler.
     """
+    if not register_sighup:
+        if getattr(signal, "SIGHUP"):
+            register_sighup = lambda x: signal.signal(signal.SIGHUP, x)
+        else:
+            register_sighup = lambda x: None
+
     log_config = (config.worker_log_config if use_worker_options
                   else config.log_config)
     log_file = (config.worker_log_file if use_worker_options
@@ -198,13 +207,7 @@ def setup_logging(config, use_worker_options=False):
 
         load_log_config()
 
-    # TODO(paul): obviously this is a terrible mechanism for
-    #   stealing SIGHUP, because it means no other part of synapse
-    #   can use it instead. If we want to catch SIGHUP anywhere
-    #   else as well, I'd suggest we find a nicer way to broadcast
-    #   it around.
-    if getattr(signal, "SIGHUP"):
-        signal.signal(signal.SIGHUP, sighup)
+    register_sighup(sighup)
 
     # make sure that the first thing we log is a thing we can grep backwards
     # for