1
0

test_matrixfederationclient.py 21 KB


  1. # Copyright 2018 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from unittest.mock import Mock
  15. from netaddr import IPSet
  16. from parameterized import parameterized
  17. from twisted.internet import defer
  18. from twisted.internet.defer import TimeoutError
  19. from twisted.internet.error import ConnectingCancelledError, DNSLookupError
  20. from twisted.test.proto_helpers import StringTransport
  21. from twisted.web.client import ResponseNeverReceived
  22. from twisted.web.http import HTTPChannel
  23. from synapse.api.errors import RequestSendFailed
  24. from synapse.http.matrixfederationclient import (
  25. JsonParser,
  26. MatrixFederationHttpClient,
  27. MatrixFederationRequest,
  28. )
  29. from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_context
  30. from tests.server import FakeTransport
  31. from tests.unittest import HomeserverTestCase
  32. def check_logcontext(context):
  33. current = current_context()
  34. if current is not context:
  35. raise AssertionError("Expected logcontext %s but was %s" % (context, current))
  36. class FederationClientTests(HomeserverTestCase):
  37. def make_homeserver(self, reactor, clock):
  38. hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
  39. return hs
  40. def prepare(self, reactor, clock, homeserver):
  41. self.cl = MatrixFederationHttpClient(self.hs, None)
  42. self.reactor.lookups["testserv"] = "1.2.3.4"
  43. def test_client_get(self):
  44. """
  45. happy-path test of a GET request
  46. """
  47. @defer.inlineCallbacks
  48. def do_request():
  49. with LoggingContext("one") as context:
  50. fetch_d = defer.ensureDeferred(
  51. self.cl.get_json("testserv:8008", "foo/bar")
  52. )
  53. # Nothing happened yet
  54. self.assertNoResult(fetch_d)
  55. # should have reset logcontext to the sentinel
  56. check_logcontext(SENTINEL_CONTEXT)
  57. try:
  58. fetch_res = yield fetch_d
  59. return fetch_res
  60. finally:
  61. check_logcontext(context)
  62. test_d = do_request()
  63. self.pump()
  64. # Nothing happened yet
  65. self.assertNoResult(test_d)
  66. # Make sure treq is trying to connect
  67. clients = self.reactor.tcpClients
  68. self.assertEqual(len(clients), 1)
  69. (host, port, factory, _timeout, _bindAddress) = clients[0]
  70. self.assertEqual(host, "1.2.3.4")
  71. self.assertEqual(port, 8008)
  72. # complete the connection and wire it up to a fake transport
  73. protocol = factory.buildProtocol(None)
  74. transport = StringTransport()
  75. protocol.makeConnection(transport)
  76. # that should have made it send the request to the transport
  77. self.assertRegex(transport.value(), b"^GET /foo/bar")
  78. self.assertRegex(transport.value(), b"Host: testserv:8008")
  79. # Deferred is still without a result
  80. self.assertNoResult(test_d)
  81. # Send it the HTTP response
  82. res_json = b'{ "a": 1 }'
  83. protocol.dataReceived(
  84. b"HTTP/1.1 200 OK\r\n"
  85. b"Server: Fake\r\n"
  86. b"Content-Type: application/json\r\n"
  87. b"Content-Length: %i\r\n"
  88. b"\r\n"
  89. b"%s" % (len(res_json), res_json)
  90. )
  91. self.pump()
  92. res = self.successResultOf(test_d)
  93. # check the response is as expected
  94. self.assertEqual(res, {"a": 1})
  95. def test_dns_error(self):
  96. """
  97. If the DNS lookup returns an error, it will bubble up.
  98. """
  99. d = defer.ensureDeferred(
  100. self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000)
  101. )
  102. self.pump()
  103. f = self.failureResultOf(d)
  104. self.assertIsInstance(f.value, RequestSendFailed)
  105. self.assertIsInstance(f.value.inner_exception, DNSLookupError)
  106. def test_client_connection_refused(self):
  107. d = defer.ensureDeferred(
  108. self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
  109. )
  110. self.pump()
  111. # Nothing happened yet
  112. self.assertNoResult(d)
  113. clients = self.reactor.tcpClients
  114. self.assertEqual(len(clients), 1)
  115. (host, port, factory, _timeout, _bindAddress) = clients[0]
  116. self.assertEqual(host, "1.2.3.4")
  117. self.assertEqual(port, 8008)
  118. e = Exception("go away")
  119. factory.clientConnectionFailed(None, e)
  120. self.pump(0.5)
  121. f = self.failureResultOf(d)
  122. self.assertIsInstance(f.value, RequestSendFailed)
  123. self.assertIs(f.value.inner_exception, e)
  124. def test_client_never_connect(self):
  125. """
  126. If the HTTP request is not connected and is timed out, it'll give a
  127. ConnectingCancelledError or TimeoutError.
  128. """
  129. d = defer.ensureDeferred(
  130. self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
  131. )
  132. self.pump()
  133. # Nothing happened yet
  134. self.assertNoResult(d)
  135. # Make sure treq is trying to connect
  136. clients = self.reactor.tcpClients
  137. self.assertEqual(len(clients), 1)
  138. self.assertEqual(clients[0][0], "1.2.3.4")
  139. self.assertEqual(clients[0][1], 8008)
  140. # Deferred is still without a result
  141. self.assertNoResult(d)
  142. # Push by enough to time it out
  143. self.reactor.advance(10.5)
  144. f = self.failureResultOf(d)
  145. self.assertIsInstance(f.value, RequestSendFailed)
  146. self.assertIsInstance(
  147. f.value.inner_exception, (ConnectingCancelledError, TimeoutError)
  148. )
  149. def test_client_connect_no_response(self):
  150. """
  151. If the HTTP request is connected, but gets no response before being
  152. timed out, it'll give a ResponseNeverReceived.
  153. """
  154. d = defer.ensureDeferred(
  155. self.cl.get_json("testserv:8008", "foo/bar", timeout=10000)
  156. )
  157. self.pump()
  158. # Nothing happened yet
  159. self.assertNoResult(d)
  160. # Make sure treq is trying to connect
  161. clients = self.reactor.tcpClients
  162. self.assertEqual(len(clients), 1)
  163. self.assertEqual(clients[0][0], "1.2.3.4")
  164. self.assertEqual(clients[0][1], 8008)
  165. conn = Mock()
  166. client = clients[0][2].buildProtocol(None)
  167. client.makeConnection(conn)
  168. # Deferred is still without a result
  169. self.assertNoResult(d)
  170. # Push by enough to time it out
  171. self.reactor.advance(10.5)
  172. f = self.failureResultOf(d)
  173. self.assertIsInstance(f.value, RequestSendFailed)
  174. self.assertIsInstance(f.value.inner_exception, ResponseNeverReceived)
  175. def test_client_ip_range_blacklist(self):
  176. """Ensure that Synapse does not try to connect to blacklisted IPs"""
  177. # Set up the ip_range blacklist
  178. self.hs.config.server.federation_ip_range_blacklist = IPSet(
  179. ["127.0.0.0/8", "fe80::/64"]
  180. )
  181. self.reactor.lookups["internal"] = "127.0.0.1"
  182. self.reactor.lookups["internalv6"] = "fe80:0:0:0:0:8a2e:370:7337"
  183. self.reactor.lookups["fine"] = "10.20.30.40"
  184. cl = MatrixFederationHttpClient(self.hs, None)
  185. # Try making a GET request to a blacklisted IPv4 address
  186. # ------------------------------------------------------
  187. # Make the request
  188. d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000))
  189. # Nothing happened yet
  190. self.assertNoResult(d)
  191. self.pump(1)
  192. # Check that it was unable to resolve the address
  193. clients = self.reactor.tcpClients
  194. self.assertEqual(len(clients), 0)
  195. f = self.failureResultOf(d)
  196. self.assertIsInstance(f.value, RequestSendFailed)
  197. self.assertIsInstance(f.value.inner_exception, DNSLookupError)
  198. # Try making a POST request to a blacklisted IPv6 address
  199. # -------------------------------------------------------
  200. # Make the request
  201. d = defer.ensureDeferred(
  202. cl.post_json("internalv6:8008", "foo/bar", timeout=10000)
  203. )
  204. # Nothing has happened yet
  205. self.assertNoResult(d)
  206. # Move the reactor forwards
  207. self.pump(1)
  208. # Check that it was unable to resolve the address
  209. clients = self.reactor.tcpClients
  210. self.assertEqual(len(clients), 0)
  211. # Check that it was due to a blacklisted DNS lookup
  212. f = self.failureResultOf(d, RequestSendFailed)
  213. self.assertIsInstance(f.value.inner_exception, DNSLookupError)
  214. # Try making a GET request to a non-blacklisted IPv4 address
  215. # ----------------------------------------------------------
  216. # Make the request
  217. d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000))
  218. # Nothing has happened yet
  219. self.assertNoResult(d)
  220. # Move the reactor forwards
  221. self.pump(1)
  222. # Check that it was able to resolve the address
  223. clients = self.reactor.tcpClients
  224. self.assertNotEqual(len(clients), 0)
  225. # Connection will still fail as this IP address does not resolve to anything
  226. f = self.failureResultOf(d, RequestSendFailed)
  227. self.assertIsInstance(f.value.inner_exception, ConnectingCancelledError)
  228. def test_client_gets_headers(self):
  229. """
  230. Once the client gets the headers, _request returns successfully.
  231. """
  232. request = MatrixFederationRequest(
  233. method="GET", destination="testserv:8008", path="foo/bar"
  234. )
  235. d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000))
  236. self.pump()
  237. conn = Mock()
  238. clients = self.reactor.tcpClients
  239. client = clients[0][2].buildProtocol(None)
  240. client.makeConnection(conn)
  241. # Deferred does not have a result
  242. self.assertNoResult(d)
  243. # Send it the HTTP response
  244. client.dataReceived(b"HTTP/1.1 200 OK\r\nServer: Fake\r\n\r\n")
  245. # We should get a successful response
  246. r = self.successResultOf(d)
  247. self.assertEqual(r.code, 200)
  248. @parameterized.expand(["get_json", "post_json", "delete_json", "put_json"])
  249. def test_timeout_reading_body(self, method_name: str):
  250. """
  251. If the HTTP request is connected, but gets no response before being
  252. timed out, it'll give a RequestSendFailed with can_retry.
  253. """
  254. method = getattr(self.cl, method_name)
  255. d = defer.ensureDeferred(method("testserv:8008", "foo/bar", timeout=10000))
  256. self.pump()
  257. conn = Mock()
  258. clients = self.reactor.tcpClients
  259. client = clients[0][2].buildProtocol(None)
  260. client.makeConnection(conn)
  261. # Deferred does not have a result
  262. self.assertNoResult(d)
  263. # Send it the HTTP response
  264. client.dataReceived(
  265. b"HTTP/1.1 200 OK\r\nContent-Type: application/json\r\n"
  266. b"Server: Fake\r\n\r\n"
  267. )
  268. # Push by enough to time it out
  269. self.reactor.advance(10.5)
  270. f = self.failureResultOf(d)
  271. self.assertIsInstance(f.value, RequestSendFailed)
  272. self.assertTrue(f.value.can_retry)
  273. self.assertIsInstance(f.value.inner_exception, defer.TimeoutError)
  274. def test_client_requires_trailing_slashes(self):
  275. """
  276. If a connection is made to a client but the client rejects it due to
  277. requiring a trailing slash. We need to retry the request with a
  278. trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622.
  279. """
  280. d = defer.ensureDeferred(
  281. self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
  282. )
  283. # Send the request
  284. self.pump()
  285. # there should have been a call to connectTCP
  286. clients = self.reactor.tcpClients
  287. self.assertEqual(len(clients), 1)
  288. (_host, _port, factory, _timeout, _bindAddress) = clients[0]
  289. # complete the connection and wire it up to a fake transport
  290. client = factory.buildProtocol(None)
  291. conn = StringTransport()
  292. client.makeConnection(conn)
  293. # that should have made it send the request to the connection
  294. self.assertRegex(conn.value(), b"^GET /foo/bar")
  295. # Clear the original request data before sending a response
  296. conn.clear()
  297. # Send the HTTP response
  298. client.dataReceived(
  299. b"HTTP/1.1 400 Bad Request\r\n"
  300. b"Content-Type: application/json\r\n"
  301. b"Content-Length: 59\r\n"
  302. b"\r\n"
  303. b'{"errcode":"M_UNRECOGNIZED","error":"Unrecognized request"}'
  304. )
  305. # We should get another request with a trailing slash
  306. self.assertRegex(conn.value(), b"^GET /foo/bar/")
  307. # Send a happy response this time
  308. client.dataReceived(
  309. b"HTTP/1.1 200 OK\r\n"
  310. b"Content-Type: application/json\r\n"
  311. b"Content-Length: 2\r\n"
  312. b"\r\n"
  313. b"{}"
  314. )
  315. # We should get a successful response
  316. r = self.successResultOf(d)
  317. self.assertEqual(r, {})
  318. def test_client_does_not_retry_on_400_plus(self):
  319. """
  320. Another test for trailing slashes but now test that we don't retry on
  321. trailing slashes on a non-400/M_UNRECOGNIZED response.
  322. See test_client_requires_trailing_slashes() for context.
  323. """
  324. d = defer.ensureDeferred(
  325. self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True)
  326. )
  327. # Send the request
  328. self.pump()
  329. # there should have been a call to connectTCP
  330. clients = self.reactor.tcpClients
  331. self.assertEqual(len(clients), 1)
  332. (_host, _port, factory, _timeout, _bindAddress) = clients[0]
  333. # complete the connection and wire it up to a fake transport
  334. client = factory.buildProtocol(None)
  335. conn = StringTransport()
  336. client.makeConnection(conn)
  337. # that should have made it send the request to the connection
  338. self.assertRegex(conn.value(), b"^GET /foo/bar")
  339. # Clear the original request data before sending a response
  340. conn.clear()
  341. # Send the HTTP response
  342. client.dataReceived(
  343. b"HTTP/1.1 404 Not Found\r\n"
  344. b"Content-Type: application/json\r\n"
  345. b"Content-Length: 2\r\n"
  346. b"\r\n"
  347. b"{}"
  348. )
  349. # We should not get another request
  350. self.assertEqual(conn.value(), b"")
  351. # We should get a 404 failure response
  352. self.failureResultOf(d)
  353. def test_client_sends_body(self):
  354. defer.ensureDeferred(
  355. self.cl.post_json(
  356. "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}
  357. )
  358. )
  359. self.pump()
  360. clients = self.reactor.tcpClients
  361. self.assertEqual(len(clients), 1)
  362. client = clients[0][2].buildProtocol(None)
  363. server = HTTPChannel()
  364. client.makeConnection(FakeTransport(server, self.reactor))
  365. server.makeConnection(FakeTransport(client, self.reactor))
  366. self.pump(0.1)
  367. self.assertEqual(len(server.requests), 1)
  368. request = server.requests[0]
  369. content = request.content.read()
  370. self.assertEqual(content, b'{"a":"b"}')
  371. def test_closes_connection(self):
  372. """Check that the client closes unused HTTP connections"""
  373. d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
  374. self.pump()
  375. # there should have been a call to connectTCP
  376. clients = self.reactor.tcpClients
  377. self.assertEqual(len(clients), 1)
  378. (_host, _port, factory, _timeout, _bindAddress) = clients[0]
  379. # complete the connection and wire it up to a fake transport
  380. client = factory.buildProtocol(None)
  381. conn = StringTransport()
  382. client.makeConnection(conn)
  383. # that should have made it send the request to the connection
  384. self.assertRegex(conn.value(), b"^GET /foo/bar")
  385. # Send the HTTP response
  386. client.dataReceived(
  387. b"HTTP/1.1 200 OK\r\n"
  388. b"Content-Type: application/json\r\n"
  389. b"Content-Length: 2\r\n"
  390. b"\r\n"
  391. b"{}"
  392. )
  393. # We should get a successful response
  394. r = self.successResultOf(d)
  395. self.assertEqual(r, {})
  396. self.assertFalse(conn.disconnecting)
  397. # wait for a while
  398. self.reactor.advance(120)
  399. self.assertTrue(conn.disconnecting)
  400. @parameterized.expand([(b"",), (b"foo",), (b'{"a": Infinity}',)])
  401. def test_json_error(self, return_value):
  402. """
  403. Test what happens if invalid JSON is returned from the remote endpoint.
  404. """
  405. test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
  406. self.pump()
  407. # Nothing happened yet
  408. self.assertNoResult(test_d)
  409. # Make sure treq is trying to connect
  410. clients = self.reactor.tcpClients
  411. self.assertEqual(len(clients), 1)
  412. (host, port, factory, _timeout, _bindAddress) = clients[0]
  413. self.assertEqual(host, "1.2.3.4")
  414. self.assertEqual(port, 8008)
  415. # complete the connection and wire it up to a fake transport
  416. protocol = factory.buildProtocol(None)
  417. transport = StringTransport()
  418. protocol.makeConnection(transport)
  419. # that should have made it send the request to the transport
  420. self.assertRegex(transport.value(), b"^GET /foo/bar")
  421. self.assertRegex(transport.value(), b"Host: testserv:8008")
  422. # Deferred is still without a result
  423. self.assertNoResult(test_d)
  424. # Send it the HTTP response
  425. protocol.dataReceived(
  426. b"HTTP/1.1 200 OK\r\n"
  427. b"Server: Fake\r\n"
  428. b"Content-Type: application/json\r\n"
  429. b"Content-Length: %i\r\n"
  430. b"\r\n"
  431. b"%s" % (len(return_value), return_value)
  432. )
  433. self.pump()
  434. f = self.failureResultOf(test_d)
  435. self.assertIsInstance(f.value, RequestSendFailed)
  436. def test_too_big(self):
  437. """
  438. Test what happens if a huge response is returned from the remote endpoint.
  439. """
  440. test_d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar"))
  441. self.pump()
  442. # Nothing happened yet
  443. self.assertNoResult(test_d)
  444. # Make sure treq is trying to connect
  445. clients = self.reactor.tcpClients
  446. self.assertEqual(len(clients), 1)
  447. (host, port, factory, _timeout, _bindAddress) = clients[0]
  448. self.assertEqual(host, "1.2.3.4")
  449. self.assertEqual(port, 8008)
  450. # complete the connection and wire it up to a fake transport
  451. protocol = factory.buildProtocol(None)
  452. transport = StringTransport()
  453. protocol.makeConnection(transport)
  454. # that should have made it send the request to the transport
  455. self.assertRegex(transport.value(), b"^GET /foo/bar")
  456. self.assertRegex(transport.value(), b"Host: testserv:8008")
  457. # Deferred is still without a result
  458. self.assertNoResult(test_d)
  459. # Send it a huge HTTP response
  460. protocol.dataReceived(
  461. b"HTTP/1.1 200 OK\r\n"
  462. b"Server: Fake\r\n"
  463. b"Content-Type: application/json\r\n"
  464. b"\r\n"
  465. )
  466. self.pump()
  467. # should still be waiting
  468. self.assertNoResult(test_d)
  469. sent = 0
  470. chunk_size = 1024 * 512
  471. while not test_d.called:
  472. protocol.dataReceived(b"a" * chunk_size)
  473. sent += chunk_size
  474. self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
  475. self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
  476. f = self.failureResultOf(test_d)
  477. self.assertIsInstance(f.value, RequestSendFailed)
  478. self.assertTrue(transport.disconnecting)
  479. def test_build_auth_headers_rejects_falsey_destinations(self) -> None:
  480. with self.assertRaises(ValueError):
  481. self.cl.build_auth_headers(None, b"GET", b"https://example.com")
  482. with self.assertRaises(ValueError):
  483. self.cl.build_auth_headers(b"", b"GET", b"https://example.com")
  484. with self.assertRaises(ValueError):
  485. self.cl.build_auth_headers(
  486. None, b"GET", b"https://example.com", destination_is=b""
  487. )
  488. with self.assertRaises(ValueError):
  489. self.cl.build_auth_headers(
  490. b"", b"GET", b"https://example.com", destination_is=b""
  491. )