endpoint.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
  16. from twisted.internet import defer, reactor
  17. from twisted.internet.error import ConnectError
  18. from twisted.names import client, dns
  19. from twisted.names.error import DNSNameError, DomainError
  20. import collections
  21. import logging
  22. import random
  23. import time
  24. logger = logging.getLogger(__name__)
  25. SERVER_CACHE = {}
  26. _Server = collections.namedtuple(
  27. "_Server", "priority weight host port expires"
  28. )
  29. def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
  30. timeout=None):
  31. """Construct an endpoint for the given matrix destination.
  32. Args:
  33. reactor: Twisted reactor.
  34. destination (bytes): The name of the server to connect to.
  35. ssl_context_factory (twisted.internet.ssl.ContextFactory): Factory
  36. which generates SSL contexts to use for TLS.
  37. timeout (int): connection timeout in seconds
  38. """
  39. domain_port = destination.split(":")
  40. domain = domain_port[0]
  41. port = int(domain_port[1]) if domain_port[1:] else None
  42. endpoint_kw_args = {}
  43. if timeout is not None:
  44. endpoint_kw_args.update(timeout=timeout)
  45. if ssl_context_factory is None:
  46. transport_endpoint = TCP4ClientEndpoint
  47. default_port = 8008
  48. else:
  49. transport_endpoint = SSL4ClientEndpoint
  50. endpoint_kw_args.update(sslContextFactory=ssl_context_factory)
  51. default_port = 8448
  52. if port is None:
  53. return _WrappingEndpointFac(SRVClientEndpoint(
  54. reactor, "matrix", domain, protocol="tcp",
  55. default_port=default_port, endpoint=transport_endpoint,
  56. endpoint_kw_args=endpoint_kw_args
  57. ))
  58. else:
  59. return _WrappingEndpointFac(transport_endpoint(
  60. reactor, domain, port, **endpoint_kw_args
  61. ))
  62. class _WrappingEndpointFac(object):
  63. def __init__(self, endpoint_fac):
  64. self.endpoint_fac = endpoint_fac
  65. @defer.inlineCallbacks
  66. def connect(self, protocolFactory):
  67. conn = yield self.endpoint_fac.connect(protocolFactory)
  68. conn = _WrappedConnection(conn)
  69. defer.returnValue(conn)
  70. class _WrappedConnection(object):
  71. """Wraps a connection and calls abort on it if it hasn't seen any actio
  72. for 5 minutes
  73. """
  74. __slots__ = ["conn", "last_request"]
  75. def __init__(self, conn):
  76. object.__setattr__(self, "conn", conn)
  77. object.__setattr__(self, "last_request", time.time())
  78. def __getattr__(self, name):
  79. return getattr(self.conn, name)
  80. def __setattr__(self, name, value):
  81. setattr(self.conn, name, value)
  82. def _time_things_out_maybe(self):
  83. # We use a slightly shorter timeout here just in case the callLater is
  84. # triggered early. Paranoia ftw.
  85. if time.time() - self.last_request >= 2.5 * 60:
  86. self.abort()
  87. def request(self, request):
  88. self.last_request = time.time()
  89. # Time this connection out if we haven't send a request in the last
  90. # N minutes
  91. reactor.callLater(3 * 60, self._time_things_out_maybe)
  92. d = self.conn.request(request)
  93. def update_request_time(res):
  94. self.last_request = time.time()
  95. reactor.callLater(3 * 60, self._time_things_out_maybe)
  96. return res
  97. d.addCallback(update_request_time)
  98. return d
  99. class SpiderEndpoint(object):
  100. """An endpoint which refuses to connect to blacklisted IP addresses
  101. Implements twisted.internet.interfaces.IStreamClientEndpoint.
  102. """
  103. def __init__(self, reactor, host, port, blacklist, whitelist,
  104. endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
  105. self.reactor = reactor
  106. self.host = host
  107. self.port = port
  108. self.blacklist = blacklist
  109. self.whitelist = whitelist
  110. self.endpoint = endpoint
  111. self.endpoint_kw_args = endpoint_kw_args
  112. @defer.inlineCallbacks
  113. def connect(self, protocolFactory):
  114. address = yield self.reactor.resolve(self.host)
  115. from netaddr import IPAddress
  116. ip_address = IPAddress(address)
  117. if ip_address in self.blacklist:
  118. if self.whitelist is None or ip_address not in self.whitelist:
  119. raise ConnectError(
  120. "Refusing to spider blacklisted IP address %s" % address
  121. )
  122. logger.info("Connecting to %s:%s", address, self.port)
  123. endpoint = self.endpoint(
  124. self.reactor, address, self.port, **self.endpoint_kw_args
  125. )
  126. connection = yield endpoint.connect(protocolFactory)
  127. defer.returnValue(connection)
  128. class SRVClientEndpoint(object):
  129. """An endpoint which looks up SRV records for a service.
  130. Cycles through the list of servers starting with each call to connect
  131. picking the next server.
  132. Implements twisted.internet.interfaces.IStreamClientEndpoint.
  133. """
  134. def __init__(self, reactor, service, domain, protocol="tcp",
  135. default_port=None, endpoint=TCP4ClientEndpoint,
  136. endpoint_kw_args={}):
  137. self.reactor = reactor
  138. self.service_name = "_%s._%s.%s" % (service, protocol, domain)
  139. if default_port is not None:
  140. self.default_server = _Server(
  141. host=domain,
  142. port=default_port,
  143. priority=0,
  144. weight=0,
  145. expires=0,
  146. )
  147. else:
  148. self.default_server = None
  149. self.endpoint = endpoint
  150. self.endpoint_kw_args = endpoint_kw_args
  151. self.servers = None
  152. self.used_servers = None
  153. @defer.inlineCallbacks
  154. def fetch_servers(self):
  155. self.used_servers = []
  156. self.servers = yield resolve_service(self.service_name)
  157. def pick_server(self):
  158. if not self.servers:
  159. if self.used_servers:
  160. self.servers = self.used_servers
  161. self.used_servers = []
  162. self.servers.sort()
  163. elif self.default_server:
  164. return self.default_server
  165. else:
  166. raise ConnectError(
  167. "Not server available for %s" % self.service_name
  168. )
  169. min_priority = self.servers[0].priority
  170. weight_indexes = list(
  171. (index, server.weight + 1)
  172. for index, server in enumerate(self.servers)
  173. if server.priority == min_priority
  174. )
  175. total_weight = sum(weight for index, weight in weight_indexes)
  176. target_weight = random.randint(0, total_weight)
  177. for index, weight in weight_indexes:
  178. target_weight -= weight
  179. if target_weight <= 0:
  180. server = self.servers[index]
  181. del self.servers[index]
  182. self.used_servers.append(server)
  183. return server
  184. @defer.inlineCallbacks
  185. def connect(self, protocolFactory):
  186. if self.servers is None:
  187. yield self.fetch_servers()
  188. server = self.pick_server()
  189. logger.info("Connecting to %s:%s", server.host, server.port)
  190. endpoint = self.endpoint(
  191. self.reactor, server.host, server.port, **self.endpoint_kw_args
  192. )
  193. connection = yield endpoint.connect(protocolFactory)
  194. defer.returnValue(connection)
  195. @defer.inlineCallbacks
  196. def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
  197. cache_entry = cache.get(service_name, None)
  198. if cache_entry:
  199. if all(s.expires > int(clock.time()) for s in cache_entry):
  200. servers = list(cache_entry)
  201. defer.returnValue(servers)
  202. servers = []
  203. try:
  204. try:
  205. answers, _, _ = yield dns_client.lookupService(service_name)
  206. except DNSNameError:
  207. defer.returnValue([])
  208. if (len(answers) == 1
  209. and answers[0].type == dns.SRV
  210. and answers[0].payload
  211. and answers[0].payload.target == dns.Name('.')):
  212. raise ConnectError("Service %s unavailable" % service_name)
  213. for answer in answers:
  214. if answer.type != dns.SRV or not answer.payload:
  215. continue
  216. payload = answer.payload
  217. host = str(payload.target)
  218. srv_ttl = answer.ttl
  219. try:
  220. answers, _, _ = yield dns_client.lookupAddress(host)
  221. except DNSNameError:
  222. continue
  223. for answer in answers:
  224. if answer.type == dns.A and answer.payload:
  225. ip = answer.payload.dottedQuad()
  226. host_ttl = min(srv_ttl, answer.ttl)
  227. servers.append(_Server(
  228. host=ip,
  229. port=int(payload.port),
  230. priority=int(payload.priority),
  231. weight=int(payload.weight),
  232. expires=int(clock.time()) + host_ttl,
  233. ))
  234. servers.sort()
  235. cache[service_name] = list(servers)
  236. except DomainError as e:
  237. # We failed to resolve the name (other than a NameError)
  238. # Try something in the cache, else rereaise
  239. cache_entry = cache.get(service_name, None)
  240. if cache_entry:
  241. logger.warn(
  242. "Failed to resolve %r, falling back to cache. %r",
  243. service_name, e
  244. )
  245. servers = list(cache_entry)
  246. else:
  247. raise e
  248. defer.returnValue(servers)