Browse Source

Kill off redundant SynapseRequestFactory (#6619)

We already get the Site via the Channel, so there's no need for a dedicated
RequestFactory: we can just use the right constructor.
Richard van der Hoff 4 years ago
parent
commit
b6b57ecb4e
3 changed files with 8 additions and 17 deletions
  1. 1 0
      changelog.d/6619.misc
  2. 3 15
      synapse/http/site.py
  3. 4 2
      tests/server.py

+ 1 - 0
changelog.d/6619.misc

@@ -0,0 +1 @@
+Simplify http handling by removing redundant SynapseRequestFactory.

+ 3 - 15
synapse/http/site.py

@@ -47,9 +47,9 @@ class SynapseRequest(Request):
         logcontext(LoggingContext) : the log context for this request
     """
 
-    def __init__(self, site, channel, *args, **kw):
+    def __init__(self, channel, *args, **kw):
         Request.__init__(self, channel, *args, **kw)
-        self.site = site
+        self.site = channel.site
         self._channel = channel  # this is used by the tests
         self.authenticated_entity = None
         self.start_time = 0
@@ -331,18 +331,6 @@ class XForwardedForRequest(SynapseRequest):
         )
 
 
-class SynapseRequestFactory(object):
-    def __init__(self, site, x_forwarded_for):
-        self.site = site
-        self.x_forwarded_for = x_forwarded_for
-
-    def __call__(self, *args, **kwargs):
-        if self.x_forwarded_for:
-            return XForwardedForRequest(self.site, *args, **kwargs)
-        else:
-            return SynapseRequest(self.site, *args, **kwargs)
-
-
 class SynapseSite(Site):
     """
     Subclass of a twisted http Site that does access logging with python's
@@ -364,7 +352,7 @@ class SynapseSite(Site):
         self.site_tag = site_tag
 
         proxied = config.get("x_forwarded", False)
-        self.requestFactory = SynapseRequestFactory(self, proxied)
+        self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
         self.access_logger = logging.getLogger(logger_name)
         self.server_version_string = server_version_string.encode("ascii")
 

+ 4 - 2
tests/server.py

@@ -20,6 +20,7 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
 from twisted.web.http import unquote
 from twisted.web.http_headers import Headers
+from twisted.web.server import Site
 
 from synapse.http.site import SynapseRequest
 from synapse.util import Clock
@@ -42,6 +43,7 @@ class FakeChannel(object):
     wire).
     """
 
+    site = attr.ib(type=Site)
     _reactor = attr.ib()
     result = attr.ib(default=attr.Factory(dict))
     _producer = None
@@ -176,9 +178,9 @@ def make_request(
         content = content.encode("utf8")
 
     site = FakeSite()
-    channel = FakeChannel(reactor)
+    channel = FakeChannel(site, reactor)
 
-    req = request(site, channel)
+    req = request(channel)
     req.process = lambda: b""
     req.content = BytesIO(content)
     req.postpath = list(map(unquote, path[1:].split(b"/")))