server.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675
  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import logging
  16. from collections import deque
  17. from io import SEEK_END, BytesIO
  18. from typing import (
  19. AnyStr,
  20. Callable,
  21. Dict,
  22. Iterable,
  23. MutableMapping,
  24. Optional,
  25. Tuple,
  26. Type,
  27. Union,
  28. )
  29. import attr
  30. from typing_extensions import Deque
  31. from zope.interface import implementer
  32. from twisted.internet import address, threads, udp
  33. from twisted.internet._resolver import SimpleResolverComplexifier
  34. from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
  35. from twisted.internet.error import DNSLookupError
  36. from twisted.internet.interfaces import (
  37. IAddress,
  38. IHostnameResolver,
  39. IProtocol,
  40. IPullProducer,
  41. IPushProducer,
  42. IReactorPluggableNameResolver,
  43. IReactorTime,
  44. IResolverSimple,
  45. ITransport,
  46. )
  47. from twisted.python.failure import Failure
  48. from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
  49. from twisted.web.http_headers import Headers
  50. from twisted.web.resource import IResource
  51. from twisted.web.server import Request, Site
  52. from synapse.http.site import SynapseRequest
  53. from synapse.types import JsonDict
  54. from synapse.util import Clock
  55. from tests.utils import setup_test_homeserver as _sth
  56. logger = logging.getLogger(__name__)
  57. class TimedOutException(Exception):
  58. """
  59. A web query timed out.
  60. """
  61. @attr.s
  62. class FakeChannel:
  63. """
  64. A fake Twisted Web Channel (the part that interfaces with the
  65. wire).
  66. """
  67. site = attr.ib(type=Union[Site, "FakeSite"])
  68. _reactor = attr.ib()
  69. result = attr.ib(type=dict, default=attr.Factory(dict))
  70. _ip = attr.ib(type=str, default="127.0.0.1")
  71. _producer: Optional[Union[IPullProducer, IPushProducer]] = None
  72. @property
  73. def json_body(self):
  74. return json.loads(self.text_body)
  75. @property
  76. def text_body(self) -> str:
  77. """The body of the result, utf-8-decoded.
  78. Raises an exception if the request has not yet completed.
  79. """
  80. if not self.is_finished:
  81. raise Exception("Request not yet completed")
  82. return self.result["body"].decode("utf8")
  83. def is_finished(self) -> bool:
  84. """check if the response has been completely received"""
  85. return self.result.get("done", False)
  86. @property
  87. def code(self):
  88. if not self.result:
  89. raise Exception("No result yet.")
  90. return int(self.result["code"])
  91. @property
  92. def headers(self) -> Headers:
  93. if not self.result:
  94. raise Exception("No result yet.")
  95. h = Headers()
  96. for i in self.result["headers"]:
  97. h.addRawHeader(*i)
  98. return h
  99. def writeHeaders(self, version, code, reason, headers):
  100. self.result["version"] = version
  101. self.result["code"] = code
  102. self.result["reason"] = reason
  103. self.result["headers"] = headers
  104. def write(self, content):
  105. assert isinstance(content, bytes), "Should be bytes! " + repr(content)
  106. if "body" not in self.result:
  107. self.result["body"] = b""
  108. self.result["body"] += content
  109. def registerProducer(self, producer, streaming):
  110. self._producer = producer
  111. self.producerStreaming = streaming
  112. def _produce():
  113. if self._producer:
  114. self._producer.resumeProducing()
  115. self._reactor.callLater(0.1, _produce)
  116. if not streaming:
  117. self._reactor.callLater(0.0, _produce)
  118. def unregisterProducer(self):
  119. if self._producer is None:
  120. return
  121. self._producer = None
  122. def requestDone(self, _self):
  123. self.result["done"] = True
  124. def getPeer(self):
  125. # We give an address so that getClientIP returns a non null entry,
  126. # causing us to record the MAU
  127. return address.IPv4Address("TCP", self._ip, 3423)
  128. def getHost(self):
  129. # this is called by Request.__init__ to configure Request.host.
  130. return address.IPv4Address("TCP", "127.0.0.1", 8888)
  131. def isSecure(self):
  132. return False
  133. @property
  134. def transport(self):
  135. return self
  136. def await_result(self, timeout_ms: int = 1000) -> None:
  137. """
  138. Wait until the request is finished.
  139. """
  140. end_time = self._reactor.seconds() + timeout_ms / 1000.0
  141. self._reactor.run()
  142. while not self.is_finished():
  143. # If there's a producer, tell it to resume producing so we get content
  144. if self._producer:
  145. self._producer.resumeProducing()
  146. if self._reactor.seconds() > end_time:
  147. raise TimedOutException("Timed out waiting for request to finish.")
  148. self._reactor.advance(0.1)
  149. def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
  150. """Process the contents of any Set-Cookie headers in the response
  151. Any cookines found are added to the given dict
  152. """
  153. headers = self.headers.getRawHeaders("Set-Cookie")
  154. if not headers:
  155. return
  156. for h in headers:
  157. parts = h.split(";")
  158. k, v = parts[0].split("=", maxsplit=1)
  159. cookies[k] = v
  160. class FakeSite:
  161. """
  162. A fake Twisted Web Site, with mocks of the extra things that
  163. Synapse adds.
  164. """
  165. server_version_string = b"1"
  166. site_tag = "test"
  167. access_logger = logging.getLogger("synapse.access.http.fake")
  168. def __init__(self, resource: IResource, reactor: IReactorTime):
  169. """
  170. Args:
  171. resource: the resource to be used for rendering all requests
  172. """
  173. self._resource = resource
  174. self.reactor = reactor
  175. def getResourceFor(self, request):
  176. return self._resource
  177. def make_request(
  178. reactor,
  179. site: Union[Site, FakeSite],
  180. method: Union[bytes, str],
  181. path: Union[bytes, str],
  182. content: Union[bytes, str, JsonDict] = b"",
  183. access_token: Optional[str] = None,
  184. request: Type[Request] = SynapseRequest,
  185. shorthand: bool = True,
  186. federation_auth_origin: Optional[bytes] = None,
  187. content_is_form: bool = False,
  188. await_result: bool = True,
  189. custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
  190. client_ip: str = "127.0.0.1",
  191. ) -> FakeChannel:
  192. """
  193. Make a web request using the given method, path and content, and render it
  194. Returns the fake Channel object which records the response to the request.
  195. Args:
  196. reactor:
  197. site: The twisted Site to use to render the request
  198. method: The HTTP request method ("verb").
  199. path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
  200. content: The body of the request. JSON-encoded, if a str of bytes.
  201. access_token: The access token to add as authorization for the request.
  202. request: The request class to create.
  203. shorthand: Whether to try and be helpful and prefix the given URL
  204. with the usual REST API path, if it doesn't contain it.
  205. federation_auth_origin: if set to not-None, we will add a fake
  206. Authorization header pretenting to be the given server name.
  207. content_is_form: Whether the content is URL encoded form data. Adds the
  208. 'Content-Type': 'application/x-www-form-urlencoded' header.
  209. await_result: whether to wait for the request to complete rendering. If true,
  210. will pump the reactor until the the renderer tells the channel the request
  211. is finished.
  212. custom_headers: (name, value) pairs to add as request headers
  213. client_ip: The IP to use as the requesting IP. Useful for testing
  214. ratelimiting.
  215. Returns:
  216. channel
  217. """
  218. if not isinstance(method, bytes):
  219. method = method.encode("ascii")
  220. if not isinstance(path, bytes):
  221. path = path.encode("ascii")
  222. # Decorate it to be the full path, if we're using shorthand
  223. if (
  224. shorthand
  225. and not path.startswith(b"/_matrix")
  226. and not path.startswith(b"/_synapse")
  227. ):
  228. if path.startswith(b"/"):
  229. path = path[1:]
  230. path = b"/_matrix/client/r0/" + path
  231. if not path.startswith(b"/"):
  232. path = b"/" + path
  233. if isinstance(content, dict):
  234. content = json.dumps(content).encode("utf8")
  235. if isinstance(content, str):
  236. content = content.encode("utf8")
  237. channel = FakeChannel(site, reactor, ip=client_ip)
  238. req = request(channel, site)
  239. req.content = BytesIO(content)
  240. # Twisted expects to be at the end of the content when parsing the request.
  241. req.content.seek(SEEK_END)
  242. if access_token:
  243. req.requestHeaders.addRawHeader(
  244. b"Authorization", b"Bearer " + access_token.encode("ascii")
  245. )
  246. if federation_auth_origin is not None:
  247. req.requestHeaders.addRawHeader(
  248. b"Authorization",
  249. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
  250. )
  251. if content:
  252. if content_is_form:
  253. req.requestHeaders.addRawHeader(
  254. b"Content-Type", b"application/x-www-form-urlencoded"
  255. )
  256. else:
  257. # Assume the body is JSON
  258. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  259. if custom_headers:
  260. for k, v in custom_headers:
  261. req.requestHeaders.addRawHeader(k, v)
  262. req.parseCookies()
  263. req.requestReceived(method, path, b"1.1")
  264. if await_result:
  265. channel.await_result()
  266. return channel
  267. @implementer(IReactorPluggableNameResolver)
  268. class ThreadedMemoryReactorClock(MemoryReactorClock):
  269. """
  270. A MemoryReactorClock that supports callFromThread.
  271. """
  272. def __init__(self):
  273. self.threadpool = ThreadPool(self)
  274. self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
  275. self._udp = []
  276. self.lookups: Dict[str, str] = {}
  277. self._thread_callbacks: Deque[Callable[[], None]] = deque()
  278. lookups = self.lookups
  279. @implementer(IResolverSimple)
  280. class FakeResolver:
  281. def getHostByName(self, name, timeout=None):
  282. if name not in lookups:
  283. return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
  284. return succeed(lookups[name])
  285. self.nameResolver = SimpleResolverComplexifier(FakeResolver())
  286. super().__init__()
  287. def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
  288. raise NotImplementedError()
  289. def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
  290. p = udp.Port(port, protocol, interface, maxPacketSize, self)
  291. p.startListening()
  292. self._udp.append(p)
  293. return p
  294. def callFromThread(self, callback, *args, **kwargs):
  295. """
  296. Make the callback fire in the next reactor iteration.
  297. """
  298. cb = lambda: callback(*args, **kwargs)
  299. # it's not safe to call callLater() here, so we append the callback to a
  300. # separate queue.
  301. self._thread_callbacks.append(cb)
  302. def getThreadPool(self):
  303. return self.threadpool
  304. def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
  305. """Add a callback that will be invoked when we receive a connection
  306. attempt to the given IP/port using `connectTCP`.
  307. Note that the callback gets run before we return the connection to the
  308. client, which means callbacks cannot block while waiting for writes.
  309. """
  310. self._tcp_callbacks[(host, port)] = callback
  311. def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
  312. """Fake L{IReactorTCP.connectTCP}."""
  313. conn = super().connectTCP(
  314. host, port, factory, timeout=timeout, bindAddress=None
  315. )
  316. callback = self._tcp_callbacks.get((host, port))
  317. if callback:
  318. callback()
  319. return conn
  320. def advance(self, amount):
  321. # first advance our reactor's time, and run any "callLater" callbacks that
  322. # makes ready
  323. super().advance(amount)
  324. # now run any "callFromThread" callbacks
  325. while True:
  326. try:
  327. callback = self._thread_callbacks.popleft()
  328. except IndexError:
  329. break
  330. callback()
  331. # check for more "callLater" callbacks added by the thread callback
  332. # This isn't required in a regular reactor, but it ends up meaning that
  333. # our database queries can complete in a single call to `advance` [1] which
  334. # simplifies tests.
  335. #
  336. # [1]: we replace the threadpool backing the db connection pool with a
  337. # mock ThreadPool which doesn't really use threads; but we still use
  338. # reactor.callFromThread to feed results back from the db functions to the
  339. # main thread.
  340. super().advance(0)
  341. class ThreadPool:
  342. """
  343. Threadless thread pool.
  344. """
  345. def __init__(self, reactor):
  346. self._reactor = reactor
  347. def start(self):
  348. pass
  349. def stop(self):
  350. pass
  351. def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
  352. def _(res):
  353. if isinstance(res, Failure):
  354. onResult(False, res)
  355. else:
  356. onResult(True, res)
  357. d = Deferred()
  358. d.addCallback(lambda x: function(*args, **kwargs))
  359. d.addBoth(_)
  360. self._reactor.callLater(0, d.callback, True)
  361. return d
  362. def setup_test_homeserver(cleanup_func, *args, **kwargs):
  363. """
  364. Set up a synchronous test server, driven by the reactor used by
  365. the homeserver.
  366. """
  367. server = _sth(cleanup_func, *args, **kwargs)
  368. # Make the thread pool synchronous.
  369. clock = server.get_clock()
  370. for database in server.get_datastores().databases:
  371. pool = database._db_pool
  372. def runWithConnection(func, *args, **kwargs):
  373. return threads.deferToThreadPool(
  374. pool._reactor,
  375. pool.threadpool,
  376. pool._runWithConnection,
  377. func,
  378. *args,
  379. **kwargs,
  380. )
  381. def runInteraction(interaction, *args, **kwargs):
  382. return threads.deferToThreadPool(
  383. pool._reactor,
  384. pool.threadpool,
  385. pool._runInteraction,
  386. interaction,
  387. *args,
  388. **kwargs,
  389. )
  390. pool.runWithConnection = runWithConnection
  391. pool.runInteraction = runInteraction
  392. pool.threadpool = ThreadPool(clock._reactor)
  393. pool.running = True
  394. # We've just changed the Databases to run DB transactions on the same
  395. # thread, so we need to disable the dedicated thread behaviour.
  396. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
  397. return server
  398. def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
  399. clock = ThreadedMemoryReactorClock()
  400. hs_clock = Clock(clock)
  401. return clock, hs_clock
  402. @implementer(ITransport)
  403. @attr.s(cmp=False)
  404. class FakeTransport:
  405. """
  406. A twisted.internet.interfaces.ITransport implementation which sends all its data
  407. straight into an IProtocol object: it exists to connect two IProtocols together.
  408. To use it, instantiate it with the receiving IProtocol, and then pass it to the
  409. sending IProtocol's makeConnection method:
  410. server = HTTPChannel()
  411. client.makeConnection(FakeTransport(server, self.reactor))
  412. If you want bidirectional communication, you'll need two instances.
  413. """
  414. other = attr.ib()
  415. """The Protocol object which will receive any data written to this transport.
  416. :type: twisted.internet.interfaces.IProtocol
  417. """
  418. _reactor = attr.ib()
  419. """Test reactor
  420. :type: twisted.internet.interfaces.IReactorTime
  421. """
  422. _protocol = attr.ib(default=None)
  423. """The Protocol which is producing data for this transport. Optional, but if set
  424. will get called back for connectionLost() notifications etc.
  425. """
  426. _peer_address: Optional[IAddress] = attr.ib(default=None)
  427. """The value to be returend by getPeer"""
  428. disconnecting = False
  429. disconnected = False
  430. connected = True
  431. buffer = attr.ib(default=b"")
  432. producer = attr.ib(default=None)
  433. autoflush = attr.ib(default=True)
  434. def getPeer(self):
  435. return self._peer_address
  436. def getHost(self):
  437. return None
  438. def loseConnection(self, reason=None):
  439. if not self.disconnecting:
  440. logger.info("FakeTransport: loseConnection(%s)", reason)
  441. self.disconnecting = True
  442. if self._protocol:
  443. self._protocol.connectionLost(reason)
  444. # if we still have data to write, delay until that is done
  445. if self.buffer:
  446. logger.info(
  447. "FakeTransport: Delaying disconnect until buffer is flushed"
  448. )
  449. else:
  450. self.connected = False
  451. self.disconnected = True
  452. def abortConnection(self):
  453. logger.info("FakeTransport: abortConnection()")
  454. if not self.disconnecting:
  455. self.disconnecting = True
  456. if self._protocol:
  457. self._protocol.connectionLost(None)
  458. self.disconnected = True
  459. def pauseProducing(self):
  460. if not self.producer:
  461. return
  462. self.producer.pauseProducing()
  463. def resumeProducing(self):
  464. if not self.producer:
  465. return
  466. self.producer.resumeProducing()
  467. def unregisterProducer(self):
  468. if not self.producer:
  469. return
  470. self.producer = None
  471. def registerProducer(self, producer, streaming):
  472. self.producer = producer
  473. self.producerStreaming = streaming
  474. def _produce():
  475. if not self.producer:
  476. # we've been unregistered
  477. return
  478. # some implementations of IProducer (for example, FileSender)
  479. # don't return a deferred.
  480. d = maybeDeferred(self.producer.resumeProducing)
  481. d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
  482. if not streaming:
  483. self._reactor.callLater(0.0, _produce)
  484. def write(self, byt):
  485. if self.disconnecting:
  486. raise Exception("Writing to disconnecting FakeTransport")
  487. self.buffer = self.buffer + byt
  488. # always actually do the write asynchronously. Some protocols (notably the
  489. # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
  490. # still doing a write. Doing a callLater here breaks the cycle.
  491. if self.autoflush:
  492. self._reactor.callLater(0.0, self.flush)
  493. def writeSequence(self, seq):
  494. for x in seq:
  495. self.write(x)
  496. def flush(self, maxbytes=None):
  497. if not self.buffer:
  498. # nothing to do. Don't write empty buffers: it upsets the
  499. # TLSMemoryBIOProtocol
  500. return
  501. if self.disconnected:
  502. return
  503. if maxbytes is not None:
  504. to_write = self.buffer[:maxbytes]
  505. else:
  506. to_write = self.buffer
  507. logger.info("%s->%s: %s", self._protocol, self.other, to_write)
  508. try:
  509. self.other.dataReceived(to_write)
  510. except Exception as e:
  511. logger.exception("Exception writing to protocol: %s", e)
  512. return
  513. self.buffer = self.buffer[len(to_write) :]
  514. if self.buffer and self.autoflush:
  515. self._reactor.callLater(0.0, self.flush)
  516. if not self.buffer and self.disconnecting:
  517. logger.info("FakeTransport: Buffer now empty, completing disconnect")
  518. self.disconnected = True
  519. def connect_client(
  520. reactor: ThreadedMemoryReactorClock, client_id: int
  521. ) -> Tuple[IProtocol, AccumulatingProtocol]:
  522. """
  523. Connect a client to a fake TCP transport.
  524. Args:
  525. reactor
  526. factory: The connecting factory to build.
  527. """
  528. factory = reactor.tcpClients.pop(client_id)[2]
  529. client = factory.buildProtocol(None)
  530. server = AccumulatingProtocol()
  531. server.makeConnection(FakeTransport(client, reactor))
  532. client.makeConnection(FakeTransport(server, reactor))
  533. return client, server