1
0

test_blacklisting.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software
  9. # distributed under the License is distributed on an "AS IS" BASIS,
  10. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. # See the License for the specific language governing permissions and
  12. # limitations under the License.
  13. from unittest.mock import patch
  14. from twisted.internet.error import DNSLookupError
  15. from twisted.test.proto_helpers import StringTransport
  16. from twisted.trial.unittest import TestCase
  17. from twisted.web.client import Agent
  18. from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
  19. from sydent.http.srvresolver import Server
  20. from tests.utils import AsyncMock, make_request, make_sydent
  21. class BlacklistingAgentTest(TestCase):
  22. def setUp(self):
  23. config = {
  24. "general": {
  25. "ip.blacklist": "5.0.0.0/8",
  26. "ip.whitelist": "5.1.1.1",
  27. },
  28. }
  29. self.sydent = make_sydent(test_config=config)
  30. self.reactor = self.sydent.reactor
  31. self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
  32. self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
  33. self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
  34. # Configure the reactor's DNS resolver.
  35. for (domain, ip) in (
  36. (self.safe_domain, self.safe_ip),
  37. (self.unsafe_domain, self.unsafe_ip),
  38. (self.allowed_domain, self.allowed_ip),
  39. ):
  40. self.reactor.lookups[domain.decode()] = ip.decode()
  41. self.reactor.lookups[ip.decode()] = ip.decode()
  42. self.ip_whitelist = self.sydent.config.general.ip_whitelist
  43. self.ip_blacklist = self.sydent.config.general.ip_blacklist
  44. def test_reactor(self):
  45. """Apply the blacklisting reactor and ensure it properly blocks
  46. connections to particular domains and IPs.
  47. """
  48. agent = Agent(
  49. BlacklistingReactorWrapper(
  50. self.reactor,
  51. ip_whitelist=self.ip_whitelist,
  52. ip_blacklist=self.ip_blacklist,
  53. ),
  54. )
  55. # The unsafe domains and IPs should be rejected.
  56. for domain in (self.unsafe_domain, self.unsafe_ip):
  57. self.failureResultOf(
  58. agent.request(b"GET", b"http://" + domain), DNSLookupError
  59. )
  60. self.reactor.tcpClients = []
  61. # The safe domains IPs should be accepted.
  62. for domain in (
  63. self.safe_domain,
  64. self.allowed_domain,
  65. self.safe_ip,
  66. self.allowed_ip,
  67. ):
  68. agent.request(b"GET", b"http://" + domain)
  69. # Grab the latest TCP connection.
  70. (
  71. host,
  72. port,
  73. client_factory,
  74. _timeout,
  75. _bindAddress,
  76. ) = self.reactor.tcpClients.pop()
  77. @patch(
  78. "sydent.http.srvresolver.SrvResolver.resolve_service", new_callable=AsyncMock
  79. )
  80. def test_federation_client_allowed_ip(self, resolver):
  81. self.sydent.run()
  82. resolver.return_value = [
  83. Server(
  84. host=self.allowed_domain,
  85. port=443,
  86. priority=1,
  87. weight=1,
  88. expires=100,
  89. )
  90. ]
  91. request, channel = make_request(
  92. self.sydent.reactor,
  93. self.sydent.clientApiHttpServer.factory,
  94. "POST",
  95. "/_matrix/identity/v2/account/register",
  96. {
  97. "access_token": "foo",
  98. "expires_in": 300,
  99. "matrix_server_name": "example.com",
  100. "token_type": "Bearer",
  101. },
  102. )
  103. transport, protocol = self._get_http_request(
  104. self.allowed_ip.decode("ascii"), 443
  105. )
  106. self.assertRegex(
  107. transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
  108. )
  109. self.assertRegex(transport.value(), b"Host: example.com")
  110. # Send it the HTTP response
  111. res_json = b'{ "sub": "@test:example.com" }'
  112. protocol.dataReceived(
  113. b"HTTP/1.1 200 OK\r\n"
  114. b"Server: Fake\r\n"
  115. b"Content-Type: application/json\r\n"
  116. b"Content-Length: %i\r\n"
  117. b"\r\n"
  118. b"%s" % (len(res_json), res_json)
  119. )
  120. self.assertEqual(channel.code, 200)
  121. @patch(
  122. "sydent.http.srvresolver.SrvResolver.resolve_service", new_callable=AsyncMock
  123. )
  124. def test_federation_client_safe_ip(self, resolver):
  125. self.sydent.run()
  126. resolver.return_value = [
  127. Server(
  128. host=self.safe_domain,
  129. port=443,
  130. priority=1,
  131. weight=1,
  132. expires=100,
  133. )
  134. ]
  135. request, channel = make_request(
  136. self.sydent.reactor,
  137. self.sydent.clientApiHttpServer.factory,
  138. "POST",
  139. "/_matrix/identity/v2/account/register",
  140. {
  141. "access_token": "foo",
  142. "expires_in": 300,
  143. "matrix_server_name": "example.com",
  144. "token_type": "Bearer",
  145. },
  146. )
  147. transport, protocol = self._get_http_request(self.safe_ip.decode("ascii"), 443)
  148. self.assertRegex(
  149. transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
  150. )
  151. self.assertRegex(transport.value(), b"Host: example.com")
  152. # Send it the HTTP response
  153. res_json = b'{ "sub": "@test:example.com" }'
  154. protocol.dataReceived(
  155. b"HTTP/1.1 200 OK\r\n"
  156. b"Server: Fake\r\n"
  157. b"Content-Type: application/json\r\n"
  158. b"Content-Length: %i\r\n"
  159. b"\r\n"
  160. b"%s" % (len(res_json), res_json)
  161. )
  162. self.assertEqual(channel.code, 200)
  163. @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
  164. def test_federation_client_unsafe_ip(self, resolver):
  165. self.sydent.run()
  166. resolver.return_value = [
  167. Server(
  168. host=self.unsafe_domain,
  169. port=443,
  170. priority=1,
  171. weight=1,
  172. expires=100,
  173. )
  174. ]
  175. request, channel = make_request(
  176. self.sydent.reactor,
  177. self.sydent.clientApiHttpServer.factory,
  178. "POST",
  179. "/_matrix/identity/v2/account/register",
  180. {
  181. "access_token": "foo",
  182. "expires_in": 300,
  183. "matrix_server_name": "example.com",
  184. "token_type": "Bearer",
  185. },
  186. )
  187. self.assertNot(self.reactor.tcpClients)
  188. self.assertEqual(channel.code, 500)
  189. def _get_http_request(self, expected_host, expected_port):
  190. clients = self.reactor.tcpClients
  191. (host, port, factory, _timeout, _bindAddress) = clients[-1]
  192. self.assertEqual(host, expected_host)
  193. self.assertEqual(port, expected_port)
  194. # complete the connection and wire it up to a fake transport
  195. protocol = factory.buildProtocol(None)
  196. transport = StringTransport()
  197. protocol.makeConnection(transport)
  198. return transport, protocol