Browse Source

Make `sydent.http.srvresolver` pass `mypy --strict` (#435)

This one was a bit uglier. We're back in twisted territory.
David Robertson 2 years ago
parent
commit
8331f94e1e

+ 1 - 0
changelog.d/435.misc

@@ -0,0 +1 @@
+Make `sydent.http.srvresolver` pass `mypy --strict`.

+ 1 - 0
pyproject.toml

@@ -53,6 +53,7 @@ files = [
     "sydent/db",
     "sydent/http/blacklisting_reactor.py",
     "sydent/http/federation_tls_options.py",
+    "sydent/http/srvresolver.py",
     "sydent/hs_federation",
     "sydent/replication",
     "sydent/sms",

+ 5 - 0
stubs/twisted/internet/error.pyi

@@ -0,0 +1,5 @@
+from typing import Optional, Any
+
+
+class ConnectError(Exception):
+    def __init__(self, osError: Optional[Any] = None, string: str = ""): ...

+ 0 - 0
stubs/twisted/names/__init__.py


+ 28 - 0
stubs/twisted/names/dns.pyi

@@ -0,0 +1,28 @@
+from typing import ClassVar, Generic, TypeVar, Optional
+
+
+class Name:
+    name: bytes
+
+    def __init__(self, name: bytes = b""): ...
+
+SRV: int
+
+
+class Record_SRV:
+    priority: int
+    weight: int
+    port: int
+    target: Name
+    ttl: int
+
+
+Payload = TypeVar("Payload")  # should be bound to IEncodableRecord
+class RRHeader(Generic[Payload]):
+    fmt: ClassVar[str]
+    name: Name
+    type: int
+    cls: int
+    ttl: int
+    payload: Optional[Payload]
+    auth: bool

+ 34 - 16
sydent/http/srvresolver.py

@@ -16,12 +16,12 @@
 import logging
 import random
 import time
-from typing import Callable, Dict, List, SupportsInt, Tuple
+from typing import Awaitable, Callable, Dict, List, SupportsInt, Tuple
 
 import attr
 from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IResolver
 from twisted.names import client, dns
+from twisted.names.dns import Record_SRV, RRHeader
 from twisted.names.error import DNSNameError, DomainError
 
 logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 SERVER_CACHE: Dict[bytes, List["Server"]] = {}
 
 
-@attr.s
+@attr.s(frozen=True, slots=True, auto_attribs=True)
 class Server:
     """
     Our record of an individual server which can be tried to reach a destination.
@@ -42,11 +42,11 @@ class Server:
             the epoch
     """
 
-    host = attr.ib()
-    port = attr.ib()
-    priority = attr.ib(default=0)
-    weight = attr.ib(default=0)
-    expires = attr.ib(default=0)
+    host: bytes
+    port: int
+    priority: int = 0
+    weight: int = 0
+    expires: int = 0
 
 
 def pick_server_from_list(server_list: List[Server]) -> Tuple[bytes, int]:
@@ -79,6 +79,28 @@ def pick_server_from_list(server_list: List[Server]) -> Tuple[bytes, int]:
     )
 
 
+# The signature of twisted.names.client.lookupService, if you omit the timeout
+# argument. This is unannotated, but we can deduce the signature as follows:
+# 1. Return type is the return type of
+#    twisted.internet.interfaces.IResolver.lookupService. Its type annotation
+#    is incorrect; its docstring says that tuple entries are a _list_ of RRHeaders,
+#    but the annotation says the entries are individual RRHeaders.
+# 2. Because we're looking up SRV records, we know that the payload of the RRHeaders
+#    will be Record_SRVs. I made RRHeader's stub generic over the type of its
+#    payload to reflect this. But that's a lie compared to Twisted's actual
+#    RRHeader Type, so we need to enclose these in strings.
+LookupService = Callable[
+    [str],
+    Awaitable[
+        Tuple[
+            List["RRHeader[Record_SRV]"],
+            List["RRHeader[object]"],
+            List["RRHeader[object]"],
+        ]
+    ],
+]
+
+
 class SrvResolver:
     """Interface to the dns client to do SRV lookups, with result caching.
     The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
@@ -93,11 +115,11 @@ class SrvResolver:
 
     def __init__(
         self,
-        dns_client: "IResolver" = client,
+        lookup_service: LookupService = client.lookupService,
         cache: Dict[bytes, List[Server]] = SERVER_CACHE,
         get_time: Callable[[], SupportsInt] = time.time,
     ) -> None:
-        self._dns_client = dns_client
+        self._lookup_service = lookup_service
         self._cache = cache
         self._get_time = get_time
 
@@ -106,14 +128,10 @@ class SrvResolver:
 
         :param service_name: The record to look up.
 
-
         :returns a list of the SRV records, or an empty list if none found.
         """
         now = int(self._get_time())
 
-        if not isinstance(service_name, bytes):
-            raise TypeError("%r is not a byte string" % (service_name,))
-
         cache_entry = self._cache.get(service_name, None)
         if cache_entry:
             if all(s.expires > now for s in cache_entry):
@@ -121,7 +139,7 @@ class SrvResolver:
                 return servers
 
         try:
-            answers, _, _ = await self._dns_client.lookupService(service_name)
+            answers, _, _ = await self._lookup_service(service_name.decode())
         except DNSNameError:
             # TODO: cache this. We can get the SOA out of the exception, and use
             # the negative-TTL value.
@@ -144,7 +162,7 @@ class SrvResolver:
             and answers[0].payload
             and answers[0].payload.target == dns.Name(b".")
         ):
-            raise ConnectError("Service %s unavailable" % service_name)
+            raise ConnectError("Service %s unavailable" % service_name.decode())
 
         servers = []