server.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import json
  2. from io import BytesIO
  3. from six import text_type
  4. import attr
  5. from twisted.internet import threads
  6. from twisted.internet.defer import Deferred
  7. from twisted.python.failure import Failure
  8. from twisted.test.proto_helpers import MemoryReactorClock
  9. from synapse.http.site import SynapseRequest
  10. from tests.utils import setup_test_homeserver as _sth
  11. @attr.s
  12. class FakeChannel(object):
  13. """
  14. A fake Twisted Web Channel (the part that interfaces with the
  15. wire).
  16. """
  17. result = attr.ib(default=attr.Factory(dict))
  18. @property
  19. def json_body(self):
  20. if not self.result:
  21. raise Exception("No result yet.")
  22. return json.loads(self.result["body"])
  23. def writeHeaders(self, version, code, reason, headers):
  24. self.result["version"] = version
  25. self.result["code"] = code
  26. self.result["reason"] = reason
  27. self.result["headers"] = headers
  28. def write(self, content):
  29. if "body" not in self.result:
  30. self.result["body"] = b""
  31. self.result["body"] += content
  32. def requestDone(self, _self):
  33. self.result["done"] = True
  34. def getPeer(self):
  35. return None
  36. def getHost(self):
  37. return None
  38. @property
  39. def transport(self):
  40. return self
  41. class FakeSite:
  42. """
  43. A fake Twisted Web Site, with mocks of the extra things that
  44. Synapse adds.
  45. """
  46. server_version_string = b"1"
  47. site_tag = "test"
  48. @property
  49. def access_logger(self):
  50. class FakeLogger:
  51. def info(self, *args, **kwargs):
  52. pass
  53. return FakeLogger()
  54. def make_request(method, path, content=b""):
  55. """
  56. Make a web request using the given method and path, feed it the
  57. content, and return the Request and the Channel underneath.
  58. """
  59. # Decorate it to be the full path
  60. if not path.startswith(b"/_matrix"):
  61. path = b"/_matrix/client/r0/" + path
  62. path = path.replace("//", "/")
  63. if isinstance(content, text_type):
  64. content = content.encode('utf8')
  65. site = FakeSite()
  66. channel = FakeChannel()
  67. req = SynapseRequest(site, channel)
  68. req.process = lambda: b""
  69. req.content = BytesIO(content)
  70. req.requestReceived(method, path, b"1.1")
  71. return req, channel
  72. def wait_until_result(clock, channel, timeout=100):
  73. """
  74. Wait until the channel has a result.
  75. """
  76. clock.run()
  77. x = 0
  78. while not channel.result:
  79. x += 1
  80. if x > timeout:
  81. raise Exception("Timed out waiting for request to finish.")
  82. clock.advance(0.1)
  83. def render(request, resource, clock):
  84. request.render(resource)
  85. wait_until_result(clock, request._channel)
  86. class ThreadedMemoryReactorClock(MemoryReactorClock):
  87. """
  88. A MemoryReactorClock that supports callFromThread.
  89. """
  90. def callFromThread(self, callback, *args, **kwargs):
  91. """
  92. Make the callback fire in the next reactor iteration.
  93. """
  94. d = Deferred()
  95. d.addCallback(lambda x: callback(*args, **kwargs))
  96. self.callLater(0, d.callback, True)
  97. return d
  98. def setup_test_homeserver(*args, **kwargs):
  99. """
  100. Set up a synchronous test server, driven by the reactor used by
  101. the homeserver.
  102. """
  103. d = _sth(*args, **kwargs).result
  104. # Make the thread pool synchronous.
  105. clock = d.get_clock()
  106. pool = d.get_db_pool()
  107. def runWithConnection(func, *args, **kwargs):
  108. return threads.deferToThreadPool(
  109. pool._reactor,
  110. pool.threadpool,
  111. pool._runWithConnection,
  112. func,
  113. *args,
  114. **kwargs
  115. )
  116. def runInteraction(interaction, *args, **kwargs):
  117. return threads.deferToThreadPool(
  118. pool._reactor,
  119. pool.threadpool,
  120. pool._runInteraction,
  121. interaction,
  122. *args,
  123. **kwargs
  124. )
  125. pool.runWithConnection = runWithConnection
  126. pool.runInteraction = runInteraction
  127. class ThreadPool:
  128. """
  129. Threadless thread pool.
  130. """
  131. def start(self):
  132. pass
  133. def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
  134. def _(res):
  135. if isinstance(res, Failure):
  136. onResult(False, res)
  137. else:
  138. onResult(True, res)
  139. d = Deferred()
  140. d.addCallback(lambda x: function(*args, **kwargs))
  141. d.addBoth(_)
  142. clock._reactor.callLater(0, d.callback, True)
  143. return d
  144. clock.threadpool = ThreadPool()
  145. pool.threadpool = ThreadPool()
  146. return d