httpclient.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import json
  16. import logging
  17. from StringIO import StringIO
  18. from twisted.internet import defer, reactor, ssl
  19. from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
  20. from twisted.internet._sslverify import _defaultCurveName, ClientTLSOptions
  21. from twisted.web.client import FileBodyProducer, Agent, readBody
  22. from twisted.web.http_headers import Headers
  23. import twisted.names.client
  24. from twisted.names.error import DNSNameError
  25. from OpenSSL import SSL, crypto
  26. logger = logging.getLogger(__name__)
  27. class SimpleHttpClient(object):
  28. """
  29. A simple, no-frills HTTP client based on the class of the same name
  30. from synapse
  31. """
  32. def __init__(self, sydent, endpoint_factory=None):
  33. self.sydent = sydent
  34. if endpoint_factory is None:
  35. # The default endpoint factory in Twisted 14.0.0 (which we require) uses the
  36. # BrowserLikePolicyForHTTPS context factory which will do regular cert validation
  37. # 'like a browser'
  38. self.agent = Agent(
  39. reactor,
  40. connectTimeout=15,
  41. )
  42. else:
  43. self.agent = Agent.usingEndpointFactory(
  44. reactor,
  45. endpoint_factory,
  46. )
  47. @defer.inlineCallbacks
  48. def get_json(self, uri):
  49. logger.debug("HTTP GET %s", uri)
  50. response = yield self.agent.request(
  51. "GET",
  52. uri.encode("ascii"),
  53. )
  54. body = yield readBody(response)
  55. defer.returnValue(json.loads(body))
  56. @defer.inlineCallbacks
  57. def post_json_get_nothing(self, uri, post_json, opts):
  58. json_str = json.dumps(post_json)
  59. headers = opts.get('headers', Headers({
  60. b"Content-Type": [b"application/json"],
  61. }))
  62. logger.debug("HTTP POST %s -> %s", json_str, uri)
  63. response = yield self.agent.request(
  64. "POST",
  65. uri.encode("ascii"),
  66. headers,
  67. bodyProducer=FileBodyProducer(StringIO(json_str))
  68. )
  69. defer.returnValue(response)
  70. class SRVClientEndpoint(object):
  71. def __init__(self, reactor, service, domain, protocol="tcp",
  72. default_port=None, endpoint=HostnameEndpoint,
  73. endpoint_kw_args={}):
  74. self.reactor = reactor
  75. self.domain = domain
  76. self.endpoint = endpoint
  77. self.endpoint_kw_args = endpoint_kw_args
  78. @defer.inlineCallbacks
  79. def lookup_server(self):
  80. service_name = "%s.%s.%s" % ('_matrix', '_tcp', self.domain)
  81. default = self.domain, 8448
  82. try:
  83. answers, _, _ = yield twisted.names.client.lookupService(service_name)
  84. except DNSNameError:
  85. logger.info("DNSNameError doing SRV lookup for %s - using default", service_name)
  86. defer.returnValue(default)
  87. for answer in answers:
  88. if answer.type != twisted.names.dns.SRV or not answer.payload:
  89. continue
  90. # XXX we just use the first
  91. logger.info("Got SRV answer: %r / %d for %s", str(answer.payload.target), answer.payload.port, service_name)
  92. defer.returnValue((str(answer.payload.target), answer.payload.port))
  93. logger.info("No valid answers found in response from %s (%r)", self.domain, answers)
  94. defer.returnValue(default)
  95. @defer.inlineCallbacks
  96. def connect(self, protocolFactory):
  97. server = yield self.lookup_server()
  98. logger.info("Connecting to %s:%s", server[0], server[1])
  99. endpoint = self.endpoint(
  100. self.reactor, server[0], server[1], **self.endpoint_kw_args
  101. )
  102. connection = yield endpoint.connect(protocolFactory)
  103. defer.returnValue(connection)
  104. def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
  105. timeout=None):
  106. """Construct an endpoint for the given matrix destination.
  107. :param reactor: Twisted reactor.
  108. :param destination: The name of the server to connect to.
  109. :type destination: bytes
  110. :param ssl_context_factory: Factory which generates SSL contexts to use for TLS.
  111. :type ssl_context_factory: twisted.internet.ssl.ContextFactory
  112. :param timeout (int): connection timeout in seconds
  113. :type timeout: int
  114. """
  115. domain_port = destination.split(":")
  116. domain = domain_port[0]
  117. port = int(domain_port[1]) if domain_port[1:] else None
  118. endpoint_kw_args = {}
  119. if timeout is not None:
  120. endpoint_kw_args.update(timeout=timeout)
  121. if ssl_context_factory is None:
  122. transport_endpoint = HostnameEndpoint
  123. default_port = 8008
  124. else:
  125. def transport_endpoint(reactor, host, port, timeout):
  126. return wrapClientTLS(
  127. ssl_context_factory,
  128. HostnameEndpoint(reactor, host, port, timeout=timeout))
  129. default_port = 8448
  130. if port is None:
  131. return SRVClientEndpoint(
  132. reactor, "matrix", domain, protocol="tcp",
  133. default_port=default_port, endpoint=transport_endpoint,
  134. endpoint_kw_args=endpoint_kw_args
  135. )
  136. else:
  137. return transport_endpoint(
  138. reactor, domain, port, **endpoint_kw_args
  139. )
  140. class FederationEndpointFactory(object):
  141. def endpointForURI(self, uri):
  142. destination = uri.netloc
  143. context_factory = FederationContextFactory()
  144. return matrix_federation_endpoint(
  145. reactor, destination, timeout=10,
  146. ssl_context_factory=context_factory,
  147. )
  148. class FederationContextFactory(object):
  149. def getContext(self):
  150. context = SSL.Context(SSL.SSLv23_METHOD)
  151. try:
  152. _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
  153. context.set_tmp_ecdh(_ecCurve)
  154. except Exception:
  155. logger.exception("Failed to enable elliptic curve for TLS")
  156. context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
  157. context.set_cipher_list("!ADH:HIGH+kEDH:!AECDH:HIGH+kEECDH")
  158. return context
  159. class FederationHttpClient(SimpleHttpClient):
  160. def __init__(self, sydent):
  161. super(FederationHttpClient, self).__init__(sydent, FederationEndpointFactory())