|
@@ -14,7 +14,7 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
|
|
-from twisted.internet import defer
|
|
|
+from twisted.internet import defer, reactor
|
|
|
from twisted.internet.error import ConnectError
|
|
|
from twisted.names import client, dns
|
|
|
from twisted.names.error import DNSNameError, DomainError
|
|
@@ -68,13 +68,75 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
|
|
|
default_port = 8448
|
|
|
|
|
|
if port is None:
|
|
|
- return SRVClientEndpoint(
|
|
|
+ return _WrappingEndpointFac(SRVClientEndpoint(
|
|
|
reactor, "matrix", domain, protocol="tcp",
|
|
|
default_port=default_port, endpoint=transport_endpoint,
|
|
|
endpoint_kw_args=endpoint_kw_args
|
|
|
- )
|
|
|
+ ))
|
|
|
else:
|
|
|
- return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
|
|
|
+ return _WrappingEndpointFac(transport_endpoint(
|
|
|
+ reactor, domain, port, **endpoint_kw_args
|
|
|
+ ))
|
|
|
+
|
|
|
+
|
|
|
+class _WrappingEndpointFac(object):
|
|
|
+ def __init__(self, endpoint_fac):
|
|
|
+ self.endpoint_fac = endpoint_fac
|
|
|
+
|
|
|
+ @defer.inlineCallbacks
|
|
|
+ def connect(self, protocolFactory):
|
|
|
+ conn = yield self.endpoint_fac.connect(protocolFactory)
|
|
|
+ conn = _WrappedConnection(conn)
|
|
|
+ defer.returnValue(conn)
|
|
|
+
|
|
|
+
|
|
|
+class _WrappedConnection(object):
|
|
|
+ """Wraps a connection and calls abort on it if it hasn't seen any action
|
|
|
+ for 2.5-3 minutes.
|
|
|
+ """
|
|
|
+ __slots__ = ["conn", "last_request"]
|
|
|
+
|
|
|
+ def __init__(self, conn):
|
|
|
+ object.__setattr__(self, "conn", conn)
|
|
|
+ object.__setattr__(self, "last_request", time.time())
|
|
|
+
|
|
|
+ def __getattr__(self, name):
|
|
|
+ return getattr(self.conn, name)
|
|
|
+
|
|
|
+ def __setattr__(self, name, value):
|
|
|
+ setattr(self.conn, name, value)
|
|
|
+
|
|
|
+ def _time_things_out_maybe(self):
|
|
|
+ # We use a slightly shorter timeout here just in case the callLater is
|
|
|
+ # triggered early. Paranoia ftw.
|
|
|
+ # TODO: Cancel the previous callLater rather than comparing time.time()?
|
|
|
+ if time.time() - self.last_request >= 2.5 * 60:
|
|
|
+ self.abort()
|
|
|
+ # Abort the underlying TLS connection. The abort() method calls
|
|
|
+ # loseConnection() on the underlying TLS connection which tries to
|
|
|
+ # shutdown the connection cleanly. We call abortConnection()
|
|
|
+ # since that will promptly close the underlying TCP connection.
|
|
|
+ self.transport.abortConnection()
|
|
|
+
|
|
|
+ def request(self, request):
|
|
|
+ self.last_request = time.time()
|
|
|
+
|
|
|
+ # Time this connection out if we haven't send a request in the last
|
|
|
+ # N minutes
|
|
|
+ # TODO: Cancel the previous callLater?
|
|
|
+ reactor.callLater(3 * 60, self._time_things_out_maybe)
|
|
|
+
|
|
|
+ d = self.conn.request(request)
|
|
|
+
|
|
|
+ def update_request_time(res):
|
|
|
+ self.last_request = time.time()
|
|
|
+ # TODO: Cancel the previous callLater?
|
|
|
+ reactor.callLater(3 * 60, self._time_things_out_maybe)
|
|
|
+ return res
|
|
|
+
|
|
|
+ d.addCallback(update_request_time)
|
|
|
+
|
|
|
+ return d
|
|
|
|
|
|
|
|
|
class SpiderEndpoint(object):
|