Browse Source

Do proper SRV lookup

So the Host header will be correct in requests to homeservers.
David Baker 5 years ago
parent
commit
a90ec3a469
4 changed files with 125 additions and 51 deletions
  1. 1 1
      setup.py
  2. 3 33
      sydent/hs_federation/verifier.py
  3. 111 17
      sydent/http/httpclient.py
  4. 10 0
      sydent/http/servlets/threepidunbindservlet.py

+ 1 - 1
setup.py

@@ -33,7 +33,7 @@ setup(
     install_requires=[
         "signedjson==1.0.0",
         "unpaddedbase64==1.1.0",
-        "Twisted>=14.0.0",
+        "Twisted>=16.0.0",
         "service_identity>=1.0.0",
         "pyasn1",
         "pynacl",

+ 3 - 33
sydent/hs_federation/verifier.py

@@ -53,32 +53,6 @@ class Verifier(object):
             # server_name: <result from keys query>,
         }
 
-    @defer.inlineCallbacks
-    def _getEndpointForServer(self, server_name):
-        if ':' in server_name:
-            defer.returnValue(tuple(server_name.rsplit(':', 1)))
-
-        service_name = "%s.%s.%s" % ('_matrix', '_tcp', server_name)
-
-        default = server_name, 8448
-
-        try:
-            answers, _, _ = yield twisted.names.client.lookupService(service_name)
-        except DNSNameError:
-            logger.info("DNSNameError doing SRV lookup for %s - using default", service_name)
-            defer.returnValue(default)
-
-        for answer in answers:
-            if answer.type != twisted.names.dns.SRV or not answer.payload:
-                continue
-
-            # XXX we just use the first
-            logger.info("Got SRV answer: %r / %d for %s", str(answer.payload.target), answer.payload.port, service_name)
-            defer.returnValue((str(answer.payload.target), answer.payload.port))
-
-        logger.info("No valid answers found in response from %s (%r)", server_name, answers)
-        defer.returnValue(default)
-
     @defer.inlineCallbacks
     def _getKeysForServer(self, server_name):
         """Get the signing key data from a home server.
@@ -90,10 +64,8 @@ class Verifier(object):
             if cached['valid_until_ts'] > now:
                 defer.returnValue(self.cache[server_name]['verify_keys'])
 
-        host_port = yield self._getEndpointForServer(server_name)
-        logger.info("Got host/port %s/%s for %s", host_port[0], host_port[1], server_name)
         client = FederationHttpClient(self.sydent)
-        result = yield client.get_json("https://%s:%s/_matrix/key/v2/server/" % host_port)
+        result = yield client.get_json("https://%s/_matrix/key/v2/server/" % server_name)
         if 'verify_keys' not in result:
             raise SignatureVerifyException("No key found in response")
 
@@ -109,10 +81,8 @@ class Verifier(object):
     def verifyServerSignedJson(self, signed_json, acceptable_server_names=None):
         """Given a signed json object, try to verify any one
         of the signatures on it
-        XXX: This contains a very noddy version of the home server
-        SRV lookup and signature verification. It forms HTTPS URLs
-        from the result of the SRV lookup which will mean the Host:
-        parameter in the request will be wrong. It only looks at
+        XXX: This contains a fairly noddy version of the home server
+        SRV lookup and signature verification. It only looks at
         the first SRV result. It does no caching (just fetches the
         signature each time and does not contact any other servers
         to do perspectives checks.

+ 111 - 17
sydent/http/httpclient.py

@@ -19,12 +19,13 @@ import logging
 
 from StringIO import StringIO
 from twisted.internet import defer, reactor, ssl
-from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName, ClientTLSOptions
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.internet._sslverify import _defaultCurveName, ClientTLSOptions
 from twisted.web.client import FileBodyProducer, Agent, readBody
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IPolicyForHTTPS
-from zope.interface import implementer
-from OpenSSL import SSL
+import twisted.names.client
+from twisted.names.error import DNSNameError
+from OpenSSL import SSL, crypto
 
 logger = logging.getLogger(__name__)
 
@@ -33,21 +34,20 @@ class SimpleHttpClient(object):
     A simple, no-frills HTTP client based on the class of the same name
     from synapse
     """
-    def __init__(self, sydent, context_factory=None):
+    def __init__(self, sydent, endpoint_factory=None):
         self.sydent = sydent
-        if context_factory is None:
-            # The default context factory in Twisted 14.0.0 (which we require) is
-            # BrowserLikePolicyForHTTPS which will do regular cert validation
+        if endpoint_factory is None:
+            # The default endpoint factory in Twisted 14.0.0 (which we require) uses the
+            # BrowserLikePolicyForHTTPS context factory which will do regular cert validation
             # 'like a browser'
             self.agent = Agent(
                 reactor,
                 connectTimeout=15,
             )
         else:
-            self.agent = Agent(
+            self.agent = Agent.usingEndpointFactory(
                 reactor,
-                connectTimeout=15,
-                contextFactory=context_factory
+                endpoint_factory,
             )
 
     @defer.inlineCallbacks
@@ -79,9 +79,104 @@ class SimpleHttpClient(object):
         )
         defer.returnValue(response)
 
-@implementer(IPolicyForHTTPS)
-class FederationPolicyForHTTPS(object):
-    def creatorForNetloc(self, hostname, port):
+class SRVClientEndpoint(object):
+    def __init__(self, reactor, service, domain, protocol="tcp",
+                 default_port=None, endpoint=HostnameEndpoint,
+                 endpoint_kw_args={}):
+        self.reactor = reactor
+        self.domain = domain
+
+        self.endpoint = endpoint
+        self.endpoint_kw_args = endpoint_kw_args
+
+    @defer.inlineCallbacks
+    def lookup_server(self):
+        service_name = "%s.%s.%s" % ('_matrix', '_tcp', self.domain)
+
+        default = self.domain, 8448
+
+        try:
+            answers, _, _ = yield twisted.names.client.lookupService(service_name)
+        except DNSNameError:
+            logger.info("DNSNameError doing SRV lookup for %s - using default", service_name)
+            defer.returnValue(default)
+
+        for answer in answers:
+            if answer.type != twisted.names.dns.SRV or not answer.payload:
+                continue
+
+            # XXX we just use the first
+            logger.info("Got SRV answer: %r / %d for %s", str(answer.payload.target), answer.payload.port, service_name)
+            defer.returnValue((str(answer.payload.target), answer.payload.port))
+
+        logger.info("No valid answers found in response from %s (%r)", self.domain, answers)
+        defer.returnValue(default)
+
+    @defer.inlineCallbacks
+    def connect(self, protocolFactory):
+        server = yield self.lookup_server()
+        logger.info("Connecting to %s:%s", server[0], server[1])
+        endpoint = self.endpoint(
+            self.reactor, server[0], server[1], **self.endpoint_kw_args
+        )
+        connection = yield endpoint.connect(protocolFactory)
+        defer.returnValue(connection)
+
+def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
+                               timeout=None):
+    """Construct an endpoint for the given matrix destination.
+
+    :param reactor: Twisted reactor.
+    :param destination: The name of the server to connect to.
+    :type destination: bytes
+    :param ssl_context_factory: Factory which generates SSL contexts to use for TLS.
+    :type ssl_context_factory: twisted.internet.ssl.ContextFactory
+    :param timeout (int): connection timeout in seconds
+    :type timeout: int
+    """
+
+    domain_port = destination.split(":")
+    domain = domain_port[0]
+    port = int(domain_port[1]) if domain_port[1:] else None
+
+    endpoint_kw_args = {}
+
+    if timeout is not None:
+        endpoint_kw_args.update(timeout=timeout)
+
+    if ssl_context_factory is None:
+        transport_endpoint = HostnameEndpoint
+        default_port = 8008
+    else:
+        def transport_endpoint(reactor, host, port, timeout):
+            return wrapClientTLS(
+                ssl_context_factory,
+                HostnameEndpoint(reactor, host, port, timeout=timeout))
+        default_port = 8448
+
+    if port is None:
+        return SRVClientEndpoint(
+            reactor, "matrix", domain, protocol="tcp",
+            default_port=default_port, endpoint=transport_endpoint,
+            endpoint_kw_args=endpoint_kw_args
+        )
+    else:
+        return transport_endpoint(
+            reactor, domain, port, **endpoint_kw_args
+        )
+
+class FederationEndpointFactory(object):
+    def endpointForURI(self, uri):
+        destination = uri.netloc
+        context_factory = FederationContextFactory()
+
+        return matrix_federation_endpoint(
+            reactor, destination, timeout=10,
+            ssl_context_factory=context_factory,
+        )
+
+class FederationContextFactory(object):
+    def getContext(self):
         context = SSL.Context(SSL.SSLv23_METHOD)
         try:
             _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
@@ -91,9 +186,8 @@ class FederationPolicyForHTTPS(object):
         context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
 
         context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
-        return ClientTLSOptions(hostname, context)
-
+        return context
 
 class FederationHttpClient(SimpleHttpClient):
     def __init__(self, sydent):
-        super(FederationHttpClient, self).__init__(sydent, FederationPolicyForHTTPS())
+        super(FederationHttpClient, self).__init__(sydent, FederationEndpointFactory())

+ 10 - 0
sydent/http/servlets/threepidunbindservlet.py

@@ -16,6 +16,7 @@
 # limitations under the License.
 
 import json
+import logging
 
 from sydent.http.servlets import get_args, jsonwrap
 from sydent.hs_federation.verifier import NoAuthenticationError
@@ -25,6 +26,8 @@ from twisted.web.resource import Resource
 from twisted.web import server
 from twisted.internet import defer
 
+logger = logging.getLogger(__name__)
+
 class ThreePidUnbindServlet(Resource):
     def __init__(self, sydent):
         self.sydent = sydent
@@ -73,6 +76,12 @@ class ThreePidUnbindServlet(Resource):
                 request.write(json.dumps({'errcode': 'M_FORBIDDEN', 'error': ex.message}))
                 request.finish()
                 return
+            except:
+                logger.exception("Exception whilst authenticating unbind request")
+                request.setResponseCode(500)
+                request.write(json.dumps({'errcode': 'M_UNKNOWN', 'error': 'Internal Server Error'}))
+                request.finish()
+                return
 
             if not mxid.endswith(':' + origin_server_name):
                 request.setResponseCode(403)
@@ -83,6 +92,7 @@ class ThreePidUnbindServlet(Resource):
             request.write(json.dumps({}))
             request.finish()
         except Exception as ex:
+            logger.exception("Exception whilst handling unbind")
             request.setResponseCode(500)
             request.write(json.dumps({'errcode': 'M_UNKNOWN', 'error': ex.message}))
             request.finish()