1
0

server.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916
  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import hashlib
  15. import json
  16. import logging
  17. import os
  18. import os.path
  19. import time
  20. import uuid
  21. import warnings
  22. from collections import deque
  23. from io import SEEK_END, BytesIO
  24. from typing import (
  25. Callable,
  26. Dict,
  27. Iterable,
  28. List,
  29. MutableMapping,
  30. Optional,
  31. Tuple,
  32. Type,
  33. Union,
  34. )
  35. from unittest.mock import Mock
  36. import attr
  37. from typing_extensions import Deque
  38. from zope.interface import implementer
  39. from twisted.internet import address, threads, udp
  40. from twisted.internet._resolver import SimpleResolverComplexifier
  41. from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
  42. from twisted.internet.error import DNSLookupError
  43. from twisted.internet.interfaces import (
  44. IAddress,
  45. IConsumer,
  46. IHostnameResolver,
  47. IProtocol,
  48. IPullProducer,
  49. IPushProducer,
  50. IReactorPluggableNameResolver,
  51. IReactorTime,
  52. IResolverSimple,
  53. ITransport,
  54. )
  55. from twisted.python.failure import Failure
  56. from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
  57. from twisted.web.http_headers import Headers
  58. from twisted.web.resource import IResource
  59. from twisted.web.server import Request, Site
  60. from synapse.config.database import DatabaseConnectionConfig
  61. from synapse.http.site import SynapseRequest
  62. from synapse.logging.context import ContextResourceUsage
  63. from synapse.server import HomeServer
  64. from synapse.storage import DataStore
  65. from synapse.storage.engines import PostgresEngine, create_engine
  66. from synapse.types import JsonDict
  67. from synapse.util import Clock
  68. from tests.utils import (
  69. LEAVE_DB,
  70. POSTGRES_BASE_DB,
  71. POSTGRES_HOST,
  72. POSTGRES_PASSWORD,
  73. POSTGRES_PORT,
  74. POSTGRES_USER,
  75. SQLITE_PERSIST_DB,
  76. USE_POSTGRES_FOR_TESTS,
  77. MockClock,
  78. default_config,
  79. )
  80. logger = logging.getLogger(__name__)
  81. # the type of thing that can be passed into `make_request` in the headers list
  82. CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
  83. class TimedOutException(Exception):
  84. """
  85. A web query timed out.
  86. """
  87. @implementer(IConsumer)
  88. @attr.s(auto_attribs=True)
  89. class FakeChannel:
  90. """
  91. A fake Twisted Web Channel (the part that interfaces with the
  92. wire).
  93. """
  94. site: Union[Site, "FakeSite"]
  95. _reactor: MemoryReactorClock
  96. result: dict = attr.Factory(dict)
  97. _ip: str = "127.0.0.1"
  98. _producer: Optional[Union[IPullProducer, IPushProducer]] = None
  99. resource_usage: Optional[ContextResourceUsage] = None
  100. _request: Optional[Request] = None
  101. @property
  102. def request(self) -> Request:
  103. assert self._request is not None
  104. return self._request
  105. @request.setter
  106. def request(self, request: Request) -> None:
  107. assert self._request is None
  108. self._request = request
  109. @property
  110. def json_body(self) -> JsonDict:
  111. body = json.loads(self.text_body)
  112. assert isinstance(body, dict)
  113. return body
  114. @property
  115. def json_list(self) -> List[JsonDict]:
  116. body = json.loads(self.text_body)
  117. assert isinstance(body, list)
  118. return body
  119. @property
  120. def text_body(self) -> str:
  121. """The body of the result, utf-8-decoded.
  122. Raises an exception if the request has not yet completed.
  123. """
  124. if not self.is_finished:
  125. raise Exception("Request not yet completed")
  126. return self.result["body"].decode("utf8")
  127. def is_finished(self) -> bool:
  128. """check if the response has been completely received"""
  129. return self.result.get("done", False)
  130. @property
  131. def code(self) -> int:
  132. if not self.result:
  133. raise Exception("No result yet.")
  134. return int(self.result["code"])
  135. @property
  136. def headers(self) -> Headers:
  137. if not self.result:
  138. raise Exception("No result yet.")
  139. h = Headers()
  140. for i in self.result["headers"]:
  141. h.addRawHeader(*i)
  142. return h
  143. def writeHeaders(self, version, code, reason, headers):
  144. self.result["version"] = version
  145. self.result["code"] = code
  146. self.result["reason"] = reason
  147. self.result["headers"] = headers
  148. def write(self, content: bytes) -> None:
  149. assert isinstance(content, bytes), "Should be bytes! " + repr(content)
  150. if "body" not in self.result:
  151. self.result["body"] = b""
  152. self.result["body"] += content
  153. # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
  154. def registerProducer( # type: ignore[override]
  155. self,
  156. producer: Union[IPullProducer, IPushProducer],
  157. streaming: bool,
  158. ) -> None:
  159. self._producer = producer
  160. self.producerStreaming = streaming
  161. def _produce() -> None:
  162. if self._producer:
  163. self._producer.resumeProducing()
  164. self._reactor.callLater(0.1, _produce)
  165. if not streaming:
  166. self._reactor.callLater(0.0, _produce)
  167. def unregisterProducer(self) -> None:
  168. if self._producer is None:
  169. return
  170. self._producer = None
  171. def requestDone(self, _self: Request) -> None:
  172. self.result["done"] = True
  173. if isinstance(_self, SynapseRequest):
  174. assert _self.logcontext is not None
  175. self.resource_usage = _self.logcontext.get_resource_usage()
  176. def getPeer(self) -> IAddress:
  177. # We give an address so that getClientAddress/getClientIP returns a non null entry,
  178. # causing us to record the MAU
  179. return address.IPv4Address("TCP", self._ip, 3423)
  180. def getHost(self) -> IAddress:
  181. # this is called by Request.__init__ to configure Request.host.
  182. return address.IPv4Address("TCP", "127.0.0.1", 8888)
  183. def isSecure(self) -> bool:
  184. return False
  185. @property
  186. def transport(self) -> "FakeChannel":
  187. return self
  188. def await_result(self, timeout_ms: int = 1000) -> None:
  189. """
  190. Wait until the request is finished.
  191. """
  192. end_time = self._reactor.seconds() + timeout_ms / 1000.0
  193. self._reactor.run()
  194. while not self.is_finished():
  195. # If there's a producer, tell it to resume producing so we get content
  196. if self._producer:
  197. self._producer.resumeProducing()
  198. if self._reactor.seconds() > end_time:
  199. raise TimedOutException("Timed out waiting for request to finish.")
  200. self._reactor.advance(0.1)
  201. def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
  202. """Process the contents of any Set-Cookie headers in the response
  203. Any cookines found are added to the given dict
  204. """
  205. headers = self.headers.getRawHeaders("Set-Cookie")
  206. if not headers:
  207. return
  208. for h in headers:
  209. parts = h.split(";")
  210. k, v = parts[0].split("=", maxsplit=1)
  211. cookies[k] = v
  212. class FakeSite:
  213. """
  214. A fake Twisted Web Site, with mocks of the extra things that
  215. Synapse adds.
  216. """
  217. server_version_string = b"1"
  218. site_tag = "test"
  219. access_logger = logging.getLogger("synapse.access.http.fake")
  220. def __init__(self, resource: IResource, reactor: IReactorTime):
  221. """
  222. Args:
  223. resource: the resource to be used for rendering all requests
  224. """
  225. self._resource = resource
  226. self.reactor = reactor
  227. def getResourceFor(self, request):
  228. return self._resource
  229. def make_request(
  230. reactor,
  231. site: Union[Site, FakeSite],
  232. method: Union[bytes, str],
  233. path: Union[bytes, str],
  234. content: Union[bytes, str, JsonDict] = b"",
  235. access_token: Optional[str] = None,
  236. request: Type[Request] = SynapseRequest,
  237. shorthand: bool = True,
  238. federation_auth_origin: Optional[bytes] = None,
  239. content_is_form: bool = False,
  240. await_result: bool = True,
  241. custom_headers: Optional[Iterable[CustomHeaderType]] = None,
  242. client_ip: str = "127.0.0.1",
  243. ) -> FakeChannel:
  244. """
  245. Make a web request using the given method, path and content, and render it
  246. Returns the fake Channel object which records the response to the request.
  247. Args:
  248. reactor:
  249. site: The twisted Site to use to render the request
  250. method: The HTTP request method ("verb").
  251. path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
  252. content: The body of the request. JSON-encoded, if a str of bytes.
  253. access_token: The access token to add as authorization for the request.
  254. request: The request class to create.
  255. shorthand: Whether to try and be helpful and prefix the given URL
  256. with the usual REST API path, if it doesn't contain it.
  257. federation_auth_origin: if set to not-None, we will add a fake
  258. Authorization header pretenting to be the given server name.
  259. content_is_form: Whether the content is URL encoded form data. Adds the
  260. 'Content-Type': 'application/x-www-form-urlencoded' header.
  261. await_result: whether to wait for the request to complete rendering. If true,
  262. will pump the reactor until the the renderer tells the channel the request
  263. is finished.
  264. custom_headers: (name, value) pairs to add as request headers
  265. client_ip: The IP to use as the requesting IP. Useful for testing
  266. ratelimiting.
  267. Returns:
  268. channel
  269. """
  270. if not isinstance(method, bytes):
  271. method = method.encode("ascii")
  272. if not isinstance(path, bytes):
  273. path = path.encode("ascii")
  274. # Decorate it to be the full path, if we're using shorthand
  275. if (
  276. shorthand
  277. and not path.startswith(b"/_matrix")
  278. and not path.startswith(b"/_synapse")
  279. ):
  280. if path.startswith(b"/"):
  281. path = path[1:]
  282. path = b"/_matrix/client/r0/" + path
  283. if not path.startswith(b"/"):
  284. path = b"/" + path
  285. if isinstance(content, dict):
  286. content = json.dumps(content).encode("utf8")
  287. if isinstance(content, str):
  288. content = content.encode("utf8")
  289. channel = FakeChannel(site, reactor, ip=client_ip)
  290. req = request(channel, site)
  291. channel.request = req
  292. req.content = BytesIO(content)
  293. # Twisted expects to be at the end of the content when parsing the request.
  294. req.content.seek(0, SEEK_END)
  295. if access_token:
  296. req.requestHeaders.addRawHeader(
  297. b"Authorization", b"Bearer " + access_token.encode("ascii")
  298. )
  299. if federation_auth_origin is not None:
  300. req.requestHeaders.addRawHeader(
  301. b"Authorization",
  302. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
  303. )
  304. if content:
  305. if content_is_form:
  306. req.requestHeaders.addRawHeader(
  307. b"Content-Type", b"application/x-www-form-urlencoded"
  308. )
  309. else:
  310. # Assume the body is JSON
  311. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  312. if custom_headers:
  313. for k, v in custom_headers:
  314. req.requestHeaders.addRawHeader(k, v)
  315. req.parseCookies()
  316. req.requestReceived(method, path, b"1.1")
  317. if await_result:
  318. channel.await_result()
  319. return channel
  320. @implementer(IReactorPluggableNameResolver)
  321. class ThreadedMemoryReactorClock(MemoryReactorClock):
  322. """
  323. A MemoryReactorClock that supports callFromThread.
  324. """
  325. def __init__(self):
  326. self.threadpool = ThreadPool(self)
  327. self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
  328. self._udp = []
  329. self.lookups: Dict[str, str] = {}
  330. self._thread_callbacks: Deque[Callable[[], None]] = deque()
  331. lookups = self.lookups
  332. @implementer(IResolverSimple)
  333. class FakeResolver:
  334. def getHostByName(self, name, timeout=None):
  335. if name not in lookups:
  336. return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
  337. return succeed(lookups[name])
  338. self.nameResolver = SimpleResolverComplexifier(FakeResolver())
  339. super().__init__()
  340. def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
  341. raise NotImplementedError()
  342. def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
  343. p = udp.Port(port, protocol, interface, maxPacketSize, self)
  344. p.startListening()
  345. self._udp.append(p)
  346. return p
  347. def callFromThread(self, callback, *args, **kwargs):
  348. """
  349. Make the callback fire in the next reactor iteration.
  350. """
  351. cb = lambda: callback(*args, **kwargs)
  352. # it's not safe to call callLater() here, so we append the callback to a
  353. # separate queue.
  354. self._thread_callbacks.append(cb)
  355. def getThreadPool(self):
  356. return self.threadpool
  357. def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
  358. """Add a callback that will be invoked when we receive a connection
  359. attempt to the given IP/port using `connectTCP`.
  360. Note that the callback gets run before we return the connection to the
  361. client, which means callbacks cannot block while waiting for writes.
  362. """
  363. self._tcp_callbacks[(host, port)] = callback
  364. def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
  365. """Fake L{IReactorTCP.connectTCP}."""
  366. conn = super().connectTCP(
  367. host, port, factory, timeout=timeout, bindAddress=None
  368. )
  369. callback = self._tcp_callbacks.get((host, port))
  370. if callback:
  371. callback()
  372. return conn
  373. def advance(self, amount):
  374. # first advance our reactor's time, and run any "callLater" callbacks that
  375. # makes ready
  376. super().advance(amount)
  377. # now run any "callFromThread" callbacks
  378. while True:
  379. try:
  380. callback = self._thread_callbacks.popleft()
  381. except IndexError:
  382. break
  383. callback()
  384. # check for more "callLater" callbacks added by the thread callback
  385. # This isn't required in a regular reactor, but it ends up meaning that
  386. # our database queries can complete in a single call to `advance` [1] which
  387. # simplifies tests.
  388. #
  389. # [1]: we replace the threadpool backing the db connection pool with a
  390. # mock ThreadPool which doesn't really use threads; but we still use
  391. # reactor.callFromThread to feed results back from the db functions to the
  392. # main thread.
  393. super().advance(0)
  394. class ThreadPool:
  395. """
  396. Threadless thread pool.
  397. """
  398. def __init__(self, reactor):
  399. self._reactor = reactor
  400. def start(self):
  401. pass
  402. def stop(self):
  403. pass
  404. def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
  405. def _(res):
  406. if isinstance(res, Failure):
  407. onResult(False, res)
  408. else:
  409. onResult(True, res)
  410. d = Deferred()
  411. d.addCallback(lambda x: function(*args, **kwargs))
  412. d.addBoth(_)
  413. self._reactor.callLater(0, d.callback, True)
  414. return d
  415. def _make_test_homeserver_synchronous(server: HomeServer) -> None:
  416. """
  417. Make the given test homeserver's database interactions synchronous.
  418. """
  419. clock = server.get_clock()
  420. for database in server.get_datastores().databases:
  421. pool = database._db_pool
  422. def runWithConnection(func, *args, **kwargs):
  423. return threads.deferToThreadPool(
  424. pool._reactor,
  425. pool.threadpool,
  426. pool._runWithConnection,
  427. func,
  428. *args,
  429. **kwargs,
  430. )
  431. def runInteraction(interaction, *args, **kwargs):
  432. return threads.deferToThreadPool(
  433. pool._reactor,
  434. pool.threadpool,
  435. pool._runInteraction,
  436. interaction,
  437. *args,
  438. **kwargs,
  439. )
  440. pool.runWithConnection = runWithConnection
  441. pool.runInteraction = runInteraction
  442. # Replace the thread pool with a threadless 'thread' pool
  443. pool.threadpool = ThreadPool(clock._reactor)
  444. pool.running = True
  445. # We've just changed the Databases to run DB transactions on the same
  446. # thread, so we need to disable the dedicated thread behaviour.
  447. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
  448. def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
  449. clock = ThreadedMemoryReactorClock()
  450. hs_clock = Clock(clock)
  451. return clock, hs_clock
  452. @implementer(ITransport)
  453. @attr.s(cmp=False)
  454. class FakeTransport:
  455. """
  456. A twisted.internet.interfaces.ITransport implementation which sends all its data
  457. straight into an IProtocol object: it exists to connect two IProtocols together.
  458. To use it, instantiate it with the receiving IProtocol, and then pass it to the
  459. sending IProtocol's makeConnection method:
  460. server = HTTPChannel()
  461. client.makeConnection(FakeTransport(server, self.reactor))
  462. If you want bidirectional communication, you'll need two instances.
  463. """
  464. other = attr.ib()
  465. """The Protocol object which will receive any data written to this transport.
  466. :type: twisted.internet.interfaces.IProtocol
  467. """
  468. _reactor = attr.ib()
  469. """Test reactor
  470. :type: twisted.internet.interfaces.IReactorTime
  471. """
  472. _protocol = attr.ib(default=None)
  473. """The Protocol which is producing data for this transport. Optional, but if set
  474. will get called back for connectionLost() notifications etc.
  475. """
  476. _peer_address: Optional[IAddress] = attr.ib(default=None)
  477. """The value to be returned by getPeer"""
  478. _host_address: Optional[IAddress] = attr.ib(default=None)
  479. """The value to be returned by getHost"""
  480. disconnecting = False
  481. disconnected = False
  482. connected = True
  483. buffer = attr.ib(default=b"")
  484. producer = attr.ib(default=None)
  485. autoflush = attr.ib(default=True)
  486. def getPeer(self) -> Optional[IAddress]:
  487. return self._peer_address
  488. def getHost(self) -> Optional[IAddress]:
  489. return self._host_address
  490. def loseConnection(self, reason=None):
  491. if not self.disconnecting:
  492. logger.info("FakeTransport: loseConnection(%s)", reason)
  493. self.disconnecting = True
  494. if self._protocol:
  495. self._protocol.connectionLost(reason)
  496. # if we still have data to write, delay until that is done
  497. if self.buffer:
  498. logger.info(
  499. "FakeTransport: Delaying disconnect until buffer is flushed"
  500. )
  501. else:
  502. self.connected = False
  503. self.disconnected = True
  504. def abortConnection(self):
  505. logger.info("FakeTransport: abortConnection()")
  506. if not self.disconnecting:
  507. self.disconnecting = True
  508. if self._protocol:
  509. self._protocol.connectionLost(None)
  510. self.disconnected = True
  511. def pauseProducing(self):
  512. if not self.producer:
  513. return
  514. self.producer.pauseProducing()
  515. def resumeProducing(self):
  516. if not self.producer:
  517. return
  518. self.producer.resumeProducing()
  519. def unregisterProducer(self):
  520. if not self.producer:
  521. return
  522. self.producer = None
  523. def registerProducer(self, producer, streaming):
  524. self.producer = producer
  525. self.producerStreaming = streaming
  526. def _produce():
  527. if not self.producer:
  528. # we've been unregistered
  529. return
  530. # some implementations of IProducer (for example, FileSender)
  531. # don't return a deferred.
  532. d = maybeDeferred(self.producer.resumeProducing)
  533. d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
  534. if not streaming:
  535. self._reactor.callLater(0.0, _produce)
  536. def write(self, byt):
  537. if self.disconnecting:
  538. raise Exception("Writing to disconnecting FakeTransport")
  539. self.buffer = self.buffer + byt
  540. # always actually do the write asynchronously. Some protocols (notably the
  541. # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
  542. # still doing a write. Doing a callLater here breaks the cycle.
  543. if self.autoflush:
  544. self._reactor.callLater(0.0, self.flush)
  545. def writeSequence(self, seq):
  546. for x in seq:
  547. self.write(x)
  548. def flush(self, maxbytes=None):
  549. if not self.buffer:
  550. # nothing to do. Don't write empty buffers: it upsets the
  551. # TLSMemoryBIOProtocol
  552. return
  553. if self.disconnected:
  554. return
  555. if maxbytes is not None:
  556. to_write = self.buffer[:maxbytes]
  557. else:
  558. to_write = self.buffer
  559. logger.info("%s->%s: %s", self._protocol, self.other, to_write)
  560. try:
  561. self.other.dataReceived(to_write)
  562. except Exception as e:
  563. logger.exception("Exception writing to protocol: %s", e)
  564. return
  565. self.buffer = self.buffer[len(to_write) :]
  566. if self.buffer and self.autoflush:
  567. self._reactor.callLater(0.0, self.flush)
  568. if not self.buffer and self.disconnecting:
  569. logger.info("FakeTransport: Buffer now empty, completing disconnect")
  570. self.disconnected = True
  571. def connect_client(
  572. reactor: ThreadedMemoryReactorClock, client_id: int
  573. ) -> Tuple[IProtocol, AccumulatingProtocol]:
  574. """
  575. Connect a client to a fake TCP transport.
  576. Args:
  577. reactor
  578. factory: The connecting factory to build.
  579. """
  580. factory = reactor.tcpClients.pop(client_id)[2]
  581. client = factory.buildProtocol(None)
  582. server = AccumulatingProtocol()
  583. server.makeConnection(FakeTransport(client, reactor))
  584. client.makeConnection(FakeTransport(server, reactor))
  585. return client, server
  586. class TestHomeServer(HomeServer):
  587. DATASTORE_CLASS = DataStore
  588. def setup_test_homeserver(
  589. cleanup_func,
  590. name="test",
  591. config=None,
  592. reactor=None,
  593. homeserver_to_use: Type[HomeServer] = TestHomeServer,
  594. **kwargs,
  595. ):
  596. """
  597. Setup a homeserver suitable for running tests against. Keyword arguments
  598. are passed to the Homeserver constructor.
  599. If no datastore is supplied, one is created and given to the homeserver.
  600. Args:
  601. cleanup_func : The function used to register a cleanup routine for
  602. after the test.
  603. Calling this method directly is deprecated: you should instead derive from
  604. HomeserverTestCase.
  605. """
  606. if reactor is None:
  607. from twisted.internet import reactor
  608. if config is None:
  609. config = default_config(name, parse=True)
  610. config.caches.resize_all_caches()
  611. config.ldap_enabled = False
  612. if "clock" not in kwargs:
  613. kwargs["clock"] = MockClock()
  614. if USE_POSTGRES_FOR_TESTS:
  615. test_db = "synapse_test_%s" % uuid.uuid4().hex
  616. database_config = {
  617. "name": "psycopg2",
  618. "args": {
  619. "database": test_db,
  620. "host": POSTGRES_HOST,
  621. "password": POSTGRES_PASSWORD,
  622. "user": POSTGRES_USER,
  623. "port": POSTGRES_PORT,
  624. "cp_min": 1,
  625. "cp_max": 5,
  626. },
  627. }
  628. else:
  629. if SQLITE_PERSIST_DB:
  630. # The current working directory is in _trial_temp, so this gets created within that directory.
  631. test_db_location = os.path.abspath("test.db")
  632. logger.debug("Will persist db to %s", test_db_location)
  633. # Ensure each test gets a clean database.
  634. try:
  635. os.remove(test_db_location)
  636. except FileNotFoundError:
  637. pass
  638. else:
  639. logger.debug("Removed existing DB at %s", test_db_location)
  640. else:
  641. test_db_location = ":memory:"
  642. database_config = {
  643. "name": "sqlite3",
  644. "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
  645. }
  646. if "db_txn_limit" in kwargs:
  647. database_config["txn_limit"] = kwargs["db_txn_limit"]
  648. database = DatabaseConnectionConfig("master", database_config)
  649. config.database.databases = [database]
  650. db_engine = create_engine(database.config)
  651. # Create the database before we actually try and connect to it, based off
  652. # the template database we generate in setupdb()
  653. if isinstance(db_engine, PostgresEngine):
  654. db_conn = db_engine.module.connect(
  655. database=POSTGRES_BASE_DB,
  656. user=POSTGRES_USER,
  657. host=POSTGRES_HOST,
  658. port=POSTGRES_PORT,
  659. password=POSTGRES_PASSWORD,
  660. )
  661. db_conn.autocommit = True
  662. cur = db_conn.cursor()
  663. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  664. cur.execute(
  665. "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
  666. )
  667. cur.close()
  668. db_conn.close()
  669. hs = homeserver_to_use(
  670. name,
  671. config=config,
  672. version_string="Synapse/tests",
  673. reactor=reactor,
  674. )
  675. # Install @cache_in_self attributes
  676. for key, val in kwargs.items():
  677. setattr(hs, "_" + key, val)
  678. # Mock TLS
  679. hs.tls_server_context_factory = Mock()
  680. hs.setup()
  681. if homeserver_to_use == TestHomeServer:
  682. hs.setup_background_tasks()
  683. if isinstance(db_engine, PostgresEngine):
  684. database = hs.get_datastores().databases[0]
  685. # We need to do cleanup on PostgreSQL
  686. def cleanup():
  687. import psycopg2
  688. # Close all the db pools
  689. database._db_pool.close()
  690. dropped = False
  691. # Drop the test database
  692. db_conn = db_engine.module.connect(
  693. database=POSTGRES_BASE_DB,
  694. user=POSTGRES_USER,
  695. host=POSTGRES_HOST,
  696. port=POSTGRES_PORT,
  697. password=POSTGRES_PASSWORD,
  698. )
  699. db_conn.autocommit = True
  700. cur = db_conn.cursor()
  701. # Try a few times to drop the DB. Some things may hold on to the
  702. # database for a few more seconds due to flakiness, preventing
  703. # us from dropping it when the test is over. If we can't drop
  704. # it, warn and move on.
  705. for _ in range(5):
  706. try:
  707. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  708. db_conn.commit()
  709. dropped = True
  710. except psycopg2.OperationalError as e:
  711. warnings.warn(
  712. "Couldn't drop old db: " + str(e), category=UserWarning
  713. )
  714. time.sleep(0.5)
  715. cur.close()
  716. db_conn.close()
  717. if not dropped:
  718. warnings.warn("Failed to drop old DB.", category=UserWarning)
  719. if not LEAVE_DB:
  720. # Register the cleanup hook
  721. cleanup_func(cleanup)
  722. # bcrypt is far too slow to be doing in unit tests
  723. # Need to let the HS build an auth handler and then mess with it
  724. # because AuthHandler's constructor requires the HS, so we can't make one
  725. # beforehand and pass it in to the HS's constructor (chicken / egg)
  726. async def hash(p):
  727. return hashlib.md5(p.encode("utf8")).hexdigest()
  728. hs.get_auth_handler().hash = hash
  729. async def validate_hash(p, h):
  730. return hashlib.md5(p.encode("utf8")).hexdigest() == h
  731. hs.get_auth_handler().validate_hash = validate_hash
  732. # Make the threadpool and database transactions synchronous for testing.
  733. _make_test_homeserver_synchronous(hs)
  734. return hs