|
@@ -14,13 +14,16 @@
|
|
|
|
|
|
import logging
|
|
|
from io import BytesIO
|
|
|
-from typing import TYPE_CHECKING
|
|
|
+from typing import TYPE_CHECKING, Optional, cast
|
|
|
|
|
|
import twisted.internet.ssl
|
|
|
from twisted.internet import defer, protocol
|
|
|
+from twisted.internet._sslverify import IOpenSSLTrustRoot
|
|
|
+from twisted.internet.interfaces import ITCPTransport
|
|
|
from twisted.internet.protocol import connectionDone
|
|
|
+from twisted.python.failure import Failure
|
|
|
from twisted.web import server
|
|
|
-from twisted.web._newclient import ResponseDone
|
|
|
+from twisted.web.client import Response, ResponseDone
|
|
|
from twisted.web.http import PotentialDataLoss
|
|
|
from twisted.web.iweb import UNKNOWN_LENGTH
|
|
|
|
|
@@ -41,7 +44,7 @@ class SslComponents:
|
|
|
self.myPrivateCertificate = self.makeMyCertificate()
|
|
|
self.trustRoot = self.makeTrustRoot()
|
|
|
|
|
|
- def makeMyCertificate(self):
|
|
|
+ def makeMyCertificate(self) -> Optional[twisted.internet.ssl.PrivateCertificate]:
|
|
|
# TODO Move some of this loading into parse_config
|
|
|
privKeyAndCertFilename = self.sydent.config.http.cert_file
|
|
|
|
|
@@ -66,7 +69,7 @@ class SslComponents:
|
|
|
fp.close()
|
|
|
return twisted.internet.ssl.PrivateCertificate.loadPEM(authData)
|
|
|
|
|
|
- def makeTrustRoot(self):
|
|
|
+ def makeTrustRoot(self) -> IOpenSSLTrustRoot:
|
|
|
# If this option is specified, use a specific root CA cert. This is useful for testing when it's not
|
|
|
# practical to get the client cert signed by a real root CA but should never be used on a production server.
|
|
|
caCertFilename = self.sydent.config.http.ca_cert_file
|
|
@@ -79,7 +82,9 @@ class SslComponents:
|
|
|
logger.warning("Failed to open CA cert file %s", caCertFilename)
|
|
|
raise
|
|
|
logger.warning("Using custom CA cert file: %s", caCertFilename)
|
|
|
- return twisted.internet._sslverify.OpenSSLCertificateAuthorities(
|
|
|
+ # Type ignore: I'm not going to add a stub for the semiprivate
|
|
|
+ # _sslverify module. I've already taken on too much stubbing as it is!
|
|
|
+ return twisted.internet._sslverify.OpenSSLCertificateAuthorities( # type: ignore
|
|
|
[caCert.original]
|
|
|
)
|
|
|
else:
|
|
@@ -93,10 +98,12 @@ class BodyExceededMaxSize(Exception):
|
|
|
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
|
|
"""A protocol which immediately errors upon receiving data."""
|
|
|
|
|
|
- def __init__(self, deferred):
|
|
|
+ transport: ITCPTransport
|
|
|
+
|
|
|
+ def __init__(self, deferred: "defer.Deferred[bytes]") -> None:
|
|
|
self.deferred = deferred
|
|
|
|
|
|
- def _maybe_fail(self):
|
|
|
+ def _maybe_fail(self) -> None:
|
|
|
"""
|
|
|
Report a max size exceed error and disconnect the first time this is called.
|
|
|
"""
|
|
@@ -106,23 +113,27 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
|
|
# discarded anyway.
|
|
|
self.transport.abortConnection()
|
|
|
|
|
|
- def dataReceived(self, data) -> None:
|
|
|
+ def dataReceived(self, data: bytes) -> None:
|
|
|
self._maybe_fail()
|
|
|
|
|
|
- def connectionLost(self, reason) -> None:
|
|
|
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
|
|
|
self._maybe_fail()
|
|
|
|
|
|
|
|
|
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
|
|
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
|
|
|
|
|
- def __init__(self, deferred, max_size):
|
|
|
+ transport: ITCPTransport
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self, deferred: "defer.Deferred[bytes]", max_size: Optional[int]
|
|
|
+ ) -> None:
|
|
|
self.stream = BytesIO()
|
|
|
self.deferred = deferred
|
|
|
self.length = 0
|
|
|
self.max_size = max_size
|
|
|
|
|
|
- def dataReceived(self, data) -> None:
|
|
|
+ def dataReceived(self, data: bytes) -> None:
|
|
|
# If the deferred was called, bail early.
|
|
|
if self.deferred.called:
|
|
|
return
|
|
@@ -139,7 +150,7 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
|
|
if self.transport is not None:
|
|
|
self.transport.abortConnection()
|
|
|
|
|
|
- def connectionLost(self, reason=connectionDone) -> None:
|
|
|
+ def connectionLost(self, reason: Failure = connectionDone) -> None:
|
|
|
# If the maximum size was already exceeded, there's nothing to do.
|
|
|
if self.deferred.called:
|
|
|
return
|
|
@@ -154,7 +165,9 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
|
|
self.deferred.errback(reason)
|
|
|
|
|
|
|
|
|
-def read_body_with_max_size(response, max_size):
|
|
|
+def read_body_with_max_size(
|
|
|
+ response: Response, max_size: Optional[int]
|
|
|
+) -> "defer.Deferred[bytes]":
|
|
|
"""
|
|
|
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
|
|
|
|
|
@@ -168,11 +181,14 @@ def read_body_with_max_size(response, max_size):
|
|
|
Returns:
|
|
|
A Deferred which resolves to the read body.
|
|
|
"""
|
|
|
- d = defer.Deferred()
|
|
|
+ d: "defer.Deferred[bytes]" = defer.Deferred()
|
|
|
|
|
|
# If the Content-Length header gives a size larger than the maximum allowed
|
|
|
# size, do not bother downloading the body.
|
|
|
+ # Type safety: twisted guarantees that response.length is either the
|
|
|
+ # "opaque" object UNKNOWN_LENGTH, or else an int.
|
|
|
if max_size is not None and response.length != UNKNOWN_LENGTH:
|
|
|
+ response.length = cast(int, response.length)
|
|
|
if response.length > max_size:
|
|
|
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
|
|
|
return d
|
|
@@ -182,12 +198,14 @@ def read_body_with_max_size(response, max_size):
|
|
|
|
|
|
|
|
|
class SizeLimitingRequest(server.Request):
|
|
|
- def handleContentChunk(self, data):
|
|
|
+ def handleContentChunk(self, data: bytes) -> None:
|
|
|
if self.content.tell() + len(data) > MAX_REQUEST_SIZE:
|
|
|
logger.info(
|
|
|
"Aborting connection from %s because the request exceeds maximum size",
|
|
|
- self.client.host,
|
|
|
+ # Formerly `self.client.host`, but `host` isn't provided by `IAddress`
|
|
|
+ self.client,
|
|
|
)
|
|
|
+ assert self.transport is not None
|
|
|
self.transport.abortConnection()
|
|
|
return
|
|
|
|