test_proxyagent.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. # Copyright 2019 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import base64
  15. import logging
  16. import os
  17. from typing import Optional
  18. from unittest.mock import patch
  19. import treq
  20. from netaddr import IPSet
  21. from twisted.internet import interfaces # noqa: F401
  22. from twisted.internet.protocol import Factory
  23. from twisted.protocols.tls import TLSMemoryBIOFactory
  24. from twisted.web.http import HTTPChannel
  25. from synapse.http.client import BlacklistingReactorWrapper
  26. from synapse.http.proxyagent import ProxyAgent
  27. from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
  28. from tests.server import FakeTransport, ThreadedMemoryReactorClock
  29. from tests.unittest import TestCase
  30. logger = logging.getLogger(__name__)
  31. HTTPFactory = Factory.forProtocol(HTTPChannel)
  32. class MatrixFederationAgentTests(TestCase):
  33. def setUp(self):
  34. self.reactor = ThreadedMemoryReactorClock()
  35. def _make_connection(
  36. self, client_factory, server_factory, ssl=False, expected_sni=None
  37. ):
  38. """Builds a test server, and completes the outgoing client connection
  39. Args:
  40. client_factory (interfaces.IProtocolFactory): the the factory that the
  41. application is trying to use to make the outbound connection. We will
  42. invoke it to build the client Protocol
  43. server_factory (interfaces.IProtocolFactory): a factory to build the
  44. server-side protocol
  45. ssl (bool): If true, we will expect an ssl connection and wrap
  46. server_factory with a TLSMemoryBIOFactory
  47. expected_sni (bytes|None): the expected SNI value
  48. Returns:
  49. IProtocol: the server Protocol returned by server_factory
  50. """
  51. if ssl:
  52. server_factory = _wrap_server_factory_for_tls(server_factory)
  53. server_protocol = server_factory.buildProtocol(None)
  54. # now, tell the client protocol factory to build the client protocol,
  55. # and wire the output of said protocol up to the server via
  56. # a FakeTransport.
  57. #
  58. # Normally this would be done by the TCP socket code in Twisted, but we are
  59. # stubbing that out here.
  60. client_protocol = client_factory.buildProtocol(None)
  61. client_protocol.makeConnection(
  62. FakeTransport(server_protocol, self.reactor, client_protocol)
  63. )
  64. # tell the server protocol to send its stuff back to the client, too
  65. server_protocol.makeConnection(
  66. FakeTransport(client_protocol, self.reactor, server_protocol)
  67. )
  68. if ssl:
  69. http_protocol = server_protocol.wrappedProtocol
  70. tls_connection = server_protocol._tlsConnection
  71. else:
  72. http_protocol = server_protocol
  73. tls_connection = None
  74. # give the reactor a pump to get the TLS juices flowing (if needed)
  75. self.reactor.advance(0)
  76. if expected_sni is not None:
  77. server_name = tls_connection.get_servername()
  78. self.assertEqual(
  79. server_name,
  80. expected_sni,
  81. "Expected SNI %s but got %s" % (expected_sni, server_name),
  82. )
  83. return http_protocol
  84. def _test_request_direct_connection(self, agent, scheme, hostname, path):
  85. """Runs a test case for a direct connection not going through a proxy.
  86. Args:
  87. agent (ProxyAgent): the proxy agent being tested
  88. scheme (bytes): expected to be either "http" or "https"
  89. hostname (bytes): the hostname to connect to in the test
  90. path (bytes): the path to connect to in the test
  91. """
  92. is_https = scheme == b"https"
  93. self.reactor.lookups[hostname.decode()] = "1.2.3.4"
  94. d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path)
  95. # there should be a pending TCP connection
  96. clients = self.reactor.tcpClients
  97. self.assertEqual(len(clients), 1)
  98. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  99. self.assertEqual(host, "1.2.3.4")
  100. self.assertEqual(port, 443 if is_https else 80)
  101. # make a test server, and wire up the client
  102. http_server = self._make_connection(
  103. client_factory,
  104. _get_test_protocol_factory(),
  105. ssl=is_https,
  106. expected_sni=hostname if is_https else None,
  107. )
  108. # the FakeTransport is async, so we need to pump the reactor
  109. self.reactor.advance(0)
  110. # now there should be a pending request
  111. self.assertEqual(len(http_server.requests), 1)
  112. request = http_server.requests[0]
  113. self.assertEqual(request.method, b"GET")
  114. self.assertEqual(request.path, b"/" + path)
  115. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname])
  116. request.write(b"result")
  117. request.finish()
  118. self.reactor.advance(0)
  119. resp = self.successResultOf(d)
  120. body = self.successResultOf(treq.content(resp))
  121. self.assertEqual(body, b"result")
  122. def test_http_request(self):
  123. agent = ProxyAgent(self.reactor)
  124. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  125. def test_https_request(self):
  126. agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
  127. self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
  128. def test_http_request_use_proxy_empty_environment(self):
  129. agent = ProxyAgent(self.reactor, use_proxy=True)
  130. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  131. @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
  132. def test_http_request_via_uppercase_no_proxy(self):
  133. agent = ProxyAgent(self.reactor, use_proxy=True)
  134. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  135. @patch.dict(
  136. os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
  137. )
  138. def test_http_request_via_no_proxy(self):
  139. agent = ProxyAgent(self.reactor, use_proxy=True)
  140. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  141. @patch.dict(
  142. os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
  143. )
  144. def test_https_request_via_no_proxy(self):
  145. agent = ProxyAgent(
  146. self.reactor,
  147. contextFactory=get_test_https_policy(),
  148. use_proxy=True,
  149. )
  150. self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
  151. @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
  152. def test_http_request_via_no_proxy_star(self):
  153. agent = ProxyAgent(self.reactor, use_proxy=True)
  154. self._test_request_direct_connection(agent, b"http", b"test.com", b"")
  155. @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
  156. def test_https_request_via_no_proxy_star(self):
  157. agent = ProxyAgent(
  158. self.reactor,
  159. contextFactory=get_test_https_policy(),
  160. use_proxy=True,
  161. )
  162. self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
  163. @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
  164. def test_http_request_via_proxy(self):
  165. agent = ProxyAgent(self.reactor, use_proxy=True)
  166. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  167. d = agent.request(b"GET", b"http://test.com")
  168. # there should be a pending TCP connection
  169. clients = self.reactor.tcpClients
  170. self.assertEqual(len(clients), 1)
  171. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  172. self.assertEqual(host, "1.2.3.5")
  173. self.assertEqual(port, 8888)
  174. # make a test server, and wire up the client
  175. http_server = self._make_connection(
  176. client_factory, _get_test_protocol_factory()
  177. )
  178. # the FakeTransport is async, so we need to pump the reactor
  179. self.reactor.advance(0)
  180. # now there should be a pending request
  181. self.assertEqual(len(http_server.requests), 1)
  182. request = http_server.requests[0]
  183. self.assertEqual(request.method, b"GET")
  184. self.assertEqual(request.path, b"http://test.com")
  185. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  186. request.write(b"result")
  187. request.finish()
  188. self.reactor.advance(0)
  189. resp = self.successResultOf(d)
  190. body = self.successResultOf(treq.content(resp))
  191. self.assertEqual(body, b"result")
  192. @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
  193. def test_https_request_via_proxy(self):
  194. """Tests that TLS-encrypted requests can be made through a proxy"""
  195. self._do_https_request_via_proxy(auth_credentials=None)
  196. @patch.dict(
  197. os.environ,
  198. {"https_proxy": "bob:pinkponies@proxy.com", "no_proxy": "unused.com"},
  199. )
  200. def test_https_request_via_proxy_with_auth(self):
  201. """Tests that authenticated, TLS-encrypted requests can be made through a proxy"""
  202. self._do_https_request_via_proxy(auth_credentials="bob:pinkponies")
  203. def _do_https_request_via_proxy(
  204. self,
  205. auth_credentials: Optional[str] = None,
  206. ):
  207. agent = ProxyAgent(
  208. self.reactor,
  209. contextFactory=get_test_https_policy(),
  210. use_proxy=True,
  211. )
  212. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  213. d = agent.request(b"GET", b"https://test.com/abc")
  214. # there should be a pending TCP connection
  215. clients = self.reactor.tcpClients
  216. self.assertEqual(len(clients), 1)
  217. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  218. self.assertEqual(host, "1.2.3.5")
  219. self.assertEqual(port, 1080)
  220. # make a test HTTP server, and wire up the client
  221. proxy_server = self._make_connection(
  222. client_factory, _get_test_protocol_factory()
  223. )
  224. # fish the transports back out so that we can do the old switcheroo
  225. s2c_transport = proxy_server.transport
  226. client_protocol = s2c_transport.other
  227. c2s_transport = client_protocol.transport
  228. # the FakeTransport is async, so we need to pump the reactor
  229. self.reactor.advance(0)
  230. # now there should be a pending CONNECT request
  231. self.assertEqual(len(proxy_server.requests), 1)
  232. request = proxy_server.requests[0]
  233. self.assertEqual(request.method, b"CONNECT")
  234. self.assertEqual(request.path, b"test.com:443")
  235. # Check whether auth credentials have been supplied to the proxy
  236. proxy_auth_header_values = request.requestHeaders.getRawHeaders(
  237. b"Proxy-Authorization"
  238. )
  239. if auth_credentials is not None:
  240. # Compute the correct header value for Proxy-Authorization
  241. encoded_credentials = base64.b64encode(b"bob:pinkponies")
  242. expected_header_value = b"Basic " + encoded_credentials
  243. # Validate the header's value
  244. self.assertIn(expected_header_value, proxy_auth_header_values)
  245. else:
  246. # Check that the Proxy-Authorization header has not been supplied to the proxy
  247. self.assertIsNone(proxy_auth_header_values)
  248. # tell the proxy server not to close the connection
  249. proxy_server.persistent = True
  250. # this just stops the http Request trying to do a chunked response
  251. # request.setHeader(b"Content-Length", b"0")
  252. request.finish()
  253. # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
  254. ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
  255. ssl_protocol = ssl_factory.buildProtocol(None)
  256. http_server = ssl_protocol.wrappedProtocol
  257. ssl_protocol.makeConnection(
  258. FakeTransport(client_protocol, self.reactor, ssl_protocol)
  259. )
  260. c2s_transport.other = ssl_protocol
  261. self.reactor.advance(0)
  262. server_name = ssl_protocol._tlsConnection.get_servername()
  263. expected_sni = b"test.com"
  264. self.assertEqual(
  265. server_name,
  266. expected_sni,
  267. "Expected SNI %s but got %s" % (expected_sni, server_name),
  268. )
  269. # now there should be a pending request
  270. self.assertEqual(len(http_server.requests), 1)
  271. request = http_server.requests[0]
  272. self.assertEqual(request.method, b"GET")
  273. self.assertEqual(request.path, b"/abc")
  274. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  275. # Check that the destination server DID NOT receive proxy credentials
  276. proxy_auth_header_values = request.requestHeaders.getRawHeaders(
  277. b"Proxy-Authorization"
  278. )
  279. self.assertIsNone(proxy_auth_header_values)
  280. request.write(b"result")
  281. request.finish()
  282. self.reactor.advance(0)
  283. resp = self.successResultOf(d)
  284. body = self.successResultOf(treq.content(resp))
  285. self.assertEqual(body, b"result")
  286. @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
  287. def test_http_request_via_proxy_with_blacklist(self):
  288. # The blacklist includes the configured proxy IP.
  289. agent = ProxyAgent(
  290. BlacklistingReactorWrapper(
  291. self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
  292. ),
  293. self.reactor,
  294. use_proxy=True,
  295. )
  296. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  297. d = agent.request(b"GET", b"http://test.com")
  298. # there should be a pending TCP connection
  299. clients = self.reactor.tcpClients
  300. self.assertEqual(len(clients), 1)
  301. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  302. self.assertEqual(host, "1.2.3.5")
  303. self.assertEqual(port, 8888)
  304. # make a test server, and wire up the client
  305. http_server = self._make_connection(
  306. client_factory, _get_test_protocol_factory()
  307. )
  308. # the FakeTransport is async, so we need to pump the reactor
  309. self.reactor.advance(0)
  310. # now there should be a pending request
  311. self.assertEqual(len(http_server.requests), 1)
  312. request = http_server.requests[0]
  313. self.assertEqual(request.method, b"GET")
  314. self.assertEqual(request.path, b"http://test.com")
  315. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  316. request.write(b"result")
  317. request.finish()
  318. self.reactor.advance(0)
  319. resp = self.successResultOf(d)
  320. body = self.successResultOf(treq.content(resp))
  321. self.assertEqual(body, b"result")
  322. @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
  323. def test_https_request_via_uppercase_proxy_with_blacklist(self):
  324. # The blacklist includes the configured proxy IP.
  325. agent = ProxyAgent(
  326. BlacklistingReactorWrapper(
  327. self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
  328. ),
  329. self.reactor,
  330. contextFactory=get_test_https_policy(),
  331. use_proxy=True,
  332. )
  333. self.reactor.lookups["proxy.com"] = "1.2.3.5"
  334. d = agent.request(b"GET", b"https://test.com/abc")
  335. # there should be a pending TCP connection
  336. clients = self.reactor.tcpClients
  337. self.assertEqual(len(clients), 1)
  338. (host, port, client_factory, _timeout, _bindAddress) = clients[0]
  339. self.assertEqual(host, "1.2.3.5")
  340. self.assertEqual(port, 1080)
  341. # make a test HTTP server, and wire up the client
  342. proxy_server = self._make_connection(
  343. client_factory, _get_test_protocol_factory()
  344. )
  345. # fish the transports back out so that we can do the old switcheroo
  346. s2c_transport = proxy_server.transport
  347. client_protocol = s2c_transport.other
  348. c2s_transport = client_protocol.transport
  349. # the FakeTransport is async, so we need to pump the reactor
  350. self.reactor.advance(0)
  351. # now there should be a pending CONNECT request
  352. self.assertEqual(len(proxy_server.requests), 1)
  353. request = proxy_server.requests[0]
  354. self.assertEqual(request.method, b"CONNECT")
  355. self.assertEqual(request.path, b"test.com:443")
  356. # tell the proxy server not to close the connection
  357. proxy_server.persistent = True
  358. # this just stops the http Request trying to do a chunked response
  359. # request.setHeader(b"Content-Length", b"0")
  360. request.finish()
  361. # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
  362. ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
  363. ssl_protocol = ssl_factory.buildProtocol(None)
  364. http_server = ssl_protocol.wrappedProtocol
  365. ssl_protocol.makeConnection(
  366. FakeTransport(client_protocol, self.reactor, ssl_protocol)
  367. )
  368. c2s_transport.other = ssl_protocol
  369. self.reactor.advance(0)
  370. server_name = ssl_protocol._tlsConnection.get_servername()
  371. expected_sni = b"test.com"
  372. self.assertEqual(
  373. server_name,
  374. expected_sni,
  375. "Expected SNI %s but got %s" % (expected_sni, server_name),
  376. )
  377. # now there should be a pending request
  378. self.assertEqual(len(http_server.requests), 1)
  379. request = http_server.requests[0]
  380. self.assertEqual(request.method, b"GET")
  381. self.assertEqual(request.path, b"/abc")
  382. self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
  383. request.write(b"result")
  384. request.finish()
  385. self.reactor.advance(0)
  386. resp = self.successResultOf(d)
  387. body = self.successResultOf(treq.content(resp))
  388. self.assertEqual(body, b"result")
  389. def _wrap_server_factory_for_tls(factory, sanlist=None):
  390. """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
  391. The resultant factory will create a TLS server which presents a certificate
  392. signed by our test CA, valid for the domains in `sanlist`
  393. Args:
  394. factory (interfaces.IProtocolFactory): protocol factory to wrap
  395. sanlist (iterable[bytes]): list of domains the cert should be valid for
  396. Returns:
  397. interfaces.IProtocolFactory
  398. """
  399. if sanlist is None:
  400. sanlist = [b"DNS:test.com"]
  401. connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
  402. return TLSMemoryBIOFactory(
  403. connection_creator, isClient=False, wrappedFactory=factory
  404. )
  405. def _get_test_protocol_factory():
  406. """Get a protocol Factory which will build an HTTPChannel
  407. Returns:
  408. interfaces.IProtocolFactory
  409. """
  410. server_factory = Factory.forProtocol(HTTPChannel)
  411. # Request.finish expects the factory to have a 'log' method.
  412. server_factory.log = _log_request
  413. return server_factory
  414. def _log_request(request):
  415. """Implements Factory.log, which is expected by Request.finish"""
  416. logger.info("Completed request %s", request)