|
@@ -16,12 +16,12 @@
|
|
|
import logging
|
|
|
import random
|
|
|
import time
|
|
|
-from typing import Callable, Dict, List, SupportsInt, Tuple
|
|
|
+from typing import Awaitable, Callable, Dict, List, SupportsInt, Tuple
|
|
|
|
|
|
import attr
|
|
|
from twisted.internet.error import ConnectError
|
|
|
-from twisted.internet.interfaces import IResolver
|
|
|
from twisted.names import client, dns
|
|
|
+from twisted.names.dns import Record_SRV, RRHeader
|
|
|
from twisted.names.error import DNSNameError, DomainError
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
|
|
|
SERVER_CACHE: Dict[bytes, List["Server"]] = {}
|
|
|
|
|
|
|
|
|
-@attr.s
|
|
|
+@attr.s(frozen=True, slots=True, auto_attribs=True)
|
|
|
class Server:
|
|
|
"""
|
|
|
Our record of an individual server which can be tried to reach a destination.
|
|
@@ -42,11 +42,11 @@ class Server:
|
|
|
the epoch
|
|
|
"""
|
|
|
|
|
|
- host = attr.ib()
|
|
|
- port = attr.ib()
|
|
|
- priority = attr.ib(default=0)
|
|
|
- weight = attr.ib(default=0)
|
|
|
- expires = attr.ib(default=0)
|
|
|
+ host: bytes
|
|
|
+ port: int
|
|
|
+ priority: int = 0
|
|
|
+ weight: int = 0
|
|
|
+ expires: int = 0
|
|
|
|
|
|
|
|
|
def pick_server_from_list(server_list: List[Server]) -> Tuple[bytes, int]:
|
|
@@ -79,6 +79,28 @@ def pick_server_from_list(server_list: List[Server]) -> Tuple[bytes, int]:
|
|
|
)
|
|
|
|
|
|
|
|
|
+# The signature of twisted.names.client.lookupService, if you omit the timeout
|
|
|
+# argument. This is unannotated, but we can deduce the signature as follows:
|
|
|
+# 1. Return type is the return type of
|
|
|
+# twisted.internet.interfaces.IResolver.lookupService. Its type annotation
|
|
|
+# is incorrect; its docstring says that tuple entries are a _list_ of RRHeaders,
|
|
|
+# but the annotation says the entries are individual RRHeaders.
|
|
|
+# 2. Because we're looking up SRV records, we know that the payload of the RRHeaders
|
|
|
+# will be Record_SRVs. I made RRHeader's stub generic over the type of its
|
|
|
+# payload to reflect this. But that's a lie compared to Twisted's actual
|
|
|
+# RRHeader Type, so we need to enclose these in strings.
|
|
|
+LookupService = Callable[
|
|
|
+ [str],
|
|
|
+ Awaitable[
|
|
|
+ Tuple[
|
|
|
+ List["RRHeader[Record_SRV]"],
|
|
|
+ List["RRHeader[object]"],
|
|
|
+ List["RRHeader[object]"],
|
|
|
+ ]
|
|
|
+ ],
|
|
|
+]
|
|
|
+
|
|
|
+
|
|
|
class SrvResolver:
|
|
|
"""Interface to the dns client to do SRV lookups, with result caching.
|
|
|
The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
|
|
@@ -93,11 +115,11 @@ class SrvResolver:
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- dns_client: "IResolver" = client,
|
|
|
+ lookup_service: LookupService = client.lookupService,
|
|
|
cache: Dict[bytes, List[Server]] = SERVER_CACHE,
|
|
|
get_time: Callable[[], SupportsInt] = time.time,
|
|
|
) -> None:
|
|
|
- self._dns_client = dns_client
|
|
|
+ self._lookup_service = lookup_service
|
|
|
self._cache = cache
|
|
|
self._get_time = get_time
|
|
|
|
|
@@ -106,14 +128,10 @@ class SrvResolver:
|
|
|
|
|
|
:param service_name: The record to look up.
|
|
|
|
|
|
-
|
|
|
:returns a list of the SRV records, or an empty list if none found.
|
|
|
"""
|
|
|
now = int(self._get_time())
|
|
|
|
|
|
- if not isinstance(service_name, bytes):
|
|
|
- raise TypeError("%r is not a byte string" % (service_name,))
|
|
|
-
|
|
|
cache_entry = self._cache.get(service_name, None)
|
|
|
if cache_entry:
|
|
|
if all(s.expires > now for s in cache_entry):
|
|
@@ -121,7 +139,7 @@ class SrvResolver:
|
|
|
return servers
|
|
|
|
|
|
try:
|
|
|
- answers, _, _ = await self._dns_client.lookupService(service_name)
|
|
|
+ answers, _, _ = await self._lookup_service(service_name.decode())
|
|
|
except DNSNameError:
|
|
|
# TODO: cache this. We can get the SOA out of the exception, and use
|
|
|
# the negative-TTL value.
|
|
@@ -144,7 +162,7 @@ class SrvResolver:
|
|
|
and answers[0].payload
|
|
|
and answers[0].payload.target == dns.Name(b".")
|
|
|
):
|
|
|
- raise ConnectError("Service %s unavailable" % service_name)
|
|
|
+ raise ConnectError("Service %s unavailable" % service_name.decode())
|
|
|
|
|
|
servers = []
|
|
|
|