test_proxyagent.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import treq
  17. from twisted.internet import interfaces # noqa: F401
  18. from twisted.internet.protocol import Factory
  19. from twisted.protocols.tls import TLSMemoryBIOFactory
  20. from twisted.web.http import HTTPChannel
  21. from synapse.http.proxyagent import ProxyAgent
  22. from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
  23. from tests.server import FakeTransport, ThreadedMemoryReactorClock
  24. from tests.unittest import TestCase
  25. logger = logging.getLogger(__name__)
  26. HTTPFactory = Factory.forProtocol(HTTPChannel)
  27. class MatrixFederationAgentTests(TestCase):
  28. def setUp(self):
  29. self.reactor = ThreadedMemoryReactorClock()
  30. def _make_connection(
  31. self, client_factory, server_factory, ssl=False, expected_sni=None
  32. ):
  33. """Builds a test server, and completes the outgoing client connection
  34. Args:
  35. client_factory (interfaces.IProtocolFactory): the the factory that the
  36. application is trying to use to make the outbound connection. We will
  37. invoke it to build the client Protocol
  38. server_factory (interfaces.IProtocolFactory): a factory to build the
  39. server-side protocol
  40. ssl (bool): If true, we will expect an ssl connection and wrap
  41. server_factory with a TLSMemoryBIOFactory
  42. expected_sni (bytes|None): the expected SNI value
  43. Returns:
  44. IProtocol: the server Protocol returned by server_factory
  45. """
  46. if ssl:
  47. server_factory = _wrap_server_factory_for_tls(server_factory)
  48. server_protocol = server_factory.buildProtocol(None)
  49. # now, tell the client protocol factory to build the client protocol,
  50. # and wire the output of said protocol up to the server via
  51. # a FakeTransport.
  52. #
  53. # Normally this would be done by the TCP socket code in Twisted, but we are
  54. # stubbing that out here.
  55. client_protocol = client_factory.buildProtocol(None)
  56. client_protocol.makeConnection(
  57. FakeTransport(server_protocol, self.reactor, client_protocol)
  58. )
  59. # tell the server protocol to send its stuff back to the client, too
  60. server_protocol.makeConnection(
  61. FakeTransport(client_protocol, self.reactor, server_protocol)
  62. )
  63. if ssl:
  64. http_protocol = server_protocol.wrappedProtocol
  65. tls_connection = server_protocol._tlsConnection
  66. else:
  67. http_protocol = server_protocol
  68. tls_connection = None
  69. # give the reactor a pump to get the TLS juices flowing (if needed)
  70. self.reactor.advance(0)
  71. if expected_sni is not None:
  72. server_name = tls_connection.get_servername()
  73. self.assertEqual(
  74. server_name,
  75. expected_sni,
  76. "Expected SNI %s but got %s" % (expected_sni, server_name),
  77. )
  78. return http_protocol
  79. def test_http_request(self):
  80. agent = ProxyAgent(self.reactor)
  81. self.reactor.lookups["test.com"] = "1.2.3.4"
  82. d = agent.request(b"GET", b"http://test.com")
  83. # there should be a pending TCP connection
  84. clients = self.reactor.tcpClients
  85. self.assertEqual(len(clients), 1)
  86. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  87. self.assertEqual(host, "1.2.3.4")
  88. self.assertEqual(port, 80)
  89. # make a test server, and wire up the client
  90. http_server = self._make_connection(
  91. client_factory, _get_test_protocol_factory()
  92. )
  93. # the FakeTransport is async, so we need to pump the reactor
  94. self.reactor.advance(0)
  95. # now there should be a pending request
  96. self.assertEqual(len(http_server.requests), 1)
  97. request = http_server.requests[0]
  98. self.assertEqual(request.method, b"GET")
  99. self.assertEqual(request.path, b"/")
  100. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  101. request.write(b"result")
  102. request.finish()
  103. self.reactor.advance(0)
  104. resp = self.successResultOf(d)
  105. body = self.successResultOf(treq.content(resp))
  106. self.assertEqual(body, b"result")
  107. def test_https_request(self):
  108. agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
  109. self.reactor.lookups["test.com"] = "1.2.3.4"
  110. d = agent.request(b"GET", b"https://test.com/abc")
  111. # there should be a pending TCP connection
  112. clients = self.reactor.tcpClients
  113. self.assertEqual(len(clients), 1)
  114. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  115. self.assertEqual(host, "1.2.3.4")
  116. self.assertEqual(port, 443)
  117. # make a test server, and wire up the client
  118. http_server = self._make_connection(
  119. client_factory,
  120. _get_test_protocol_factory(),
  121. ssl=True,
  122. expected_sni=b"test.com",
  123. )
  124. # the FakeTransport is async, so we need to pump the reactor
  125. self.reactor.advance(0)
  126. # now there should be a pending request
  127. self.assertEqual(len(http_server.requests), 1)
  128. request = http_server.requests[0]
  129. self.assertEqual(request.method, b"GET")
  130. self.assertEqual(request.path, b"/abc")
  131. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  132. request.write(b"result")
  133. request.finish()
  134. self.reactor.advance(0)
  135. resp = self.successResultOf(d)
  136. body = self.successResultOf(treq.content(resp))
  137. self.assertEqual(body, b"result")
  138. def test_http_request_via_proxy(self):
  139. agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
  140. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  141. d = agent.request(b"GET", b"http://test.com")
  142. # there should be a pending TCP connection
  143. clients = self.reactor.tcpClients
  144. self.assertEqual(len(clients), 1)
  145. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  146. self.assertEqual(host, "1.2.3.5")
  147. self.assertEqual(port, 8888)
  148. # make a test server, and wire up the client
  149. http_server = self._make_connection(
  150. client_factory, _get_test_protocol_factory()
  151. )
  152. # the FakeTransport is async, so we need to pump the reactor
  153. self.reactor.advance(0)
  154. # now there should be a pending request
  155. self.assertEqual(len(http_server.requests), 1)
  156. request = http_server.requests[0]
  157. self.assertEqual(request.method, b"GET")
  158. self.assertEqual(request.path, b"http://test.com")
  159. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  160. request.write(b"result")
  161. request.finish()
  162. self.reactor.advance(0)
  163. resp = self.successResultOf(d)
  164. body = self.successResultOf(treq.content(resp))
  165. self.assertEqual(body, b"result")
  166. def test_https_request_via_proxy(self):
  167. agent = ProxyAgent(
  168. self.reactor,
  169. contextFactory=get_test_https_policy(),
  170. https_proxy=b"proxy.com",
  171. )
  172. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  173. d = agent.request(b"GET", b"https://test.com/abc")
  174. # there should be a pending TCP connection
  175. clients = self.reactor.tcpClients
  176. self.assertEqual(len(clients), 1)
  177. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  178. self.assertEqual(host, "1.2.3.5")
  179. self.assertEqual(port, 1080)
  180. # make a test HTTP server, and wire up the client
  181. proxy_server = self._make_connection(
  182. client_factory, _get_test_protocol_factory()
  183. )
  184. # fish the transports back out so that we can do the old switcheroo
  185. s2c_transport = proxy_server.transport
  186. client_protocol = s2c_transport.other
  187. c2s_transport = client_protocol.transport
  188. # the FakeTransport is async, so we need to pump the reactor
  189. self.reactor.advance(0)
  190. # now there should be a pending CONNECT request
  191. self.assertEqual(len(proxy_server.requests), 1)
  192. request = proxy_server.requests[0]
  193. self.assertEqual(request.method, b"CONNECT")
  194. self.assertEqual(request.path, b"test.com:443")
  195. # tell the proxy server not to close the connection
  196. proxy_server.persistent = True
  197. # this just stops the http Request trying to do a chunked response
  198. # request.setHeader(b"Content-Length", b"0")
  199. request.finish()
  200. # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
  201. ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
  202. ssl_protocol = ssl_factory.buildProtocol(None)
  203. http_server = ssl_protocol.wrappedProtocol
  204. ssl_protocol.makeConnection(
  205. FakeTransport(client_protocol, self.reactor, ssl_protocol)
  206. )
  207. c2s_transport.other = ssl_protocol
  208. self.reactor.advance(0)
  209. server_name = ssl_protocol._tlsConnection.get_servername()
  210. expected_sni = b"test.com"
  211. self.assertEqual(
  212. server_name,
  213. expected_sni,
  214. "Expected SNI %s but got %s" % (expected_sni, server_name),
  215. )
  216. # now there should be a pending request
  217. self.assertEqual(len(http_server.requests), 1)
  218. request = http_server.requests[0]
  219. self.assertEqual(request.method, b"GET")
  220. self.assertEqual(request.path, b"/abc")
  221. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  222. request.write(b"result")
  223. request.finish()
  224. self.reactor.advance(0)
  225. resp = self.successResultOf(d)
  226. body = self.successResultOf(treq.content(resp))
  227. self.assertEqual(body, b"result")
  228. def _wrap_server_factory_for_tls(factory, sanlist=None):
  229. """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
  230. The resultant factory will create a TLS server which presents a certificate
  231. signed by our test CA, valid for the domains in `sanlist`
  232. Args:
  233. factory (interfaces.IProtocolFactory): protocol factory to wrap
  234. sanlist (iterable[bytes]): list of domains the cert should be valid for
  235. Returns:
  236. interfaces.IProtocolFactory
  237. """
  238. if sanlist is None:
  239. sanlist = [b"DNS:test.com"]
  240. connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
  241. return TLSMemoryBIOFactory(
  242. connection_creator, isClient=False, wrappedFactory=factory
  243. )
  244. def _get_test_protocol_factory():
  245. """Get a protocol Factory which will build an HTTPChannel
  246. Returns:
  247. interfaces.IProtocolFactory
  248. """
  249. server_factory = Factory.forProtocol(HTTPChannel)
  250. # Request.finish expects the factory to have a 'log' method.
  251. server_factory.log = _log_request
  252. return server_factory
  253. def _log_request(request):
  254. """Implements Factory.log, which is expected by Request.finish"""
  255. logger.info("Completed request %s", request)