server.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. import json
  2. import logging
  3. from io import BytesIO
  4. from six import text_type
  5. import attr
  6. from zope.interface import implementer
  7. from twisted.internet import address, threads, udp
  8. from twisted.internet._resolver import HostResolution
  9. from twisted.internet.address import IPv4Address
  10. from twisted.internet.defer import Deferred
  11. from twisted.internet.error import DNSLookupError
  12. from twisted.internet.interfaces import IReactorPluggableNameResolver
  13. from twisted.python.failure import Failure
  14. from twisted.test.proto_helpers import MemoryReactorClock
  15. from twisted.web.http import unquote
  16. from twisted.web.http_headers import Headers
  17. from synapse.http.site import SynapseRequest
  18. from synapse.util import Clock
  19. from tests.utils import setup_test_homeserver as _sth
  20. logger = logging.getLogger(__name__)
  21. class TimedOutException(Exception):
  22. """
  23. A web query timed out.
  24. """
  25. @attr.s
  26. class FakeChannel(object):
  27. """
  28. A fake Twisted Web Channel (the part that interfaces with the
  29. wire).
  30. """
  31. _reactor = attr.ib()
  32. result = attr.ib(default=attr.Factory(dict))
  33. _producer = None
  34. @property
  35. def json_body(self):
  36. if not self.result:
  37. raise Exception("No result yet.")
  38. return json.loads(self.result["body"].decode('utf8'))
  39. @property
  40. def code(self):
  41. if not self.result:
  42. raise Exception("No result yet.")
  43. return int(self.result["code"])
  44. @property
  45. def headers(self):
  46. if not self.result:
  47. raise Exception("No result yet.")
  48. h = Headers()
  49. for i in self.result["headers"]:
  50. h.addRawHeader(*i)
  51. return h
  52. def writeHeaders(self, version, code, reason, headers):
  53. self.result["version"] = version
  54. self.result["code"] = code
  55. self.result["reason"] = reason
  56. self.result["headers"] = headers
  57. def write(self, content):
  58. assert isinstance(content, bytes), "Should be bytes! " + repr(content)
  59. if "body" not in self.result:
  60. self.result["body"] = b""
  61. self.result["body"] += content
  62. def registerProducer(self, producer, streaming):
  63. self._producer = producer
  64. self.producerStreaming = streaming
  65. def _produce():
  66. if self._producer:
  67. self._producer.resumeProducing()
  68. self._reactor.callLater(0.1, _produce)
  69. if not streaming:
  70. self._reactor.callLater(0.0, _produce)
  71. def unregisterProducer(self):
  72. if self._producer is None:
  73. return
  74. self._producer = None
  75. def requestDone(self, _self):
  76. self.result["done"] = True
  77. def getPeer(self):
  78. # We give an address so that getClientIP returns a non null entry,
  79. # causing us to record the MAU
  80. return address.IPv4Address("TCP", "127.0.0.1", 3423)
  81. def getHost(self):
  82. return None
  83. @property
  84. def transport(self):
  85. return self
  86. class FakeSite:
  87. """
  88. A fake Twisted Web Site, with mocks of the extra things that
  89. Synapse adds.
  90. """
  91. server_version_string = b"1"
  92. site_tag = "test"
  93. @property
  94. def access_logger(self):
  95. class FakeLogger:
  96. def info(self, *args, **kwargs):
  97. pass
  98. return FakeLogger()
  99. def make_request(
  100. reactor,
  101. method,
  102. path,
  103. content=b"",
  104. access_token=None,
  105. request=SynapseRequest,
  106. shorthand=True,
  107. ):
  108. """
  109. Make a web request using the given method and path, feed it the
  110. content, and return the Request and the Channel underneath.
  111. Args:
  112. method (bytes/unicode): The HTTP request method ("verb").
  113. path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
  114. escaped UTF-8 & spaces and such).
  115. content (bytes or dict): The body of the request. JSON-encoded, if
  116. a dict.
  117. shorthand: Whether to try and be helpful and prefix the given URL
  118. with the usual REST API path, if it doesn't contain it.
  119. Returns:
  120. A synapse.http.site.SynapseRequest.
  121. """
  122. if not isinstance(method, bytes):
  123. method = method.encode('ascii')
  124. if not isinstance(path, bytes):
  125. path = path.encode('ascii')
  126. # Decorate it to be the full path, if we're using shorthand
  127. if shorthand and not path.startswith(b"/_matrix"):
  128. path = b"/_matrix/client/r0/" + path
  129. path = path.replace(b"//", b"/")
  130. if not path.startswith(b"/"):
  131. path = b"/" + path
  132. if isinstance(content, text_type):
  133. content = content.encode('utf8')
  134. site = FakeSite()
  135. channel = FakeChannel(reactor)
  136. req = request(site, channel)
  137. req.process = lambda: b""
  138. req.content = BytesIO(content)
  139. req.postpath = list(map(unquote, path[1:].split(b'/')))
  140. if access_token:
  141. req.requestHeaders.addRawHeader(
  142. b"Authorization", b"Bearer " + access_token.encode('ascii')
  143. )
  144. if content:
  145. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  146. req.requestReceived(method, path, b"1.1")
  147. return req, channel
  148. def wait_until_result(clock, request, timeout=100):
  149. """
  150. Wait until the request is finished.
  151. """
  152. clock.run()
  153. x = 0
  154. while not request.finished:
  155. # If there's a producer, tell it to resume producing so we get content
  156. if request._channel._producer:
  157. request._channel._producer.resumeProducing()
  158. x += 1
  159. if x > timeout:
  160. raise TimedOutException("Timed out waiting for request to finish.")
  161. clock.advance(0.1)
  162. def render(request, resource, clock):
  163. request.render(resource)
  164. wait_until_result(clock, request)
  165. @implementer(IReactorPluggableNameResolver)
  166. class ThreadedMemoryReactorClock(MemoryReactorClock):
  167. """
  168. A MemoryReactorClock that supports callFromThread.
  169. """
  170. def __init__(self):
  171. self._udp = []
  172. self.lookups = {}
  173. class Resolver(object):
  174. def resolveHostName(
  175. _self,
  176. resolutionReceiver,
  177. hostName,
  178. portNumber=0,
  179. addressTypes=None,
  180. transportSemantics='TCP',
  181. ):
  182. resolution = HostResolution(hostName)
  183. resolutionReceiver.resolutionBegan(resolution)
  184. if hostName not in self.lookups:
  185. raise DNSLookupError("OH NO")
  186. resolutionReceiver.addressResolved(
  187. IPv4Address('TCP', self.lookups[hostName], portNumber)
  188. )
  189. resolutionReceiver.resolutionComplete()
  190. return resolution
  191. self.nameResolver = Resolver()
  192. super(ThreadedMemoryReactorClock, self).__init__()
  193. def listenUDP(self, port, protocol, interface='', maxPacketSize=8196):
  194. p = udp.Port(port, protocol, interface, maxPacketSize, self)
  195. p.startListening()
  196. self._udp.append(p)
  197. return p
  198. def callFromThread(self, callback, *args, **kwargs):
  199. """
  200. Make the callback fire in the next reactor iteration.
  201. """
  202. d = Deferred()
  203. d.addCallback(lambda x: callback(*args, **kwargs))
  204. self.callLater(0, d.callback, True)
  205. return d
  206. def setup_test_homeserver(cleanup_func, *args, **kwargs):
  207. """
  208. Set up a synchronous test server, driven by the reactor used by
  209. the homeserver.
  210. """
  211. d = _sth(cleanup_func, *args, **kwargs).result
  212. if isinstance(d, Failure):
  213. d.raiseException()
  214. # Make the thread pool synchronous.
  215. clock = d.get_clock()
  216. pool = d.get_db_pool()
  217. def runWithConnection(func, *args, **kwargs):
  218. return threads.deferToThreadPool(
  219. pool._reactor,
  220. pool.threadpool,
  221. pool._runWithConnection,
  222. func,
  223. *args,
  224. **kwargs
  225. )
  226. def runInteraction(interaction, *args, **kwargs):
  227. return threads.deferToThreadPool(
  228. pool._reactor,
  229. pool.threadpool,
  230. pool._runInteraction,
  231. interaction,
  232. *args,
  233. **kwargs
  234. )
  235. pool.runWithConnection = runWithConnection
  236. pool.runInteraction = runInteraction
  237. class ThreadPool:
  238. """
  239. Threadless thread pool.
  240. """
  241. def start(self):
  242. pass
  243. def stop(self):
  244. pass
  245. def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
  246. def _(res):
  247. if isinstance(res, Failure):
  248. onResult(False, res)
  249. else:
  250. onResult(True, res)
  251. d = Deferred()
  252. d.addCallback(lambda x: function(*args, **kwargs))
  253. d.addBoth(_)
  254. clock._reactor.callLater(0, d.callback, True)
  255. return d
  256. clock.threadpool = ThreadPool()
  257. pool.threadpool = ThreadPool()
  258. pool.running = True
  259. return d
  260. def get_clock():
  261. clock = ThreadedMemoryReactorClock()
  262. hs_clock = Clock(clock)
  263. return (clock, hs_clock)
  264. @attr.s(cmp=False)
  265. class FakeTransport(object):
  266. """
  267. A twisted.internet.interfaces.ITransport implementation which sends all its data
  268. straight into an IProtocol object: it exists to connect two IProtocols together.
  269. To use it, instantiate it with the receiving IProtocol, and then pass it to the
  270. sending IProtocol's makeConnection method:
  271. server = HTTPChannel()
  272. client.makeConnection(FakeTransport(server, self.reactor))
  273. If you want bidirectional communication, you'll need two instances.
  274. """
  275. other = attr.ib()
  276. """The Protocol object which will receive any data written to this transport.
  277. :type: twisted.internet.interfaces.IProtocol
  278. """
  279. _reactor = attr.ib()
  280. """Test reactor
  281. :type: twisted.internet.interfaces.IReactorTime
  282. """
  283. disconnecting = False
  284. buffer = attr.ib(default=b'')
  285. producer = attr.ib(default=None)
  286. def getPeer(self):
  287. return None
  288. def getHost(self):
  289. return None
  290. def loseConnection(self):
  291. self.disconnecting = True
  292. def abortConnection(self):
  293. self.disconnecting = True
  294. def pauseProducing(self):
  295. if not self.producer:
  296. return
  297. self.producer.pauseProducing()
  298. def resumeProducing(self):
  299. if not self.producer:
  300. return
  301. self.producer.resumeProducing()
  302. def unregisterProducer(self):
  303. if not self.producer:
  304. return
  305. self.producer = None
  306. def registerProducer(self, producer, streaming):
  307. self.producer = producer
  308. self.producerStreaming = streaming
  309. def _produce():
  310. d = self.producer.resumeProducing()
  311. d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
  312. if not streaming:
  313. self._reactor.callLater(0.0, _produce)
  314. def write(self, byt):
  315. self.buffer = self.buffer + byt
  316. def _write():
  317. if not self.buffer:
  318. # nothing to do. Don't write empty buffers: it upsets the
  319. # TLSMemoryBIOProtocol
  320. return
  321. if getattr(self.other, "transport") is not None:
  322. self.other.dataReceived(self.buffer)
  323. self.buffer = b""
  324. return
  325. self._reactor.callLater(0.0, _write)
  326. # always actually do the write asynchronously. Some protocols (notably the
  327. # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
  328. # still doing a write. Doing a callLater here breaks the cycle.
  329. self._reactor.callLater(0.0, _write)
  330. def writeSequence(self, seq):
  331. for x in seq:
  332. self.write(x)