Просмотр исходного кода

Make `sydent.http.federation_tls_options` pass `mypy --strict` (#434)

And blacklisting_reactor already passes, so chuck that in the config for
free!
David Robertson 2 лет назад
Родитель
Сommit
4cabc4b8c3

+ 1 - 0
changelog.d/434.misc

@@ -0,0 +1 @@
+Add type annotations to `mypy.http.federation_tls_options`.

+ 2 - 0
pyproject.toml

@@ -51,6 +51,8 @@ files = [
     #     find sydent tests -type d -not -name __pycache__ -exec bash -c "mypy --strict '{}' > /dev/null"  \; -print
     "sydent/config",
     "sydent/db",
+    "sydent/http/blacklisting_reactor.py",
+    "sydent/http/federation_tls_options.py",
     "sydent/hs_federation",
     "sydent/replication",
     "sydent/sms",

+ 0 - 0
stubs/twisted/internet/__init__.pyi


+ 16 - 0
stubs/twisted/internet/ssl.py

@@ -0,0 +1,16 @@
+from typing import Optional, Any
+
+import OpenSSL.SSL
+from twisted.internet._sslverify import IOpenSSLTrustRoot
+
+
+def platformTrust() -> IOpenSSLTrustRoot:
+    ...
+
+
+class CertificateOptions:
+    def __init__(self, trustRoot: Optional[IOpenSSLTrustRoot] = None, **kwargs: Any):
+        ...
+
+    def _makeContext(self) -> OpenSSL.SSL.Context:
+        ...

+ 14 - 0
stubs/twisted/python/failure.pyi

@@ -0,0 +1,14 @@
+from types import TracebackType
+from typing import Type, Optional
+
+
+class Failure(BaseException):
+
+    def __init__(
+        self,
+        exc_value: Optional[BaseException] = None,
+        exc_type: Optional[Type[BaseException]] = None,
+        exc_tb: Optional[TracebackType] = None,
+        captureVars: bool = False,
+    ):
+        ...

+ 20 - 8
sydent/http/federation_tls_options.py

@@ -11,15 +11,23 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+from typing import Callable
 
 from OpenSSL import SSL
 from twisted.internet import ssl
 from twisted.internet.abstract import isIPAddress, isIPv6Address
 from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
+from twisted.protocols.tls import TLSMemoryBIOProtocol
+from twisted.python.failure import Failure
 from zope.interface import implementer
 
+logger = logging.getLogger(__name__)
 
-def _tolerateErrors(wrapped):
+F = Callable[[SSL.Connection, int, int], None]
+
+
+def _tolerateErrors(wrapped: F) -> F:
     """
     Wrap up an info_callback for pyOpenSSL so that if something goes wrong
     the error is immediately logged and the connection is dropped if possible.
@@ -27,10 +35,10 @@ def _tolerateErrors(wrapped):
     documentation, see the twisted documentation.
     """
 
-    def infoCallback(connection, where, ret):
+    def infoCallback(connection: SSL.Connection, where: int, ret: int) -> None:
         try:
             return wrapped(connection, where, ret)
-        except:  # noqa: E722, taken from the twisted implementation
+        except BaseException:
             f = Failure()
             logger.exception("Error during info_callback")
             connection.get_app_data().failVerification(f)
@@ -38,7 +46,7 @@ def _tolerateErrors(wrapped):
     return infoCallback
 
 
-def _idnaBytes(text):
+def _idnaBytes(text: str) -> bytes:
     """
     Convert some text typed by a human into some ASCII bytes. This is a
     copy of twisted.internet._idna._idnaBytes. For documentation, see the
@@ -60,7 +68,7 @@ class ClientTLSOptions:
     verification left out. For documentation, see the twisted documentation.
     """
 
-    def __init__(self, hostname, ctx):
+    def __init__(self, hostname: str, ctx: SSL.Context):
         self._ctx = ctx
 
         if isIPAddress(hostname) or isIPv6Address(hostname):
@@ -72,13 +80,17 @@ class ClientTLSOptions:
 
         ctx.set_info_callback(_tolerateErrors(self._identityVerifyingInfoCallback))
 
-    def clientConnectionForTLS(self, tlsProtocol):
+    def clientConnectionForTLS(
+        self, tlsProtocol: TLSMemoryBIOProtocol
+    ) -> SSL.Connection:
         context = self._ctx
         connection = SSL.Connection(context, None)
         connection.set_app_data(tlsProtocol)
         return connection
 
-    def _identityVerifyingInfoCallback(self, connection, where, ret):
+    def _identityVerifyingInfoCallback(
+        self, connection: SSL.Connection, where: int, ret: int
+    ) -> None:
         # Literal IPv4 and IPv6 addresses are not permitted
         # as host names according to the RFCs
         if where & SSL.SSL_CB_HANDSHAKE_START and self._sendSNI:
@@ -95,6 +107,6 @@ class ClientTLSOptionsFactory:
         else:
             self._options = ssl.CertificateOptions()
 
-    def get_options(self, host):
+    def get_options(self, host: str) -> ClientTLSOptions:
         # Use _makeContext so that we get a fresh OpenSSL CTX each time.
         return ClientTLSOptions(host, self._options._makeContext())