Browse Source

`mypy --strict` for `sydent.http.matrixfederationagent` (#444)

* Let's typecheck matrix federation agent

* Improve type stub for twisted.python.log.err

* Additional stubs that apply to matrixfederationagent

* Annotate the _RoutingResult struct

* Remove unused typeignore (now in mypy config)

* Annotate well_known_cache

* Annotate _parse_cache_control

* Annotate _cache_period_from_headers

* Annotate LoggingHostnameEndpoint

* Annotate _do_get_well_known

* Workaround no annotation for Headers.copy

* annotate EndpointFactory

* Avoid str/bytes confusion in well_known handling

* Annotations for MatrixFederationAgent

* Suppress reactor complaint for now

* sydent.http/*.py now passes mypy --strict

* Isort

* Changelog

* Additional linting --- looks like it didn't fully run?

* Keep 3.6 flake8 happy with annotations on previous line

* Review fixup
David Robertson 2 years ago

+ 1 - 0

@@ -0,0 +1 @@
+Get `sydent.http.matrixfederationagent` to pass `mypy --strict`.

+ 3 - 8

@@ -49,16 +49,11 @@ strict = true
 files = [
     # Find files that pass with
     #     find sydent tests -type d -not -name __pycache__ -exec bash -c "mypy --strict '{}' > /dev/null"  \; -print
+    # TODO "sydent/*.py"
-    "sydent/http/",
-    "sydent/http/",
-    "sydent/http/",
-    "sydent/http/",
-    "sydent/http/",
-    "sydent/http/",
-    "sydent/http/",
-    "sydent/http/",
+    "sydent/http/*.py",
+    # TODO "sydent/http/servlets",

+ 32 - 0

@@ -0,0 +1,32 @@
+from typing import Any, AnyStr, Optional
+from twisted.internet import interfaces
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import (
+    IOpenSSLClientConnectionCreator,
+    IProtocol,
+    IProtocolFactory,
+    IStreamClientEndpoint,
+from zope.interface import implementer
+class HostnameEndpoint:
+    # Reactor should be a "provider of L{IReactorTCP}, L{IReactorTime} and
+    # either L{IReactorPluggableNameResolver} or L{IReactorPluggableResolver}."
+    # I don't know how to encode that in the type system.
+    def __init__(
+        self,
+        reactor: object,
+        host: AnyStr,
+        port: int,
+        timeout: float = 30,
+        bindAddress: Optional[bytes] = None,
+        attemptDelay: Optional[float] = None,
+    ): ...
+    def connect(self, protocol_factory: IProtocolFactory) -> Deferred[IProtocol]: ...
+def wrapClientTLS(
+    connectionCreator: IOpenSSLClientConnectionCreator,
+    wrappedEndpoint: IStreamClientEndpoint,
+) -> IStreamClientEndpoint: ...

+ 1 - 1

@@ -5,5 +5,5 @@ from twisted.python.failure import Failure
 def err(
     _stuff: Union[None, Exception, Failure] = None,
     _why: Optional[str] = None,
-    **kw: Any,
+    **kw: object,
 ) -> None: ...

+ 38 - 8

@@ -9,7 +9,13 @@ from twisted.internet.interfaces import (
 from twisted.internet.task import Cooperator
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
+from twisted.web.iweb import (
+    IAgent,
+    IAgentEndpointFactory,
+    IBodyProducer,
+    IPolicyForHTTPS,
+    IResponse,
 from zope.interface import implementer
 C = TypeVar("C")
@@ -20,13 +26,23 @@ class BrowserLikePolicyForHTTPS:
         self, hostname: bytes, port: int
     ) -> IOpenSSLClientConnectionCreator: ...
-class HTTPConnectionPool: ...
+class HTTPConnectionPool:
+    persistent: bool
+    maxPersistentPerHost: int
+    cachedConnectionTimeout: float
+    retryAutomatically: bool
+    def __init__(self, reactor: object, persistent: bool = True): ...
 class Agent:
+    # Here and in `usingEndpointFactory`, reactor should be a "provider of
+    # L{IReactorTCP}, L{IReactorTime} and either
+    # L{IReactorPluggableNameResolver} or L{IReactorPluggableResolver}."
+    # I don't know how to encode that in the type system; see also
+    #
     def __init__(
-        reactor: Any,
+        reactor: object,
         contextFactory: IPolicyForHTTPS = BrowserLikePolicyForHTTPS(),
         connectTimeout: Optional[float] = None,
         bindAddress: Optional[bytes] = None,
@@ -39,17 +55,20 @@ class Agent:
         headers: Optional[Headers] = None,
         bodyProducer: Optional[IBodyProducer] = None,
     ) -> Deferred[IResponse]: ...
+    @classmethod
+    def usingEndpointFactory(
+        cls: Type[C],
+        reactor: object,
+        endpointFactory: IAgentEndpointFactory,
+        pool: Optional[HTTPConnectionPool] = None,
+    ) -> C: ...
 class FileBodyProducer:
     def __init__(
         inputFile: BinaryIO,
-        # Type safety: twisted.internet.task.cooperate is a function with the
-        # same signature as Cooperator.cooperate. (It just wraps a module-level
-        # global cooperator.) But there's no easy way to annotate "either this
-        # type or a specific module".
-        cooperator: Cooperator = twisted.internet.task,  # type: ignore[assignment]
+        cooperator: Cooperator = ...,
         readSize: int = 2 ** 16,
     ): ...
     # Length is either `int` or the opaque object UNKNOWN_LENGTH.
@@ -95,3 +114,14 @@ class URI:
     ): ...
     def fromBytes(cls: Type[C], uri: bytes, defaultPort: Optional[int] = None) -> C: ...
+class RedirectAgent:
+    def __init__(self, Agent: Agent, redirectLimit: int = 20): ...
+    def request(
+        self,
+        method: bytes,
+        uri: bytes,
+        headers: Optional[Headers] = None,
+        bodyProducer: Optional[IBodyProducer] = None,
+    ) -> Deferred[IResponse]: ...

+ 2 - 0

@@ -51,3 +51,5 @@ class Request:
 class PotentialDataLoss(Exception): ...
 CACHED: object
+def stringToDatetime(dateString: bytes) -> int: ...

+ 6 - 1

@@ -177,7 +177,12 @@ class FederationHttpClient(HTTPClient[MatrixFederationAgent]):
     def __init__(self, sydent: "Sydent") -> None:
         self.sydent = sydent
         self.agent = MatrixFederationAgent(
-            BlacklistingReactorWrapper(
+            # Type-safety: I don't have a good way of expressing that
+            # the reactor is IReactorTCP, IReactorTime and
+            # IReactorPluggableNameResolver all at once. But it is, because
+            # it wraps the sydent reactor.
+            # TODO: can we introduce a SydentReactor type like SynapseReactor?
+            BlacklistingReactorWrapper(  # type: ignore[arg-type]

+ 74 - 55

@@ -15,19 +15,31 @@
 import logging
 import random
 import time
-from typing import Optional, Tuple, Union
+from typing import Any, Callable, Dict, Generator, Optional, Tuple
 import attr
-from netaddr import IPAddress  # type: ignore
+from netaddr import IPAddress
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, Response
+from twisted.internet.interfaces import (
+    IProtocol,
+    IProtocolFactory,
+    IReactorTime,
+    IStreamClientEndpoint,
+from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent
 from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer
+from twisted.web.iweb import (
+    IAgent,
+    IAgentEndpointFactory,
+    IBodyProducer,
+    IPolicyForHTTPS,
+    IResponse,
 from zope.interface import implementer
+from sydent.http.federation_tls_options import ClientTLSOptionsFactory
 from sydent.http.httpcommon import read_body_with_max_size
 from sydent.http.srvresolver import SrvResolver, pick_server_from_list
 from sydent.util import json_decoder
@@ -49,7 +61,7 @@ WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
 WELL_KNOWN_MAX_SIZE = 50 * 1024  # 50 KiB
 logger = logging.getLogger(__name__)
-well_known_cache = TTLCache("well-known")
+well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
@@ -59,32 +71,28 @@ class MatrixFederationAgent:
     Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
     :param reactor: twisted reactor to use for underlying requests
-    :type reactor: IReactor
     :param tls_client_options_factory: Factory to use for fetching client tls
         options, or none to disable TLS.
-    :type tls_client_options_factory: ClientTLSOptionsFactory, None
     :param _well_known_tls_policy: TLS policy to use for fetching .well-known
         files. None to use a default (browser-like) implementation.
-    :type _well_known_tls_policy: IPolicyForHTTPS, None
-    :param _srv_resolver: SRVResolver impl to use for looking up SRV records.
-        None to use a default implementation.
-    :type _srv_resolver: SrvResolver, None
     :param _well_known_cache: TTLCache impl for storing cached well-known
         lookups. Omit to use a default implementation.
-    :type _well_known_cache: TTLCache
     def __init__(
-        reactor,
-        tls_client_options_factory,
-        _well_known_tls_policy=None,
-        _srv_resolver: Optional["SrvResolver"] = None,
-        _well_known_cache: "TTLCache" = well_known_cache,
+        # This reactor should also be IReactorTCP and IReactorPluggableNameResolver
+        # because it eventually makes its way to HostnameEndpoint.__init__.
+        # But that's not easy to express with an annotation. We use the
+        # `seconds` attribute below, so mark this as IReactorTime for now.
+        reactor: IReactorTime,
+        tls_client_options_factory: Optional[ClientTLSOptionsFactory],
+        _well_known_tls_policy: Optional[IPolicyForHTTPS] = None,
+        _srv_resolver: Optional[SrvResolver] = None,
+        _well_known_cache: TTLCache[bytes, Optional[bytes]] = well_known_cache,
     ) -> None:
         self._reactor = reactor
@@ -98,15 +106,15 @@ class MatrixFederationAgent:
         self._pool.maxPersistentPerHost = 5
         self._pool.cachedConnectionTimeout = 2 * 60
-        agent_args = {}
         if _well_known_tls_policy is not None:
             # the param is called 'contextFactory', but actually passing a
             # contextfactory is deprecated, and it expects an IPolicyForHTTPS.
-            agent_args["contextFactory"] = _well_known_tls_policy
-        _well_known_agent = RedirectAgent(
-            Agent(self._reactor, pool=self._pool, **agent_args),
-        )
-        self._well_known_agent = _well_known_agent
+            _well_known_agent = Agent(
+                self._reactor, pool=self._pool, contextFactory=_well_known_tls_policy
+            )
+        else:
+            _well_known_agent = Agent(self._reactor, pool=self._pool)
+        self._well_known_agent = RedirectAgent(_well_known_agent)
         # our cache of .well-known lookup results, mapping from server name
         # to delegated name. The values can be:
@@ -121,7 +129,7 @@ class MatrixFederationAgent:
         uri: bytes,
         headers: Optional["Headers"] = None,
         bodyProducer: Optional["IBodyProducer"] = None,
-    ) -> Response:
+    ) -> Generator["defer.Deferred[Any]", Any, IResponse]:
         :param method: HTTP method (GET/POST/etc).
@@ -141,7 +149,8 @@ class MatrixFederationAgent:
             (including problems that prevent the request from being sent).
         parsed_uri = URI.fromBytes(uri, defaultPort=-1)
-        res = yield defer.ensureDeferred(self._route_matrix_uri(parsed_uri))
+        routing: _RoutingResult
+        routing = yield defer.ensureDeferred(self._route_matrix_uri(parsed_uri))
         # set up the TLS connection params
@@ -152,32 +161,37 @@ class MatrixFederationAgent:
             tls_options = None
             tls_options = self._tls_client_options_factory.get_options(
-                res.tls_server_name.decode("ascii")
+                routing.tls_server_name.decode("ascii")
         # make sure that the Host header is set correctly
         if headers is None:
             headers = Headers()
-            headers = headers.copy()
+            # Type safety: Headers.copy doesn't have a return type annotated,
+            # and I don't want to stub web.http_headers. Could use stubgen? It's
+            # a pretty simple file.
+            headers = headers.copy()  # type: ignore[no-untyped-call]
             assert headers is not None
         if not headers.hasHeader(b"host"):
-            headers.addRawHeader(b"host", res.host_header)
+            headers.addRawHeader(b"host", routing.host_header)
+        @implementer(IAgentEndpointFactory)
         class EndpointFactory:
-            def endpointForURI(_uri):
-                ep = LoggingHostnameEndpoint(
+            def endpointForURI(_uri: URI) -> IStreamClientEndpoint:
+                ep: IStreamClientEndpoint = LoggingHostnameEndpoint(
-                    res.target_host,
-                    res.target_port,
+                    routing.target_host,
+                    routing.target_port,
                 if tls_options is not None:
                     ep = wrapClientTLS(tls_options, ep)
                 return ep
         agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
+        res: IResponse
         res = yield agent.request(method, uri, headers, bodyProducer)
         return res
@@ -232,9 +246,11 @@ class MatrixFederationAgent:
                 # parse the server name in the .well-known response into host/port.
                 # (This code is lifted from twisted.web.client.URI.fromBytes).
                 if b":" in well_known_server:
-                    well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
+                    well_known_host, well_known_port_raw = well_known_server.rsplit(
+                        b":", 1
+                    )
-                        well_known_port = int(well_known_port)
+                        well_known_port = int(well_known_port_raw)
                     except ValueError:
                         # the part after the colon could not be parsed as an int
                         # - we assume it is an IPv6 literal with no port (the closing
@@ -308,7 +324,7 @@ class MatrixFederationAgent:
     async def _do_get_well_known(
         self, server_name: bytes
-    ) -> Tuple[Union[bytes, None, object], int]:
+    ) -> Tuple[Optional[bytes], float]:
         """Actually fetch and parse a .well-known, without checking the cache
         :param server_name: Name of the server, from the requested url
@@ -321,6 +337,7 @@ class MatrixFederationAgent:
         uri = b"https://%s/.well-known/matrix/server" % (server_name,)
         uri_str = uri.decode("ascii")"Fetching %s", uri_str)
+        cache_period: Optional[float]
             response = await self._well_known_agent.request(b"GET", uri)
             body = await read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE)
@@ -338,7 +355,7 @@ class MatrixFederationAgent:
             # add some randomness to the TTL to avoid a stampeding herd every hour
             # after startup
-            cache_period: float = WELL_KNOWN_INVALID_CACHE_PERIOD
+            cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
             cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
             return (None, cache_period)
@@ -363,27 +380,33 @@ class MatrixFederationAgent:
 class LoggingHostnameEndpoint:
     """A wrapper for HostnameEndpint which logs when it connects"""
-    def __init__(self, reactor, host, port, *args, **kwargs):
+    def __init__(
+        self, reactor: IReactorTime, host: bytes, port: int, *args: Any, **kwargs: Any
+    ): = host
         self.port = port
         self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)"Endpoint created with %s:%d", host, port)
-    def connect(self, protocol_factory):
+    def connect(
+        self, protocol_factory: IProtocolFactory
+    ) -> "defer.Deferred[IProtocol]":"Connecting to %s:%i","ascii"), self.port)
         return self.ep.connect(protocol_factory)
-def _cache_period_from_headers(headers, time_now=time.time):
+def _cache_period_from_headers(
+    headers: Headers, time_now: Callable[[], float] = time.time
+) -> Optional[float]:
     cache_controls = _parse_cache_control(headers)
     if b"no-store" in cache_controls:
         return 0
-    if b"max-age" in cache_controls:
+    max_age = cache_controls.get(b"max-age")
+    if max_age is not None:
-            max_age = int(cache_controls[b"max-age"])
-            return max_age
+            return int(max_age)
         except ValueError:
@@ -401,8 +424,8 @@ def _cache_period_from_headers(headers, time_now=time.time):
     return None
-def _parse_cache_control(headers):
-    cache_controls = {}
+def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
+    cache_controls: Dict[bytes, Optional[bytes]] = {}
     for hdr in headers.getRawHeaders(b"cache-control", []):
         for directive in hdr.split(b","):
             splits = [x.strip() for x in directive.split(b"=", 1)]
@@ -412,7 +435,7 @@ def _parse_cache_control(headers):
     return cache_controls
+@attr.s(frozen=True, slots=True, auto_attribs=True)
 class _RoutingResult:
     """The result returned by `_route_matrix_uri`.
     Contains the parameters needed to direct a federation connection to a particular
@@ -421,30 +444,26 @@ class _RoutingResult:
     chosen from the list.
-    host_header = attr.ib()
+    host_header: bytes
     The value we should assign to the Host header (host:port from the matrix
     URI, or .well-known).
-    :type: bytes
-    tls_server_name = attr.ib()
+    tls_server_name: bytes
     The server name we should set in the SNI (typically host, without port, from the
     matrix URI or .well-known)
-    :type: bytes
-    target_host = attr.ib()
+    target_host: bytes
     The hostname (or IP literal) we should route the TCP connection to (the target of the
     SRV record, or the hostname from the URL/.well-known)
-    :type: bytes
-    target_port = attr.ib()
+    target_port: int
     The port we should route the TCP connection to (the target of the SRV record, or
     the port from the URL/.well-known, or 8448)
-    :type: int