Browse Source

Merge branch 'erikj/reduce_size_of_cache' into erikj/merge_cache_prs

Erik Johnston 3 years ago
parent
commit
a99c692906

+ 1 - 0
changelog.d/9726.bugfix

@@ -0,0 +1 @@
+Fixes the OIDC SSO flow when using a `public_baseurl` value including a non-root URL path.

+ 1 - 0
changelog.d/9817.misc

@@ -0,0 +1 @@
+Fix a long-standing bug which caused `max_upload_size` to not be correctly enforced.

+ 1 - 0
changelog.d/9874.misc

@@ -0,0 +1 @@
+Pass a reactor into `SynapseSite` to make testing easier.

+ 1 - 0
changelog.d/9876.misc

@@ -0,0 +1 @@
+Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules.

+ 1 - 0
changelog.d/9878.misc

@@ -0,0 +1 @@
+Remove redundant `_PushHTTPChannel` test class.

+ 39 - 39
synapse/api/auth.py

@@ -12,14 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 import pymacaroons
 from netaddr import IPAddress
 
 from twisted.web.server import Request
 
-import synapse.types
 from synapse import event_auth
 from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.constants import EventTypes, HistoryVisibility, Membership
@@ -36,11 +35,14 @@ from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing as opentracing
 from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import StateMap, UserID
+from synapse.types import Requester, StateMap, UserID, create_requester
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -68,7 +70,7 @@ class Auth:
     The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
@@ -88,13 +90,13 @@ class Auth:
 
     async def check_from_context(
         self, room_version: str, event, context, do_sig_check=True
-    ):
+    ) -> None:
         prev_state_ids = await context.get_prev_state_ids()
         auth_events_ids = self.compute_auth_events(
             event, prev_state_ids, for_verification=True
         )
-        auth_events = await self.store.get_events(auth_events_ids)
-        auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+        auth_events_by_id = await self.store.get_events(auth_events_ids)
+        auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
 
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
         event_auth.check(
@@ -151,17 +153,11 @@ class Auth:
 
         raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
 
-    async def check_host_in_room(self, room_id, host):
+    async def check_host_in_room(self, room_id: str, host: str) -> bool:
         with Measure(self.clock, "check_host_in_room"):
-            latest_event_ids = await self.store.is_host_joined(room_id, host)
-            return latest_event_ids
-
-    def can_federate(self, event, auth_events):
-        creation_event = auth_events.get((EventTypes.Create, ""))
+            return await self.store.is_host_joined(room_id, host)
 
-        return creation_event.content.get("m.federate", True) is True
-
-    def get_public_keys(self, invite_event):
+    def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
         return event_auth.get_public_keys(invite_event)
 
     async def get_user_by_req(
@@ -170,7 +166,7 @@ class Auth:
         allow_guest: bool = False,
         rights: str = "access",
         allow_expired: bool = False,
-    ) -> synapse.types.Requester:
+    ) -> Requester:
         """Get a registered user's ID.
 
         Args:
@@ -196,7 +192,7 @@ class Auth:
             access_token = self.get_access_token_from_request(request)
 
             user_id, app_service = await self._get_appservice_user_id(request)
-            if user_id:
+            if user_id and app_service:
                 if ip_addr and self._track_appservice_user_ips:
                     await self.store.insert_client_ip(
                         user_id=user_id,
@@ -206,9 +202,7 @@ class Auth:
                         device_id="dummy-device",  # stubbed
                     )
 
-                requester = synapse.types.create_requester(
-                    user_id, app_service=app_service
-                )
+                requester = create_requester(user_id, app_service=app_service)
 
                 request.requester = user_id
                 opentracing.set_tag("authenticated_entity", user_id)
@@ -251,7 +245,7 @@ class Auth:
                     errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
-            requester = synapse.types.create_requester(
+            requester = create_requester(
                 user_info.user_id,
                 token_id,
                 is_guest,
@@ -271,7 +265,9 @@ class Auth:
         except KeyError:
             raise MissingClientTokenError()
 
-    async def _get_appservice_user_id(self, request):
+    async def _get_appservice_user_id(
+        self, request: Request
+    ) -> Tuple[Optional[str], Optional[ApplicationService]]:
         app_service = self.store.get_app_service_by_token(
             self.get_access_token_from_request(request)
         )
@@ -283,6 +279,9 @@ class Auth:
             if ip_address not in app_service.ip_range_whitelist:
                 return None, None
 
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
+
         if b"user_id" not in request.args:
             return app_service.sender, app_service
 
@@ -387,7 +386,9 @@ class Auth:
             logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
             raise InvalidClientTokenError("Invalid macaroon passed.")
 
-    def _parse_and_validate_macaroon(self, token, rights="access"):
+    def _parse_and_validate_macaroon(
+        self, token: str, rights: str = "access"
+    ) -> Tuple[str, bool]:
         """Takes a macaroon and tries to parse and validate it. This is cached
         if and only if rights == access and there isn't an expiry.
 
@@ -432,15 +433,16 @@ class Auth:
 
         return user_id, guest
 
-    def validate_macaroon(self, macaroon, type_string, user_id):
+    def validate_macaroon(
+        self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
+    ) -> None:
         """
         validate that a Macaroon is understood by and was signed by this server.
 
         Args:
-            macaroon(pymacaroons.Macaroon): The macaroon to validate
-            type_string(str): The kind of token required (e.g. "access",
-                              "delete_pusher")
-            user_id (str): The user_id required
+            macaroon: The macaroon to validate
+            type_string: The kind of token required (e.g. "access", "delete_pusher")
+            user_id: The user_id required
         """
         v = pymacaroons.Verifier()
 
@@ -465,9 +467,7 @@ class Auth:
         if not service:
             logger.warning("Unrecognised appservice access token.")
             raise InvalidClientTokenError()
-        request.requester = synapse.types.create_requester(
-            service.sender, app_service=service
-        )
+        request.requester = create_requester(service.sender, app_service=service)
         return service
 
     async def is_server_admin(self, user: UserID) -> bool:
@@ -519,7 +519,7 @@ class Auth:
 
         return auth_ids
 
-    async def check_can_change_room_list(self, room_id: str, user: UserID):
+    async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
         """Determine whether the user is allowed to edit the room's entry in the
         published room list.
 
@@ -554,11 +554,11 @@ class Auth:
         return user_level >= send_level
 
     @staticmethod
-    def has_access_token(request: Request):
+    def has_access_token(request: Request) -> bool:
         """Checks if the request has an access_token.
 
         Returns:
-            bool: False if no access_token was given, True otherwise.
+            False if no access_token was given, True otherwise.
         """
         # This will always be set by the time Twisted calls us.
         assert request.args is not None
@@ -568,13 +568,13 @@ class Auth:
         return bool(query_params) or bool(auth_headers)
 
     @staticmethod
-    def get_access_token_from_request(request: Request):
+    def get_access_token_from_request(request: Request) -> str:
         """Extracts the access_token from the request.
 
         Args:
             request: The http request.
         Returns:
-            unicode: The access_token
+            The access_token
         Raises:
             MissingClientTokenError: If there isn't a single access_token in the
                 request
@@ -649,5 +649,5 @@ class Auth:
                 % (user_id, room_id),
             )
 
-    def check_auth_blocking(self, *args, **kwargs):
-        return self._auth_blocking.check_auth_blocking(*args, **kwargs)
+    async def check_auth_blocking(self, *args, **kwargs) -> None:
+        await self._auth_blocking.check_auth_blocking(*args, **kwargs)

+ 6 - 3
synapse/api/auth_blocking.py

@@ -13,18 +13,21 @@
 # limitations under the License.
 
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.constants import LimitBlockingTypes, UserTypes
 from synapse.api.errors import Codes, ResourceLimitError
 from synapse.config.server import is_threepid_reserved
 from synapse.types import Requester
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class AuthBlocking:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
         self._server_notices_mxid = hs.config.server_notices_mxid
@@ -43,7 +46,7 @@ class AuthBlocking:
         threepid: Optional[dict] = None,
         user_type: Optional[str] = None,
         requester: Optional[Requester] = None,
-    ):
+    ) -> None:
         """Checks if the user should be rejected for some external reason,
         such as monthly active user limiting or global disable flag
 

+ 3 - 0
synapse/api/constants.py

@@ -17,6 +17,9 @@
 
 """Contains constants from the specification."""
 
+# the max size of a (canonical-json-encoded) event
+MAX_PDU_SIZE = 65536
+
 # the "depth" field on events is limited to 2**63 - 1
 MAX_DEPTH = 2 ** 63 - 1
 

+ 26 - 4
synapse/app/_base.py

@@ -30,9 +30,10 @@ from twisted.internet import defer, error, reactor
 from twisted.protocols.tls import TLSMemoryBIOFactory
 
 import synapse
+from synapse.api.constants import MAX_PDU_SIZE
 from synapse.app import check_bind_error
 from synapse.app.phone_stats_home import start_phone_stats_home
-from synapse.config.server import ListenerConfig
+from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto import context_factory
 from synapse.logging.context import PreserveLoggingContext
 from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -288,7 +289,7 @@ def refresh_certificate(hs):
         logger.info("Context factories updated.")
 
 
-async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
+async def start(hs: "synapse.server.HomeServer"):
     """
     Start a Synapse server or worker.
 
@@ -300,7 +301,6 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
 
     Args:
         hs: homeserver instance
-        listeners: Listener configuration ('listeners' in homeserver.yaml)
     """
     # Set up the SIGHUP machinery.
     if hasattr(signal, "SIGHUP"):
@@ -336,7 +336,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
     synapse.logging.opentracing.init_tracer(hs)  # type: ignore[attr-defined] # noqa
 
     # It is now safe to start your Synapse.
-    hs.start_listening(listeners)
+    hs.start_listening()
     hs.get_datastore().db_pool.start_profiling()
     hs.get_pusherpool().start()
 
@@ -530,3 +530,25 @@ def sdnotify(state):
         # this is a bit surprising, since we don't expect to have a NOTIFY_SOCKET
         # unless systemd is expecting us to notify it.
         logger.warning("Unable to send notification to systemd: %s", e)
+
+
+def max_request_body_size(config: HomeServerConfig) -> int:
+    """Get a suitable maximum size for incoming HTTP requests"""
+
+    # Other than media uploads, the biggest request we expect to see is a fully-loaded
+    # /federation/v1/send request.
+    #
+    # The main thing in such a request is up to 50 PDUs, and up to 100 EDUs. PDUs are
+    # limited to 65536 bytes (possibly slightly more if the sender didn't use canonical
+    # json encoding); there is no specced limit to EDUs (see
+    # https://github.com/matrix-org/matrix-doc/issues/3121).
+    #
+    # in short, we somewhat arbitrarily limit requests to 200 * 64K (about 12.5M)
+    #
+    max_request_size = 200 * MAX_PDU_SIZE
+
+    # if we have a media repo enabled, we may need to allow larger uploads than that
+    if config.media.can_load_media_repo:
+        max_request_size = max(max_request_size, config.media.max_upload_size)
+
+    return max_request_size

+ 1 - 7
synapse/app/admin_cmd.py

@@ -70,12 +70,6 @@ class AdminCmdSlavedStore(
 class AdminCmdServer(HomeServer):
     DATASTORE_CLASS = AdminCmdSlavedStore
 
-    def _listen_http(self, listener_config):
-        pass
-
-    def start_listening(self, listeners):
-        pass
-
 
 async def export_data_command(hs, args):
     """Export data for a user.
@@ -232,7 +226,7 @@ def start(config_options):
 
     async def run():
         with LoggingContext("command"):
-            _base.start(ss, [])
+            _base.start(ss)
             await args.func(ss, args)
 
     _base.start_worker_reactor(

+ 7 - 5
synapse/app/generic_worker.py

@@ -15,7 +15,7 @@
 # limitations under the License.
 import logging
 import sys
-from typing import Dict, Iterable, Optional
+from typing import Dict, Optional
 
 from twisted.internet import address
 from twisted.web.resource import IResource
@@ -32,7 +32,7 @@ from synapse.api.urls import (
     SERVER_KEY_V2_PREFIX,
 )
 from synapse.app import _base
-from synapse.app._base import register_start
+from synapse.app._base import max_request_body_size, register_start
 from synapse.config._base import ConfigError
 from synapse.config.homeserver import HomeServerConfig
 from synapse.config.logger import setup_logging
@@ -367,14 +367,16 @@ class GenericWorkerServer(HomeServer):
                 listener_config,
                 root_resource,
                 self.version_string,
+                max_request_body_size=max_request_body_size(self.config),
+                reactor=self.get_reactor(),
             ),
             reactor=self.get_reactor(),
         )
 
         logger.info("Synapse worker now listening on port %d", port)
 
-    def start_listening(self, listeners: Iterable[ListenerConfig]):
-        for listener in listeners:
+    def start_listening(self):
+        for listener in self.config.worker_listeners:
             if listener.type == "http":
                 self._listen_http(listener)
             elif listener.type == "manhole":
@@ -468,7 +470,7 @@ def start(config_options):
     # streams. Will no-op if no streams can be written to by this worker.
     hs.get_replication_streamer()
 
-    register_start(_base.start, hs, config.worker_listeners)
+    register_start(_base.start, hs)
 
     _base.start_worker_reactor("synapse-generic-worker", config)
 

+ 22 - 20
synapse/app/homeserver.py

@@ -17,7 +17,7 @@
 import logging
 import os
 import sys
-from typing import Iterable, Iterator
+from typing import Iterator
 
 from twisted.internet import reactor
 from twisted.web.resource import EncodingResourceWrapper, IResource
@@ -36,7 +36,13 @@ from synapse.api.urls import (
     WEB_CLIENT_PREFIX,
 )
 from synapse.app import _base
-from synapse.app._base import listen_ssl, listen_tcp, quit_with_error, register_start
+from synapse.app._base import (
+    listen_ssl,
+    listen_tcp,
+    max_request_body_size,
+    quit_with_error,
+    register_start,
+)
 from synapse.config._base import ConfigError
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.config.homeserver import HomeServerConfig
@@ -126,19 +132,21 @@ class SynapseHomeServer(HomeServer):
         else:
             root_resource = OptionsResource()
 
-        root_resource = create_resource_tree(resources, root_resource)
+        site = SynapseSite(
+            "synapse.access.%s.%s" % ("https" if tls else "http", site_tag),
+            site_tag,
+            listener_config,
+            create_resource_tree(resources, root_resource),
+            self.version_string,
+            max_request_body_size=max_request_body_size(self.config),
+            reactor=self.get_reactor(),
+        )
 
         if tls:
             ports = listen_ssl(
                 bind_addresses,
                 port,
-                SynapseSite(
-                    "synapse.access.https.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                    self.version_string,
-                ),
+                site,
                 self.tls_server_context_factory,
                 reactor=self.get_reactor(),
             )
@@ -148,13 +156,7 @@ class SynapseHomeServer(HomeServer):
             ports = listen_tcp(
                 bind_addresses,
                 port,
-                SynapseSite(
-                    "synapse.access.http.%s" % (site_tag,),
-                    site_tag,
-                    listener_config,
-                    root_resource,
-                    self.version_string,
-                ),
+                site,
                 reactor=self.get_reactor(),
             )
             logger.info("Synapse now listening on TCP port %d", port)
@@ -273,14 +275,14 @@ class SynapseHomeServer(HomeServer):
 
         return resources
 
-    def start_listening(self, listeners: Iterable[ListenerConfig]):
+    def start_listening(self):
         if self.config.redis_enabled:
             # If redis is enabled we connect via the replication command handler
             # in the same way as the workers (since we're effectively a client
             # rather than a server).
             self.get_tcp_replication().start_replication(self)
 
-        for listener in listeners:
+        for listener in self.config.server.listeners:
             if listener.type == "http":
                 self._listening_services.extend(
                     self._listener_http(self.config, listener)
@@ -413,7 +415,7 @@ def setup(config_options):
             # Loading the provider metadata also ensures the provider config is valid.
             await oidc.load_metadata()
 
-        await _base.start(hs, config.listeners)
+        await _base.start(hs)
 
         hs.get_datastore().db_pool.updates.start_doing_background_updates()
 

+ 2 - 1
synapse/config/logger.py

@@ -31,7 +31,6 @@ from twisted.logger import (
 )
 
 import synapse
-from synapse.app import _base as appbase
 from synapse.logging._structured import setup_structured_logging
 from synapse.logging.context import LoggingContextFilter
 from synapse.logging.filter import MetadataFilter
@@ -318,6 +317,8 @@ def setup_logging(
     # Perform one-time logging configuration.
     _setup_stdlib_logging(config, log_config_path, logBeginner=logBeginner)
     # Add a SIGHUP handler to reload the logging configuration, if one is available.
+    from synapse.app import _base as appbase
+
     appbase.register_sighup(_reload_logging_config, log_config_path)
 
     # Log immediately so we can grep backwards.

+ 4 - 4
synapse/config/server.py

@@ -235,7 +235,11 @@ class ServerConfig(Config):
         self.print_pidfile = config.get("print_pidfile")
         self.user_agent_suffix = config.get("user_agent_suffix")
         self.use_frozen_dicts = config.get("use_frozen_dicts", False)
+
         self.public_baseurl = config.get("public_baseurl")
+        if self.public_baseurl is not None:
+            if self.public_baseurl[-1] != "/":
+                self.public_baseurl += "/"
 
         # Whether to enable user presence.
         presence_config = config.get("presence") or {}
@@ -407,10 +411,6 @@ class ServerConfig(Config):
             config_path=("federation_ip_range_blacklist",),
         )
 
-        if self.public_baseurl is not None:
-            if self.public_baseurl[-1] != "/":
-                self.public_baseurl += "/"
-
         # (undocumented) option for torturing the worker-mode replication a bit,
         # for testing. The value defines the number of milliseconds to pause before
         # sending out any replication updates.

+ 4 - 4
synapse/event_auth.py

@@ -14,14 +14,14 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import SignatureVerifyException, verify_signed_json
 from unpaddedbase64 import decode_base64
 
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import MAX_PDU_SIZE, EventTypes, JoinRules, Membership
 from synapse.api.errors import AuthError, EventSizeError, SynapseError
 from synapse.api.room_versions import (
     KNOWN_ROOM_VERSIONS,
@@ -205,7 +205,7 @@ def _check_size_limits(event: EventBase) -> None:
         too_big("type")
     if len(event.event_id) > 255:
         too_big("event_id")
-    if len(encode_canonical_json(event.get_pdu_json())) > 65536:
+    if len(encode_canonical_json(event.get_pdu_json())) > MAX_PDU_SIZE:
         too_big("event")
 
 
@@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
     return False
 
 
-def get_public_keys(invite_event):
+def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
     public_keys = []
     if "public_key" in invite_event.content:
         o = {"public_key": invite_event.content["public_key"]}

+ 17 - 5
synapse/handlers/oidc.py

@@ -15,7 +15,7 @@
 import inspect
 import logging
 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
-from urllib.parse import urlencode
+from urllib.parse import urlencode, urlparse
 
 import attr
 import pymacaroons
@@ -68,8 +68,8 @@ logger = logging.getLogger(__name__)
 #
 # Here we have the names of the cookies, and the options we use to set them.
 _SESSION_COOKIES = [
-    (b"oidc_session", b"Path=/_synapse/client/oidc; HttpOnly; Secure; SameSite=None"),
-    (b"oidc_session_no_samesite", b"Path=/_synapse/client/oidc; HttpOnly"),
+    (b"oidc_session", b"HttpOnly; Secure; SameSite=None"),
+    (b"oidc_session_no_samesite", b"HttpOnly"),
 ]
 
 #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
@@ -279,6 +279,13 @@ class OidcProvider:
         self._config = provider
         self._callback_url = hs.config.oidc_callback_url  # type: str
 
+        # Calculate the prefix for OIDC callback paths based on the public_baseurl.
+        # We'll insert this into the Path= parameter of any session cookies we set.
+        public_baseurl_path = urlparse(hs.config.server.public_baseurl).path
+        self._callback_path_prefix = (
+            public_baseurl_path.encode("utf-8") + b"_synapse/client/oidc"
+        )
+
         self._oidc_attribute_requirements = provider.attribute_requirements
         self._scopes = provider.scopes
         self._user_profile_method = provider.user_profile_method
@@ -779,8 +786,13 @@ class OidcProvider:
 
         for cookie_name, options in _SESSION_COOKIES:
             request.cookies.append(
-                b"%s=%s; Max-Age=3600; %s"
-                % (cookie_name, cookie.encode("utf-8"), options)
+                b"%s=%s; Max-Age=3600; Path=%s; %s"
+                % (
+                    cookie_name,
+                    cookie.encode("utf-8"),
+                    self._callback_path_prefix,
+                    options,
+                )
             )
 
         metadata = await self.load_metadata()

+ 55 - 14
synapse/http/site.py

@@ -14,13 +14,14 @@
 import contextlib
 import logging
 import time
-from typing import Optional, Tuple, Type, Union
+from typing import Optional, Tuple, Union
 
 import attr
 from zope.interface import implementer
 
-from twisted.internet.interfaces import IAddress
+from twisted.internet.interfaces import IAddress, IReactorTime
 from twisted.python.failure import Failure
+from twisted.web.resource import IResource
 from twisted.web.server import Request, Site
 
 from synapse.config.server import ListenerConfig
@@ -49,6 +50,7 @@ class SynapseRequest(Request):
      * Redaction of access_token query-params in __repr__
      * Logging at start and end
      * Metrics to record CPU, wallclock and DB time by endpoint.
+     * A limit to the size of request which will be accepted
 
     It also provides a method `processing`, which returns a context manager. If this
     method is called, the request won't be logged until the context manager is closed;
@@ -59,8 +61,9 @@ class SynapseRequest(Request):
         logcontext: the log context for this request
     """
 
-    def __init__(self, channel, *args, **kw):
+    def __init__(self, channel, *args, max_request_body_size=1024, **kw):
         Request.__init__(self, channel, *args, **kw)
+        self._max_request_body_size = max_request_body_size
         self.site = channel.site  # type: SynapseSite
         self._channel = channel  # this is used by the tests
         self.start_time = 0.0
@@ -97,6 +100,18 @@ class SynapseRequest(Request):
             self.site.site_tag,
         )
 
+    def handleContentChunk(self, data):
+        # we should have a `content` by now.
+        assert self.content, "handleContentChunk() called before gotLength()"
+        if self.content.tell() + len(data) > self._max_request_body_size:
+            logger.warning(
+                "Aborting connection from %s because the request exceeds maximum size",
+                self.client,
+            )
+            self.transport.abortConnection()
+            return
+        super().handleContentChunk(data)
+
     @property
     def requester(self) -> Optional[Union[Requester, str]]:
         return self._requester
@@ -485,29 +500,55 @@ class _XForwardedForAddress:
 
 class SynapseSite(Site):
     """
-    Subclass of a twisted http Site that does access logging with python's
-    standard logging
+    Synapse-specific twisted http Site
+
+    This does two main things.
+
+    First, it replaces the requestFactory in use so that we build SynapseRequests
+    instead of regular t.w.server.Requests. All of the  constructor params are really
+    just parameters for SynapseRequest.
+
+    Second, it inhibits the log() method called by Request.finish, since SynapseRequest
+    does its own logging.
     """
 
     def __init__(
         self,
-        logger_name,
-        site_tag,
+        logger_name: str,
+        site_tag: str,
         config: ListenerConfig,
-        resource,
+        resource: IResource,
         server_version_string,
-        *args,
-        **kwargs,
+        max_request_body_size: int,
+        reactor: IReactorTime,
     ):
-        Site.__init__(self, resource, *args, **kwargs)
+        """
+
+        Args:
+            logger_name:  The name of the logger to use for access logs.
+            site_tag:  A tag to use for this site - mostly in access logs.
+            config:  Configuration for the HTTP listener corresponding to this site
+            resource:  The base of the resource tree to be used for serving requests on
+                this site
+            server_version_string: A string to present for the Server header
+            max_request_body_size: Maximum request body length to allow before
+                dropping the connection
+            reactor: reactor to be used to manage connection timeouts
+        """
+        Site.__init__(self, resource, reactor=reactor)
 
         self.site_tag = site_tag
 
         assert config.http_options is not None
         proxied = config.http_options.x_forwarded
-        self.requestFactory = (
-            XForwardedForRequest if proxied else SynapseRequest
-        )  # type: Type[Request]
+        request_class = XForwardedForRequest if proxied else SynapseRequest
+
+        def request_factory(channel, queued) -> Request:
+            return request_class(
+                channel, max_request_body_size=max_request_body_size, queued=queued
+            )
+
+        self.requestFactory = request_factory  # type: ignore
         self.access_logger = logging.getLogger(logger_name)
         self.server_version_string = server_version_string.encode("ascii")
 

+ 0 - 2
synapse/rest/media/v1/upload_resource.py

@@ -51,8 +51,6 @@ class UploadResource(DirectServeJsonResource):
 
     async def _async_render_POST(self, request: SynapseRequest) -> None:
         requester = await self.auth.get_user_by_req(request)
-        # TODO: The checks here are a bit late. The content will have
-        # already been uploaded to a tmp file at this point
         content_length = request.getHeader("Content-Length")
         if content_length is None:
             raise SynapseError(msg="Request must specify a Content-Length", code=400)

+ 8 - 0
synapse/server.py

@@ -287,6 +287,14 @@ class HomeServer(metaclass=abc.ABCMeta):
         if self.config.run_background_tasks:
             self.setup_background_tasks()
 
+    def start_listening(self) -> None:
+        """Start the HTTP, manhole, metrics, etc listeners
+
+        Does nothing in this base class; overridden in derived classes to start the
+        appropriate listeners.
+        """
+        pass
+
     def setup_background_tasks(self) -> None:
         """
         Some handlers have side effects on instantiation (like registering

+ 57 - 19
synapse/util/caches/lrucache.py

@@ -17,8 +17,10 @@ from functools import wraps
 from typing import (
     Any,
     Callable,
+    Collection,
     Generic,
     Iterable,
+    List,
     Optional,
     Type,
     TypeVar,
@@ -83,15 +85,30 @@ class _Node:
     __slots__ = ["prev_node", "next_node", "key", "value", "callbacks", "memory"]
 
     def __init__(
-        self, prev_node, next_node, key, value, callbacks: Optional[set] = None
+        self,
+        prev_node,
+        next_node,
+        key,
+        value,
+        callbacks: Collection[Callable[[], None]] = (),
     ):
         self.prev_node = prev_node
         self.next_node = next_node
         self.key = key
         self.value = value
-        self.callbacks = callbacks or set()
-
         self.memory = 0
+
+        # Set of callbacks to run when the node gets deleted. We store as a list
+        # rather than a set to keep memory usage down (and since we expect few
+        # entries per node the performance of checking for duplication in a list
+        # vs using a set is negligible).
+        #
+        # Note that we store this as an optional list to keep the memory
+        # footprint down. Empty lists are 56 bytes (and empty sets are 216 bytes).
+        self.callbacks = None  # type: Optional[List[Callable[[], None]]]
+
+        self.add_callbacks(callbacks)
+
         if TRACK_MEMORY_USAGE:
             self.memory = (
                 _get_size_of(key)
@@ -101,6 +118,32 @@ class _Node:
             )
             self.memory += _get_size_of(self.memory, recurse=False)
 
+    def add_callbacks(self, callbacks: Collection[Callable[[], None]]) -> None:
+        """Add to stored list of callbacks, removing duplicates."""
+
+        if not callbacks:
+            return
+
+        if not self.callbacks:
+            self.callbacks = []
+
+        for callback in callbacks:
+            if callback not in self.callbacks:
+                self.callbacks.append(callback)
+
+    def run_and_clear_callbacks(self) -> None:
+        """Run all callbacks and clear the stored set of callbacks. Used when
+        the node is being deleted.
+        """
+
+        if not self.callbacks:
+            return
+
+        for callback in self.callbacks:
+            callback()
+
+        self.callbacks = None
+
 
 class LruCache(Generic[KT, VT]):
     """
@@ -213,10 +256,10 @@ class LruCache(Generic[KT, VT]):
 
         self.len = synchronized(cache_len)
 
-        def add_node(key, value, callbacks: Optional[set] = None):
+        def add_node(key, value, callbacks: Collection[Callable[[], None]] = ()):
             prev_node = list_root
             next_node = prev_node.next_node
-            node = _Node(prev_node, next_node, key, value, callbacks or set())
+            node = _Node(prev_node, next_node, key, value, callbacks)
             prev_node.next_node = node
             next_node.prev_node = node
             cache[key] = node
@@ -250,9 +293,7 @@ class LruCache(Generic[KT, VT]):
                 deleted_len = size_callback(node.value)
                 cached_cache_len[0] -= deleted_len
 
-            for cb in node.callbacks:
-                cb()
-            node.callbacks.clear()
+            node.run_and_clear_callbacks()
 
             if TRACK_MEMORY_USAGE and metrics:
                 metrics.dec_memory_usage(node.memory)
@@ -263,7 +304,7 @@ class LruCache(Generic[KT, VT]):
         def cache_get(
             key: KT,
             default: Literal[None] = None,
-            callbacks: Iterable[Callable[[], None]] = ...,
+            callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
         ) -> Optional[VT]:
             ...
@@ -272,7 +313,7 @@ class LruCache(Generic[KT, VT]):
         def cache_get(
             key: KT,
             default: T,
-            callbacks: Iterable[Callable[[], None]] = ...,
+            callbacks: Collection[Callable[[], None]] = ...,
             update_metrics: bool = ...,
         ) -> Union[T, VT]:
             ...
@@ -281,13 +322,13 @@ class LruCache(Generic[KT, VT]):
         def cache_get(
             key: KT,
             default: Optional[T] = None,
-            callbacks: Iterable[Callable[[], None]] = (),
+            callbacks: Collection[Callable[[], None]] = (),
             update_metrics: bool = True,
         ):
             node = cache.get(key, None)
             if node is not None:
                 move_node_to_front(node)
-                node.callbacks.update(callbacks)
+                node.add_callbacks(callbacks)
                 if update_metrics and metrics:
                     metrics.inc_hits()
                 return node.value
@@ -303,10 +344,8 @@ class LruCache(Generic[KT, VT]):
                 # We sometimes store large objects, e.g. dicts, which cause
                 # the inequality check to take a long time. So let's only do
                 # the check if we have some callbacks to call.
-                if node.callbacks and value != node.value:
-                    for cb in node.callbacks:
-                        cb()
-                    node.callbacks.clear()
+                if value != node.value:
+                    node.run_and_clear_callbacks()
 
                 # We don't bother to protect this by value != node.value as
                 # generally size_callback will be cheap compared with equality
@@ -316,7 +355,7 @@ class LruCache(Generic[KT, VT]):
                     cached_cache_len[0] -= size_callback(node.value)
                     cached_cache_len[0] += size_callback(value)
 
-                node.callbacks.update(callbacks)
+                node.add_callbacks(callbacks)
 
                 move_node_to_front(node)
                 node.value = value
@@ -369,8 +408,7 @@ class LruCache(Generic[KT, VT]):
             list_root.next_node = list_root
             list_root.prev_node = list_root
             for node in cache.values():
-                for cb in node.callbacks:
-                    cb()
+                node.run_and_clear_callbacks()
             cache.clear()
             if size_callback:
                 cached_cache_len[0] = 0

+ 83 - 0
tests/http/test_site.py

@@ -0,0 +1,83 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet.address import IPv6Address
+from twisted.test.proto_helpers import StringTransport
+
+from synapse.app.homeserver import SynapseHomeServer
+
+from tests.unittest import HomeserverTestCase
+
+
+class SynapseRequestTestCase(HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer)
+
+    def test_large_request(self):
+        """overlarge HTTP requests should be rejected"""
+        self.hs.start_listening()
+
+        # find the HTTP server which is configured to listen on port 0
+        (port, factory, _backlog, interface) = self.reactor.tcpServers[0]
+        self.assertEqual(interface, "::")
+        self.assertEqual(port, 0)
+
+        # as a control case, first send a regular request.
+
+        # complete the connection and wire it up to a fake transport
+        client_address = IPv6Address("TCP", "::1", "2345")
+        protocol = factory.buildProtocol(client_address)
+        transport = StringTransport()
+        protocol.makeConnection(transport)
+
+        protocol.dataReceived(
+            b"POST / HTTP/1.1\r\n"
+            b"Connection: close\r\n"
+            b"Transfer-Encoding: chunked\r\n"
+            b"\r\n"
+            b"0\r\n"
+            b"\r\n"
+        )
+
+        while not transport.disconnecting:
+            self.reactor.advance(1)
+
+        # we should get a 404
+        self.assertRegex(transport.value().decode(), r"^HTTP/1\.1 404 ")
+
+        # now send an oversized request
+        protocol = factory.buildProtocol(client_address)
+        transport = StringTransport()
+        protocol.makeConnection(transport)
+
+        protocol.dataReceived(
+            b"POST / HTTP/1.1\r\n"
+            b"Connection: close\r\n"
+            b"Transfer-Encoding: chunked\r\n"
+            b"\r\n"
+        )
+
+        # we deliberately send all the data in one big chunk, to ensure that
+        # twisted isn't buffering the data in the chunked transfer decoder.
+        # we start with the chunk size, in hex. (We won't actually send this much)
+        protocol.dataReceived(b"10000000\r\n")
+        sent = 0
+        while not transport.disconnected:
+            self.assertLess(sent, 0x10000000, "connection did not drop")
+            protocol.dataReceived(b"\0" * 1024)
+            sent += 1024
+
+        # default max upload size is 50M, so it should drop on the next buffer after
+        # that.
+        self.assertEqual(sent, 50 * 1024 * 1024 + 1024)

+ 21 - 115
tests/replication/_base.py

@@ -12,14 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, Dict, List, Optional, Tuple, Type
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
-from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.protocol import Protocol
-from twisted.internet.task import LoopingCall
-from twisted.web.http import HTTPChannel
 from twisted.web.resource import Resource
-from twisted.web.server import Request, Site
 
 from synapse.app.generic_worker import GenericWorkerServer
 from synapse.http.server import JsonResource
@@ -33,7 +29,6 @@ from synapse.replication.tcp.resource import (
     ServerReplicationStreamProtocol,
 )
 from synapse.server import HomeServer
-from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeTransport
@@ -154,7 +149,19 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         client_protocol = client_factory.buildProtocol(None)
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site)
+        channel = self.site.buildProtocol(None)
+
+        # hook into the channel's request factory so that we can keep a record
+        # of the requests
+        requests: List[SynapseRequest] = []
+        real_request_factory = channel.requestFactory
+
+        def request_factory(*args, **kwargs):
+            request = real_request_factory(*args, **kwargs)
+            requests.append(request)
+            return request
+
+        channel.requestFactory = request_factory
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -176,7 +183,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         server_to_client_transport.loseConnection()
         client_to_server_transport.loseConnection()
 
-        return channel.request
+        # there should have been exactly one request
+        self.assertEqual(len(requests), 1)
+
+        return requests[0]
 
     def assert_request_is_get_repl_stream_updates(
         self, request: SynapseRequest, stream_name: str
@@ -349,6 +359,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             config=worker_hs.config.server.listeners[0],
             resource=resource,
             server_version_string="1",
+            max_request_body_size=4096,
+            reactor=self.reactor,
         )
 
         if worker_hs.config.redis.redis_enabled:
@@ -386,7 +398,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         client_protocol = client_factory.buildProtocol(None)
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs])
+        channel = self._hs_to_site[hs].buildProtocol(None)
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -444,112 +456,6 @@ class TestReplicationDataHandler(ReplicationDataHandler):
             self.received_rdata_rows.append((stream_name, token, r))
 
 
-class _PushHTTPChannel(HTTPChannel):
-    """A HTTPChannel that wraps pull producers to push producers.
-
-    This is a hack to get around the fact that HTTPChannel transparently wraps a
-    pull producer (which is what Synapse uses to reply to requests) with
-    `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
-    uses the standard reactor rather than letting us use our test reactor, which
-    makes it very hard to test.
-    """
-
-    def __init__(
-        self, reactor: IReactorTime, request_factory: Type[Request], site: Site
-    ):
-        super().__init__()
-        self.reactor = reactor
-        self.requestFactory = request_factory
-        self.site = site
-
-        self._pull_to_push_producer = None  # type: Optional[_PullToPushProducer]
-
-    def registerProducer(self, producer, streaming):
-        # Convert pull producers to push producer.
-        if not streaming:
-            self._pull_to_push_producer = _PullToPushProducer(
-                self.reactor, producer, self
-            )
-            producer = self._pull_to_push_producer
-
-        super().registerProducer(producer, True)
-
-    def unregisterProducer(self):
-        if self._pull_to_push_producer:
-            # We need to manually stop the _PullToPushProducer.
-            self._pull_to_push_producer.stop()
-
-    def checkPersistence(self, request, version):
-        """Check whether the connection can be re-used"""
-        # We hijack this to always say no for ease of wiring stuff up in
-        # `handle_http_replication_attempt`.
-        request.responseHeaders.setRawHeaders(b"connection", [b"close"])
-        return False
-
-    def requestDone(self, request):
-        # Store the request for inspection.
-        self.request = request
-        super().requestDone(request)
-
-
-class _PullToPushProducer:
-    """A push producer that wraps a pull producer."""
-
-    def __init__(
-        self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
-    ):
-        self._clock = Clock(reactor)
-        self._producer = producer
-        self._consumer = consumer
-
-        # While running we use a looping call with a zero delay to call
-        # resumeProducing on given producer.
-        self._looping_call = None  # type: Optional[LoopingCall]
-
-        # We start writing next reactor tick.
-        self._start_loop()
-
-    def _start_loop(self):
-        """Start the looping call to"""
-
-        if not self._looping_call:
-            # Start a looping call which runs every tick.
-            self._looping_call = self._clock.looping_call(self._run_once, 0)
-
-    def stop(self):
-        """Stops calling resumeProducing."""
-        if self._looping_call:
-            self._looping_call.stop()
-            self._looping_call = None
-
-    def pauseProducing(self):
-        """Implements IPushProducer"""
-        self.stop()
-
-    def resumeProducing(self):
-        """Implements IPushProducer"""
-        self._start_loop()
-
-    def stopProducing(self):
-        """Implements IPushProducer"""
-        self.stop()
-        self._producer.stopProducing()
-
-    def _run_once(self):
-        """Calls resumeProducing on producer once."""
-
-        try:
-            self._producer.resumeProducing()
-        except Exception:
-            logger.exception("Failed to call resumeProducing")
-            try:
-                self._consumer.unregisterProducer()
-            except Exception:
-                pass
-
-            self.stopProducing()
-
-
 class FakeRedisPubSubServer:
     """A fake Redis server for pub/sub."""
 

+ 0 - 6
tests/server.py

@@ -603,12 +603,6 @@ class FakeTransport:
         if self.disconnected:
             return
 
-        if not hasattr(self.other, "transport"):
-            # the other has no transport yet; reschedule
-            if self.autoflush:
-                self._reactor.callLater(0.0, self.flush)
-            return
-
         if maxbytes is not None:
             to_write = self.buffer[:maxbytes]
         else:

+ 2 - 0
tests/test_server.py

@@ -202,6 +202,8 @@ class OptionsResourceTests(unittest.TestCase):
             parse_listener_def({"type": "http", "port": 0}),
             self.resource,
             "1.0",
+            max_request_body_size=1234,
+            reactor=self.reactor,
         )
 
         # render the request and return the channel

+ 2 - 0
tests/unittest.py

@@ -247,6 +247,8 @@ class HomeserverTestCase(TestCase):
             config=self.hs.config.server.listeners[0],
             resource=self.resource,
             server_version_string="1",
+            max_request_body_size=1234,
+            reactor=self.reactor,
         )
 
         from tests.rest.client.v1.utils import RestHelper