Browse Source

Merge remote-tracking branch 'hera/rav/server_acls' into develop

Neil Johnson 5 years ago
parent
commit
feef8461d1

+ 2 - 0
synapse/api/constants.py

@@ -76,6 +76,8 @@ class EventTypes(object):
     Topic = "m.room.topic"
     Name = "m.room.name"
 
+    ServerACL = "m.room.server_acl"
+
 
 class RejectedReason(object):
     AUTH_ERROR = "auth_error"

+ 148 - 2
synapse/federation/federation_server.py

@@ -14,10 +14,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import re
 
 from canonicaljson import json
+import six
 from twisted.internet import defer
+from twisted.internet.abstract import isIPAddress
 
+from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.federation.federation_base import (
@@ -27,6 +31,7 @@ from synapse.federation.federation_base import (
 
 from synapse.federation.persistence import TransactionActions
 from synapse.federation.units import Edu, Transaction
+from synapse.http.endpoint import parse_server_name
 from synapse.types import get_domain_from_id
 from synapse.util import async
 from synapse.util.caches.response_cache import ResponseCache
@@ -74,6 +79,9 @@ class FederationServer(FederationBase):
     @log_function
     def on_backfill_request(self, origin, room_id, versions, limit):
         with (yield self._server_linearizer.queue((origin, room_id))):
+            origin_host, _ = parse_server_name(origin)
+            yield self.check_server_matches_acl(origin_host, room_id)
+
             pdus = yield self.handler.on_backfill_request(
                 origin, room_id, versions, limit
             )
@@ -134,6 +142,8 @@ class FederationServer(FederationBase):
 
         received_pdus_counter.inc(len(transaction.pdus))
 
+        origin_host, _ = parse_server_name(transaction.origin)
+
         pdus_by_room = {}
 
         for p in transaction.pdus:
@@ -154,9 +164,21 @@ class FederationServer(FederationBase):
         # we can process different rooms in parallel (which is useful if they
         # require callouts to other servers to fetch missing events), but
         # impose a limit to avoid going too crazy with ram/cpu.
+
         @defer.inlineCallbacks
         def process_pdus_for_room(room_id):
             logger.debug("Processing PDUs for %s", room_id)
+            try:
+                yield self.check_server_matches_acl(origin_host, room_id)
+            except AuthError as e:
+                logger.warn(
+                    "Ignoring PDUs for room %s from banned server", room_id,
+                )
+                for pdu in pdus_by_room[room_id]:
+                    event_id = pdu.event_id
+                    pdu_results[event_id] = e.error_dict()
+                return
+
             for pdu in pdus_by_room[room_id]:
                 event_id = pdu.event_id
                 try:
@@ -211,6 +233,9 @@ class FederationServer(FederationBase):
         if not event_id:
             raise NotImplementedError("Specify an event")
 
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, room_id)
+
         in_room = yield self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
@@ -234,6 +259,9 @@ class FederationServer(FederationBase):
         if not event_id:
             raise NotImplementedError("Specify an event")
 
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, room_id)
+
         in_room = yield self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
@@ -298,7 +326,9 @@ class FederationServer(FederationBase):
         defer.returnValue((200, resp))
 
     @defer.inlineCallbacks
-    def on_make_join_request(self, room_id, user_id):
+    def on_make_join_request(self, origin, room_id, user_id):
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, room_id)
         pdu = yield self.handler.on_make_join_request(room_id, user_id)
         time_now = self._clock.time_msec()
         defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@@ -306,6 +336,8 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_invite_request(self, origin, content):
         pdu = event_from_pdu_json(content)
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, pdu.room_id)
         ret_pdu = yield self.handler.on_invite_request(origin, pdu)
         time_now = self._clock.time_msec()
         defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -314,6 +346,10 @@ class FederationServer(FederationBase):
     def on_send_join_request(self, origin, content):
         logger.debug("on_send_join_request: content: %s", content)
         pdu = event_from_pdu_json(content)
+
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
         res_pdus = yield self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
@@ -325,7 +361,9 @@ class FederationServer(FederationBase):
         }))
 
     @defer.inlineCallbacks
-    def on_make_leave_request(self, room_id, user_id):
+    def on_make_leave_request(self, origin, room_id, user_id):
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, room_id)
         pdu = yield self.handler.on_make_leave_request(room_id, user_id)
         time_now = self._clock.time_msec()
         defer.returnValue({"event": pdu.get_pdu_json(time_now)})
@@ -334,6 +372,10 @@ class FederationServer(FederationBase):
     def on_send_leave_request(self, origin, content):
         logger.debug("on_send_leave_request: content: %s", content)
         pdu = event_from_pdu_json(content)
+
+        origin_host, _ = parse_server_name(origin)
+        yield self.check_server_matches_acl(origin_host, pdu.room_id)
+
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
         yield self.handler.on_send_leave_request(origin, pdu)
         defer.returnValue((200, {}))
@@ -341,6 +383,9 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_event_auth(self, origin, room_id, event_id):
         with (yield self._server_linearizer.queue((origin, room_id))):
+            origin_host, _ = parse_server_name(origin)
+            yield self.check_server_matches_acl(origin_host, room_id)
+
             time_now = self._clock.time_msec()
             auth_pdus = yield self.handler.on_event_auth(event_id)
             res = {
@@ -369,6 +414,9 @@ class FederationServer(FederationBase):
             Deferred: Results in `dict` with the same format as `content`
         """
         with (yield self._server_linearizer.queue((origin, room_id))):
+            origin_host, _ = parse_server_name(origin)
+            yield self.check_server_matches_acl(origin_host, room_id)
+
             auth_chain = [
                 event_from_pdu_json(e)
                 for e in content["auth_chain"]
@@ -442,6 +490,9 @@ class FederationServer(FederationBase):
     def on_get_missing_events(self, origin, room_id, earliest_events,
                               latest_events, limit, min_depth):
         with (yield self._server_linearizer.queue((origin, room_id))):
+            origin_host, _ = parse_server_name(origin)
+            yield self.check_server_matches_acl(origin_host, room_id)
+
             logger.info(
                 "on_get_missing_events: earliest_events: %r, latest_events: %r,"
                 " limit: %d, min_depth: %d",
@@ -579,6 +630,101 @@ class FederationServer(FederationBase):
         )
         defer.returnValue(ret)
 
+    @defer.inlineCallbacks
+    def check_server_matches_acl(self, server_name, room_id):
+        """Check if the given server is allowed by the server ACLs in the room
+
+        Args:
+            server_name (str): name of server, *without any port part*
+            room_id (str): ID of the room to check
+
+        Raises:
+            AuthError if the server does not match the ACL
+        """
+        state_ids = yield self.store.get_current_state_ids(room_id)
+        acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
+
+        if not acl_event_id:
+            return
+
+        acl_event = yield self.store.get_event(acl_event_id)
+        if server_matches_acl_event(server_name, acl_event):
+            return
+
+        raise AuthError(code=403, msg="Server is banned from room")
+
+
+def server_matches_acl_event(server_name, acl_event):
+    """Check if the given server is allowed by the ACL event
+
+    Args:
+        server_name (str): name of server, without any port part
+        acl_event (EventBase): m.room.server_acl event
+
+    Returns:
+        bool: True if this server is allowed by the ACLs
+    """
+    logger.debug("Checking %s against acl %s", server_name, acl_event.content)
+
+    # first of all, check if literal IPs are blocked, and if so, whether the
+    # server name is a literal IP
+    allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
+    if not isinstance(allow_ip_literals, bool):
+        logger.warn("Ignorning non-bool allow_ip_literals flag")
+        allow_ip_literals = True
+    if not allow_ip_literals:
+        # check for ipv6 literals. These start with '['.
+        if server_name[0] == '[':
+            return False
+
+        # check for ipv4 literals. We can just lift the routine from twisted.
+        if isIPAddress(server_name):
+            return False
+
+    # next,  check the deny list
+    deny = acl_event.content.get("deny", [])
+    if not isinstance(deny, (list, tuple)):
+        logger.warn("Ignorning non-list deny ACL %s", deny)
+        deny = []
+    for e in deny:
+        if _acl_entry_matches(server_name, e):
+            # logger.info("%s matched deny rule %s", server_name, e)
+            return False
+
+    # then the allow list.
+    allow = acl_event.content.get("allow", [])
+    if not isinstance(allow, (list, tuple)):
+        logger.warn("Ignorning non-list allow ACL %s", allow)
+        allow = []
+    for e in allow:
+        if _acl_entry_matches(server_name, e):
+            # logger.info("%s matched allow rule %s", server_name, e)
+            return True
+
+    # everything else should be rejected.
+    # logger.info("%s fell through", server_name)
+    return False
+
+
+def _acl_entry_matches(server_name, acl_entry):
+    if not isinstance(acl_entry, six.string_types):
+        logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+        return False
+    regex = _glob_to_regex(acl_entry)
+    return regex.match(server_name)
+
+
+def _glob_to_regex(glob):
+    res = ''
+    for c in glob:
+        if c == '*':
+            res = res + '.*'
+        elif c == '?':
+            res = res + '.'
+        else:
+            res = res + re.escape(c)
+    return re.compile(res + "\\Z", re.IGNORECASE)
+
 
 class FederationHandlerRegistry(object):
     """Allows classes to register themselves as handlers for a given EDU or

+ 6 - 2
synapse/federation/transport/server.py

@@ -385,7 +385,9 @@ class FederationMakeJoinServlet(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, context, user_id):
-        content = yield self.handler.on_make_join_request(context, user_id)
+        content = yield self.handler.on_make_join_request(
+            origin, context, user_id,
+        )
         defer.returnValue((200, content))
 
 
@@ -394,7 +396,9 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, context, user_id):
-        content = yield self.handler.on_make_leave_request(context, user_id)
+        content = yield self.handler.on_make_leave_request(
+            origin, context, user_id,
+        )
         defer.returnValue((200, content))
 
 

+ 0 - 0
tests/federation/__init__.py


+ 57 - 0
tests/federation/test_federation_server.py

@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.events import FrozenEvent
+from synapse.federation.federation_server import server_matches_acl_event
+from tests import unittest
+
+
+@unittest.DEBUG
+class ServerACLsTestCase(unittest.TestCase):
+    def test_blacklisted_server(self):
+        e = _create_acl_event({
+            "allow": ["*"],
+            "deny": ["evil.com"],
+        })
+        logging.info("ACL event: %s", e.content)
+
+        self.assertFalse(server_matches_acl_event("evil.com", e))
+        self.assertFalse(server_matches_acl_event("EVIL.COM", e))
+
+        self.assertTrue(server_matches_acl_event("evil.com.au", e))
+        self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e))
+
+    def test_block_ip_literals(self):
+        e = _create_acl_event({
+            "allow_ip_literals": False,
+            "allow": ["*"],
+        })
+        logging.info("ACL event: %s", e.content)
+
+        self.assertFalse(server_matches_acl_event("1.2.3.4", e))
+        self.assertTrue(server_matches_acl_event("1a.2.3.4", e))
+        self.assertFalse(server_matches_acl_event("[1:2::]", e))
+        self.assertTrue(server_matches_acl_event("1:2:3:4", e))
+
+
+def _create_acl_event(content):
+    return FrozenEvent({
+        "room_id": "!a:b",
+        "event_id": "$a:b",
+        "type": "m.room.server_acls",
+        "sender": "@a:b",
+        "content": content
+    })