test_client.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from io import BytesIO
  15. from unittest.mock import Mock
  16. from netaddr import IPSet
  17. from twisted.internet.error import DNSLookupError
  18. from twisted.python.failure import Failure
  19. from twisted.test.proto_helpers import AccumulatingProtocol
  20. from twisted.web.client import Agent, ResponseDone
  21. from twisted.web.iweb import UNKNOWN_LENGTH
  22. from synapse.api.errors import SynapseError
  23. from synapse.http.client import (
  24. BlacklistingAgentWrapper,
  25. BlacklistingReactorWrapper,
  26. BodyExceededMaxSize,
  27. read_body_with_max_size,
  28. )
  29. from tests.server import FakeTransport, get_clock
  30. from tests.unittest import TestCase
  31. class ReadBodyWithMaxSizeTests(TestCase):
  32. def _build_response(self, length=UNKNOWN_LENGTH):
  33. """Start reading the body, returns the response, result and proto"""
  34. response = Mock(length=length)
  35. result = BytesIO()
  36. deferred = read_body_with_max_size(response, result, 6)
  37. # Fish the protocol out of the response.
  38. protocol = response.deliverBody.call_args[0][0]
  39. protocol.transport = Mock()
  40. return result, deferred, protocol
  41. def _assert_error(self, deferred, protocol):
  42. """Ensure that the expected error is received."""
  43. self.assertIsInstance(deferred.result, Failure)
  44. self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
  45. protocol.transport.abortConnection.assert_called_once()
  46. def _cleanup_error(self, deferred):
  47. """Ensure that the error in the Deferred is handled gracefully."""
  48. called = [False]
  49. def errback(f):
  50. called[0] = True
  51. deferred.addErrback(errback)
  52. self.assertTrue(called[0])
  53. def test_no_error(self):
  54. """A response that is NOT too large."""
  55. result, deferred, protocol = self._build_response()
  56. # Start sending data.
  57. protocol.dataReceived(b"12345")
  58. # Close the connection.
  59. protocol.connectionLost(Failure(ResponseDone()))
  60. self.assertEqual(result.getvalue(), b"12345")
  61. self.assertEqual(deferred.result, 5)
  62. def test_too_large(self):
  63. """A response which is too large raises an exception."""
  64. result, deferred, protocol = self._build_response()
  65. # Start sending data.
  66. protocol.dataReceived(b"1234567890")
  67. self.assertEqual(result.getvalue(), b"1234567890")
  68. self._assert_error(deferred, protocol)
  69. self._cleanup_error(deferred)
  70. def test_multiple_packets(self):
  71. """Data should be accumulated through mutliple packets."""
  72. result, deferred, protocol = self._build_response()
  73. # Start sending data.
  74. protocol.dataReceived(b"12")
  75. protocol.dataReceived(b"34")
  76. # Close the connection.
  77. protocol.connectionLost(Failure(ResponseDone()))
  78. self.assertEqual(result.getvalue(), b"1234")
  79. self.assertEqual(deferred.result, 4)
  80. def test_additional_data(self):
  81. """A connection can receive data after being closed."""
  82. result, deferred, protocol = self._build_response()
  83. # Start sending data.
  84. protocol.dataReceived(b"1234567890")
  85. self._assert_error(deferred, protocol)
  86. # More data might have come in.
  87. protocol.dataReceived(b"1234567890")
  88. self.assertEqual(result.getvalue(), b"1234567890")
  89. self._assert_error(deferred, protocol)
  90. self._cleanup_error(deferred)
  91. def test_content_length(self):
  92. """The body shouldn't be read (at all) if the Content-Length header is too large."""
  93. result, deferred, protocol = self._build_response(length=10)
  94. # Deferred shouldn't be called yet.
  95. self.assertFalse(deferred.called)
  96. # Start sending data.
  97. protocol.dataReceived(b"12345")
  98. self._assert_error(deferred, protocol)
  99. self._cleanup_error(deferred)
  100. # The data is never consumed.
  101. self.assertEqual(result.getvalue(), b"")
  102. class BlacklistingAgentTest(TestCase):
  103. def setUp(self):
  104. self.reactor, self.clock = get_clock()
  105. self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
  106. self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
  107. self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
  108. # Configure the reactor's DNS resolver.
  109. for (domain, ip) in (
  110. (self.safe_domain, self.safe_ip),
  111. (self.unsafe_domain, self.unsafe_ip),
  112. (self.allowed_domain, self.allowed_ip),
  113. ):
  114. self.reactor.lookups[domain.decode()] = ip.decode()
  115. self.reactor.lookups[ip.decode()] = ip.decode()
  116. self.ip_whitelist = IPSet([self.allowed_ip.decode()])
  117. self.ip_blacklist = IPSet(["5.0.0.0/8"])
  118. def test_reactor(self):
  119. """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
  120. agent = Agent(
  121. BlacklistingReactorWrapper(
  122. self.reactor,
  123. ip_whitelist=self.ip_whitelist,
  124. ip_blacklist=self.ip_blacklist,
  125. ),
  126. )
  127. # The unsafe domains and IPs should be rejected.
  128. for domain in (self.unsafe_domain, self.unsafe_ip):
  129. self.failureResultOf(
  130. agent.request(b"GET", b"http://" + domain), DNSLookupError
  131. )
  132. # The safe domains IPs should be accepted.
  133. for domain in (
  134. self.safe_domain,
  135. self.allowed_domain,
  136. self.safe_ip,
  137. self.allowed_ip,
  138. ):
  139. d = agent.request(b"GET", b"http://" + domain)
  140. # Grab the latest TCP connection.
  141. (
  142. host,
  143. port,
  144. client_factory,
  145. _timeout,
  146. _bindAddress,
  147. ) = self.reactor.tcpClients[-1]
  148. # Make the connection and pump data through it.
  149. client = client_factory.buildProtocol(None)
  150. server = AccumulatingProtocol()
  151. server.makeConnection(FakeTransport(client, self.reactor))
  152. client.makeConnection(FakeTransport(server, self.reactor))
  153. client.dataReceived(
  154. b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
  155. )
  156. response = self.successResultOf(d)
  157. self.assertEqual(response.code, 200)
  158. def test_agent(self):
  159. """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
  160. agent = BlacklistingAgentWrapper(
  161. Agent(self.reactor),
  162. ip_whitelist=self.ip_whitelist,
  163. ip_blacklist=self.ip_blacklist,
  164. )
  165. # The unsafe IPs should be rejected.
  166. self.failureResultOf(
  167. agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
  168. )
  169. # The safe and unsafe domains and safe IPs should be accepted.
  170. for domain in (
  171. self.safe_domain,
  172. self.unsafe_domain,
  173. self.allowed_domain,
  174. self.safe_ip,
  175. self.allowed_ip,
  176. ):
  177. d = agent.request(b"GET", b"http://" + domain)
  178. # Grab the latest TCP connection.
  179. (
  180. host,
  181. port,
  182. client_factory,
  183. _timeout,
  184. _bindAddress,
  185. ) = self.reactor.tcpClients[-1]
  186. # Make the connection and pump data through it.
  187. client = client_factory.buildProtocol(None)
  188. server = AccumulatingProtocol()
  189. server.makeConnection(FakeTransport(client, self.reactor))
  190. client.makeConnection(FakeTransport(server, self.reactor))
  191. client.dataReceived(
  192. b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
  193. )
  194. response = self.successResultOf(d)
  195. self.assertEqual(response.code, 200)