123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- import json
- from io import BytesIO
- from six import text_type
- import attr
- from twisted.internet import threads
- from twisted.internet.defer import Deferred
- from twisted.python.failure import Failure
- from twisted.test.proto_helpers import MemoryReactorClock
- from synapse.http.site import SynapseRequest
- from synapse.util import Clock
- from tests.utils import setup_test_homeserver as _sth
- @attr.s
- class FakeChannel(object):
- """
- A fake Twisted Web Channel (the part that interfaces with the
- wire).
- """
- result = attr.ib(default=attr.Factory(dict))
- @property
- def json_body(self):
- if not self.result:
- raise Exception("No result yet.")
- return json.loads(self.result["body"].decode('utf8'))
- @property
- def code(self):
- if not self.result:
- raise Exception("No result yet.")
- return int(self.result["code"])
- def writeHeaders(self, version, code, reason, headers):
- self.result["version"] = version
- self.result["code"] = code
- self.result["reason"] = reason
- self.result["headers"] = headers
- def write(self, content):
- if "body" not in self.result:
- self.result["body"] = b""
- self.result["body"] += content
- def requestDone(self, _self):
- self.result["done"] = True
- def getPeer(self):
- return None
- def getHost(self):
- return None
- @property
- def transport(self):
- return self
- class FakeSite:
- """
- A fake Twisted Web Site, with mocks of the extra things that
- Synapse adds.
- """
- server_version_string = b"1"
- site_tag = "test"
- @property
- def access_logger(self):
- class FakeLogger:
- def info(self, *args, **kwargs):
- pass
- return FakeLogger()
- def make_request(method, path, content=b""):
- """
- Make a web request using the given method and path, feed it the
- content, and return the Request and the Channel underneath.
- """
- if not isinstance(method, bytes):
- method = method.encode('ascii')
- if not isinstance(path, bytes):
- path = path.encode('ascii')
- # Decorate it to be the full path
- if not path.startswith(b"/_matrix"):
- path = b"/_matrix/client/r0/" + path
- path = path.replace(b"//", b"/")
- if isinstance(content, text_type):
- content = content.encode('utf8')
- site = FakeSite()
- channel = FakeChannel()
- req = SynapseRequest(site, channel)
- req.process = lambda: b""
- req.content = BytesIO(content)
- req.requestReceived(method, path, b"1.1")
- return req, channel
- def wait_until_result(clock, channel, timeout=100):
- """
- Wait until the channel has a result.
- """
- clock.run()
- x = 0
- while not channel.result:
- x += 1
- if x > timeout:
- raise Exception("Timed out waiting for request to finish.")
- clock.advance(0.1)
- def render(request, resource, clock):
- request.render(resource)
- wait_until_result(clock, request._channel)
- class ThreadedMemoryReactorClock(MemoryReactorClock):
- """
- A MemoryReactorClock that supports callFromThread.
- """
- def callFromThread(self, callback, *args, **kwargs):
- """
- Make the callback fire in the next reactor iteration.
- """
- d = Deferred()
- d.addCallback(lambda x: callback(*args, **kwargs))
- self.callLater(0, d.callback, True)
- return d
- def setup_test_homeserver(cleanup_func, *args, **kwargs):
- """
- Set up a synchronous test server, driven by the reactor used by
- the homeserver.
- """
- d = _sth(cleanup_func, *args, **kwargs).result
- if isinstance(d, Failure):
- d.raiseException()
- # Make the thread pool synchronous.
- clock = d.get_clock()
- pool = d.get_db_pool()
- def runWithConnection(func, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runWithConnection,
- func,
- *args,
- **kwargs
- )
- def runInteraction(interaction, *args, **kwargs):
- return threads.deferToThreadPool(
- pool._reactor,
- pool.threadpool,
- pool._runInteraction,
- interaction,
- *args,
- **kwargs
- )
- pool.runWithConnection = runWithConnection
- pool.runInteraction = runInteraction
- class ThreadPool:
- """
- Threadless thread pool.
- """
- def start(self):
- pass
- def stop(self):
- pass
- def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
- def _(res):
- if isinstance(res, Failure):
- onResult(False, res)
- else:
- onResult(True, res)
- d = Deferred()
- d.addCallback(lambda x: function(*args, **kwargs))
- d.addBoth(_)
- clock._reactor.callLater(0, d.callback, True)
- return d
- clock.threadpool = ThreadPool()
- pool.threadpool = ThreadPool()
- return d
- def get_clock():
- clock = ThreadedMemoryReactorClock()
- hs_clock = Clock(clock)
- return (clock, hs_clock)
|