|
@@ -16,12 +16,23 @@ from io import BytesIO
|
|
|
|
|
|
from mock import Mock
|
|
|
|
|
|
+from netaddr import IPSet
|
|
|
+
|
|
|
+from twisted.internet.error import DNSLookupError
|
|
|
from twisted.python.failure import Failure
|
|
|
-from twisted.web.client import ResponseDone
|
|
|
+from twisted.test.proto_helpers import AccumulatingProtocol
|
|
|
+from twisted.web.client import Agent, ResponseDone
|
|
|
from twisted.web.iweb import UNKNOWN_LENGTH
|
|
|
|
|
|
-from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
|
|
|
+from synapse.api.errors import SynapseError
|
|
|
+from synapse.http.client import (
|
|
|
+ BlacklistingAgentWrapper,
|
|
|
+ BlacklistingReactorWrapper,
|
|
|
+ BodyExceededMaxSize,
|
|
|
+ read_body_with_max_size,
|
|
|
+)
|
|
|
|
|
|
+from tests.server import FakeTransport, get_clock
|
|
|
from tests.unittest import TestCase
|
|
|
|
|
|
|
|
@@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
|
|
|
|
|
# The data is never consumed.
|
|
|
self.assertEqual(result.getvalue(), b"")
|
|
|
+
|
|
|
+
|
|
|
+class BlacklistingAgentTest(TestCase):
|
|
|
+ def setUp(self):
|
|
|
+ self.reactor, self.clock = get_clock()
|
|
|
+
|
|
|
+ self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
|
|
|
+ self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
|
|
|
+ self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
|
|
|
+
|
|
|
+ # Configure the reactor's DNS resolver.
|
|
|
+ for (domain, ip) in (
|
|
|
+ (self.safe_domain, self.safe_ip),
|
|
|
+ (self.unsafe_domain, self.unsafe_ip),
|
|
|
+ (self.allowed_domain, self.allowed_ip),
|
|
|
+ ):
|
|
|
+ self.reactor.lookups[domain.decode()] = ip.decode()
|
|
|
+ self.reactor.lookups[ip.decode()] = ip.decode()
|
|
|
+
|
|
|
+ self.ip_whitelist = IPSet([self.allowed_ip.decode()])
|
|
|
+ self.ip_blacklist = IPSet(["5.0.0.0/8"])
|
|
|
+
|
|
|
+ def test_reactor(self):
|
|
|
+ """Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
|
|
|
+ agent = Agent(
|
|
|
+ BlacklistingReactorWrapper(
|
|
|
+ self.reactor,
|
|
|
+ ip_whitelist=self.ip_whitelist,
|
|
|
+ ip_blacklist=self.ip_blacklist,
|
|
|
+ ),
|
|
|
+ )
|
|
|
+
|
|
|
+ # The unsafe domains and IPs should be rejected.
|
|
|
+ for domain in (self.unsafe_domain, self.unsafe_ip):
|
|
|
+ self.failureResultOf(
|
|
|
+ agent.request(b"GET", b"http://" + domain), DNSLookupError
|
|
|
+ )
|
|
|
+
|
|
|
+ # The safe domains IPs should be accepted.
|
|
|
+ for domain in (
|
|
|
+ self.safe_domain,
|
|
|
+ self.allowed_domain,
|
|
|
+ self.safe_ip,
|
|
|
+ self.allowed_ip,
|
|
|
+ ):
|
|
|
+ d = agent.request(b"GET", b"http://" + domain)
|
|
|
+
|
|
|
+ # Grab the latest TCP connection.
|
|
|
+ (
|
|
|
+ host,
|
|
|
+ port,
|
|
|
+ client_factory,
|
|
|
+ _timeout,
|
|
|
+ _bindAddress,
|
|
|
+ ) = self.reactor.tcpClients[-1]
|
|
|
+
|
|
|
+ # Make the connection and pump data through it.
|
|
|
+ client = client_factory.buildProtocol(None)
|
|
|
+ server = AccumulatingProtocol()
|
|
|
+ server.makeConnection(FakeTransport(client, self.reactor))
|
|
|
+ client.makeConnection(FakeTransport(server, self.reactor))
|
|
|
+ client.dataReceived(
|
|
|
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
|
|
|
+ )
|
|
|
+
|
|
|
+ response = self.successResultOf(d)
|
|
|
+ self.assertEqual(response.code, 200)
|
|
|
+
|
|
|
+ def test_agent(self):
|
|
|
+ """Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
|
|
|
+ agent = BlacklistingAgentWrapper(
|
|
|
+ Agent(self.reactor),
|
|
|
+ ip_whitelist=self.ip_whitelist,
|
|
|
+ ip_blacklist=self.ip_blacklist,
|
|
|
+ )
|
|
|
+
|
|
|
+ # The unsafe IPs should be rejected.
|
|
|
+ self.failureResultOf(
|
|
|
+ agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
|
|
|
+ )
|
|
|
+
|
|
|
+ # The safe and unsafe domains and safe IPs should be accepted.
|
|
|
+ for domain in (
|
|
|
+ self.safe_domain,
|
|
|
+ self.unsafe_domain,
|
|
|
+ self.allowed_domain,
|
|
|
+ self.safe_ip,
|
|
|
+ self.allowed_ip,
|
|
|
+ ):
|
|
|
+ d = agent.request(b"GET", b"http://" + domain)
|
|
|
+
|
|
|
+ # Grab the latest TCP connection.
|
|
|
+ (
|
|
|
+ host,
|
|
|
+ port,
|
|
|
+ client_factory,
|
|
|
+ _timeout,
|
|
|
+ _bindAddress,
|
|
|
+ ) = self.reactor.tcpClients[-1]
|
|
|
+
|
|
|
+ # Make the connection and pump data through it.
|
|
|
+ client = client_factory.buildProtocol(None)
|
|
|
+ server = AccumulatingProtocol()
|
|
|
+ server.makeConnection(FakeTransport(client, self.reactor))
|
|
|
+ client.makeConnection(FakeTransport(server, self.reactor))
|
|
|
+ client.dataReceived(
|
|
|
+ b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
|
|
|
+ )
|
|
|
+
|
|
|
+ response = self.successResultOf(d)
|
|
|
+ self.assertEqual(response.code, 200)
|