|
@@ -12,6 +12,7 @@
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
+import socket
|
|
|
|
|
|
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
|
|
from twisted.internet import defer, reactor
|
|
@@ -30,7 +31,10 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
SERVER_CACHE = {}
|
|
|
|
|
|
-
|
|
|
+# our record of an individual server which can be tried to reach a destination.
|
|
|
+#
|
|
|
+# "host" is actually a dotted-quad or ipv6 address string. Except when there's
|
|
|
+# no SRV record, in which case it is the original hostname.
|
|
|
_Server = collections.namedtuple(
|
|
|
"_Server", "priority weight host port expires"
|
|
|
)
|
|
@@ -219,9 +223,10 @@ class SRVClientEndpoint(object):
|
|
|
return self.default_server
|
|
|
else:
|
|
|
raise ConnectError(
|
|
|
- "Not server available for %s" % self.service_name
|
|
|
+ "No server available for %s" % self.service_name
|
|
|
)
|
|
|
|
|
|
+ # look for all servers with the same priority
|
|
|
min_priority = self.servers[0].priority
|
|
|
weight_indexes = list(
|
|
|
(index, server.weight + 1)
|
|
@@ -231,11 +236,22 @@ class SRVClientEndpoint(object):
|
|
|
|
|
|
total_weight = sum(weight for index, weight in weight_indexes)
|
|
|
target_weight = random.randint(0, total_weight)
|
|
|
-
|
|
|
for index, weight in weight_indexes:
|
|
|
target_weight -= weight
|
|
|
if target_weight <= 0:
|
|
|
server = self.servers[index]
|
|
|
+ # XXX: this looks totally dubious:
|
|
|
+ #
|
|
|
+ # (a) we never reuse a server until we have been through
|
|
|
+ # all of the servers at the same priority, so if the
|
|
|
+ # weights are A: 100, B:1, we always do ABABAB instead of
|
|
|
+ # AAAA...AAAB (approximately).
|
|
|
+ #
|
|
|
+ # (b) After using all the servers at the lowest priority,
|
|
|
+ # we move onto the next priority. We should only use the
|
|
|
+ # second priority if servers at the top priority are
|
|
|
+ # unreachable.
|
|
|
+ #
|
|
|
del self.servers[index]
|
|
|
self.used_servers.append(server)
|
|
|
return server
|
|
@@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
|
|
|
continue
|
|
|
|
|
|
payload = answer.payload
|
|
|
- host = str(payload.target)
|
|
|
- srv_ttl = answer.ttl
|
|
|
|
|
|
- try:
|
|
|
- answers, _, _ = yield dns_client.lookupAddress(host)
|
|
|
- except DNSNameError:
|
|
|
- continue
|
|
|
+ hosts = yield _get_hosts_for_srv_record(
|
|
|
+ dns_client, str(payload.target)
|
|
|
+ )
|
|
|
|
|
|
- for answer in answers:
|
|
|
- if answer.type == dns.A and answer.payload:
|
|
|
- ip = answer.payload.dottedQuad()
|
|
|
- host_ttl = min(srv_ttl, answer.ttl)
|
|
|
+ for (ip, ttl) in hosts:
|
|
|
+ host_ttl = min(answer.ttl, ttl)
|
|
|
|
|
|
- servers.append(_Server(
|
|
|
- host=ip,
|
|
|
- port=int(payload.port),
|
|
|
- priority=int(payload.priority),
|
|
|
- weight=int(payload.weight),
|
|
|
- expires=int(clock.time()) + host_ttl,
|
|
|
- ))
|
|
|
+ servers.append(_Server(
|
|
|
+ host=ip,
|
|
|
+ port=int(payload.port),
|
|
|
+ priority=int(payload.priority),
|
|
|
+ weight=int(payload.weight),
|
|
|
+ expires=int(clock.time()) + host_ttl,
|
|
|
+ ))
|
|
|
|
|
|
servers.sort()
|
|
|
cache[service_name] = list(servers)
|
|
@@ -317,3 +328,68 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
|
|
|
raise e
|
|
|
|
|
|
defer.returnValue(servers)
|
|
|
+
|
|
|
+
|
|
|
+@defer.inlineCallbacks
|
|
|
+def _get_hosts_for_srv_record(dns_client, host):
|
|
|
+ """Look up each of the hosts in a SRV record
|
|
|
+
|
|
|
+ Args:
|
|
|
+ dns_client (twisted.names.dns.IResolver):
|
|
|
+ host (basestring): host to look up
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Deferred[list[(str, int)]]: a list of (host, ttl) pairs
|
|
|
+
|
|
|
+ """
|
|
|
+ ip4_servers = []
|
|
|
+ ip6_servers = []
|
|
|
+
|
|
|
+ def cb(res):
|
|
|
+ # lookupAddress and lookupIP6Address return a three-tuple
|
|
|
+ # giving the answer, authority, and additional sections of the
|
|
|
+ # response.
|
|
|
+ #
|
|
|
+ # we only care about the answers.
|
|
|
+
|
|
|
+ return res[0]
|
|
|
+
|
|
|
+ def eb(res):
|
|
|
+ res.trap(DNSNameError)
|
|
|
+ return []
|
|
|
+
|
|
|
+ # no logcontexts here, so we can safely fire these off and gatherResults
|
|
|
+ d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
|
|
|
+ d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
|
|
|
+ results = yield defer.gatherResults([d1, d2], consumeErrors=True)
|
|
|
+
|
|
|
+ for result in results:
|
|
|
+ for answer in result:
|
|
|
+ if not answer.payload:
|
|
|
+ continue
|
|
|
+
|
|
|
+ try:
|
|
|
+ if answer.type == dns.A:
|
|
|
+ ip = answer.payload.dottedQuad()
|
|
|
+ ip4_servers.append((ip, answer.ttl))
|
|
|
+ elif answer.type == dns.AAAA:
|
|
|
+ ip = socket.inet_ntop(
|
|
|
+ socket.AF_INET6, answer.payload.address,
|
|
|
+ )
|
|
|
+ ip6_servers.append((ip, answer.ttl))
|
|
|
+ else:
|
|
|
+ # the most likely candidate here is a CNAME record.
|
|
|
+ # rfc2782 says srvs may not point to aliases.
|
|
|
+ logger.warn(
|
|
|
+ "Ignoring unexpected DNS record type %s for %s",
|
|
|
+ answer.type, host,
|
|
|
+ )
|
|
|
+ continue
|
|
|
+ except Exception as e:
|
|
|
+ logger.warn("Ignoring invalid DNS response for %s: %s",
|
|
|
+ host, e)
|
|
|
+ continue
|
|
|
+
|
|
|
+ # keep the ipv4 results before the ipv6 results, mostly to match historical
|
|
|
+ # behaviour.
|
|
|
+ defer.returnValue(ip4_servers + ip6_servers)
|