|
@@ -0,0 +1,243 @@
|
|
|
+# Copyright 2021 The Matrix.org Foundation C.I.C.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+
|
|
|
+
|
|
|
+from mock import patch
|
|
|
+from netaddr import IPSet
|
|
|
+from twisted.internet import defer
|
|
|
+from twisted.internet.error import DNSLookupError
|
|
|
+from twisted.test.proto_helpers import StringTransport
|
|
|
+from twisted.trial.unittest import TestCase
|
|
|
+from twisted.web.client import Agent
|
|
|
+
|
|
|
+from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
|
|
|
+from sydent.http.srvresolver import Server
|
|
|
+from tests.utils import make_request, make_sydent
|
|
|
+
|
|
|
+
|
|
|
+class BlacklistingAgentTest(TestCase):
|
|
|
+ def setUp(self):
|
|
|
+ config = {
|
|
|
+ "general": {
|
|
|
+ "ip.blacklist": "5.0.0.0/8",
|
|
|
+ "ip.whitelist": "5.1.1.1",
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ self.sydent = make_sydent(test_config=config)
|
|
|
+
|
|
|
+ self.reactor = self.sydent.reactor
|
|
|
+
|
|
|
+ self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
|
|
|
+ self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
|
|
|
+ self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
|
|
|
+
|
|
|
+ # Configure the reactor's DNS resolver.
|
|
|
+ for (domain, ip) in (
|
|
|
+ (self.safe_domain, self.safe_ip),
|
|
|
+ (self.unsafe_domain, self.unsafe_ip),
|
|
|
+ (self.allowed_domain, self.allowed_ip),
|
|
|
+ ):
|
|
|
+ self.reactor.lookups[domain.decode()] = ip.decode()
|
|
|
+ self.reactor.lookups[ip.decode()] = ip.decode()
|
|
|
+
|
|
|
+ self.ip_whitelist = self.sydent.ip_whitelist
|
|
|
+ self.ip_blacklist = self.sydent.ip_blacklist
|
|
|
+
|
|
|
+ def test_reactor(self):
|
|
|
+ """Apply the blacklisting reactor and ensure it properly blocks
|
|
|
+ connections to particular domains and IPs.
|
|
|
+ """
|
|
|
+ agent = Agent(
|
|
|
+ BlacklistingReactorWrapper(
|
|
|
+ self.reactor,
|
|
|
+ ip_whitelist=self.ip_whitelist,
|
|
|
+ ip_blacklist=self.ip_blacklist,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # The unsafe domains and IPs should be rejected.
|
|
|
+ for domain in (self.unsafe_domain, self.unsafe_ip):
|
|
|
+ self.failureResultOf(
|
|
|
+ agent.request(b"GET", b"http://" + domain), DNSLookupError
|
|
|
+ )
|
|
|
+
|
|
|
+ self.reactor.tcpClients = []
|
|
|
+
|
|
|
+ # The safe domains IPs should be accepted.
|
|
|
+ for domain in (
|
|
|
+ self.safe_domain,
|
|
|
+ self.allowed_domain,
|
|
|
+ self.safe_ip,
|
|
|
+ self.allowed_ip,
|
|
|
+ ):
|
|
|
+ agent.request(b"GET", b"http://" + domain)
|
|
|
+
|
|
|
+ # Grab the latest TCP connection.
|
|
|
+ (
|
|
|
+ host,
|
|
|
+ port,
|
|
|
+ client_factory,
|
|
|
+ _timeout,
|
|
|
+ _bindAddress,
|
|
|
+ ) = self.reactor.tcpClients.pop()
|
|
|
+
|
|
|
+ @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
|
|
|
+ def test_federation_client_allowed_ip(self, resolver):
|
|
|
+ self.sydent.run()
|
|
|
+
|
|
|
+ request, channel = make_request(
|
|
|
+ self.sydent.reactor,
|
|
|
+ "POST",
|
|
|
+ "/_matrix/identity/v2/account/register",
|
|
|
+ {
|
|
|
+ "access_token": "foo",
|
|
|
+ "expires_in": 300,
|
|
|
+ "matrix_server_name": "example.com",
|
|
|
+ "token_type": "Bearer",
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ resolver.return_value = defer.succeed(
|
|
|
+ [
|
|
|
+ Server(
|
|
|
+ host=self.allowed_domain,
|
|
|
+ port=443,
|
|
|
+ priority=1,
|
|
|
+ weight=1,
|
|
|
+ expires=100,
|
|
|
+ )
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ request.render(self.sydent.servlets.registerServlet)
|
|
|
+
|
|
|
+ transport, protocol = self._get_http_request(
|
|
|
+ self.allowed_ip.decode("ascii"), 443
|
|
|
+ )
|
|
|
+
|
|
|
+ self.assertRegex(
|
|
|
+ transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
|
|
|
+ )
|
|
|
+ self.assertRegex(transport.value(), b"Host: example.com")
|
|
|
+
|
|
|
+ # Send it the HTTP response
|
|
|
+ res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
|
|
|
+ protocol.dataReceived(
|
|
|
+ b"HTTP/1.1 200 OK\r\n"
|
|
|
+ b"Server: Fake\r\n"
|
|
|
+ b"Content-Type: application/json\r\n"
|
|
|
+ b"Content-Length: %i\r\n"
|
|
|
+ b"\r\n"
|
|
|
+ b"%s" % (len(res_json), res_json)
|
|
|
+ )
|
|
|
+
|
|
|
+ self.assertEqual(channel.code, 200)
|
|
|
+
|
|
|
+ @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
|
|
|
+ def test_federation_client_safe_ip(self, resolver):
|
|
|
+ self.sydent.run()
|
|
|
+
|
|
|
+ request, channel = make_request(
|
|
|
+ self.sydent.reactor,
|
|
|
+ "POST",
|
|
|
+ "/_matrix/identity/v2/account/register",
|
|
|
+ {
|
|
|
+ "access_token": "foo",
|
|
|
+ "expires_in": 300,
|
|
|
+ "matrix_server_name": "example.com",
|
|
|
+ "token_type": "Bearer",
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ resolver.return_value = defer.succeed(
|
|
|
+ [
|
|
|
+ Server(
|
|
|
+ host=self.safe_domain,
|
|
|
+ port=443,
|
|
|
+ priority=1,
|
|
|
+ weight=1,
|
|
|
+ expires=100,
|
|
|
+ )
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ request.render(self.sydent.servlets.registerServlet)
|
|
|
+
|
|
|
+ transport, protocol = self._get_http_request(self.safe_ip.decode("ascii"), 443)
|
|
|
+
|
|
|
+ self.assertRegex(
|
|
|
+ transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
|
|
|
+ )
|
|
|
+ self.assertRegex(transport.value(), b"Host: example.com")
|
|
|
+
|
|
|
+ # Send it the HTTP response
|
|
|
+ res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
|
|
|
+ protocol.dataReceived(
|
|
|
+ b"HTTP/1.1 200 OK\r\n"
|
|
|
+ b"Server: Fake\r\n"
|
|
|
+ b"Content-Type: application/json\r\n"
|
|
|
+ b"Content-Length: %i\r\n"
|
|
|
+ b"\r\n"
|
|
|
+ b"%s" % (len(res_json), res_json)
|
|
|
+ )
|
|
|
+
|
|
|
+ self.assertEqual(channel.code, 200)
|
|
|
+
|
|
|
+ @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
|
|
|
+ def test_federation_client_unsafe_ip(self, resolver):
|
|
|
+ self.sydent.run()
|
|
|
+
|
|
|
+ request, channel = make_request(
|
|
|
+ self.sydent.reactor,
|
|
|
+ "POST",
|
|
|
+ "/_matrix/identity/v2/account/register",
|
|
|
+ {
|
|
|
+ "access_token": "foo",
|
|
|
+ "expires_in": 300,
|
|
|
+ "matrix_server_name": "example.com",
|
|
|
+ "token_type": "Bearer",
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ resolver.return_value = defer.succeed(
|
|
|
+ [
|
|
|
+ Server(
|
|
|
+ host=self.unsafe_domain,
|
|
|
+ port=443,
|
|
|
+ priority=1,
|
|
|
+ weight=1,
|
|
|
+ expires=100,
|
|
|
+ )
|
|
|
+ ]
|
|
|
+ )
|
|
|
+
|
|
|
+ request.render(self.sydent.servlets.registerServlet)
|
|
|
+
|
|
|
+ self.assertNot(self.reactor.tcpClients)
|
|
|
+
|
|
|
+ self.assertEqual(channel.code, 500)
|
|
|
+
|
|
|
+ def _get_http_request(self, expected_host, expected_port):
|
|
|
+ clients = self.reactor.tcpClients
|
|
|
+ (host, port, factory, _timeout, _bindAddress) = clients[-1]
|
|
|
+ self.assertEqual(host, expected_host)
|
|
|
+ self.assertEqual(port, expected_port)
|
|
|
+
|
|
|
+ # complete the connection and wire it up to a fake transport
|
|
|
+ protocol = factory.buildProtocol(None)
|
|
|
+ transport = StringTransport()
|
|
|
+ protocol.makeConnection(transport)
|
|
|
+
|
|
|
+ return transport, protocol
|