server.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import json
  2. from io import BytesIO
  3. from six import text_type
  4. import attr
  5. from twisted.internet import threads
  6. from twisted.internet.defer import Deferred
  7. from twisted.python.failure import Failure
  8. from twisted.test.proto_helpers import MemoryReactorClock
  9. from synapse.http.site import SynapseRequest
  10. from synapse.util import Clock
  11. from tests.utils import setup_test_homeserver as _sth
  12. @attr.s
  13. class FakeChannel(object):
  14. """
  15. A fake Twisted Web Channel (the part that interfaces with the
  16. wire).
  17. """
  18. result = attr.ib(default=attr.Factory(dict))
  19. _producer = None
  20. @property
  21. def json_body(self):
  22. if not self.result:
  23. raise Exception("No result yet.")
  24. return json.loads(self.result["body"].decode('utf8'))
  25. @property
  26. def code(self):
  27. if not self.result:
  28. raise Exception("No result yet.")
  29. return int(self.result["code"])
  30. def writeHeaders(self, version, code, reason, headers):
  31. self.result["version"] = version
  32. self.result["code"] = code
  33. self.result["reason"] = reason
  34. self.result["headers"] = headers
  35. def write(self, content):
  36. if "body" not in self.result:
  37. self.result["body"] = b""
  38. self.result["body"] += content
  39. def registerProducer(self, producer, streaming):
  40. self._producer = producer
  41. def unregisterProducer(self):
  42. if self._producer is None:
  43. return
  44. self._producer = None
  45. def requestDone(self, _self):
  46. self.result["done"] = True
  47. def getPeer(self):
  48. return None
  49. def getHost(self):
  50. return None
  51. @property
  52. def transport(self):
  53. return self
  54. class FakeSite:
  55. """
  56. A fake Twisted Web Site, with mocks of the extra things that
  57. Synapse adds.
  58. """
  59. server_version_string = b"1"
  60. site_tag = "test"
  61. @property
  62. def access_logger(self):
  63. class FakeLogger:
  64. def info(self, *args, **kwargs):
  65. pass
  66. return FakeLogger()
  67. def make_request(method, path, content=b""):
  68. """
  69. Make a web request using the given method and path, feed it the
  70. content, and return the Request and the Channel underneath.
  71. """
  72. if not isinstance(method, bytes):
  73. method = method.encode('ascii')
  74. if not isinstance(path, bytes):
  75. path = path.encode('ascii')
  76. # Decorate it to be the full path
  77. if not path.startswith(b"/_matrix"):
  78. path = b"/_matrix/client/r0/" + path
  79. path = path.replace(b"//", b"/")
  80. if isinstance(content, text_type):
  81. content = content.encode('utf8')
  82. site = FakeSite()
  83. channel = FakeChannel()
  84. req = SynapseRequest(site, channel)
  85. req.process = lambda: b""
  86. req.content = BytesIO(content)
  87. req.requestReceived(method, path, b"1.1")
  88. return req, channel
  89. def wait_until_result(clock, request, timeout=100):
  90. """
  91. Wait until the request is finished.
  92. """
  93. clock.run()
  94. x = 0
  95. while not request.finished:
  96. # If there's a producer, tell it to resume producing so we get content
  97. if request._channel._producer:
  98. request._channel._producer.resumeProducing()
  99. x += 1
  100. if x > timeout:
  101. raise Exception("Timed out waiting for request to finish.")
  102. clock.advance(0.1)
  103. def render(request, resource, clock):
  104. request.render(resource)
  105. wait_until_result(clock, request)
  106. class ThreadedMemoryReactorClock(MemoryReactorClock):
  107. """
  108. A MemoryReactorClock that supports callFromThread.
  109. """
  110. def callFromThread(self, callback, *args, **kwargs):
  111. """
  112. Make the callback fire in the next reactor iteration.
  113. """
  114. d = Deferred()
  115. d.addCallback(lambda x: callback(*args, **kwargs))
  116. self.callLater(0, d.callback, True)
  117. return d
  118. def setup_test_homeserver(cleanup_func, *args, **kwargs):
  119. """
  120. Set up a synchronous test server, driven by the reactor used by
  121. the homeserver.
  122. """
  123. d = _sth(cleanup_func, *args, **kwargs).result
  124. if isinstance(d, Failure):
  125. d.raiseException()
  126. # Make the thread pool synchronous.
  127. clock = d.get_clock()
  128. pool = d.get_db_pool()
  129. def runWithConnection(func, *args, **kwargs):
  130. return threads.deferToThreadPool(
  131. pool._reactor,
  132. pool.threadpool,
  133. pool._runWithConnection,
  134. func,
  135. *args,
  136. **kwargs
  137. )
  138. def runInteraction(interaction, *args, **kwargs):
  139. return threads.deferToThreadPool(
  140. pool._reactor,
  141. pool.threadpool,
  142. pool._runInteraction,
  143. interaction,
  144. *args,
  145. **kwargs
  146. )
  147. pool.runWithConnection = runWithConnection
  148. pool.runInteraction = runInteraction
  149. class ThreadPool:
  150. """
  151. Threadless thread pool.
  152. """
  153. def start(self):
  154. pass
  155. def stop(self):
  156. pass
  157. def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
  158. def _(res):
  159. if isinstance(res, Failure):
  160. onResult(False, res)
  161. else:
  162. onResult(True, res)
  163. d = Deferred()
  164. d.addCallback(lambda x: function(*args, **kwargs))
  165. d.addBoth(_)
  166. clock._reactor.callLater(0, d.callback, True)
  167. return d
  168. clock.threadpool = ThreadPool()
  169. pool.threadpool = ThreadPool()
  170. return d
  171. def get_clock():
  172. clock = ThreadedMemoryReactorClock()
  173. hs_clock = Clock(clock)
  174. return (clock, hs_clock)