server.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. import json
  2. from io import BytesIO
  3. from six import text_type
  4. import attr
  5. from zope.interface import implementer
  6. from twisted.internet import address, threads, udp
  7. from twisted.internet._resolver import HostResolution
  8. from twisted.internet.address import IPv4Address
  9. from twisted.internet.defer import Deferred
  10. from twisted.internet.error import DNSLookupError
  11. from twisted.internet.interfaces import IReactorPluggableNameResolver
  12. from twisted.python.failure import Failure
  13. from twisted.test.proto_helpers import MemoryReactorClock
  14. from twisted.web.http import unquote
  15. from twisted.web.http_headers import Headers
  16. from synapse.http.site import SynapseRequest
  17. from synapse.util import Clock
  18. from tests.utils import setup_test_homeserver as _sth
  19. class TimedOutException(Exception):
  20. """
  21. A web query timed out.
  22. """
  23. @attr.s
  24. class FakeChannel(object):
  25. """
  26. A fake Twisted Web Channel (the part that interfaces with the
  27. wire).
  28. """
  29. _reactor = attr.ib()
  30. result = attr.ib(default=attr.Factory(dict))
  31. _producer = None
  32. @property
  33. def json_body(self):
  34. if not self.result:
  35. raise Exception("No result yet.")
  36. return json.loads(self.result["body"].decode('utf8'))
  37. @property
  38. def code(self):
  39. if not self.result:
  40. raise Exception("No result yet.")
  41. return int(self.result["code"])
  42. @property
  43. def headers(self):
  44. if not self.result:
  45. raise Exception("No result yet.")
  46. h = Headers()
  47. for i in self.result["headers"]:
  48. h.addRawHeader(*i)
  49. return h
  50. def writeHeaders(self, version, code, reason, headers):
  51. self.result["version"] = version
  52. self.result["code"] = code
  53. self.result["reason"] = reason
  54. self.result["headers"] = headers
  55. def write(self, content):
  56. assert isinstance(content, bytes), "Should be bytes! " + repr(content)
  57. if "body" not in self.result:
  58. self.result["body"] = b""
  59. self.result["body"] += content
  60. def registerProducer(self, producer, streaming):
  61. self._producer = producer
  62. self.producerStreaming = streaming
  63. def _produce():
  64. if self._producer:
  65. self._producer.resumeProducing()
  66. self._reactor.callLater(0.1, _produce)
  67. if not streaming:
  68. self._reactor.callLater(0.0, _produce)
  69. def unregisterProducer(self):
  70. if self._producer is None:
  71. return
  72. self._producer = None
  73. def requestDone(self, _self):
  74. self.result["done"] = True
  75. def getPeer(self):
  76. # We give an address so that getClientIP returns a non null entry,
  77. # causing us to record the MAU
  78. return address.IPv4Address("TCP", "127.0.0.1", 3423)
  79. def getHost(self):
  80. return None
  81. @property
  82. def transport(self):
  83. return self
  84. class FakeSite:
  85. """
  86. A fake Twisted Web Site, with mocks of the extra things that
  87. Synapse adds.
  88. """
  89. server_version_string = b"1"
  90. site_tag = "test"
  91. @property
  92. def access_logger(self):
  93. class FakeLogger:
  94. def info(self, *args, **kwargs):
  95. pass
  96. return FakeLogger()
  97. def make_request(
  98. reactor,
  99. method,
  100. path,
  101. content=b"",
  102. access_token=None,
  103. request=SynapseRequest,
  104. shorthand=True,
  105. ):
  106. """
  107. Make a web request using the given method and path, feed it the
  108. content, and return the Request and the Channel underneath.
  109. Args:
  110. method (bytes/unicode): The HTTP request method ("verb").
  111. path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
  112. escaped UTF-8 & spaces and such).
  113. content (bytes or dict): The body of the request. JSON-encoded, if
  114. a dict.
  115. shorthand: Whether to try and be helpful and prefix the given URL
  116. with the usual REST API path, if it doesn't contain it.
  117. Returns:
  118. A synapse.http.site.SynapseRequest.
  119. """
  120. if not isinstance(method, bytes):
  121. method = method.encode('ascii')
  122. if not isinstance(path, bytes):
  123. path = path.encode('ascii')
  124. # Decorate it to be the full path, if we're using shorthand
  125. if shorthand and not path.startswith(b"/_matrix"):
  126. path = b"/_matrix/client/r0/" + path
  127. path = path.replace(b"//", b"/")
  128. if not path.startswith(b"/"):
  129. path = b"/" + path
  130. if isinstance(content, text_type):
  131. content = content.encode('utf8')
  132. site = FakeSite()
  133. channel = FakeChannel(reactor)
  134. req = request(site, channel)
  135. req.process = lambda: b""
  136. req.content = BytesIO(content)
  137. req.postpath = list(map(unquote, path[1:].split(b'/')))
  138. if access_token:
  139. req.requestHeaders.addRawHeader(
  140. b"Authorization", b"Bearer " + access_token.encode('ascii')
  141. )
  142. if content:
  143. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  144. req.requestReceived(method, path, b"1.1")
  145. return req, channel
  146. def wait_until_result(clock, request, timeout=100):
  147. """
  148. Wait until the request is finished.
  149. """
  150. clock.run()
  151. x = 0
  152. while not request.finished:
  153. # If there's a producer, tell it to resume producing so we get content
  154. if request._channel._producer:
  155. request._channel._producer.resumeProducing()
  156. x += 1
  157. if x > timeout:
  158. raise TimedOutException("Timed out waiting for request to finish.")
  159. clock.advance(0.1)
  160. def render(request, resource, clock):
  161. request.render(resource)
  162. wait_until_result(clock, request)
  163. @implementer(IReactorPluggableNameResolver)
  164. class ThreadedMemoryReactorClock(MemoryReactorClock):
  165. """
  166. A MemoryReactorClock that supports callFromThread.
  167. """
  168. def __init__(self):
  169. self._udp = []
  170. self.lookups = {}
  171. class Resolver(object):
  172. def resolveHostName(
  173. _self,
  174. resolutionReceiver,
  175. hostName,
  176. portNumber=0,
  177. addressTypes=None,
  178. transportSemantics='TCP',
  179. ):
  180. resolution = HostResolution(hostName)
  181. resolutionReceiver.resolutionBegan(resolution)
  182. if hostName not in self.lookups:
  183. raise DNSLookupError("OH NO")
  184. resolutionReceiver.addressResolved(
  185. IPv4Address('TCP', self.lookups[hostName], portNumber)
  186. )
  187. resolutionReceiver.resolutionComplete()
  188. return resolution
  189. self.nameResolver = Resolver()
  190. super(ThreadedMemoryReactorClock, self).__init__()
  191. def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
  192. p = udp.Port(port, protocol, interface, maxPacketSize, self)
  193. p.startListening()
  194. self._udp.append(p)
  195. return p
  196. def callFromThread(self, callback, *args, **kwargs):
  197. """
  198. Make the callback fire in the next reactor iteration.
  199. """
  200. d = Deferred()
  201. d.addCallback(lambda x: callback(*args, **kwargs))
  202. self.callLater(0, d.callback, True)
  203. return d
  204. def setup_test_homeserver(cleanup_func, *args, **kwargs):
  205. """
  206. Set up a synchronous test server, driven by the reactor used by
  207. the homeserver.
  208. """
  209. d = _sth(cleanup_func, *args, **kwargs).result
  210. if isinstance(d, Failure):
  211. d.raiseException()
  212. # Make the thread pool synchronous.
  213. clock = d.get_clock()
  214. pool = d.get_db_pool()
  215. def runWithConnection(func, *args, **kwargs):
  216. return threads.deferToThreadPool(
  217. pool._reactor,
  218. pool.threadpool,
  219. pool._runWithConnection,
  220. func,
  221. *args,
  222. **kwargs
  223. )
  224. def runInteraction(interaction, *args, **kwargs):
  225. return threads.deferToThreadPool(
  226. pool._reactor,
  227. pool.threadpool,
  228. pool._runInteraction,
  229. interaction,
  230. *args,
  231. **kwargs
  232. )
  233. pool.runWithConnection = runWithConnection
  234. pool.runInteraction = runInteraction
  235. class ThreadPool:
  236. """
  237. Threadless thread pool.
  238. """
  239. def start(self):
  240. pass
  241. def stop(self):
  242. pass
  243. def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
  244. def _(res):
  245. if isinstance(res, Failure):
  246. onResult(False, res)
  247. else:
  248. onResult(True, res)
  249. d = Deferred()
  250. d.addCallback(lambda x: function(*args, **kwargs))
  251. d.addBoth(_)
  252. clock._reactor.callLater(0, d.callback, True)
  253. return d
  254. clock.threadpool = ThreadPool()
  255. pool.threadpool = ThreadPool()
  256. pool.running = True
  257. return d
  258. def get_clock():
  259. clock = ThreadedMemoryReactorClock()
  260. hs_clock = Clock(clock)
  261. return (clock, hs_clock)
  262. @attr.s
  263. class FakeTransport(object):
  264. """
  265. A twisted.internet.interfaces.ITransport implementation which sends all its data
  266. straight into an IProtocol object: it exists to connect two IProtocols together.
  267. To use it, instantiate it with the receiving IProtocol, and then pass it to the
  268. sending IProtocol's makeConnection method:
  269. server = HTTPChannel()
  270. client.makeConnection(FakeTransport(server, self.reactor))
  271. If you want bidirectional communication, you'll need two instances.
  272. """
  273. other = attr.ib()
  274. """The Protocol object which will receive any data written to this transport.
  275. :type: twisted.internet.interfaces.IProtocol
  276. """
  277. _reactor = attr.ib()
  278. """Test reactor
  279. :type: twisted.internet.interfaces.IReactorTime
  280. """
  281. disconnecting = False
  282. buffer = attr.ib(default=b'')
  283. producer = attr.ib(default=None)
  284. def getPeer(self):
  285. return None
  286. def getHost(self):
  287. return None
  288. def loseConnection(self):
  289. self.disconnecting = True
  290. def abortConnection(self):
  291. self.disconnecting = True
  292. def pauseProducing(self):
  293. if not self.producer:
  294. return
  295. self.producer.pauseProducing()
  296. def resumeProducing(self):
  297. if not self.producer:
  298. return
  299. self.producer.resumeProducing()
  300. def unregisterProducer(self):
  301. if not self.producer:
  302. return
  303. self.producer = None
  304. def registerProducer(self, producer, streaming):
  305. self.producer = producer
  306. self.producerStreaming = streaming
  307. def _produce():
  308. d = self.producer.resumeProducing()
  309. d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
  310. if not streaming:
  311. self._reactor.callLater(0.0, _produce)
  312. def write(self, byt):
  313. self.buffer = self.buffer + byt
  314. def _write():
  315. if getattr(self.other, "transport") is not None:
  316. self.other.dataReceived(self.buffer)
  317. self.buffer = b""
  318. return
  319. self._reactor.callLater(0.0, _write)
  320. _write()
  321. def writeSequence(self, seq):
  322. for x in seq:
  323. self.write(x)