test_proxyagent.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import base64
  16. import logging
  17. import os
  18. from typing import Optional
  19. from unittest.mock import patch
  20. import treq
  21. from netaddr import IPSet
  22. from twisted.internet import interfaces # noqa: F401
  23. from twisted.internet.protocol import Factory
  24. from twisted.protocols.tls import TLSMemoryBIOFactory
  25. from twisted.web.http import HTTPChannel
  26. from synapse.http.client import BlacklistingReactorWrapper
  27. from synapse.http.proxyagent import ProxyAgent
  28. from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
  29. from tests.server import FakeTransport, ThreadedMemoryReactorClock
  30. from tests.unittest import TestCase
  31. logger = logging.getLogger(__name__)
  32. HTTPFactory = Factory.forProtocol(HTTPChannel)
  33. class MatrixFederationAgentTests(TestCase):
  34. def setUp(self):
  35. self.reactor = ThreadedMemoryReactorClock()
  36. def _make_connection(
  37. self, client_factory, server_factory, ssl=False, expected_sni=None
  38. ):
  39. """Builds a test server, and completes the outgoing client connection
  40. Args:
  41. client_factory (interfaces.IProtocolFactory): the the factory that the
  42. application is trying to use to make the outbound connection. We will
  43. invoke it to build the client Protocol
  44. server_factory (interfaces.IProtocolFactory): a factory to build the
  45. server-side protocol
  46. ssl (bool): If true, we will expect an ssl connection and wrap
  47. server_factory with a TLSMemoryBIOFactory
  48. expected_sni (bytes|None): the expected SNI value
  49. Returns:
  50. IProtocol: the server Protocol returned by server_factory
  51. """
  52. if ssl:
  53. server_factory = _wrap_server_factory_for_tls(server_factory)
  54. server_protocol = server_factory.buildProtocol(None)
  55. # now, tell the client protocol factory to build the client protocol,
  56. # and wire the output of said protocol up to the server via
  57. # a FakeTransport.
  58. #
  59. # Normally this would be done by the TCP socket code in Twisted, but we are
  60. # stubbing that out here.
  61. client_protocol = client_factory.buildProtocol(None)
  62. client_protocol.makeConnection(
  63. FakeTransport(server_protocol, self.reactor, client_protocol)
  64. )
  65. # tell the server protocol to send its stuff back to the client, too
  66. server_protocol.makeConnection(
  67. FakeTransport(client_protocol, self.reactor, server_protocol)
  68. )
  69. if ssl:
  70. http_protocol = server_protocol.wrappedProtocol
  71. tls_connection = server_protocol._tlsConnection
  72. else:
  73. http_protocol = server_protocol
  74. tls_connection = None
  75. # give the reactor a pump to get the TLS juices flowing (if needed)
  76. self.reactor.advance(0)
  77. if expected_sni is not None:
  78. server_name = tls_connection.get_servername()
  79. self.assertEqual(
  80. server_name,
  81. expected_sni,
  82. "Expected SNI %s but got %s" % (expected_sni, server_name),
  83. )
  84. return http_protocol
  85. def _test_request_direct_connection(self, agent, scheme, hostname, path):
  86. """Runs a test case for a direct connection not going through a proxy.
  87. Args:
  88. agent (ProxyAgent): the proxy agent being tested
  89. scheme (bytes): expected to be either "http" or "https"
  90. hostname (bytes): the hostname to connect to in the test
  91. path (bytes): the path to connect to in the test
  92. """
  93. is_https = scheme == b"https"
  94. self.reactor.lookups[hostname.decode()] = "1.2.3.4"
  95. d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path)
  96. # there should be a pending TCP connection
  97. clients = self.reactor.tcpClients
  98. self.assertEqual(len(clients), 1)
  99. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  100. self.assertEqual(host, "1.2.3.4")
  101. self.assertEqual(port, 443 if is_https else 80)
  102. # make a test server, and wire up the client
  103. http_server = self._make_connection(
  104. client_factory,
  105. _get_test_protocol_factory(),
  106. ssl=is_https,
  107. expected_sni=hostname if is_https else None,
  108. )
  109. # the FakeTransport is async, so we need to pump the reactor
  110. self.reactor.advance(0)
  111. # now there should be a pending request
  112. self.assertEqual(len(http_server.requests), 1)
  113. request = http_server.requests[0]
  114. self.assertEqual(request.method, b"GET")
  115. self.assertEqual(request.path, b"/" + path)
  116. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname])
  117. request.write(b"result")
  118. request.finish()
  119. self.reactor.advance(0)
  120. resp = self.successResultOf(d)
  121. body = self.successResultOf(treq.content(resp))
  122. self.assertEqual(body, b"result")
  123. def test_http_request(self):
  124. agent = ProxyAgent(self.reactor)
  125. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  126. def test_https_request(self):
  127. agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
  128. self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
  129. def test_http_request_use_proxy_empty_environment(self):
  130. agent = ProxyAgent(self.reactor, use_proxy=True)
  131. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  132. @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
  133. def test_http_request_via_uppercase_no_proxy(self):
  134. agent = ProxyAgent(self.reactor, use_proxy=True)
  135. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  136. @patch.dict(
  137. os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
  138. )
  139. def test_http_request_via_no_proxy(self):
  140. agent = ProxyAgent(self.reactor, use_proxy=True)
  141. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  142. @patch.dict(
  143. os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
  144. )
  145. def test_https_request_via_no_proxy(self):
  146. agent = ProxyAgent(
  147. self.reactor,
  148. contextFactory=get_test_https_policy(),
  149. use_proxy=True,
  150. )
  151. self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
  152. @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
  153. def test_http_request_via_no_proxy_star(self):
  154. agent = ProxyAgent(self.reactor, use_proxy=True)
  155. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  156. @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
  157. def test_https_request_via_no_proxy_star(self):
  158. agent = ProxyAgent(
  159. self.reactor,
  160. contextFactory=get_test_https_policy(),
  161. use_proxy=True,
  162. )
  163. self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
  164. @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
  165. def test_http_request_via_proxy(self):
  166. agent = ProxyAgent(self.reactor, use_proxy=True)
  167. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  168. d = agent.request(b"GET", b"http://test.com")
  169. # there should be a pending TCP connection
  170. clients = self.reactor.tcpClients
  171. self.assertEqual(len(clients), 1)
  172. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  173. self.assertEqual(host, "1.2.3.5")
  174. self.assertEqual(port, 8888)
  175. # make a test server, and wire up the client
  176. http_server = self._make_connection(
  177. client_factory, _get_test_protocol_factory()
  178. )
  179. # the FakeTransport is async, so we need to pump the reactor
  180. self.reactor.advance(0)
  181. # now there should be a pending request
  182. self.assertEqual(len(http_server.requests), 1)
  183. request = http_server.requests[0]
  184. self.assertEqual(request.method, b"GET")
  185. self.assertEqual(request.path, b"http://test.com")
  186. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  187. request.write(b"result")
  188. request.finish()
  189. self.reactor.advance(0)
  190. resp = self.successResultOf(d)
  191. body = self.successResultOf(treq.content(resp))
  192. self.assertEqual(body, b"result")
  193. @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
  194. def test_https_request_via_proxy(self):
  195. """Tests that TLS-encrypted requests can be made through a proxy"""
  196. self._do_https_request_via_proxy(auth_credentials=None)
  197. @patch.dict(
  198. os.environ,
  199. {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
  200. )
  201. def test_https_request_via_proxy_with_auth(self):
  202. """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
  203. self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
  204. def _do_https_request_via_proxy(
  205. self,
  206. auth_credentials: Optional[str] = None,
  207. ):
  208. agent = ProxyAgent(
  209. self.reactor,
  210. contextFactory=get_test_https_policy(),
  211. use_proxy=True,
  212. )
  213. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  214. d = agent.request(b"GET", b"https://test.com/abc")
  215. # there should be a pending TCP connection
  216. clients = self.reactor.tcpClients
  217. self.assertEqual(len(clients), 1)
  218. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  219. self.assertEqual(host, "1.2.3.5")
  220. self.assertEqual(port, 1080)
  221. # make a test HTTP server, and wire up the client
  222. proxy_server = self._make_connection(
  223. client_factory, _get_test_protocol_factory()
  224. )
  225. # fish the transports back out so that we can do the old switcheroo
  226. s2c_transport = proxy_server.transport
  227. client_protocol = s2c_transport.other
  228. c2s_transport = client_protocol.transport
  229. # the FakeTransport is async, so we need to pump the reactor
  230. self.reactor.advance(0)
  231. # now there should be a pending CONNECT request
  232. self.assertEqual(len(proxy_server.requests), 1)
  233. request = proxy_server.requests[0]
  234. self.assertEqual(request.method, b"CONNECT")
  235. self.assertEqual(request.path, b"test.com:443")
  236. # Check whether auth credentials have been supplied to the proxy
  237. proxy_auth_header_values = request.requestHeaders.getRawHeaders(
  238. b"Proxy-Authorization"
  239. )
  240. if auth_credentials is not None:
  241. # Compute the correct header value for Proxy-Authorization
  242. encoded_credentials = base64.b64encode(b"bob:pinkponies")
  243. expected_header_value = b"Basic " + encoded_credentials
  244. # Validate the header's value
  245. self.assertIn(expected_header_value, proxy_auth_header_values)
  246. else:
  247. # Check that the Proxy-Authorization header has not been supplied to the proxy
  248. self.assertIsNone(proxy_auth_header_values)
  249. # tell the proxy server not to close the connection
  250. proxy_server.persistent = True
  251. # this just stops the http Request trying to do a chunked response
  252. # request.setHeader(b"Content-Length", b"0")
  253. request.finish()
  254. # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
  255. ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
  256. ssl_protocol = ssl_factory.buildProtocol(None)
  257. http_server = ssl_protocol.wrappedProtocol
  258. ssl_protocol.makeConnection(
  259. FakeTransport(client_protocol, self.reactor, ssl_protocol)
  260. )
  261. c2s_transport.other = ssl_protocol
  262. self.reactor.advance(0)
  263. server_name = ssl_protocol._tlsConnection.get_servername()
  264. expected_sni = b"test.com"
  265. self.assertEqual(
  266. server_name,
  267. expected_sni,
  268. "Expected SNI %s but got %s" % (expected_sni, server_name),
  269. )
  270. # now there should be a pending request
  271. self.assertEqual(len(http_server.requests), 1)
  272. request = http_server.requests[0]
  273. self.assertEqual(request.method, b"GET")
  274. self.assertEqual(request.path, b"/abc")
  275. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  276. # Check that the destination server DID NOT receive proxy credentials
  277. proxy_auth_header_values = request.requestHeaders.getRawHeaders(
  278. b"Proxy-Authorization"
  279. )
  280. self.assertIsNone(proxy_auth_header_values)
  281. request.write(b"result")
  282. request.finish()
  283. self.reactor.advance(0)
  284. resp = self.successResultOf(d)
  285. body = self.successResultOf(treq.content(resp))
  286. self.assertEqual(body, b"result")
  287. @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
  288. def test_http_request_via_proxy_with_blacklist(self):
  289. # The blacklist includes the configured proxy IP.
  290. agent = ProxyAgent(
  291. BlacklistingReactorWrapper(
  292. self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
  293. ),
  294. self.reactor,
  295. use_proxy=True,
  296. )
  297. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  298. d = agent.request(b"GET", b"http://test.com")
  299. # there should be a pending TCP connection
  300. clients = self.reactor.tcpClients
  301. self.assertEqual(len(clients), 1)
  302. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  303. self.assertEqual(host, "1.2.3.5")
  304. self.assertEqual(port, 8888)
  305. # make a test server, and wire up the client
  306. http_server = self._make_connection(
  307. client_factory, _get_test_protocol_factory()
  308. )
  309. # the FakeTransport is async, so we need to pump the reactor
  310. self.reactor.advance(0)
  311. # now there should be a pending request
  312. self.assertEqual(len(http_server.requests), 1)
  313. request = http_server.requests[0]
  314. self.assertEqual(request.method, b"GET")
  315. self.assertEqual(request.path, b"http://test.com")
  316. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  317. request.write(b"result")
  318. request.finish()
  319. self.reactor.advance(0)
  320. resp = self.successResultOf(d)
  321. body = self.successResultOf(treq.content(resp))
  322. self.assertEqual(body, b"result")
  323. @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
  324. def test_https_request_via_uppercase_proxy_with_blacklist(self):
  325. # The blacklist includes the configured proxy IP.
  326. agent = ProxyAgent(
  327. BlacklistingReactorWrapper(
  328. self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
  329. ),
  330. self.reactor,
  331. contextFactory=get_test_https_policy(),
  332. use_proxy=True,
  333. )
  334. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  335. d = agent.request(b"GET", b"https://test.com/abc")
  336. # there should be a pending TCP connection
  337. clients = self.reactor.tcpClients
  338. self.assertEqual(len(clients), 1)
  339. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  340. self.assertEqual(host, "1.2.3.5")
  341. self.assertEqual(port, 1080)
  342. # make a test HTTP server, and wire up the client
  343. proxy_server = self._make_connection(
  344. client_factory, _get_test_protocol_factory()
  345. )
  346. # fish the transports back out so that we can do the old switcheroo
  347. s2c_transport = proxy_server.transport
  348. client_protocol = s2c_transport.other
  349. c2s_transport = client_protocol.transport
  350. # the FakeTransport is async, so we need to pump the reactor
  351. self.reactor.advance(0)
  352. # now there should be a pending CONNECT request
  353. self.assertEqual(len(proxy_server.requests), 1)
  354. request = proxy_server.requests[0]
  355. self.assertEqual(request.method, b"CONNECT")
  356. self.assertEqual(request.path, b"test.com:443")
  357. # tell the proxy server not to close the connection
  358. proxy_server.persistent = True
  359. # this just stops the http Request trying to do a chunked response
  360. # request.setHeader(b"Content-Length", b"0")
  361. request.finish()
  362. # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
  363. ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
  364. ssl_protocol = ssl_factory.buildProtocol(None)
  365. http_server = ssl_protocol.wrappedProtocol
  366. ssl_protocol.makeConnection(
  367. FakeTransport(client_protocol, self.reactor, ssl_protocol)
  368. )
  369. c2s_transport.other = ssl_protocol
  370. self.reactor.advance(0)
  371. server_name = ssl_protocol._tlsConnection.get_servername()
  372. expected_sni = b"test.com"
  373. self.assertEqual(
  374. server_name,
  375. expected_sni,
  376. "Expected SNI %s but got %s" % (expected_sni, server_name),
  377. )
  378. # now there should be a pending request
  379. self.assertEqual(len(http_server.requests), 1)
  380. request = http_server.requests[0]
  381. self.assertEqual(request.method, b"GET")
  382. self.assertEqual(request.path, b"/abc")
  383. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  384. request.write(b"result")
  385. request.finish()
  386. self.reactor.advance(0)
  387. resp = self.successResultOf(d)
  388. body = self.successResultOf(treq.content(resp))
  389. self.assertEqual(body, b"result")
  390. def _wrap_server_factory_for_tls(factory, sanlist=None):
  391. """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
  392. The resultant factory will create a TLS server which presents a certificate
  393. signed by our test CA, valid for the domains in `sanlist`
  394. Args:
  395. factory (interfaces.IProtocolFactory): protocol factory to wrap
  396. sanlist (iterable[bytes]): list of domains the cert should be valid for
  397. Returns:
  398. interfaces.IProtocolFactory
  399. """
  400. if sanlist is None:
  401. sanlist = [b"DNS:test.com"]
  402. connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
  403. return TLSMemoryBIOFactory(
  404. connection_creator, isClient=False, wrappedFactory=factory
  405. )
  406. def _get_test_protocol_factory():
  407. """Get a protocol Factory which will build an HTTPChannel
  408. Returns:
  409. interfaces.IProtocolFactory
  410. """
  411. server_factory = Factory.forProtocol(HTTPChannel)
  412. # Request.finish expects the factory to have a 'log' method.
  413. server_factory.log = _log_request
  414. return server_factory
  415. def _log_request(request):
  416. """Implements Factory.log, which is expected by Request.finish"""
  417. logger.info("Completed request %s", request)