test_proxyagent.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import treq
  17. from netaddr import IPSet
  18. from twisted.internet import interfaces # noqa: F401
  19. from twisted.internet.protocol import Factory
  20. from twisted.protocols.tls import TLSMemoryBIOFactory
  21. from twisted.web.http import HTTPChannel
  22. from synapse.http.client import BlacklistingReactorWrapper
  23. from synapse.http.proxyagent import ProxyAgent
  24. from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
  25. from tests.server import FakeTransport, ThreadedMemoryReactorClock
  26. from tests.unittest import TestCase
  27. logger = logging.getLogger(__name__)
  28. HTTPFactory = Factory.forProtocol(HTTPChannel)
  29. class MatrixFederationAgentTests(TestCase):
  30. def setUp(self):
  31. self.reactor = ThreadedMemoryReactorClock()
  32. def _make_connection(
  33. self, client_factory, server_factory, ssl=False, expected_sni=None
  34. ):
  35. """Builds a test server, and completes the outgoing client connection
  36. Args:
  37. client_factory (interfaces.IProtocolFactory): the the factory that the
  38. application is trying to use to make the outbound connection. We will
  39. invoke it to build the client Protocol
  40. server_factory (interfaces.IProtocolFactory): a factory to build the
  41. server-side protocol
  42. ssl (bool): If true, we will expect an ssl connection and wrap
  43. server_factory with a TLSMemoryBIOFactory
  44. expected_sni (bytes|None): the expected SNI value
  45. Returns:
  46. IProtocol: the server Protocol returned by server_factory
  47. """
  48. if ssl:
  49. server_factory = _wrap_server_factory_for_tls(server_factory)
  50. server_protocol = server_factory.buildProtocol(None)
  51. # now, tell the client protocol factory to build the client protocol,
  52. # and wire the output of said protocol up to the server via
  53. # a FakeTransport.
  54. #
  55. # Normally this would be done by the TCP socket code in Twisted, but we are
  56. # stubbing that out here.
  57. client_protocol = client_factory.buildProtocol(None)
  58. client_protocol.makeConnection(
  59. FakeTransport(server_protocol, self.reactor, client_protocol)
  60. )
  61. # tell the server protocol to send its stuff back to the client, too
  62. server_protocol.makeConnection(
  63. FakeTransport(client_protocol, self.reactor, server_protocol)
  64. )
  65. if ssl:
  66. http_protocol = server_protocol.wrappedProtocol
  67. tls_connection = server_protocol._tlsConnection
  68. else:
  69. http_protocol = server_protocol
  70. tls_connection = None
  71. # give the reactor a pump to get the TLS juices flowing (if needed)
  72. self.reactor.advance(0)
  73. if expected_sni is not None:
  74. server_name = tls_connection.get_servername()
  75. self.assertEqual(
  76. server_name,
  77. expected_sni,
  78. "Expected SNI %s but got %s" % (expected_sni, server_name),
  79. )
  80. return http_protocol
  81. def test_http_request(self):
  82. agent = ProxyAgent(self.reactor)
  83. self.reactor.lookups["test.com"] = "1.2.3.4"
  84. d = agent.request(b"GET", b"http://test.com")
  85. # there should be a pending TCP connection
  86. clients = self.reactor.tcpClients
  87. self.assertEqual(len(clients), 1)
  88. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  89. self.assertEqual(host, "1.2.3.4")
  90. self.assertEqual(port, 80)
  91. # make a test server, and wire up the client
  92. http_server = self._make_connection(
  93. client_factory, _get_test_protocol_factory()
  94. )
  95. # the FakeTransport is async, so we need to pump the reactor
  96. self.reactor.advance(0)
  97. # now there should be a pending request
  98. self.assertEqual(len(http_server.requests), 1)
  99. request = http_server.requests[0]
  100. self.assertEqual(request.method, b"GET")
  101. self.assertEqual(request.path, b"/")
  102. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  103. request.write(b"result")
  104. request.finish()
  105. self.reactor.advance(0)
  106. resp = self.successResultOf(d)
  107. body = self.successResultOf(treq.content(resp))
  108. self.assertEqual(body, b"result")
  109. def test_https_request(self):
  110. agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
  111. self.reactor.lookups["test.com"] = "1.2.3.4"
  112. d = agent.request(b"GET", b"https://test.com/abc")
  113. # there should be a pending TCP connection
  114. clients = self.reactor.tcpClients
  115. self.assertEqual(len(clients), 1)
  116. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  117. self.assertEqual(host, "1.2.3.4")
  118. self.assertEqual(port, 443)
  119. # make a test server, and wire up the client
  120. http_server = self._make_connection(
  121. client_factory,
  122. _get_test_protocol_factory(),
  123. ssl=True,
  124. expected_sni=b"test.com",
  125. )
  126. # the FakeTransport is async, so we need to pump the reactor
  127. self.reactor.advance(0)
  128. # now there should be a pending request
  129. self.assertEqual(len(http_server.requests), 1)
  130. request = http_server.requests[0]
  131. self.assertEqual(request.method, b"GET")
  132. self.assertEqual(request.path, b"/abc")
  133. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  134. request.write(b"result")
  135. request.finish()
  136. self.reactor.advance(0)
  137. resp = self.successResultOf(d)
  138. body = self.successResultOf(treq.content(resp))
  139. self.assertEqual(body, b"result")
  140. def test_http_request_via_proxy(self):
  141. agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
  142. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  143. d = agent.request(b"GET", b"http://test.com")
  144. # there should be a pending TCP connection
  145. clients = self.reactor.tcpClients
  146. self.assertEqual(len(clients), 1)
  147. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  148. self.assertEqual(host, "1.2.3.5")
  149. self.assertEqual(port, 8888)
  150. # make a test server, and wire up the client
  151. http_server = self._make_connection(
  152. client_factory, _get_test_protocol_factory()
  153. )
  154. # the FakeTransport is async, so we need to pump the reactor
  155. self.reactor.advance(0)
  156. # now there should be a pending request
  157. self.assertEqual(len(http_server.requests), 1)
  158. request = http_server.requests[0]
  159. self.assertEqual(request.method, b"GET")
  160. self.assertEqual(request.path, b"http://test.com")
  161. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  162. request.write(b"result")
  163. request.finish()
  164. self.reactor.advance(0)
  165. resp = self.successResultOf(d)
  166. body = self.successResultOf(treq.content(resp))
  167. self.assertEqual(body, b"result")
  168. def test_https_request_via_proxy(self):
  169. agent = ProxyAgent(
  170. self.reactor,
  171. contextFactory=get_test_https_policy(),
  172. https_proxy=b"proxy.com",
  173. )
  174. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  175. d = agent.request(b"GET", b"https://test.com/abc")
  176. # there should be a pending TCP connection
  177. clients = self.reactor.tcpClients
  178. self.assertEqual(len(clients), 1)
  179. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  180. self.assertEqual(host, "1.2.3.5")
  181. self.assertEqual(port, 1080)
  182. # make a test HTTP server, and wire up the client
  183. proxy_server = self._make_connection(
  184. client_factory, _get_test_protocol_factory()
  185. )
  186. # fish the transports back out so that we can do the old switcheroo
  187. s2c_transport = proxy_server.transport
  188. client_protocol = s2c_transport.other
  189. c2s_transport = client_protocol.transport
  190. # the FakeTransport is async, so we need to pump the reactor
  191. self.reactor.advance(0)
  192. # now there should be a pending CONNECT request
  193. self.assertEqual(len(proxy_server.requests), 1)
  194. request = proxy_server.requests[0]
  195. self.assertEqual(request.method, b"CONNECT")
  196. self.assertEqual(request.path, b"test.com:443")
  197. # tell the proxy server not to close the connection
  198. proxy_server.persistent = True
  199. # this just stops the http Request trying to do a chunked response
  200. # request.setHeader(b"Content-Length", b"0")
  201. request.finish()
  202. # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
  203. ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
  204. ssl_protocol = ssl_factory.buildProtocol(None)
  205. http_server = ssl_protocol.wrappedProtocol
  206. ssl_protocol.makeConnection(
  207. FakeTransport(client_protocol, self.reactor, ssl_protocol)
  208. )
  209. c2s_transport.other = ssl_protocol
  210. self.reactor.advance(0)
  211. server_name = ssl_protocol._tlsConnection.get_servername()
  212. expected_sni = b"test.com"
  213. self.assertEqual(
  214. server_name,
  215. expected_sni,
  216. "Expected SNI %s but got %s" % (expected_sni, server_name),
  217. )
  218. # now there should be a pending request
  219. self.assertEqual(len(http_server.requests), 1)
  220. request = http_server.requests[0]
  221. self.assertEqual(request.method, b"GET")
  222. self.assertEqual(request.path, b"/abc")
  223. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  224. request.write(b"result")
  225. request.finish()
  226. self.reactor.advance(0)
  227. resp = self.successResultOf(d)
  228. body = self.successResultOf(treq.content(resp))
  229. self.assertEqual(body, b"result")
  230. def test_http_request_via_proxy_with_blacklist(self):
  231. # The blacklist includes the configured proxy IP.
  232. agent = ProxyAgent(
  233. BlacklistingReactorWrapper(
  234. self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
  235. ),
  236. self.reactor,
  237. http_proxy=b"proxy.com:8888",
  238. )
  239. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  240. d = agent.request(b"GET", b"http://test.com")
  241. # there should be a pending TCP connection
  242. clients = self.reactor.tcpClients
  243. self.assertEqual(len(clients), 1)
  244. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  245. self.assertEqual(host, "1.2.3.5")
  246. self.assertEqual(port, 8888)
  247. # make a test server, and wire up the client
  248. http_server = self._make_connection(
  249. client_factory, _get_test_protocol_factory()
  250. )
  251. # the FakeTransport is async, so we need to pump the reactor
  252. self.reactor.advance(0)
  253. # now there should be a pending request
  254. self.assertEqual(len(http_server.requests), 1)
  255. request = http_server.requests[0]
  256. self.assertEqual(request.method, b"GET")
  257. self.assertEqual(request.path, b"http://test.com")
  258. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  259. request.write(b"result")
  260. request.finish()
  261. self.reactor.advance(0)
  262. resp = self.successResultOf(d)
  263. body = self.successResultOf(treq.content(resp))
  264. self.assertEqual(body, b"result")
  265. def test_https_request_via_proxy_with_blacklist(self):
  266. # The blacklist includes the configured proxy IP.
  267. agent = ProxyAgent(
  268. BlacklistingReactorWrapper(
  269. self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
  270. ),
  271. self.reactor,
  272. contextFactory=get_test_https_policy(),
  273. https_proxy=b"proxy.com",
  274. )
  275. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  276. d = agent.request(b"GET", b"https://test.com/abc")
  277. # there should be a pending TCP connection
  278. clients = self.reactor.tcpClients
  279. self.assertEqual(len(clients), 1)
  280. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  281. self.assertEqual(host, "1.2.3.5")
  282. self.assertEqual(port, 1080)
  283. # make a test HTTP server, and wire up the client
  284. proxy_server = self._make_connection(
  285. client_factory, _get_test_protocol_factory()
  286. )
  287. # fish the transports back out so that we can do the old switcheroo
  288. s2c_transport = proxy_server.transport
  289. client_protocol = s2c_transport.other
  290. c2s_transport = client_protocol.transport
  291. # the FakeTransport is async, so we need to pump the reactor
  292. self.reactor.advance(0)
  293. # now there should be a pending CONNECT request
  294. self.assertEqual(len(proxy_server.requests), 1)
  295. request = proxy_server.requests[0]
  296. self.assertEqual(request.method, b"CONNECT")
  297. self.assertEqual(request.path, b"test.com:443")
  298. # tell the proxy server not to close the connection
  299. proxy_server.persistent = True
  300. # this just stops the http Request trying to do a chunked response
  301. # request.setHeader(b"Content-Length", b"0")
  302. request.finish()
  303. # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
  304. ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
  305. ssl_protocol = ssl_factory.buildProtocol(None)
  306. http_server = ssl_protocol.wrappedProtocol
  307. ssl_protocol.makeConnection(
  308. FakeTransport(client_protocol, self.reactor, ssl_protocol)
  309. )
  310. c2s_transport.other = ssl_protocol
  311. self.reactor.advance(0)
  312. server_name = ssl_protocol._tlsConnection.get_servername()
  313. expected_sni = b"test.com"
  314. self.assertEqual(
  315. server_name,
  316. expected_sni,
  317. "Expected SNI %s but got %s" % (expected_sni, server_name),
  318. )
  319. # now there should be a pending request
  320. self.assertEqual(len(http_server.requests), 1)
  321. request = http_server.requests[0]
  322. self.assertEqual(request.method, b"GET")
  323. self.assertEqual(request.path, b"/abc")
  324. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  325. request.write(b"result")
  326. request.finish()
  327. self.reactor.advance(0)
  328. resp = self.successResultOf(d)
  329. body = self.successResultOf(treq.content(resp))
  330. self.assertEqual(body, b"result")
  331. def _wrap_server_factory_for_tls(factory, sanlist=None):
  332. """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
  333. The resultant factory will create a TLS server which presents a certificate
  334. signed by our test CA, valid for the domains in `sanlist`
  335. Args:
  336. factory (interfaces.IProtocolFactory): protocol factory to wrap
  337. sanlist (iterable[bytes]): list of domains the cert should be valid for
  338. Returns:
  339. interfaces.IProtocolFactory
  340. """
  341. if sanlist is None:
  342. sanlist = [b"DNS:test.com"]
  343. connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
  344. return TLSMemoryBIOFactory(
  345. connection_creator, isClient=False, wrappedFactory=factory
  346. )
  347. def _get_test_protocol_factory():
  348. """Get a protocol Factory which will build an HTTPChannel
  349. Returns:
  350. interfaces.IProtocolFactory
  351. """
  352. server_factory = Factory.forProtocol(HTTPChannel)
  353. # Request.finish expects the factory to have a 'log' method.
  354. server_factory.log = _log_request
  355. return server_factory
  356. def _log_request(request):
  357. """Implements Factory.log, which is expected by Request.finish"""
  358. logger.info("Completed request %s", request)