test_client.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from io import BytesIO
  15. from typing import Tuple, Union
  16. from unittest.mock import Mock
  17. from netaddr import IPSet
  18. from twisted.internet.defer import Deferred
  19. from twisted.internet.error import DNSLookupError
  20. from twisted.python.failure import Failure
  21. from twisted.test.proto_helpers import AccumulatingProtocol
  22. from twisted.web.client import Agent, ResponseDone
  23. from twisted.web.iweb import UNKNOWN_LENGTH
  24. from synapse.api.errors import SynapseError
  25. from synapse.http.client import (
  26. BlacklistingAgentWrapper,
  27. BlacklistingReactorWrapper,
  28. BodyExceededMaxSize,
  29. _DiscardBodyWithMaxSizeProtocol,
  30. read_body_with_max_size,
  31. )
  32. from tests.server import FakeTransport, get_clock
  33. from tests.unittest import TestCase
  34. class ReadBodyWithMaxSizeTests(TestCase):
  35. def _build_response(
  36. self, length: Union[int, str] = UNKNOWN_LENGTH
  37. ) -> Tuple[BytesIO, "Deferred[int]", _DiscardBodyWithMaxSizeProtocol]:
  38. """Start reading the body, returns the response, result and proto"""
  39. response = Mock(length=length)
  40. result = BytesIO()
  41. deferred = read_body_with_max_size(response, result, 6)
  42. # Fish the protocol out of the response.
  43. protocol = response.deliverBody.call_args[0][0]
  44. protocol.transport = Mock()
  45. return result, deferred, protocol
  46. def _assert_error(
  47. self, deferred: "Deferred[int]", protocol: _DiscardBodyWithMaxSizeProtocol
  48. ) -> None:
  49. """Ensure that the expected error is received."""
  50. assert isinstance(deferred.result, Failure)
  51. self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
  52. assert protocol.transport is not None
  53. # type-ignore: presumably abortConnection has been replaced with a Mock.
  54. protocol.transport.abortConnection.assert_called_once() # type: ignore[attr-defined]
  55. def _cleanup_error(self, deferred: "Deferred[int]") -> None:
  56. """Ensure that the error in the Deferred is handled gracefully."""
  57. called = [False]
  58. def errback(f: Failure) -> None:
  59. called[0] = True
  60. deferred.addErrback(errback)
  61. self.assertTrue(called[0])
  62. def test_no_error(self) -> None:
  63. """A response that is NOT too large."""
  64. result, deferred, protocol = self._build_response()
  65. # Start sending data.
  66. protocol.dataReceived(b"12345")
  67. # Close the connection.
  68. protocol.connectionLost(Failure(ResponseDone()))
  69. self.assertEqual(result.getvalue(), b"12345")
  70. self.assertEqual(deferred.result, 5)
  71. def test_too_large(self) -> None:
  72. """A response which is too large raises an exception."""
  73. result, deferred, protocol = self._build_response()
  74. # Start sending data.
  75. protocol.dataReceived(b"1234567890")
  76. self.assertEqual(result.getvalue(), b"1234567890")
  77. self._assert_error(deferred, protocol)
  78. self._cleanup_error(deferred)
  79. def test_multiple_packets(self) -> None:
  80. """Data should be accumulated through mutliple packets."""
  81. result, deferred, protocol = self._build_response()
  82. # Start sending data.
  83. protocol.dataReceived(b"12")
  84. protocol.dataReceived(b"34")
  85. # Close the connection.
  86. protocol.connectionLost(Failure(ResponseDone()))
  87. self.assertEqual(result.getvalue(), b"1234")
  88. self.assertEqual(deferred.result, 4)
  89. def test_additional_data(self) -> None:
  90. """A connection can receive data after being closed."""
  91. result, deferred, protocol = self._build_response()
  92. # Start sending data.
  93. protocol.dataReceived(b"1234567890")
  94. self._assert_error(deferred, protocol)
  95. # More data might have come in.
  96. protocol.dataReceived(b"1234567890")
  97. self.assertEqual(result.getvalue(), b"1234567890")
  98. self._assert_error(deferred, protocol)
  99. self._cleanup_error(deferred)
  100. def test_content_length(self) -> None:
  101. """The body shouldn't be read (at all) if the Content-Length header is too large."""
  102. result, deferred, protocol = self._build_response(length=10)
  103. # Deferred shouldn't be called yet.
  104. self.assertFalse(deferred.called)
  105. # Start sending data.
  106. protocol.dataReceived(b"12345")
  107. self._assert_error(deferred, protocol)
  108. self._cleanup_error(deferred)
  109. # The data is never consumed.
  110. self.assertEqual(result.getvalue(), b"")
  111. class BlacklistingAgentTest(TestCase):
  112. def setUp(self) -> None:
  113. self.reactor, self.clock = get_clock()
  114. self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
  115. self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
  116. self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
  117. # Configure the reactor's DNS resolver.
  118. for domain, ip in (
  119. (self.safe_domain, self.safe_ip),
  120. (self.unsafe_domain, self.unsafe_ip),
  121. (self.allowed_domain, self.allowed_ip),
  122. ):
  123. self.reactor.lookups[domain.decode()] = ip.decode()
  124. self.reactor.lookups[ip.decode()] = ip.decode()
  125. self.ip_whitelist = IPSet([self.allowed_ip.decode()])
  126. self.ip_blacklist = IPSet(["5.0.0.0/8"])
  127. def test_reactor(self) -> None:
  128. """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
  129. agent = Agent(
  130. BlacklistingReactorWrapper(
  131. self.reactor,
  132. ip_whitelist=self.ip_whitelist,
  133. ip_blacklist=self.ip_blacklist,
  134. ),
  135. )
  136. # The unsafe domains and IPs should be rejected.
  137. for domain in (self.unsafe_domain, self.unsafe_ip):
  138. self.failureResultOf(
  139. agent.request(b"GET", b"http://" + domain), DNSLookupError
  140. )
  141. # The safe domains IPs should be accepted.
  142. for domain in (
  143. self.safe_domain,
  144. self.allowed_domain,
  145. self.safe_ip,
  146. self.allowed_ip,
  147. ):
  148. d = agent.request(b"GET", b"http://" + domain)
  149. # Grab the latest TCP connection.
  150. (
  151. host,
  152. port,
  153. client_factory,
  154. _timeout,
  155. _bindAddress,
  156. ) = self.reactor.tcpClients[-1]
  157. # Make the connection and pump data through it.
  158. client = client_factory.buildProtocol(None)
  159. server = AccumulatingProtocol()
  160. server.makeConnection(FakeTransport(client, self.reactor))
  161. client.makeConnection(FakeTransport(server, self.reactor))
  162. client.dataReceived(
  163. b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
  164. )
  165. response = self.successResultOf(d)
  166. self.assertEqual(response.code, 200)
  167. def test_agent(self) -> None:
  168. """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
  169. agent = BlacklistingAgentWrapper(
  170. Agent(self.reactor),
  171. ip_blacklist=self.ip_blacklist,
  172. ip_whitelist=self.ip_whitelist,
  173. )
  174. # The unsafe IPs should be rejected.
  175. self.failureResultOf(
  176. agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
  177. )
  178. # The safe and unsafe domains and safe IPs should be accepted.
  179. for domain in (
  180. self.safe_domain,
  181. self.unsafe_domain,
  182. self.allowed_domain,
  183. self.safe_ip,
  184. self.allowed_ip,
  185. ):
  186. d = agent.request(b"GET", b"http://" + domain)
  187. # Grab the latest TCP connection.
  188. (
  189. host,
  190. port,
  191. client_factory,
  192. _timeout,
  193. _bindAddress,
  194. ) = self.reactor.tcpClients[-1]
  195. # Make the connection and pump data through it.
  196. client = client_factory.buildProtocol(None)
  197. server = AccumulatingProtocol()
  198. server.makeConnection(FakeTransport(client, self.reactor))
  199. client.makeConnection(FakeTransport(server, self.reactor))
  200. client.dataReceived(
  201. b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
  202. )
  203. response = self.successResultOf(d)
  204. self.assertEqual(response.code, 200)