test_blacklisting.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software
  9. # distributed under the License is distributed on an "AS IS" BASIS,
  10. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. # See the License for the specific language governing permissions and
  12. # limitations under the License.
  13. from unittest.mock import patch
  14. from twisted.internet.error import DNSLookupError
  15. from twisted.test.proto_helpers import StringTransport
  16. from twisted.trial.unittest import TestCase
  17. from twisted.web.client import Agent
  18. from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
  19. from sydent.http.srvresolver import Server
  20. from tests.utils import AsyncMock, make_request, make_sydent
  21. class BlacklistingAgentTest(TestCase):
  22. def setUp(self):
  23. config = {
  24. "general": {
  25. "ip.blacklist": "5.0.0.0/8",
  26. "ip.whitelist": "5.1.1.1",
  27. },
  28. }
  29. self.sydent = make_sydent(test_config=config)
  30. self.reactor = self.sydent.reactor
  31. self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
  32. self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
  33. self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
  34. # Configure the reactor's DNS resolver.
  35. for (domain, ip) in (
  36. (self.safe_domain, self.safe_ip),
  37. (self.unsafe_domain, self.unsafe_ip),
  38. (self.allowed_domain, self.allowed_ip),
  39. ):
  40. self.reactor.lookups[domain.decode()] = ip.decode()
  41. self.reactor.lookups[ip.decode()] = ip.decode()
  42. self.ip_whitelist = self.sydent.config.general.ip_whitelist
  43. self.ip_blacklist = self.sydent.config.general.ip_blacklist
  44. def test_reactor(self):
  45. """Apply the blacklisting reactor and ensure it properly blocks
  46. connections to particular domains and IPs.
  47. """
  48. agent = Agent(
  49. BlacklistingReactorWrapper(
  50. self.reactor,
  51. ip_whitelist=self.ip_whitelist,
  52. ip_blacklist=self.ip_blacklist,
  53. ),
  54. )
  55. # The unsafe domains and IPs should be rejected.
  56. for domain in (self.unsafe_domain, self.unsafe_ip):
  57. self.failureResultOf(
  58. agent.request(b"GET", b"http://" + domain), DNSLookupError
  59. )
  60. self.reactor.tcpClients = []
  61. # The safe domains IPs should be accepted.
  62. for domain in (
  63. self.safe_domain,
  64. self.allowed_domain,
  65. self.safe_ip,
  66. self.allowed_ip,
  67. ):
  68. agent.request(b"GET", b"http://" + domain)
  69. # Grab the latest TCP connection.
  70. (
  71. host,
  72. port,
  73. client_factory,
  74. _timeout,
  75. _bindAddress,
  76. ) = self.reactor.tcpClients.pop()
  77. @patch(
  78. "sydent.http.srvresolver.SrvResolver.resolve_service", new_callable=AsyncMock
  79. )
  80. def test_federation_client_allowed_ip(self, resolver):
  81. self.sydent.run()
  82. request, channel = make_request(
  83. self.sydent.reactor,
  84. "POST",
  85. "/_matrix/identity/v2/account/register",
  86. {
  87. "access_token": "foo",
  88. "expires_in": 300,
  89. "matrix_server_name": "example.com",
  90. "token_type": "Bearer",
  91. },
  92. )
  93. resolver.return_value = [
  94. Server(
  95. host=self.allowed_domain,
  96. port=443,
  97. priority=1,
  98. weight=1,
  99. expires=100,
  100. )
  101. ]
  102. request.render(self.sydent.servlets.registerServlet)
  103. transport, protocol = self._get_http_request(
  104. self.allowed_ip.decode("ascii"), 443
  105. )
  106. self.assertRegex(
  107. transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
  108. )
  109. self.assertRegex(transport.value(), b"Host: example.com")
  110. # Send it the HTTP response
  111. res_json = b'{ "sub": "@test:example.com" }'
  112. protocol.dataReceived(
  113. b"HTTP/1.1 200 OK\r\n"
  114. b"Server: Fake\r\n"
  115. b"Content-Type: application/json\r\n"
  116. b"Content-Length: %i\r\n"
  117. b"\r\n"
  118. b"%s" % (len(res_json), res_json)
  119. )
  120. self.assertEqual(channel.code, 200)
  121. @patch(
  122. "sydent.http.srvresolver.SrvResolver.resolve_service", new_callable=AsyncMock
  123. )
  124. def test_federation_client_safe_ip(self, resolver):
  125. self.sydent.run()
  126. request, channel = make_request(
  127. self.sydent.reactor,
  128. "POST",
  129. "/_matrix/identity/v2/account/register",
  130. {
  131. "access_token": "foo",
  132. "expires_in": 300,
  133. "matrix_server_name": "example.com",
  134. "token_type": "Bearer",
  135. },
  136. )
  137. resolver.return_value = [
  138. Server(
  139. host=self.safe_domain,
  140. port=443,
  141. priority=1,
  142. weight=1,
  143. expires=100,
  144. )
  145. ]
  146. request.render(self.sydent.servlets.registerServlet)
  147. transport, protocol = self._get_http_request(self.safe_ip.decode("ascii"), 443)
  148. self.assertRegex(
  149. transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
  150. )
  151. self.assertRegex(transport.value(), b"Host: example.com")
  152. # Send it the HTTP response
  153. res_json = b'{ "sub": "@test:example.com" }'
  154. protocol.dataReceived(
  155. b"HTTP/1.1 200 OK\r\n"
  156. b"Server: Fake\r\n"
  157. b"Content-Type: application/json\r\n"
  158. b"Content-Length: %i\r\n"
  159. b"\r\n"
  160. b"%s" % (len(res_json), res_json)
  161. )
  162. self.assertEqual(channel.code, 200)
  163. @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
  164. def test_federation_client_unsafe_ip(self, resolver):
  165. self.sydent.run()
  166. request, channel = make_request(
  167. self.sydent.reactor,
  168. "POST",
  169. "/_matrix/identity/v2/account/register",
  170. {
  171. "access_token": "foo",
  172. "expires_in": 300,
  173. "matrix_server_name": "example.com",
  174. "token_type": "Bearer",
  175. },
  176. )
  177. resolver.return_value = [
  178. Server(
  179. host=self.unsafe_domain,
  180. port=443,
  181. priority=1,
  182. weight=1,
  183. expires=100,
  184. )
  185. ]
  186. request.render(self.sydent.servlets.registerServlet)
  187. self.assertNot(self.reactor.tcpClients)
  188. self.assertEqual(channel.code, 500)
  189. def _get_http_request(self, expected_host, expected_port):
  190. clients = self.reactor.tcpClients
  191. (host, port, factory, _timeout, _bindAddress) = clients[-1]
  192. self.assertEqual(host, expected_host)
  193. self.assertEqual(port, expected_port)
  194. # complete the connection and wire it up to a fake transport
  195. protocol = factory.buildProtocol(None)
  196. transport = StringTransport()
  197. protocol.makeConnection(transport)
  198. return transport, protocol