Browse Source

Merge different Resource implementation classes (#7732)

Erik Johnston 3 years ago
parent
commit
5cdca53aa0

+ 1 - 0
changelog.d/7732.bugfix

@@ -0,0 +1 @@
+Fix "Tried to close a non-active scope!" error messages when opentracing is enabled.

+ 1 - 5
synapse/federation/transport/server.py

@@ -361,11 +361,7 @@ class BaseFederationServlet(object):
                 continue
 
             server.register_paths(
-                method,
-                (pattern,),
-                self._wrap(code),
-                self.__class__.__name__,
-                trace=False,
+                method, (pattern,), self._wrap(code), self.__class__.__name__,
             )
 
 

+ 5 - 14
synapse/http/additional_resource.py

@@ -13,13 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
+from synapse.http.server import DirectServeJsonResource
 
-from synapse.http.server import wrap_json_request_handler
 
-
-class AdditionalResource(Resource):
+class AdditionalResource(DirectServeJsonResource):
     """Resource wrapper for additional_resources
 
     If the user has configured additional_resources, we need to wrap the
@@ -41,16 +38,10 @@ class AdditionalResource(Resource):
             handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
                 function to be called to handle the request.
         """
-        Resource.__init__(self)
+        super().__init__()
         self._handler = handler
 
-        # required by the request_handler wrapper
-        self.clock = hs.get_clock()
-
-    def render(self, request):
-        self._async_render(request)
-        return NOT_DONE_YET
-
-    @wrap_json_request_handler
     def _async_render(self, request):
+        # Cheekily pass the result straight through, so we don't need to worry
+        # if its an awaitable or not.
         return self._handler(request)

+ 190 - 175
synapse/http/server.py

@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import collections
 import html
 import logging
@@ -21,7 +22,7 @@ import types
 import urllib
 from http import HTTPStatus
 from io import BytesIO
-from typing import Awaitable, Callable, TypeVar, Union
+from typing import Any, Callable, Dict, Tuple, Union
 
 import jinja2
 from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 """
 
 
-def wrap_json_request_handler(h):
-    """Wraps a request handler method with exception handling.
-
-    Also does the wrapping with request.processing as per wrap_async_request_handler.
-
-    The handler method must have a signature of "handle_foo(self, request)",
-    where "request" must be a SynapseRequest.
-
-    The handler must return a deferred or a coroutine. If the deferred succeeds
-    we assume that a response has been sent. If the deferred fails with a SynapseError we use
-    it to send a JSON response with the appropriate HTTP reponse code. If the
-    deferred fails with any other type of error we send a 500 reponse.
+def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
+    """Sends a JSON error response to clients.
     """
 
-    async def wrapped_request_handler(self, request):
-        try:
-            await h(self, request)
-        except SynapseError as e:
-            code = e.code
-            logger.info("%s SynapseError: %s - %s", request, code, e.msg)
-
-            # Only respond with an error response if we haven't already started
-            # writing, otherwise lets just kill the connection
-            if request.startedWriting:
-                if request.transport:
-                    try:
-                        request.transport.abortConnection()
-                    except Exception:
-                        # abortConnection throws if the connection is already closed
-                        pass
-            else:
-                respond_with_json(
-                    request,
-                    code,
-                    e.error_dict(),
-                    send_cors=True,
-                    pretty_print=_request_user_agent_is_curl(request),
-                )
-
-        except Exception:
-            # failure.Failure() fishes the original Failure out
-            # of our stack, and thus gives us a sensible stack
-            # trace.
-            f = failure.Failure()
-            logger.error(
-                "Failed handle request via %r: %r",
-                request.request_metrics.name,
-                request,
-                exc_info=(f.type, f.value, f.getTracebackObject()),
-            )
-            # Only respond with an error response if we haven't already started
-            # writing, otherwise lets just kill the connection
-            if request.startedWriting:
-                if request.transport:
-                    try:
-                        request.transport.abortConnection()
-                    except Exception:
-                        # abortConnection throws if the connection is already closed
-                        pass
-            else:
-                respond_with_json(
-                    request,
-                    500,
-                    {"error": "Internal server error", "errcode": Codes.UNKNOWN},
-                    send_cors=True,
-                    pretty_print=_request_user_agent_is_curl(request),
-                )
-
-    return wrap_async_request_handler(wrapped_request_handler)
-
-
-TV = TypeVar("TV")
-
-
-def wrap_html_request_handler(
-    h: Callable[[TV, SynapseRequest], Awaitable]
-) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
-    """Wraps a request handler method with exception handling.
+    if f.check(SynapseError):
+        error_code = f.value.code
+        error_dict = f.value.error_dict()
 
-    Also does the wrapping with request.processing as per wrap_async_request_handler.
-
-    The handler method must have a signature of "handle_foo(self, request)",
-    where "request" must be a SynapseRequest.
-    """
+        logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
+    else:
+        error_code = 500
+        error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
 
-    async def wrapped_request_handler(self, request):
-        try:
-            await h(self, request)
-        except Exception:
-            f = failure.Failure()
-            return_html_error(f, request, HTML_ERROR_TEMPLATE)
+        logger.error(
+            "Failed handle request via %r: %r",
+            request.request_metrics.name,
+            request,
+            exc_info=(f.type, f.value, f.getTracebackObject()),
+        )
 
-    return wrap_async_request_handler(wrapped_request_handler)
+    # Only respond with an error response if we haven't already started writing,
+    # otherwise lets just kill the connection
+    if request.startedWriting:
+        if request.transport:
+            try:
+                request.transport.abortConnection()
+            except Exception:
+                # abortConnection throws if the connection is already closed
+                pass
+    else:
+        respond_with_json(
+            request,
+            error_code,
+            error_dict,
+            send_cors=True,
+            pretty_print=_request_user_agent_is_curl(request),
+        )
 
 
 def return_html_error(
@@ -249,7 +194,113 @@ class HttpServer(object):
         pass
 
 
-class JsonResource(HttpServer, resource.Resource):
+class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
+    """Base class for resources that have async handlers.
+
+    Sub classes can either implement `_async_render_<METHOD>` to handle
+    requests by method, or override `_async_render` to handle all requests.
+
+    Args:
+        extract_context: Whether to attempt to extract the opentracing
+            context from the request the servlet is handling.
+    """
+
+    def __init__(self, extract_context=False):
+        super().__init__()
+
+        self._extract_context = extract_context
+
+    def render(self, request):
+        """ This gets called by twisted every time someone sends us a request.
+        """
+        defer.ensureDeferred(self._async_render_wrapper(request))
+        return NOT_DONE_YET
+
+    @wrap_async_request_handler
+    async def _async_render_wrapper(self, request):
+        """This is a wrapper that delegates to `_async_render` and handles
+        exceptions, return values, metrics, etc.
+        """
+        try:
+            request.request_metrics.name = self.__class__.__name__
+
+            with trace_servlet(request, self._extract_context):
+                callback_return = await self._async_render(request)
+
+                if callback_return is not None:
+                    code, response = callback_return
+                    self._send_response(request, code, response)
+        except Exception:
+            # failure.Failure() fishes the original Failure out
+            # of our stack, and thus gives us a sensible stack
+            # trace.
+            f = failure.Failure()
+            self._send_error_response(f, request)
+
+    async def _async_render(self, request):
+        """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
+        no appropriate method exists. Can be overriden in sub classes for
+        different routing.
+        """
+
+        method_handler = getattr(
+            self, "_async_render_%s" % (request.method.decode("ascii"),), None
+        )
+        if method_handler:
+            raw_callback_return = method_handler(request)
+
+            # Is it synchronous? We'll allow this for now.
+            if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+                callback_return = await raw_callback_return
+            else:
+                callback_return = raw_callback_return
+
+            return callback_return
+
+        _unrecognised_request_handler(request)
+
+    @abc.abstractmethod
+    def _send_response(
+        self, request: SynapseRequest, code: int, response_object: Any,
+    ) -> None:
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def _send_error_response(
+        self, f: failure.Failure, request: SynapseRequest,
+    ) -> None:
+        raise NotImplementedError()
+
+
+class DirectServeJsonResource(_AsyncResource):
+    """A resource that will call `self._async_on_<METHOD>` on new requests,
+    formatting responses and errors as JSON.
+    """
+
+    def _send_response(
+        self, request, code, response_object,
+    ):
+        """Implements _AsyncResource._send_response
+        """
+        # TODO: Only enable CORS for the requests that need it.
+        respond_with_json(
+            request,
+            code,
+            response_object,
+            send_cors=True,
+            pretty_print=_request_user_agent_is_curl(request),
+            canonical_json=self.canonical_json,
+        )
+
+    def _send_error_response(
+        self, f: failure.Failure, request: SynapseRequest,
+    ) -> None:
+        """Implements _AsyncResource._send_error_response
+        """
+        return_json_error(f, request)
+
+
+class JsonResource(DirectServeJsonResource):
     """ This implements the HttpServer interface and provides JSON support for
     Resources.
 
@@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
         "_PathEntry", ["pattern", "callback", "servlet_classname"]
     )
 
-    def __init__(self, hs, canonical_json=True):
-        resource.Resource.__init__(self)
+    def __init__(self, hs, canonical_json=True, extract_context=False):
+        super().__init__(extract_context)
 
         self.canonical_json = canonical_json
         self.clock = hs.get_clock()
         self.path_regexs = {}
         self.hs = hs
 
-    def register_paths(
-        self, method, path_patterns, callback, servlet_classname, trace=True
-    ):
+    def register_paths(self, method, path_patterns, callback, servlet_classname):
         """
         Registers a request handler against a regular expression. Later request URLs are
         checked against these regular expressions in order to identify an appropriate
@@ -295,37 +344,42 @@ class JsonResource(HttpServer, resource.Resource):
 
             servlet_classname (str): The name of the handler to be used in prometheus
                 and opentracing logs.
-
-            trace (bool): Whether we should start a span to trace the servlet.
         """
         method = method.encode("utf-8")  # method is bytes on py3
 
-        if trace:
-            # We don't extract the context from the servlet because we can't
-            # trust the sender
-            callback = trace_servlet(servlet_classname)(callback)
-
         for path_pattern in path_patterns:
             logger.debug("Registering for %s %s", method, path_pattern.pattern)
             self.path_regexs.setdefault(method, []).append(
                 self._PathEntry(path_pattern, callback, servlet_classname)
             )
 
-    def render(self, request):
-        """ This gets called by twisted every time someone sends us a request.
+    def _get_handler_for_request(
+        self, request: SynapseRequest
+    ) -> Tuple[Callable, str, Dict[str, str]]:
+        """Finds a callback method to handle the given request.
+
+        Returns:
+            A tuple of the callback to use, the name of the servlet, and the
+            key word arguments to pass to the callback
         """
-        defer.ensureDeferred(self._async_render(request))
-        return NOT_DONE_YET
+        request_path = request.path.decode("ascii")
+
+        # Loop through all the registered callbacks to check if the method
+        # and path regex match
+        for path_entry in self.path_regexs.get(request.method, []):
+            m = path_entry.pattern.match(request_path)
+            if m:
+                # We found a match!
+                return path_entry.callback, path_entry.servlet_classname, m.groupdict()
+
+        # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
+        return _unrecognised_request_handler, "unrecognised_request_handler", {}
 
-    @wrap_json_request_handler
     async def _async_render(self, request):
-        """ This gets called from render() every time someone sends us a request.
-            This checks if anyone has registered a callback for that method and
-            path.
-        """
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
-        # Make sure we have a name for this handler in prometheus.
+        # Make sure we have an appopriate name for this handler in prometheus
+        # (rather than the default of JsonResource).
         request.request_metrics.name = servlet_classname
 
         # Now trigger the callback. If it returns a response, we send it
@@ -338,81 +392,42 @@ class JsonResource(HttpServer, resource.Resource):
             }
         )
 
-        callback_return = callback(request, **kwargs)
+        raw_callback_return = callback(request, **kwargs)
 
         # Is it synchronous? We'll allow this for now.
-        if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
-            callback_return = await callback_return
+        if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
+            callback_return = await raw_callback_return
+        else:
+            callback_return = raw_callback_return
 
-        if callback_return is not None:
-            code, response = callback_return
-            self._send_response(request, code, response)
+        return callback_return
 
-    def _get_handler_for_request(self, request):
-        """Finds a callback method to handle the given request
 
-        Args:
-            request (twisted.web.http.Request):
+class DirectServeHtmlResource(_AsyncResource):
+    """A resource that will call `self._async_on_<METHOD>` on new requests,
+    formatting responses and errors as HTML.
+    """
 
-        Returns:
-            Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
-                label to use for that method in prometheus metrics, and the
-                dict mapping keys to path components as specified in the
-                handler's path match regexp.
-
-                The callback will normally be a method registered via
-                register_paths, so will return (possibly via Deferred) either
-                None, or a tuple of (http code, response body).
-        """
-        request_path = request.path.decode("ascii")
-
-        # Loop through all the registered callbacks to check if the method
-        # and path regex match
-        for path_entry in self.path_regexs.get(request.method, []):
-            m = path_entry.pattern.match(request_path)
-            if m:
-                # We found a match!
-                return path_entry.callback, path_entry.servlet_classname, m.groupdict()
-
-        # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
-        return _unrecognised_request_handler, "unrecognised_request_handler", {}
+    # The error template to use for this resource
+    ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
 
     def _send_response(
-        self, request, code, response_json_object, response_code_message=None
+        self, request: SynapseRequest, code: int, response_object: Any,
     ):
-        # TODO: Only enable CORS for the requests that need it.
-        respond_with_json(
-            request,
-            code,
-            response_json_object,
-            send_cors=True,
-            response_code_message=response_code_message,
-            pretty_print=_request_user_agent_is_curl(request),
-            canonical_json=self.canonical_json,
-        )
-
-
-class DirectServeResource(resource.Resource):
-    def render(self, request):
+        """Implements _AsyncResource._send_response
         """
-        Render the request, using an asynchronous render handler if it exists.
-        """
-        async_render_callback_name = "_async_render_" + request.method.decode("ascii")
-
-        # Try and get the async renderer
-        callback = getattr(self, async_render_callback_name, None)
+        # We expect to get bytes for us to write
+        assert isinstance(response_object, bytes)
+        html_bytes = response_object
 
-        # No async renderer for this request method.
-        if not callback:
-            return super().render(request)
+        respond_with_html_bytes(request, 200, html_bytes)
 
-        resp = trace_servlet(self.__class__.__name__)(callback)(request)
-
-        # If it's a coroutine, turn it into a Deferred
-        if isinstance(resp, types.CoroutineType):
-            defer.ensureDeferred(resp)
-
-        return NOT_DONE_YET
+    def _send_error_response(
+        self, f: failure.Failure, request: SynapseRequest,
+    ) -> None:
+        """Implements _AsyncResource._send_error_response
+        """
+        return_html_error(f, request, self.ERROR_TEMPLATE)
 
 
 class StaticResource(File):

+ 31 - 37
synapse/logging/opentracing.py

@@ -169,7 +169,6 @@ import contextlib
 import inspect
 import logging
 import re
-import types
 from functools import wraps
 from typing import TYPE_CHECKING, Dict, Optional, Type
 
@@ -182,6 +181,7 @@ from synapse.config import ConfigError
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.http.site import SynapseRequest
 
 # Helper class
 
@@ -793,48 +793,42 @@ def tag_args(func):
     return _tag_args_inner
 
 
-def trace_servlet(servlet_name, extract_context=False):
-    """Decorator which traces a serlet. It starts a span with some servlet specific
-    tags such as the servlet_name and request information
+@contextlib.contextmanager
+def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
+    """Returns a context manager which traces a request. It starts a span
+    with some servlet specific tags such as the request metrics name and
+    request information.
 
     Args:
-        servlet_name (str): The name to be used for the span's operation_name
-        extract_context (bool): Whether to attempt to extract the opentracing
+        request
+        extract_context: Whether to attempt to extract the opentracing
             context from the request the servlet is handling.
-
     """
 
-    def _trace_servlet_inner_1(func):
-        if not opentracing:
-            return func
-
-        @wraps(func)
-        async def _trace_servlet_inner(request, *args, **kwargs):
-            request_tags = {
-                "request_id": request.get_request_id(),
-                tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
-                tags.HTTP_METHOD: request.get_method(),
-                tags.HTTP_URL: request.get_redacted_uri(),
-                tags.PEER_HOST_IPV6: request.getClientIP(),
-            }
-
-            if extract_context:
-                scope = start_active_span_from_request(
-                    request, servlet_name, tags=request_tags
-                )
-            else:
-                scope = start_active_span(servlet_name, tags=request_tags)
-
-            with scope:
-                result = func(request, *args, **kwargs)
+    if opentracing is None:
+        yield
+        return
 
-                if not isinstance(result, (types.CoroutineType, defer.Deferred)):
-                    # Some servlets aren't async and just return results
-                    # directly, so we handle that here.
-                    return result
+    request_tags = {
+        "request_id": request.get_request_id(),
+        tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+        tags.HTTP_METHOD: request.get_method(),
+        tags.HTTP_URL: request.get_redacted_uri(),
+        tags.PEER_HOST_IPV6: request.getClientIP(),
+    }
 
-                return await result
+    request_name = request.request_metrics.name
+    if extract_context:
+        scope = start_active_span_from_request(request, request_name, tags=request_tags)
+    else:
+        scope = start_active_span(request_name, tags=request_tags)
 
-        return _trace_servlet_inner
+    with scope:
+        try:
+            yield
+        finally:
+            # We set the operation name again in case its changed (which happens
+            # with JsonResource).
+            scope.span.set_operation_name(request.request_metrics.name)
 
-    return _trace_servlet_inner_1
+            scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)

+ 2 - 1
synapse/replication/http/__init__.py

@@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
 
 class ReplicationRestResource(JsonResource):
     def __init__(self, hs):
-        JsonResource.__init__(self, hs, canonical_json=False)
+        # We enable extracting jaeger contexts here as these are internal APIs.
+        super().__init__(hs, canonical_json=False, extract_context=True)
         self.register_servlets(hs)
 
     def register_servlets(self, hs):

+ 2 - 9
synapse/replication/http/_base.py

@@ -28,11 +28,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
-from synapse.logging.opentracing import (
-    inject_active_span_byte_dict,
-    trace,
-    trace_servlet,
-)
+from synapse.logging.opentracing import inject_active_span_byte_dict, trace
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import random_string
 
@@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
         args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
         pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
 
-        handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
-        # We don't let register paths trace this servlet using the default tracing
-        # options because we wish to extract the context explicitly.
         http_server.register_paths(
-            method, [pattern], handler, self.__class__.__name__, trace=False
+            method, [pattern], handler, self.__class__.__name__,
         )
 
     def _cached_handler(self, request, txn_id, **kwargs):

+ 2 - 8
synapse/rest/consent/consent_resource.py

@@ -26,11 +26,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import NotFoundError, StoreError, SynapseError
 from synapse.config import ConfigError
-from synapse.http.server import (
-    DirectServeResource,
-    respond_with_html,
-    wrap_html_request_handler,
-)
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
 from synapse.http.servlet import parse_string
 from synapse.types import UserID
 
@@ -48,7 +44,7 @@ else:
         return a == b
 
 
-class ConsentResource(DirectServeResource):
+class ConsentResource(DirectServeHtmlResource):
     """A twisted Resource to display a privacy policy and gather consent to it
 
     When accessed via GET, returns the privacy policy via a template.
@@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
 
         self._hmac_secret = hs.config.form_secret.encode("utf-8")
 
-    @wrap_html_request_handler
     async def _async_render_GET(self, request):
         """
         Args:
@@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
         except TemplateNotFound:
             raise NotFoundError("Unknown policy version")
 
-    @wrap_html_request_handler
     async def _async_render_POST(self, request):
         """
         Args:

+ 4 - 8
synapse/rest/key/v2/remote_key_resource.py

@@ -20,17 +20,13 @@ from signedjson.sign import sign_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import (
-    DirectServeResource,
-    respond_with_json_bytes,
-    wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
 
 logger = logging.getLogger(__name__)
 
 
-class RemoteKey(DirectServeResource):
+class RemoteKey(DirectServeJsonResource):
     """HTTP resource for retreiving the TLS certificate and NACL signature
     verification keys for a collection of servers. Checks that the reported
     X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
     isLeaf = True
 
     def __init__(self, hs):
+        super().__init__()
+
         self.fetcher = ServerKeyFetcher(hs)
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.federation_domain_whitelist = hs.config.federation_domain_whitelist
         self.config = hs.config
 
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
         if len(request.postpath) == 1:
             (server,) = request.postpath
@@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
 
         await self.query_keys(request, query, query_remote_on_cache_miss=True)
 
-    @wrap_json_request_handler
     async def _async_render_POST(self, request):
         content = parse_json_object_from_request(request)
 

+ 3 - 11
synapse/rest/media/v1/config_resource.py

@@ -14,16 +14,10 @@
 # limitations under the License.
 #
 
-from twisted.web.server import NOT_DONE_YET
+from synapse.http.server import DirectServeJsonResource, respond_with_json
 
-from synapse.http.server import (
-    DirectServeResource,
-    respond_with_json,
-    wrap_json_request_handler,
-)
 
-
-class MediaConfigResource(DirectServeResource):
+class MediaConfigResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs):
@@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
         self.auth = hs.get_auth()
         self.limits_dict = {"m.upload.size": config.max_upload_size}
 
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
         await self.auth.get_user_by_req(request)
         respond_with_json(request, 200, self.limits_dict, send_cors=True)
 
-    def render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request):
         respond_with_json(request, 200, {}, send_cors=True)
-        return NOT_DONE_YET

+ 2 - 10
synapse/rest/media/v1/download_resource.py

@@ -15,18 +15,14 @@
 import logging
 
 import synapse.http.servlet
-from synapse.http.server import (
-    DirectServeResource,
-    set_cors_headers,
-    wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
 
 from ._base import parse_media_id, respond_404
 
 logger = logging.getLogger(__name__)
 
 
-class DownloadResource(DirectServeResource):
+class DownloadResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs, media_repo):
@@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
         self.media_repo = media_repo
         self.server_name = hs.hostname
 
-        # this is expected by @wrap_json_request_handler
-        self.clock = hs.get_clock()
-
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
         set_cors_headers(request)
         request.setHeader(

+ 4 - 6
synapse/rest/media/v1/preview_url_resource.py

@@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.client import SimpleHttpClient
 from synapse.http.server import (
-    DirectServeResource,
+    DirectServeJsonResource,
     respond_with_json,
     respond_with_json_bytes,
-    wrap_json_request_handler,
 )
 from synapse.http.servlet import parse_integer, parse_string
 from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
 OG_TAG_VALUE_MAXLEN = 1000
 
 
-class PreviewUrlResource(DirectServeResource):
+class PreviewUrlResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs, media_repo, media_storage):
@@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
                 self._start_expire_url_cache_data, 10 * 1000
             )
 
-    def render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request):
         request.setHeader(b"Allow", b"OPTIONS, GET")
-        return respond_with_json(request, 200, {}, send_cors=True)
+        respond_with_json(request, 200, {}, send_cors=True)
 
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
 
         # XXX: if get_user_by_req fails, what should we do in an async render?

+ 2 - 8
synapse/rest/media/v1/thumbnail_resource.py

@@ -16,11 +16,7 @@
 
 import logging
 
-from synapse.http.server import (
-    DirectServeResource,
-    set_cors_headers,
-    wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, set_cors_headers
 from synapse.http.servlet import parse_integer, parse_string
 
 from ._base import (
@@ -34,7 +30,7 @@ from ._base import (
 logger = logging.getLogger(__name__)
 
 
-class ThumbnailResource(DirectServeResource):
+class ThumbnailResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs, media_repo, media_storage):
@@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
         self.media_storage = media_storage
         self.dynamic_thumbnails = hs.config.dynamic_thumbnails
         self.server_name = hs.hostname
-        self.clock = hs.get_clock()
 
-    @wrap_json_request_handler
     async def _async_render_GET(self, request):
         set_cors_headers(request)
         server_name, media_id, _ = parse_media_id(request)

+ 3 - 11
synapse/rest/media/v1/upload_resource.py

@@ -15,20 +15,14 @@
 
 import logging
 
-from twisted.web.server import NOT_DONE_YET
-
 from synapse.api.errors import Codes, SynapseError
-from synapse.http.server import (
-    DirectServeResource,
-    respond_with_json,
-    wrap_json_request_handler,
-)
+from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_string
 
 logger = logging.getLogger(__name__)
 
 
-class UploadResource(DirectServeResource):
+class UploadResource(DirectServeJsonResource):
     isLeaf = True
 
     def __init__(self, hs, media_repo):
@@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
         self.max_upload_size = hs.config.max_upload_size
         self.clock = hs.get_clock()
 
-    def render_OPTIONS(self, request):
+    async def _async_render_OPTIONS(self, request):
         respond_with_json(request, 200, {}, send_cors=True)
-        return NOT_DONE_YET
 
-    @wrap_json_request_handler
     async def _async_render_POST(self, request):
         requester = await self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have

+ 3 - 4
synapse/rest/oidc/callback_resource.py

@@ -14,18 +14,17 @@
 # limitations under the License.
 import logging
 
-from synapse.http.server import DirectServeResource, wrap_html_request_handler
+from synapse.http.server import DirectServeHtmlResource
 
 logger = logging.getLogger(__name__)
 
 
-class OIDCCallbackResource(DirectServeResource):
+class OIDCCallbackResource(DirectServeHtmlResource):
     isLeaf = 1
 
     def __init__(self, hs):
         super().__init__()
         self._oidc_handler = hs.get_oidc_handler()
 
-    @wrap_html_request_handler
     async def _async_render_GET(self, request):
-        return await self._oidc_handler.handle_oidc_callback(request)
+        await self._oidc_handler.handle_oidc_callback(request)

+ 2 - 2
synapse/rest/saml2/response_resource.py

@@ -16,10 +16,10 @@
 from twisted.python import failure
 
 from synapse.api.errors import SynapseError
-from synapse.http.server import DirectServeResource, return_html_error
+from synapse.http.server import DirectServeHtmlResource, return_html_error
 
 
-class SAML2ResponseResource(DirectServeResource):
+class SAML2ResponseResource(DirectServeHtmlResource):
     """A Twisted web resource which handles the SAML response"""
 
     isLeaf = 1

+ 62 - 0
tests/http/test_additional_resource.py

@@ -0,0 +1,62 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.http.additional_resource import AdditionalResource
+from synapse.http.server import respond_with_json
+
+from tests.unittest import HomeserverTestCase
+
+
+class _AsyncTestCustomEndpoint:
+    def __init__(self, config, module_api):
+        pass
+
+    async def handle_request(self, request):
+        respond_with_json(request, 200, {"some_key": "some_value_async"})
+
+
+class _SyncTestCustomEndpoint:
+    def __init__(self, config, module_api):
+        pass
+
+    async def handle_request(self, request):
+        respond_with_json(request, 200, {"some_key": "some_value_sync"})
+
+
+class AdditionalResourceTests(HomeserverTestCase):
+    """Very basic tests that `AdditionalResource` works correctly with sync
+    and async handlers.
+    """
+
+    def test_async(self):
+        handler = _AsyncTestCustomEndpoint({}, None).handle_request
+        self.resource = AdditionalResource(self.hs, handler)
+
+        request, channel = self.make_request("GET", "/")
+        self.render(request)
+
+        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.json_body, {"some_key": "some_value_async"})
+
+    def test_sync(self):
+        handler = _SyncTestCustomEndpoint({}, None).handle_request
+        self.resource = AdditionalResource(self.hs, handler)
+
+        request, channel = self.make_request("GET", "/")
+        self.render(request)
+
+        self.assertEqual(request.code, 200)
+        self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})

+ 3 - 9
tests/test_server.py

@@ -24,12 +24,7 @@ from twisted.web.server import NOT_DONE_YET
 
 from synapse.api.errors import Codes, RedirectException, SynapseError
 from synapse.config.server import parse_listener_def
-from synapse.http.server import (
-    DirectServeResource,
-    JsonResource,
-    OptionsResource,
-    wrap_html_request_handler,
-)
+from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
 from synapse.http.site import SynapseSite, logger
 from synapse.logging.context import make_deferred_yieldable
 from synapse.util import Clock
@@ -256,12 +251,11 @@ class OptionsResourceTests(unittest.TestCase):
 
 
 class WrapHtmlRequestHandlerTests(unittest.TestCase):
-    class TestResource(DirectServeResource):
+    class TestResource(DirectServeHtmlResource):
         callback = None
 
-        @wrap_html_request_handler
         async def _async_render_GET(self, request):
-            return await self.callback(request)
+            await self.callback(request)
 
     def setUp(self):
         self.reactor = ThreadedMemoryReactorClock()