فهرست منبع

Mypy for httpserver (#442)

* Make Servlets more like a proper class, and not a dictionary with attribute syntax
David Robertson 2 سال پیش
والد
کامیت
03f23b22a8

+ 1 - 0
changelog.d/442.misc

@@ -0,0 +1 @@
+Make `sydent.http.httpserver` pass `mypy --strict`.

+ 1 - 0
pyproject.toml

@@ -57,6 +57,7 @@ files = [
     "sydent/http/httpclient.py",
     "sydent/http/httpcommon.py",
     "sydent/http/httpsclient.py",
+    "sydent/http/httpserver.py",
     "sydent/http/srvresolver.py",
     "sydent/hs_federation",
     "sydent/replication",

+ 4 - 2
stubs/twisted/internet/ssl.pyi

@@ -4,7 +4,7 @@ import OpenSSL.SSL
 
 # I don't like importing from _sslverify, but IOpenSSLTrustRoot isn't re-exported
 # anywhere else in twisted.
-from twisted.internet._sslverify import IOpenSSLTrustRoot
+from twisted.internet._sslverify import IOpenSSLTrustRoot, KeyPair
 from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
 from zope.interface import implementer
 
@@ -17,7 +17,9 @@ class Certificate:
 
 def platformTrust() -> IOpenSSLTrustRoot: ...
 
-class PrivateCertificate(Certificate): ...
+class PrivateCertificate(Certificate):
+    # PrivateKey is not set until you call _setPrivateKey, e.g. via load()
+    privateKey: KeyPair
 
 class CertificateOptions:
     def __init__(

+ 1 - 0
stubs/twisted/web/http.pyi

@@ -44,6 +44,7 @@ class Request:
     # - we use `self.transport.abortConnection`, which belongs to that interface
     # - twisted does too! in its implementation of HTTPChannel.forceAbortClient
     transport: Optional[ITCPTransport]
+    def __init__(self, channel: HTTPChannel): ...
     def getHeader(self, key: AnyStr) -> Optional[AnyStr]: ...
     def handleContentChunk(self, data: bytes) -> None: ...
 

+ 13 - 0
stubs/twisted/web/resource.pyi

@@ -0,0 +1,13 @@
+from typing import ClassVar
+
+from zope.interface import Interface, implementer
+
+class IResource(Interface):
+    isLeaf: ClassVar[bool]
+    def __init__() -> None: ...
+    def putChild(path: bytes, child: IResource) -> None: ...
+
+@implementer(IResource)
+class Resource:
+    isLeaf: ClassVar[bool]
+    def putChild(self, path: bytes, child: IResource) -> None: ...

+ 22 - 0
stubs/twisted/web/server.pyi

@@ -0,0 +1,22 @@
+from typing import Callable, Optional, Type, Union
+
+from twisted.web import http
+from twisted.web.resource import IResource
+
+class Request(http.Request): ...
+
+# A requestFactory is allowed to be "[a] factory which is called with (channel)
+# and creates L{Request} instances.".
+RequestFactory = Callable[[http.HTTPChannel], Request]
+
+# should really inherit from http.HTTPFactory
+class Site:
+    displayTracebacks: bool
+    def __init__(
+        self,
+        resource: IResource,
+        requestFactory: Optional[RequestFactory] = None,
+        # Args and kwargs get passed to http.HTTPFactory. But we don't use them.
+        *args: object,
+        **kwargs: object,
+    ): ...

+ 4 - 5
sydent/http/httpserver.py

@@ -134,11 +134,10 @@ class ClientApiHttpServer:
         v2.putChild(b"lookup", self.sydent.servlets.lookup_v2)
         v2.putChild(b"hash_details", self.sydent.servlets.hash_details)
 
-        self.factory = Site(root)
-        self.factory.requestFactory = SizeLimitingRequest
+        self.factory = Site(root, SizeLimitingRequest)
         self.factory.displayTracebacks = False
 
-    def setup(self):
+    def setup(self) -> None:
         httpPort = self.sydent.config.http.client_port
         interface = self.sydent.config.http.client_bind_address
 
@@ -154,7 +153,7 @@ class InternalApiHttpServer:
     def __init__(self, sydent: "Sydent") -> None:
         self.sydent = sydent
 
-    def setup(self, interface, port):
+    def setup(self, interface: str, port: int) -> None:
         logger.info("Starting Internal API HTTP server on %s:%d", interface, port)
         root = Resource()
 
@@ -199,7 +198,7 @@ class ReplicationHttpsServer:
         self.factory = Site(root)
         self.factory.displayTracebacks = False
 
-    def setup(self):
+    def setup(self) -> None:
         httpPort = self.sydent.config.http.replication_port
         interface = self.sydent.config.http.replication_bind_address
 

+ 35 - 45
sydent/sydent.py

@@ -131,50 +131,7 @@ class Sydent:
 
         self.sig_verifier = Verifier(self)
 
-        self.servlets = Servlets()
-        self.servlets.v1 = V1Servlet(self)
-        self.servlets.v2 = V2Servlet(self)
-        self.servlets.emailRequestCode = EmailRequestCodeServlet(self)
-        self.servlets.emailRequestCodeV2 = EmailRequestCodeServlet(
-            self, require_auth=True
-        )
-        self.servlets.emailValidate = EmailValidateCodeServlet(self)
-        self.servlets.emailValidateV2 = EmailValidateCodeServlet(
-            self, require_auth=True
-        )
-        self.servlets.msisdnRequestCode = MsisdnRequestCodeServlet(self)
-        self.servlets.msisdnRequestCodeV2 = MsisdnRequestCodeServlet(
-            self, require_auth=True
-        )
-        self.servlets.msisdnValidate = MsisdnValidateCodeServlet(self)
-        self.servlets.msisdnValidateV2 = MsisdnValidateCodeServlet(
-            self, require_auth=True
-        )
-        self.servlets.lookup = LookupServlet(self)
-        self.servlets.bulk_lookup = BulkLookupServlet(self)
-        self.servlets.hash_details = HashDetailsServlet(self, lookup_pepper)
-        self.servlets.lookup_v2 = LookupV2Servlet(self, lookup_pepper)
-        self.servlets.pubkey_ed25519 = Ed25519Servlet(self)
-        self.servlets.pubkeyIsValid = PubkeyIsValidServlet(self)
-        self.servlets.ephemeralPubkeyIsValid = EphemeralPubkeyIsValidServlet(self)
-        self.servlets.threepidBind = ThreePidBindServlet(self)
-        self.servlets.threepidBindV2 = ThreePidBindServlet(self, require_auth=True)
-        self.servlets.threepidUnbind = ThreePidUnbindServlet(self)
-        self.servlets.replicationPush = ReplicationPushServlet(self)
-        self.servlets.getValidated3pid = GetValidated3pidServlet(self)
-        self.servlets.getValidated3pidV2 = GetValidated3pidServlet(
-            self, require_auth=True
-        )
-        self.servlets.storeInviteServlet = StoreInviteServlet(self)
-        self.servlets.storeInviteServletV2 = StoreInviteServlet(self, require_auth=True)
-        self.servlets.blindlySignStuffServlet = BlindlySignStuffServlet(self)
-        self.servlets.blindlySignStuffServletV2 = BlindlySignStuffServlet(
-            self, require_auth=True
-        )
-        self.servlets.termsServlet = TermsServlet(self)
-        self.servlets.accountServlet = AccountServlet(self)
-        self.servlets.registerServlet = RegisterServlet(self)
-        self.servlets.logoutServlet = LogoutServlet(self)
+        self.servlets: Servlets = Servlets(self, lookup_pepper)
 
         self.threepidBinder = ThreepidBinder(self)
 
@@ -292,7 +249,40 @@ class Validators:
 
 
 class Servlets:
-    pass
+    def __init__(self, sydent: Sydent, lookup_pepper: str):
+        self.v1 = V1Servlet(sydent)
+        self.v2 = V2Servlet(sydent)
+        self.emailRequestCode = EmailRequestCodeServlet(sydent)
+        self.emailRequestCodeV2 = EmailRequestCodeServlet(sydent, require_auth=True)
+        self.emailValidate = EmailValidateCodeServlet(sydent)
+        self.emailValidateV2 = EmailValidateCodeServlet(sydent, require_auth=True)
+        self.msisdnRequestCode = MsisdnRequestCodeServlet(sydent)
+        self.msisdnRequestCodeV2 = MsisdnRequestCodeServlet(sydent, require_auth=True)
+        self.msisdnValidate = MsisdnValidateCodeServlet(sydent)
+        self.msisdnValidateV2 = MsisdnValidateCodeServlet(sydent, require_auth=True)
+        self.lookup = LookupServlet(sydent)
+        self.bulk_lookup = BulkLookupServlet(sydent)
+        self.hash_details = HashDetailsServlet(sydent, lookup_pepper)
+        self.lookup_v2 = LookupV2Servlet(sydent, lookup_pepper)
+        self.pubkey_ed25519 = Ed25519Servlet(sydent)
+        self.pubkeyIsValid = PubkeyIsValidServlet(sydent)
+        self.ephemeralPubkeyIsValid = EphemeralPubkeyIsValidServlet(sydent)
+        self.threepidBind = ThreePidBindServlet(sydent)
+        self.threepidBindV2 = ThreePidBindServlet(sydent, require_auth=True)
+        self.threepidUnbind = ThreePidUnbindServlet(sydent)
+        self.replicationPush = ReplicationPushServlet(sydent)
+        self.getValidated3pid = GetValidated3pidServlet(sydent)
+        self.getValidated3pidV2 = GetValidated3pidServlet(sydent, require_auth=True)
+        self.storeInviteServlet = StoreInviteServlet(sydent)
+        self.storeInviteServletV2 = StoreInviteServlet(sydent, require_auth=True)
+        self.blindlySignStuffServlet = BlindlySignStuffServlet(sydent)
+        self.blindlySignStuffServletV2 = BlindlySignStuffServlet(
+            sydent, require_auth=True
+        )
+        self.termsServlet = TermsServlet(sydent)
+        self.accountServlet = AccountServlet(sydent)
+        self.registerServlet = RegisterServlet(sydent)
+        self.logoutServlet = LogoutServlet(sydent)
 
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)