|
@@ -15,19 +15,31 @@
|
|
|
import logging
|
|
|
import random
|
|
|
import time
|
|
|
-from typing import Optional, Tuple, Union
|
|
|
+from typing import Any, Callable, Dict, Generator, Optional, Tuple
|
|
|
|
|
|
import attr
|
|
|
-from netaddr import IPAddress # type: ignore
|
|
|
+from netaddr import IPAddress
|
|
|
from twisted.internet import defer
|
|
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
|
|
-from twisted.internet.interfaces import IStreamClientEndpoint
|
|
|
-from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, Response
|
|
|
+from twisted.internet.interfaces import (
|
|
|
+ IProtocol,
|
|
|
+ IProtocolFactory,
|
|
|
+ IReactorTime,
|
|
|
+ IStreamClientEndpoint,
|
|
|
+)
|
|
|
+from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent
|
|
|
from twisted.web.http import stringToDatetime
|
|
|
from twisted.web.http_headers import Headers
|
|
|
-from twisted.web.iweb import IAgent, IBodyProducer
|
|
|
+from twisted.web.iweb import (
|
|
|
+ IAgent,
|
|
|
+ IAgentEndpointFactory,
|
|
|
+ IBodyProducer,
|
|
|
+ IPolicyForHTTPS,
|
|
|
+ IResponse,
|
|
|
+)
|
|
|
from zope.interface import implementer
|
|
|
|
|
|
+from sydent.http.federation_tls_options import ClientTLSOptionsFactory
|
|
|
from sydent.http.httpcommon import read_body_with_max_size
|
|
|
from sydent.http.srvresolver import SrvResolver, pick_server_from_list
|
|
|
from sydent.util import json_decoder
|
|
@@ -49,7 +61,7 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
|
|
|
WELL_KNOWN_MAX_SIZE = 50 * 1024 # 50 KiB
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
-well_known_cache = TTLCache("well-known")
|
|
|
+well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
|
|
|
|
|
|
|
|
|
@implementer(IAgent)
|
|
@@ -59,32 +71,28 @@ class MatrixFederationAgent:
|
|
|
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
|
|
|
|
|
|
:param reactor: twisted reactor to use for underlying requests
|
|
|
- :type reactor: IReactor
|
|
|
|
|
|
:param tls_client_options_factory: Factory to use for fetching client tls
|
|
|
options, or none to disable TLS.
|
|
|
- :type tls_client_options_factory: ClientTLSOptionsFactory, None
|
|
|
|
|
|
:param _well_known_tls_policy: TLS policy to use for fetching .well-known
|
|
|
files. None to use a default (browser-like) implementation.
|
|
|
- :type _well_known_tls_policy: IPolicyForHTTPS, None
|
|
|
-
|
|
|
- :param _srv_resolver: SRVResolver impl to use for looking up SRV records.
|
|
|
- None to use a default implementation.
|
|
|
- :type _srv_resolver: SrvResolver, None
|
|
|
|
|
|
:param _well_known_cache: TTLCache impl for storing cached well-known
|
|
|
lookups. Omit to use a default implementation.
|
|
|
- :type _well_known_cache: TTLCache
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- reactor,
|
|
|
- tls_client_options_factory,
|
|
|
- _well_known_tls_policy=None,
|
|
|
- _srv_resolver: Optional["SrvResolver"] = None,
|
|
|
- _well_known_cache: "TTLCache" = well_known_cache,
|
|
|
+ # This reactor should also be IReactorTCP and IReactorPluggableNameResolver
|
|
|
+ # because it eventually makes its way to HostnameEndpoint.__init__.
|
|
|
+ # But that's not easy to express with an annotation. We use the
|
|
|
+ # `seconds` attribute below, so mark this as IReactorTime for now.
|
|
|
+ reactor: IReactorTime,
|
|
|
+ tls_client_options_factory: Optional[ClientTLSOptionsFactory],
|
|
|
+ _well_known_tls_policy: Optional[IPolicyForHTTPS] = None,
|
|
|
+ _srv_resolver: Optional[SrvResolver] = None,
|
|
|
+ _well_known_cache: TTLCache[bytes, Optional[bytes]] = well_known_cache,
|
|
|
) -> None:
|
|
|
self._reactor = reactor
|
|
|
|
|
@@ -98,15 +106,15 @@ class MatrixFederationAgent:
|
|
|
self._pool.maxPersistentPerHost = 5
|
|
|
self._pool.cachedConnectionTimeout = 2 * 60
|
|
|
|
|
|
- agent_args = {}
|
|
|
if _well_known_tls_policy is not None:
|
|
|
# the param is called 'contextFactory', but actually passing a
|
|
|
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
|
|
|
- agent_args["contextFactory"] = _well_known_tls_policy
|
|
|
- _well_known_agent = RedirectAgent(
|
|
|
- Agent(self._reactor, pool=self._pool, **agent_args),
|
|
|
- )
|
|
|
- self._well_known_agent = _well_known_agent
|
|
|
+ _well_known_agent = Agent(
|
|
|
+ self._reactor, pool=self._pool, contextFactory=_well_known_tls_policy
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ _well_known_agent = Agent(self._reactor, pool=self._pool)
|
|
|
+ self._well_known_agent = RedirectAgent(_well_known_agent)
|
|
|
|
|
|
# our cache of .well-known lookup results, mapping from server name
|
|
|
# to delegated name. The values can be:
|
|
@@ -121,7 +129,7 @@ class MatrixFederationAgent:
|
|
|
uri: bytes,
|
|
|
headers: Optional["Headers"] = None,
|
|
|
bodyProducer: Optional["IBodyProducer"] = None,
|
|
|
- ) -> Response:
|
|
|
+ ) -> Generator["defer.Deferred[Any]", Any, IResponse]:
|
|
|
"""
|
|
|
:param method: HTTP method (GET/POST/etc).
|
|
|
|
|
@@ -141,7 +149,8 @@ class MatrixFederationAgent:
|
|
|
(including problems that prevent the request from being sent).
|
|
|
"""
|
|
|
parsed_uri = URI.fromBytes(uri, defaultPort=-1)
|
|
|
- res = yield defer.ensureDeferred(self._route_matrix_uri(parsed_uri))
|
|
|
+ routing: _RoutingResult
|
|
|
+ routing = yield defer.ensureDeferred(self._route_matrix_uri(parsed_uri))
|
|
|
|
|
|
# set up the TLS connection params
|
|
|
#
|
|
@@ -152,32 +161,37 @@ class MatrixFederationAgent:
|
|
|
tls_options = None
|
|
|
else:
|
|
|
tls_options = self._tls_client_options_factory.get_options(
|
|
|
- res.tls_server_name.decode("ascii")
|
|
|
+ routing.tls_server_name.decode("ascii")
|
|
|
)
|
|
|
|
|
|
# make sure that the Host header is set correctly
|
|
|
if headers is None:
|
|
|
headers = Headers()
|
|
|
else:
|
|
|
- headers = headers.copy()
|
|
|
+ # Type safety: Headers.copy doesn't have a return type annotated,
|
|
|
+ # and I don't want to stub web.http_headers. Could use stubgen? It's
|
|
|
+ # a pretty simple file.
|
|
|
+ headers = headers.copy() # type: ignore[no-untyped-call]
|
|
|
assert headers is not None
|
|
|
|
|
|
if not headers.hasHeader(b"host"):
|
|
|
- headers.addRawHeader(b"host", res.host_header)
|
|
|
+ headers.addRawHeader(b"host", routing.host_header)
|
|
|
|
|
|
+ @implementer(IAgentEndpointFactory)
|
|
|
class EndpointFactory:
|
|
|
@staticmethod
|
|
|
- def endpointForURI(_uri):
|
|
|
- ep = LoggingHostnameEndpoint(
|
|
|
+ def endpointForURI(_uri: URI) -> IStreamClientEndpoint:
|
|
|
+ ep: IStreamClientEndpoint = LoggingHostnameEndpoint(
|
|
|
self._reactor,
|
|
|
- res.target_host,
|
|
|
- res.target_port,
|
|
|
+ routing.target_host,
|
|
|
+ routing.target_port,
|
|
|
)
|
|
|
if tls_options is not None:
|
|
|
ep = wrapClientTLS(tls_options, ep)
|
|
|
return ep
|
|
|
|
|
|
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
|
|
|
+ res: IResponse
|
|
|
res = yield agent.request(method, uri, headers, bodyProducer)
|
|
|
return res
|
|
|
|
|
@@ -232,9 +246,11 @@ class MatrixFederationAgent:
|
|
|
# parse the server name in the .well-known response into host/port.
|
|
|
# (This code is lifted from twisted.web.client.URI.fromBytes).
|
|
|
if b":" in well_known_server:
|
|
|
- well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
|
|
|
+ well_known_host, well_known_port_raw = well_known_server.rsplit(
|
|
|
+ b":", 1
|
|
|
+ )
|
|
|
try:
|
|
|
- well_known_port = int(well_known_port)
|
|
|
+ well_known_port = int(well_known_port_raw)
|
|
|
except ValueError:
|
|
|
# the part after the colon could not be parsed as an int
|
|
|
# - we assume it is an IPv6 literal with no port (the closing
|
|
@@ -308,7 +324,7 @@ class MatrixFederationAgent:
|
|
|
|
|
|
async def _do_get_well_known(
|
|
|
self, server_name: bytes
|
|
|
- ) -> Tuple[Union[bytes, None, object], int]:
|
|
|
+ ) -> Tuple[Optional[bytes], float]:
|
|
|
"""Actually fetch and parse a .well-known, without checking the cache
|
|
|
|
|
|
:param server_name: Name of the server, from the requested url
|
|
@@ -321,6 +337,7 @@ class MatrixFederationAgent:
|
|
|
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
|
|
|
uri_str = uri.decode("ascii")
|
|
|
logger.info("Fetching %s", uri_str)
|
|
|
+ cache_period: Optional[float]
|
|
|
try:
|
|
|
response = await self._well_known_agent.request(b"GET", uri)
|
|
|
body = await read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE)
|
|
@@ -338,7 +355,7 @@ class MatrixFederationAgent:
|
|
|
|
|
|
# add some randomness to the TTL to avoid a stampeding herd every hour
|
|
|
# after startup
|
|
|
- cache_period: float = WELL_KNOWN_INVALID_CACHE_PERIOD
|
|
|
+ cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
|
|
|
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
|
|
|
return (None, cache_period)
|
|
|
|
|
@@ -363,27 +380,33 @@ class MatrixFederationAgent:
|
|
|
class LoggingHostnameEndpoint:
|
|
|
"""A wrapper for HostnameEndpint which logs when it connects"""
|
|
|
|
|
|
- def __init__(self, reactor, host, port, *args, **kwargs):
|
|
|
+ def __init__(
|
|
|
+ self, reactor: IReactorTime, host: bytes, port: int, *args: Any, **kwargs: Any
|
|
|
+ ):
|
|
|
self.host = host
|
|
|
self.port = port
|
|
|
self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
|
|
|
logger.info("Endpoint created with %s:%d", host, port)
|
|
|
|
|
|
- def connect(self, protocol_factory):
|
|
|
+ def connect(
|
|
|
+ self, protocol_factory: IProtocolFactory
|
|
|
+ ) -> "defer.Deferred[IProtocol]":
|
|
|
logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
|
|
|
return self.ep.connect(protocol_factory)
|
|
|
|
|
|
|
|
|
-def _cache_period_from_headers(headers, time_now=time.time):
|
|
|
+def _cache_period_from_headers(
|
|
|
+ headers: Headers, time_now: Callable[[], float] = time.time
|
|
|
+) -> Optional[float]:
|
|
|
cache_controls = _parse_cache_control(headers)
|
|
|
|
|
|
if b"no-store" in cache_controls:
|
|
|
return 0
|
|
|
|
|
|
- if b"max-age" in cache_controls:
|
|
|
+ max_age = cache_controls.get(b"max-age")
|
|
|
+ if max_age is not None:
|
|
|
try:
|
|
|
- max_age = int(cache_controls[b"max-age"])
|
|
|
- return max_age
|
|
|
+ return int(max_age)
|
|
|
except ValueError:
|
|
|
pass
|
|
|
|
|
@@ -401,8 +424,8 @@ def _cache_period_from_headers(headers, time_now=time.time):
|
|
|
return None
|
|
|
|
|
|
|
|
|
-def _parse_cache_control(headers):
|
|
|
- cache_controls = {}
|
|
|
+def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
|
|
|
+ cache_controls: Dict[bytes, Optional[bytes]] = {}
|
|
|
for hdr in headers.getRawHeaders(b"cache-control", []):
|
|
|
for directive in hdr.split(b","):
|
|
|
splits = [x.strip() for x in directive.split(b"=", 1)]
|
|
@@ -412,7 +435,7 @@ def _parse_cache_control(headers):
|
|
|
return cache_controls
|
|
|
|
|
|
|
|
|
-@attr.s
|
|
|
+@attr.s(frozen=True, slots=True, auto_attribs=True)
|
|
|
class _RoutingResult:
|
|
|
"""The result returned by `_route_matrix_uri`.
|
|
|
Contains the parameters needed to direct a federation connection to a particular
|
|
@@ -421,30 +444,26 @@ class _RoutingResult:
|
|
|
chosen from the list.
|
|
|
"""
|
|
|
|
|
|
- host_header = attr.ib()
|
|
|
+ host_header: bytes
|
|
|
"""
|
|
|
The value we should assign to the Host header (host:port from the matrix
|
|
|
URI, or .well-known).
|
|
|
- :type: bytes
|
|
|
"""
|
|
|
|
|
|
- tls_server_name = attr.ib()
|
|
|
+ tls_server_name: bytes
|
|
|
"""
|
|
|
The server name we should set in the SNI (typically host, without port, from the
|
|
|
matrix URI or .well-known)
|
|
|
- :type: bytes
|
|
|
"""
|
|
|
|
|
|
- target_host = attr.ib()
|
|
|
+ target_host: bytes
|
|
|
"""
|
|
|
The hostname (or IP literal) we should route the TCP connection to (the target of the
|
|
|
SRV record, or the hostname from the URL/.well-known)
|
|
|
- :type: bytes
|
|
|
"""
|
|
|
|
|
|
- target_port = attr.ib()
|
|
|
+ target_port: int
|
|
|
"""
|
|
|
The port we should route the TCP connection to (the target of the SRV record, or
|
|
|
the port from the URL/.well-known, or 8448)
|
|
|
- :type: int
|
|
|
"""
|