endpoint.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014, 2015 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
  16. from twisted.internet import defer
  17. from twisted.internet.error import ConnectError
  18. from twisted.names import client, dns
  19. from twisted.names.error import DNSNameError
  20. import collections
  21. import logging
  22. import random
  23. logger = logging.getLogger(__name__)
  24. def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
  25. timeout=None):
  26. """Construct an endpoint for the given matrix destination.
  27. Args:
  28. reactor: Twisted reactor.
  29. destination (bytes): The name of the server to connect to.
  30. ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory
  31. which generates SSL contexts to use for TLS.
  32. timeout (int): connection timeout in seconds
  33. """
  34. domain_port = destination.split(":")
  35. domain = domain_port[0]
  36. port = int(domain_port[1]) if domain_port[1:] else None
  37. endpoint_kw_args = {}
  38. if timeout is not None:
  39. endpoint_kw_args.update(timeout=timeout)
  40. if ssl_context_factory is None:
  41. transport_endpoint = TCP4ClientEndpoint
  42. default_port = 8008
  43. else:
  44. transport_endpoint = SSL4ClientEndpoint
  45. endpoint_kw_args.update(sslContextFactory=ssl_context_factory)
  46. default_port = 8448
  47. if port is None:
  48. return SRVClientEndpoint(
  49. reactor, "matrix", domain, protocol="tcp",
  50. default_port=default_port, endpoint=transport_endpoint,
  51. endpoint_kw_args=endpoint_kw_args
  52. )
  53. else:
  54. return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
  55. class SRVClientEndpoint(object):
  56. """An endpoint which looks up SRV records for a service.
  57. Cycles through the list of servers starting with each call to connect
  58. picking the next server.
  59. Implements twisted.internet.interfaces.IStreamClientEndpoint.
  60. """
  61. _Server = collections.namedtuple(
  62. "_Server", "priority weight host port"
  63. )
  64. def __init__(self, reactor, service, domain, protocol="tcp",
  65. default_port=None, endpoint=TCP4ClientEndpoint,
  66. endpoint_kw_args={}):
  67. self.reactor = reactor
  68. self.service_name = "_%s._%s.%s" % (service, protocol, domain)
  69. if default_port is not None:
  70. self.default_server = self._Server(
  71. host=domain,
  72. port=default_port,
  73. priority=0,
  74. weight=0
  75. )
  76. else:
  77. self.default_server = None
  78. self.endpoint = endpoint
  79. self.endpoint_kw_args = endpoint_kw_args
  80. self.servers = None
  81. self.used_servers = None
  82. @defer.inlineCallbacks
  83. def fetch_servers(self):
  84. try:
  85. answers, auth, add = yield client.lookupService(self.service_name)
  86. except DNSNameError:
  87. answers = []
  88. if (len(answers) == 1
  89. and answers[0].type == dns.SRV
  90. and answers[0].payload
  91. and answers[0].payload.target == dns.Name('.')):
  92. raise ConnectError("Service %s unavailable", self.service_name)
  93. self.servers = []
  94. self.used_servers = []
  95. for answer in answers:
  96. if answer.type != dns.SRV or not answer.payload:
  97. continue
  98. payload = answer.payload
  99. self.servers.append(self._Server(
  100. host=str(payload.target),
  101. port=int(payload.port),
  102. priority=int(payload.priority),
  103. weight=int(payload.weight)
  104. ))
  105. self.servers.sort()
  106. def pick_server(self):
  107. if not self.servers:
  108. if self.used_servers:
  109. self.servers = self.used_servers
  110. self.used_servers = []
  111. self.servers.sort()
  112. elif self.default_server:
  113. return self.default_server
  114. else:
  115. raise ConnectError(
  116. "Not server available for %s", self.service_name
  117. )
  118. min_priority = self.servers[0].priority
  119. weight_indexes = list(
  120. (index, server.weight + 1)
  121. for index, server in enumerate(self.servers)
  122. if server.priority == min_priority
  123. )
  124. total_weight = sum(weight for index, weight in weight_indexes)
  125. target_weight = random.randint(0, total_weight)
  126. for index, weight in weight_indexes:
  127. target_weight -= weight
  128. if target_weight <= 0:
  129. server = self.servers[index]
  130. del self.servers[index]
  131. self.used_servers.append(server)
  132. return server
  133. @defer.inlineCallbacks
  134. def connect(self, protocolFactory):
  135. if self.servers is None:
  136. yield self.fetch_servers()
  137. server = self.pick_server()
  138. logger.info("Connecting to %s:%s", server.host, server.port)
  139. endpoint = self.endpoint(
  140. self.reactor, server.host, server.port, **self.endpoint_kw_args
  141. )
  142. connection = yield endpoint.connect(protocolFactory)
  143. defer.returnValue(connection)