server.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import json
  2. from io import BytesIO
  3. from six import text_type
  4. import attr
  5. from twisted.internet import threads
  6. from twisted.internet.defer import Deferred
  7. from twisted.python.failure import Failure
  8. from twisted.test.proto_helpers import MemoryReactorClock
  9. from synapse.http.site import SynapseRequest
  10. from synapse.util import Clock
  11. from tests.utils import setup_test_homeserver as _sth
  12. @attr.s
  13. class FakeChannel(object):
  14. """
  15. A fake Twisted Web Channel (the part that interfaces with the
  16. wire).
  17. """
  18. result = attr.ib(default=attr.Factory(dict))
  19. @property
  20. def json_body(self):
  21. if not self.result:
  22. raise Exception("No result yet.")
  23. return json.loads(self.result["body"].decode('utf8'))
  24. @property
  25. def code(self):
  26. if not self.result:
  27. raise Exception("No result yet.")
  28. return int(self.result["code"])
  29. def writeHeaders(self, version, code, reason, headers):
  30. self.result["version"] = version
  31. self.result["code"] = code
  32. self.result["reason"] = reason
  33. self.result["headers"] = headers
  34. def write(self, content):
  35. if "body" not in self.result:
  36. self.result["body"] = b""
  37. self.result["body"] += content
  38. def requestDone(self, _self):
  39. self.result["done"] = True
  40. def getPeer(self):
  41. return None
  42. def getHost(self):
  43. return None
  44. @property
  45. def transport(self):
  46. return self
  47. class FakeSite:
  48. """
  49. A fake Twisted Web Site, with mocks of the extra things that
  50. Synapse adds.
  51. """
  52. server_version_string = b"1"
  53. site_tag = "test"
  54. @property
  55. def access_logger(self):
  56. class FakeLogger:
  57. def info(self, *args, **kwargs):
  58. pass
  59. return FakeLogger()
  60. def make_request(method, path, content=b""):
  61. """
  62. Make a web request using the given method and path, feed it the
  63. content, and return the Request and the Channel underneath.
  64. """
  65. if not isinstance(method, bytes):
  66. method = method.encode('ascii')
  67. if not isinstance(path, bytes):
  68. path = path.encode('ascii')
  69. # Decorate it to be the full path
  70. if not path.startswith(b"/_matrix"):
  71. path = b"/_matrix/client/r0/" + path
  72. path = path.replace(b"//", b"/")
  73. if isinstance(content, text_type):
  74. content = content.encode('utf8')
  75. site = FakeSite()
  76. channel = FakeChannel()
  77. req = SynapseRequest(site, channel)
  78. req.process = lambda: b""
  79. req.content = BytesIO(content)
  80. req.requestReceived(method, path, b"1.1")
  81. return req, channel
  82. def wait_until_result(clock, channel, timeout=100):
  83. """
  84. Wait until the channel has a result.
  85. """
  86. clock.run()
  87. x = 0
  88. while not channel.result:
  89. x += 1
  90. if x > timeout:
  91. raise Exception("Timed out waiting for request to finish.")
  92. clock.advance(0.1)
  93. def render(request, resource, clock):
  94. request.render(resource)
  95. wait_until_result(clock, request._channel)
  96. class ThreadedMemoryReactorClock(MemoryReactorClock):
  97. """
  98. A MemoryReactorClock that supports callFromThread.
  99. """
  100. def callFromThread(self, callback, *args, **kwargs):
  101. """
  102. Make the callback fire in the next reactor iteration.
  103. """
  104. d = Deferred()
  105. d.addCallback(lambda x: callback(*args, **kwargs))
  106. self.callLater(0, d.callback, True)
  107. return d
  108. def setup_test_homeserver(cleanup_func, *args, **kwargs):
  109. """
  110. Set up a synchronous test server, driven by the reactor used by
  111. the homeserver.
  112. """
  113. d = _sth(cleanup_func, *args, **kwargs).result
  114. if isinstance(d, Failure):
  115. d.raiseException()
  116. # Make the thread pool synchronous.
  117. clock = d.get_clock()
  118. pool = d.get_db_pool()
  119. def runWithConnection(func, *args, **kwargs):
  120. return threads.deferToThreadPool(
  121. pool._reactor,
  122. pool.threadpool,
  123. pool._runWithConnection,
  124. func,
  125. *args,
  126. **kwargs
  127. )
  128. def runInteraction(interaction, *args, **kwargs):
  129. return threads.deferToThreadPool(
  130. pool._reactor,
  131. pool.threadpool,
  132. pool._runInteraction,
  133. interaction,
  134. *args,
  135. **kwargs
  136. )
  137. pool.runWithConnection = runWithConnection
  138. pool.runInteraction = runInteraction
  139. class ThreadPool:
  140. """
  141. Threadless thread pool.
  142. """
  143. def start(self):
  144. pass
  145. def stop(self):
  146. pass
  147. def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
  148. def _(res):
  149. if isinstance(res, Failure):
  150. onResult(False, res)
  151. else:
  152. onResult(True, res)
  153. d = Deferred()
  154. d.addCallback(lambda x: function(*args, **kwargs))
  155. d.addBoth(_)
  156. clock._reactor.callLater(0, d.callback, True)
  157. return d
  158. clock.threadpool = ThreadPool()
  159. pool.threadpool = ThreadPool()
  160. return d
  161. def get_clock():
  162. clock = ThreadedMemoryReactorClock()
  163. hs_clock = Clock(clock)
  164. return (clock, hs_clock)