Просмотр исходного кода

First pass at mypy for sydent.http (#439)

Get `sydent.http.{httpclient,httpsclient,httpcommon}` passing `mypy --strict`.

* Whoops, rename a bunch of .py stubs to .pyi
* Use `Response` instead of `IResponse`

  Rationale: I want to annotate one of these, so that we know the type of
  headers (which would allow us to detect the problem fixed by #415). if I
  write a stub for `IResponse` I'll have to write a stub for the entire
  `iweb` interface module, which seems like a PITA.

  Also: AFAICS `Response` is the only implementation of `IResponse`. So I
  feel a bit naughty, but not too naughty.

* Recognise that we're using a TCP transport
  
  otherwise we can't call `abortConnection`.

* Stub `Failure.check`
* Type safety for `response.length`
David Robertson 2 лет назад
Родитель
Сommit
3f6a0ac794

+ 1 - 0
changelog.d/439.misc

@@ -0,0 +1 @@
+Make `sydent.http.{httpclient, httpsclient, httpcommon}` pass `mypy --strict`.

+ 3 - 0
pyproject.toml

@@ -54,6 +54,9 @@ files = [
     "sydent/http/auth.py",
     "sydent/http/blacklisting_reactor.py",
     "sydent/http/federation_tls_options.py",
+    "sydent/http/httpclient.py",
+    "sydent/http/httpcommon.py",
+    "sydent/http/httpsclient.py",
     "sydent/http/srvresolver.py",
     "sydent/hs_federation",
     "sydent/replication",

+ 0 - 16
stubs/twisted/internet/ssl.py

@@ -1,16 +0,0 @@
-from typing import Optional, Any
-
-import OpenSSL.SSL
-from twisted.internet._sslverify import IOpenSSLTrustRoot
-
-
-def platformTrust() -> IOpenSSLTrustRoot:
-    ...
-
-
-class CertificateOptions:
-    def __init__(self, trustRoot: Optional[IOpenSSLTrustRoot] = None, **kwargs: Any):
-        ...
-
-    def _makeContext(self) -> OpenSSL.SSL.Context:
-        ...

+ 44 - 0
stubs/twisted/internet/ssl.pyi

@@ -0,0 +1,44 @@
+from typing import Optional, Any, List, Dict, AnyStr, TypeVar, Type
+
+import OpenSSL.SSL
+
+# I don't like importing from _sslverify, but IOpenSSLTrustRoot isn't re-exported
+# anywhere else in twisted.
+from twisted.internet._sslverify import IOpenSSLTrustRoot
+from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
+
+from zope.interface import implementer
+
+C = TypeVar("C")
+
+class Certificate:
+    original: OpenSSL.crypto.X509
+    @classmethod
+    def loadPEM(cls: Type[C], data: AnyStr) -> C: ...
+
+def platformTrust() -> IOpenSSLTrustRoot: ...
+
+class PrivateCertificate(Certificate): ...
+
+class CertificateOptions:
+    def __init__(
+        self, trustRoot: Optional[IOpenSSLTrustRoot] = None, **kwargs: Any
+    ): ...
+    def _makeContext(self) -> OpenSSL.SSL.Context: ...
+
+def optionsForClientTLS(
+    hostname: str,
+    trustRoot: Optional[IOpenSSLTrustRoot] = None,
+    clientCertificate: Optional[PrivateCertificate] = None,
+    acceptableProtocols: Optional[List[bytes]] = None,
+    *,
+    # Shouldn't use extraCertificateOptions:
+    # "any time you need to pass an option here that is a bug in this interface."
+    extraCertificateOptions: Optional[Dict[Any, Any]] = None,
+) -> IOpenSSLClientConnectionCreator: ...
+
+
+# Type safety: I don't want to respecify the methods on the interface that we
+# don't use.
+@implementer(IOpenSSLTrustRoot)  # type: ignore[misc]
+class OpenSSLDefaultPaths: ...

+ 0 - 0
stubs/twisted/names/__init__.py → stubs/twisted/names/__init__.pyi


+ 8 - 1
stubs/twisted/python/failure.pyi

@@ -1,5 +1,7 @@
 from types import TracebackType
-from typing import Type, Optional
+from typing import Type, Optional, Union, TypeVar, overload
+
+E = TypeVar("E")
 
 
 class Failure(BaseException):
@@ -12,3 +14,8 @@ class Failure(BaseException):
         captureVars: bool = False,
     ):
         ...
+
+    @overload
+    def check(self, singleErrorType: Type[E]) -> Optional[E]: ...
+    @overload
+    def check(self, *errorTypes: Union[str, Type[Exception]]) -> Optional[Exception]: ...

+ 70 - 1
stubs/twisted/web/client.pyi

@@ -1,4 +1,73 @@
+from typing import BinaryIO, Any, Optional
+
+import twisted.internet
 from twisted.internet.defer import Deferred
-from twisted.web.iweb import IResponse
+from twisted.internet.interfaces import (
+    IOpenSSLClientConnectionCreator,
+    IConsumer,
+    IProtocol,
+)
+from twisted.internet.task import Cooperator
+from twisted.web.http_headers import Headers
+
+from twisted.web.iweb import IResponse, IAgent, IBodyProducer, IPolicyForHTTPS
+from zope.interface import implementer
+
+@implementer(IPolicyForHTTPS)
+class BrowserLikePolicyForHTTPS:
+    def creatorForNetloc(
+        self, hostname: bytes, port: int
+    ) -> IOpenSSLClientConnectionCreator: ...
+
+class HTTPConnectionPool: ...
+
+@implementer(IAgent)
+class Agent:
+    def __init__(
+        self,
+        reactor: Any,
+        contextFactory: IPolicyForHTTPS = BrowserLikePolicyForHTTPS(),
+        connectTimeout: Optional[float] = None,
+        bindAddress: Optional[bytes] = None,
+        pool: Optional[HTTPConnectionPool] = None,
+    ): ...
+    # Type safety: IAgent says this returns a Deferred[IResponse].
+    # I'm narrowing it here (but that's not strictly allowed because Deferred[T]
+    # is _contra_variant in T. It's all muddling.
+    def request(  # type: ignore[override]
+        self,
+        method: bytes,
+        uri: bytes,
+        headers: Optional[Headers] = None,
+        bodyProducer: Optional[IBodyProducer] = None,
+    ) -> Deferred[Response]: ...
+
+@implementer(IBodyProducer)
+class FileBodyProducer:
+    def __init__(
+        self,
+        inputFile: BinaryIO,
+        # Type safety: twisted.internet.task.cooperate is a function with the
+        # same signature as Cooperator.cooperate. (It just wraps a module-level
+        # global cooperator.) But there's no easy way to annotate "either this
+        # type or a specific module".
+        cooperator: Cooperator = twisted.internet.task,  # type: ignore[assignment]
+        readSize: int = 2 ** 16,
+    ): ...
+    # Length is either `int` or the opaque object UNKNOWN_LENGTH.
+    length: int | object
+    def startProducing(self, consumer: IConsumer) -> Deferred[None]: ...
+    def stopProducing(self) -> None: ...
+    def pauseProducing(self) -> None: ...
+    def resumeProducing(self) -> None: ...
 
 def readBody(response: IResponse) -> Deferred[bytes]: ...
+
+class Response:
+    code: int
+    headers: Headers
+    # Length is either `int` or the opaque object UNKNOWN_LENGTH.
+    length: int | object
+    def deliverBody(self, protocol: IProtocol) -> None: ...
+
+class ResponseDone: ...

+ 0 - 23
stubs/twisted/web/http.py

@@ -1,23 +0,0 @@
-import typing
-from typing import AnyStr, Optional, Dict, List
-
-from twisted.internet.defer import Deferred
-from twisted.logger import Logger
-from twisted.web.http_headers import Headers
-
-
-class Request:
-
-    method: bytes
-    uri: bytes
-    path: bytes
-    args: Dict[bytes, List[bytes]]
-    content: typing.BinaryIO
-    cookies: List[bytes]
-    requestHeaders: Headers
-    responseHeaders: Headers
-    notifications: List[Deferred[None]]
-    _disconnected: bool
-    _log: Logger
-
-    def getHeader(self, key: AnyStr) -> Optional[AnyStr]: ...

+ 48 - 0
stubs/twisted/web/http.pyi

@@ -0,0 +1,48 @@
+import typing
+from typing import AnyStr, Optional, Dict, List
+
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import ITCPTransport, IAddress
+from twisted.logger import Logger
+from twisted.web.http_headers import Headers
+
+class HTTPChannel: ...
+
+class Request:
+    # Instance attributes mentioned in the docstring
+    method: bytes
+    uri: bytes
+    path: bytes
+    args: Dict[bytes, List[bytes]]
+    content: typing.BinaryIO
+    cookies: List[bytes]
+    requestHeaders: Headers
+    responseHeaders: Headers
+    notifications: List[Deferred[None]]
+    _disconnected: bool
+    _log: Logger
+
+    # Other instance attributes set in __init__
+    channel: HTTPChannel
+    client: IAddress
+    # This was hard to derive.
+    # - `transport` is `self.channel.transport`
+    # - `self.channel` is set in the constructor, and looks like it's always
+    #   an `HTTPChannel`.
+    # - `HTTPChannel` is a `LineReceiver` is a `Protocol` is a `BaseProtocol`.
+    # - `BaseProtocol` sets `self.transport` to initially `None`.
+    #
+    # Note that `transport` is set to an ITransport in makeConnection,
+    # so is almost certainly not None by the time it reaches our code.
+    #
+    # I've narrowed this to ITCPTransport because
+    # - we use `self.transport.abortConnection`, which belongs to that interface
+    # - twisted does too! in its implementation of HTTPChannel.forceAbortClient
+    transport: Optional[ITCPTransport]
+
+
+    def getHeader(self, key: AnyStr) -> Optional[AnyStr]: ...
+
+    def handleContentChunk(self, data: bytes) -> None: ...
+
+class PotentialDataLoss(Exception): ...

+ 18 - 12
sydent/http/httpclient.py

@@ -15,11 +15,10 @@
 import json
 import logging
 from io import BytesIO
-from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, TypeVar, cast
 
-from twisted.web.client import Agent, FileBodyProducer
+from twisted.web.client import Agent, FileBodyProducer, Response
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IResponse
 
 from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
 from sydent.http.federation_tls_options import ClientTLSOptionsFactory
@@ -34,12 +33,15 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class HTTPClient:
+AgentType = TypeVar("AgentType", Agent, MatrixFederationAgent)
+
+
+class HTTPClient(Generic[AgentType]):
     """A base HTTP class that contains methods for making GET and POST HTTP
     requests.
     """
 
-    agent: IAgent
+    agent: AgentType
 
     async def get_json(self, uri: str, max_size: Optional[int] = None) -> JsonDict:
         """Make a GET request to an endpoint returning JSON and parse result
@@ -52,7 +54,7 @@ class HTTPClient:
         """
         logger.debug("HTTP GET %s", uri)
 
-        response = await self.agent.request(
+        response: Response = await self.agent.request(
             b"GET",
             uri.encode("utf8"),
         )
@@ -63,11 +65,15 @@ class HTTPClient:
         except Exception:
             logger.exception("Error parsing JSON from %s", uri)
             raise
-        return json_body
+        if not isinstance(json_body, dict):
+            raise TypeError
+        # Cast safety: json only permits strings as object keys, so `json_body`
+        # must be Dict[str, Any] rather than Dict[Any, Any].
+        return cast(JsonDict, json_body)
 
     async def post_json_get_nothing(
         self, uri: str, post_json: JsonDict, opts: Dict[str, Any]
-    ) -> IResponse:
+    ) -> Response:
         """Make a POST request to an endpoint returning nothing
 
         :param uri: The URI to make a POST request to.
@@ -89,7 +95,7 @@ class HTTPClient:
         post_json: Dict[str, Any],
         opts: Dict[str, Any],
         max_size: Optional[int] = None,
-    ) -> Tuple[IResponse, Optional[JsonDict]]:
+    ) -> Tuple[Response, Optional[JsonDict]]:
         """Make a POST request to an endpoint that might be returning JSON and parse
         result
 
@@ -119,7 +125,7 @@ class HTTPClient:
 
         logger.debug("HTTP POST %s -> %s", json_bytes, uri)
 
-        response = await self.agent.request(
+        response: Response = await self.agent.request(
             b"POST",
             uri.encode("utf8"),
             headers,
@@ -142,7 +148,7 @@ class HTTPClient:
         return response, json_body
 
 
-class SimpleHttpClient(HTTPClient):
+class SimpleHttpClient(HTTPClient[Agent]):
     """A simple, no-frills HTTP client based on the class of the same name
     from Synapse.
     """
@@ -162,7 +168,7 @@ class SimpleHttpClient(HTTPClient):
         )
 
 
-class FederationHttpClient(HTTPClient):
+class FederationHttpClient(HTTPClient[MatrixFederationAgent]):
     """HTTP client for federation requests to homeservers. Uses a
     MatrixFederationAgent.
     """

+ 34 - 16
sydent/http/httpcommon.py

@@ -14,13 +14,16 @@
 
 import logging
 from io import BytesIO
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, cast
 
 import twisted.internet.ssl
 from twisted.internet import defer, protocol
+from twisted.internet._sslverify import IOpenSSLTrustRoot
+from twisted.internet.interfaces import ITCPTransport
 from twisted.internet.protocol import connectionDone
+from twisted.python.failure import Failure
 from twisted.web import server
-from twisted.web._newclient import ResponseDone
+from twisted.web.client import Response, ResponseDone
 from twisted.web.http import PotentialDataLoss
 from twisted.web.iweb import UNKNOWN_LENGTH
 
@@ -41,7 +44,7 @@ class SslComponents:
         self.myPrivateCertificate = self.makeMyCertificate()
         self.trustRoot = self.makeTrustRoot()
 
-    def makeMyCertificate(self):
+    def makeMyCertificate(self) -> Optional[twisted.internet.ssl.PrivateCertificate]:
         # TODO Move some of this loading into parse_config
         privKeyAndCertFilename = self.sydent.config.http.cert_file
 
@@ -66,7 +69,7 @@ class SslComponents:
         fp.close()
         return twisted.internet.ssl.PrivateCertificate.loadPEM(authData)
 
-    def makeTrustRoot(self):
+    def makeTrustRoot(self) -> IOpenSSLTrustRoot:
         # If this option is specified, use a specific root CA cert. This is useful for testing when it's not
         # practical to get the client cert signed by a real root CA but should never be used on a production server.
         caCertFilename = self.sydent.config.http.ca_cert_file
@@ -79,7 +82,9 @@ class SslComponents:
                 logger.warning("Failed to open CA cert file %s", caCertFilename)
                 raise
             logger.warning("Using custom CA cert file: %s", caCertFilename)
-            return twisted.internet._sslverify.OpenSSLCertificateAuthorities(
+            # Type ignore: I'm not going to add a stub for the semiprivate
+            # _sslverify module. I've already taken on too much stubbing as it is!
+            return twisted.internet._sslverify.OpenSSLCertificateAuthorities(  # type: ignore
                 [caCert.original]
             )
         else:
@@ -93,10 +98,12 @@ class BodyExceededMaxSize(Exception):
 class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which immediately errors upon receiving data."""
 
-    def __init__(self, deferred):
+    transport: ITCPTransport
+
+    def __init__(self, deferred: "defer.Deferred[bytes]") -> None:
         self.deferred = deferred
 
-    def _maybe_fail(self):
+    def _maybe_fail(self) -> None:
         """
         Report a max size exceed error and disconnect the first time this is called.
         """
@@ -106,23 +113,27 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
             # discarded anyway.
             self.transport.abortConnection()
 
-    def dataReceived(self, data) -> None:
+    def dataReceived(self, data: bytes) -> None:
         self._maybe_fail()
 
-    def connectionLost(self, reason) -> None:
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         self._maybe_fail()
 
 
 class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
     """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
 
-    def __init__(self, deferred, max_size):
+    transport: ITCPTransport
+
+    def __init__(
+        self, deferred: "defer.Deferred[bytes]", max_size: Optional[int]
+    ) -> None:
         self.stream = BytesIO()
         self.deferred = deferred
         self.length = 0
         self.max_size = max_size
 
-    def dataReceived(self, data) -> None:
+    def dataReceived(self, data: bytes) -> None:
         # If the deferred was called, bail early.
         if self.deferred.called:
             return
@@ -139,7 +150,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
             if self.transport is not None:
                 self.transport.abortConnection()
 
-    def connectionLost(self, reason=connectionDone) -> None:
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         # If the maximum size was already exceeded, there's nothing to do.
         if self.deferred.called:
             return
@@ -154,7 +165,9 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
             self.deferred.errback(reason)
 
 
-def read_body_with_max_size(response, max_size):
+def read_body_with_max_size(
+    response: Response, max_size: Optional[int]
+) -> "defer.Deferred[bytes]":
     """
     Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
 
@@ -168,11 +181,14 @@ def read_body_with_max_size(response, max_size):
     Returns:
         A Deferred which resolves to the read body.
     """
-    d = defer.Deferred()
+    d: "defer.Deferred[bytes]" = defer.Deferred()
 
     # If the Content-Length header gives a size larger than the maximum allowed
     # size, do not bother downloading the body.
+    # Type safety: twisted guarantees that response.length is either the
+    # "opaque" object UNKNOWN_LENGTH, or else an int.
     if max_size is not None and response.length != UNKNOWN_LENGTH:
+        response.length = cast(int, response.length)
         if response.length > max_size:
             response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
             return d
@@ -182,12 +198,14 @@ def read_body_with_max_size(response, max_size):
 
 
 class SizeLimitingRequest(server.Request):
-    def handleContentChunk(self, data):
+    def handleContentChunk(self, data: bytes) -> None:
         if self.content.tell() + len(data) > MAX_REQUEST_SIZE:
             logger.info(
                 "Aborting connection from %s because the request exceeds maximum size",
-                self.client.host,
+                # Formerly `self.client.host`, but `host` isn't provided by `IAddress`
+                self.client,
             )
+            assert self.transport is not None
             self.transport.abortConnection()
             return
 

+ 8 - 6
sydent/http/httpsclient.py

@@ -18,10 +18,11 @@ from io import BytesIO
 from typing import TYPE_CHECKING, Optional
 
 from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
 from twisted.internet.ssl import optionsForClientTLS
-from twisted.web.client import Agent, FileBodyProducer
+from twisted.web.client import Agent, FileBodyProducer, Response
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IPolicyForHTTPS, IResponse
+from twisted.web.iweb import IPolicyForHTTPS
 from zope.interface import implementer
 
 from sydent.types import JsonDict
@@ -41,7 +42,7 @@ class ReplicationHttpsClient:
 
     def __init__(self, sydent: "Sydent") -> None:
         self.sydent = sydent
-        self.agent = None
+        self.agent: Optional[Agent] = None
 
         if self.sydent.sslComponents.myPrivateCertificate:
             # We will already have logged a warn if this is absent, so don't do it again
@@ -53,7 +54,7 @@ class ReplicationHttpsClient:
 
     def postJson(
         self, uri: str, jsonObject: JsonDict
-    ) -> Optional["Deferred[IResponse]"]:
+    ) -> Optional["Deferred[Response]"]:
         """
         Sends an POST request over HTTPS.
 
@@ -61,7 +62,6 @@ class ReplicationHttpsClient:
         :param jsonObject: The request's body.
 
         :return: The request's response.
-        :rtype: twisted.internet.defer.Deferred[twisted.web.iweb.IResponse]
         """
         logger.debug("POSTing request to %s", uri)
         if not self.agent:
@@ -85,7 +85,9 @@ class SydentPolicyForHTTPS:
     def __init__(self, sydent: "Sydent") -> None:
         self.sydent = sydent
 
-    def creatorForNetloc(self, hostname, port):
+    def creatorForNetloc(
+        self, hostname: bytes, port: int
+    ) -> IOpenSSLClientConnectionCreator:
         return optionsForClientTLS(
             hostname.decode("ascii"),
             trustRoot=self.sydent.sslComponents.trustRoot,

+ 3 - 3
sydent/http/matrixfederationagent.py

@@ -22,10 +22,10 @@ from netaddr import IPAddress  # type: ignore
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent
+from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, Response
 from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import IAgent, IBodyProducer
 from zope.interface import implementer
 
 from sydent.http.httpcommon import read_body_with_max_size
@@ -121,7 +121,7 @@ class MatrixFederationAgent:
         uri: bytes,
         headers: Optional["Headers"] = None,
         bodyProducer: Optional["IBodyProducer"] = None,
-    ) -> IResponse:
+    ) -> Response:
         """
         :param method: HTTP method (GET/POST/etc).
 

+ 1 - 1
sydent/sydent.py

@@ -178,7 +178,7 @@ class Sydent:
 
         self.threepidBinder = ThreepidBinder(self)
 
-        self.sslComponents = SslComponents(self)
+        self.sslComponents: SslComponents = SslComponents(self)
 
         self.clientApiHttpServer = ClientApiHttpServer(self)
         self.replicationHttpsServer = ReplicationHttpsServer(self)