test_blacklisting.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from mock import patch
  15. from netaddr import IPSet
  16. from twisted.internet import defer
  17. from twisted.internet.error import DNSLookupError
  18. from twisted.test.proto_helpers import StringTransport
  19. from twisted.trial.unittest import TestCase
  20. from twisted.web.client import Agent
  21. from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
  22. from sydent.http.srvresolver import Server
  23. from tests.utils import make_request, make_sydent
  24. class BlacklistingAgentTest(TestCase):
  25. def setUp(self):
  26. config = {
  27. "general": {
  28. "ip.blacklist": "5.0.0.0/8",
  29. "ip.whitelist": "5.1.1.1",
  30. },
  31. }
  32. self.sydent = make_sydent(test_config=config)
  33. self.reactor = self.sydent.reactor
  34. self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
  35. self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
  36. self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
  37. # Configure the reactor's DNS resolver.
  38. for (domain, ip) in (
  39. (self.safe_domain, self.safe_ip),
  40. (self.unsafe_domain, self.unsafe_ip),
  41. (self.allowed_domain, self.allowed_ip),
  42. ):
  43. self.reactor.lookups[domain.decode()] = ip.decode()
  44. self.reactor.lookups[ip.decode()] = ip.decode()
  45. self.ip_whitelist = self.sydent.ip_whitelist
  46. self.ip_blacklist = self.sydent.ip_blacklist
  47. def test_reactor(self):
  48. """Apply the blacklisting reactor and ensure it properly blocks
  49. connections to particular domains and IPs.
  50. """
  51. agent = Agent(
  52. BlacklistingReactorWrapper(
  53. self.reactor,
  54. ip_whitelist=self.ip_whitelist,
  55. ip_blacklist=self.ip_blacklist,
  56. ),
  57. )
  58. # The unsafe domains and IPs should be rejected.
  59. for domain in (self.unsafe_domain, self.unsafe_ip):
  60. self.failureResultOf(
  61. agent.request(b"GET", b"http://" + domain), DNSLookupError
  62. )
  63. self.reactor.tcpClients = []
  64. # The safe domains IPs should be accepted.
  65. for domain in (
  66. self.safe_domain,
  67. self.allowed_domain,
  68. self.safe_ip,
  69. self.allowed_ip,
  70. ):
  71. agent.request(b"GET", b"http://" + domain)
  72. # Grab the latest TCP connection.
  73. (
  74. host,
  75. port,
  76. client_factory,
  77. _timeout,
  78. _bindAddress,
  79. ) = self.reactor.tcpClients.pop()
  80. @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
  81. def test_federation_client_allowed_ip(self, resolver):
  82. self.sydent.run()
  83. request, channel = make_request(
  84. self.sydent.reactor,
  85. "POST",
  86. "/_matrix/identity/v2/account/register",
  87. {
  88. "access_token": "foo",
  89. "expires_in": 300,
  90. "matrix_server_name": "example.com",
  91. "token_type": "Bearer",
  92. },
  93. )
  94. resolver.return_value = defer.succeed(
  95. [
  96. Server(
  97. host=self.allowed_domain,
  98. port=443,
  99. priority=1,
  100. weight=1,
  101. expires=100,
  102. )
  103. ]
  104. )
  105. request.render(self.sydent.servlets.registerServlet)
  106. transport, protocol = self._get_http_request(
  107. self.allowed_ip.decode("ascii"), 443
  108. )
  109. self.assertRegex(
  110. transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
  111. )
  112. self.assertRegex(transport.value(), b"Host: example.com")
  113. # Send it the HTTP response
  114. res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
  115. protocol.dataReceived(
  116. b"HTTP/1.1 200 OK\r\n"
  117. b"Server: Fake\r\n"
  118. b"Content-Type: application/json\r\n"
  119. b"Content-Length: %i\r\n"
  120. b"\r\n"
  121. b"%s" % (len(res_json), res_json)
  122. )
  123. self.assertEqual(channel.code, 200)
  124. @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
  125. def test_federation_client_safe_ip(self, resolver):
  126. self.sydent.run()
  127. request, channel = make_request(
  128. self.sydent.reactor,
  129. "POST",
  130. "/_matrix/identity/v2/account/register",
  131. {
  132. "access_token": "foo",
  133. "expires_in": 300,
  134. "matrix_server_name": "example.com",
  135. "token_type": "Bearer",
  136. },
  137. )
  138. resolver.return_value = defer.succeed(
  139. [
  140. Server(
  141. host=self.safe_domain,
  142. port=443,
  143. priority=1,
  144. weight=1,
  145. expires=100,
  146. )
  147. ]
  148. )
  149. request.render(self.sydent.servlets.registerServlet)
  150. transport, protocol = self._get_http_request(self.safe_ip.decode("ascii"), 443)
  151. self.assertRegex(
  152. transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
  153. )
  154. self.assertRegex(transport.value(), b"Host: example.com")
  155. # Send it the HTTP response
  156. res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
  157. protocol.dataReceived(
  158. b"HTTP/1.1 200 OK\r\n"
  159. b"Server: Fake\r\n"
  160. b"Content-Type: application/json\r\n"
  161. b"Content-Length: %i\r\n"
  162. b"\r\n"
  163. b"%s" % (len(res_json), res_json)
  164. )
  165. self.assertEqual(channel.code, 200)
  166. @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
  167. def test_federation_client_unsafe_ip(self, resolver):
  168. self.sydent.run()
  169. request, channel = make_request(
  170. self.sydent.reactor,
  171. "POST",
  172. "/_matrix/identity/v2/account/register",
  173. {
  174. "access_token": "foo",
  175. "expires_in": 300,
  176. "matrix_server_name": "example.com",
  177. "token_type": "Bearer",
  178. },
  179. )
  180. resolver.return_value = defer.succeed(
  181. [
  182. Server(
  183. host=self.unsafe_domain,
  184. port=443,
  185. priority=1,
  186. weight=1,
  187. expires=100,
  188. )
  189. ]
  190. )
  191. request.render(self.sydent.servlets.registerServlet)
  192. self.assertNot(self.reactor.tcpClients)
  193. self.assertEqual(channel.code, 500)
  194. def _get_http_request(self, expected_host, expected_port):
  195. clients = self.reactor.tcpClients
  196. (host, port, factory, _timeout, _bindAddress) = clients[-1]
  197. self.assertEqual(host, expected_host)
  198. self.assertEqual(port, expected_port)
  199. # complete the connection and wire it up to a fake transport
  200. protocol = factory.buildProtocol(None)
  201. transport = StringTransport()
  202. protocol.makeConnection(transport)
  203. return transport, protocol