1
0

server.py 32 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027
  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import hashlib
  15. import json
  16. import logging
  17. import os
  18. import os.path
  19. import time
  20. import uuid
  21. import warnings
  22. from collections import deque
  23. from io import SEEK_END, BytesIO
  24. from typing import (
  25. Any,
  26. Awaitable,
  27. Callable,
  28. Dict,
  29. Iterable,
  30. List,
  31. MutableMapping,
  32. Optional,
  33. Sequence,
  34. Tuple,
  35. Type,
  36. TypeVar,
  37. Union,
  38. cast,
  39. )
  40. from unittest.mock import Mock
  41. import attr
  42. from typing_extensions import Deque, ParamSpec
  43. from zope.interface import implementer
  44. from twisted.internet import address, threads, udp
  45. from twisted.internet._resolver import SimpleResolverComplexifier
  46. from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
  47. from twisted.internet.error import DNSLookupError
  48. from twisted.internet.interfaces import (
  49. IAddress,
  50. IConnector,
  51. IConsumer,
  52. IHostnameResolver,
  53. IProducer,
  54. IProtocol,
  55. IPullProducer,
  56. IPushProducer,
  57. IReactorPluggableNameResolver,
  58. IReactorTime,
  59. IResolverSimple,
  60. ITransport,
  61. )
  62. from twisted.internet.protocol import ClientFactory, DatagramProtocol
  63. from twisted.python import threadpool
  64. from twisted.python.failure import Failure
  65. from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
  66. from twisted.web.http_headers import Headers
  67. from twisted.web.resource import IResource
  68. from twisted.web.server import Request, Site
  69. from synapse.config.database import DatabaseConnectionConfig
  70. from synapse.config.homeserver import HomeServerConfig
  71. from synapse.events.presence_router import load_legacy_presence_router
  72. from synapse.events.spamcheck import load_legacy_spam_checkers
  73. from synapse.events.third_party_rules import load_legacy_third_party_event_rules
  74. from synapse.handlers.auth import load_legacy_password_auth_providers
  75. from synapse.http.site import SynapseRequest
  76. from synapse.logging.context import ContextResourceUsage
  77. from synapse.server import HomeServer
  78. from synapse.storage import DataStore
  79. from synapse.storage.engines import PostgresEngine, create_engine
  80. from synapse.types import ISynapseReactor, JsonDict
  81. from synapse.util import Clock
  82. from tests.utils import (
  83. LEAVE_DB,
  84. POSTGRES_BASE_DB,
  85. POSTGRES_HOST,
  86. POSTGRES_PASSWORD,
  87. POSTGRES_PORT,
  88. POSTGRES_USER,
  89. SQLITE_PERSIST_DB,
  90. USE_POSTGRES_FOR_TESTS,
  91. MockClock,
  92. default_config,
  93. )
  94. logger = logging.getLogger(__name__)
  95. R = TypeVar("R")
  96. P = ParamSpec("P")
  97. # the type of thing that can be passed into `make_request` in the headers list
  98. CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
  99. class TimedOutException(Exception):
  100. """
  101. A web query timed out.
  102. """
  103. @implementer(ITransport, IPushProducer, IConsumer)
  104. @attr.s(auto_attribs=True)
  105. class FakeChannel:
  106. """
  107. A fake Twisted Web Channel (the part that interfaces with the
  108. wire).
  109. See twisted.web.http.HTTPChannel.
  110. """
  111. site: Union[Site, "FakeSite"]
  112. _reactor: MemoryReactorClock
  113. result: dict = attr.Factory(dict)
  114. _ip: str = "127.0.0.1"
  115. _producer: Optional[Union[IPullProducer, IPushProducer]] = None
  116. resource_usage: Optional[ContextResourceUsage] = None
  117. _request: Optional[Request] = None
  118. @property
  119. def request(self) -> Request:
  120. assert self._request is not None
  121. return self._request
  122. @request.setter
  123. def request(self, request: Request) -> None:
  124. assert self._request is None
  125. self._request = request
  126. @property
  127. def json_body(self) -> JsonDict:
  128. body = json.loads(self.text_body)
  129. assert isinstance(body, dict)
  130. return body
  131. @property
  132. def json_list(self) -> List[JsonDict]:
  133. body = json.loads(self.text_body)
  134. assert isinstance(body, list)
  135. return body
  136. @property
  137. def text_body(self) -> str:
  138. """The body of the result, utf-8-decoded.
  139. Raises an exception if the request has not yet completed.
  140. """
  141. if not self.is_finished():
  142. raise Exception("Request not yet completed")
  143. return self.result["body"].decode("utf8")
  144. def is_finished(self) -> bool:
  145. """check if the response has been completely received"""
  146. return self.result.get("done", False)
  147. @property
  148. def code(self) -> int:
  149. if not self.result:
  150. raise Exception("No result yet.")
  151. return int(self.result["code"])
  152. @property
  153. def headers(self) -> Headers:
  154. if not self.result:
  155. raise Exception("No result yet.")
  156. h = Headers()
  157. for i in self.result["headers"]:
  158. h.addRawHeader(*i)
  159. return h
  160. def writeHeaders(
  161. self, version: bytes, code: bytes, reason: bytes, headers: Headers
  162. ) -> None:
  163. self.result["version"] = version
  164. self.result["code"] = code
  165. self.result["reason"] = reason
  166. self.result["headers"] = headers
  167. def write(self, data: bytes) -> None:
  168. assert isinstance(data, bytes), "Should be bytes! " + repr(data)
  169. if "body" not in self.result:
  170. self.result["body"] = b""
  171. self.result["body"] += data
  172. def writeSequence(self, data: Iterable[bytes]) -> None:
  173. for x in data:
  174. self.write(x)
  175. def loseConnection(self) -> None:
  176. self.unregisterProducer()
  177. self.transport.loseConnection()
  178. # Type ignore: mypy doesn't like the fact that producer isn't an IProducer.
  179. def registerProducer(self, producer: IProducer, streaming: bool) -> None:
  180. # TODO This should ensure that the IProducer is an IPushProducer or
  181. # IPullProducer, unfortunately twisted.protocols.basic.FileSender does
  182. # implement those, but doesn't declare it.
  183. self._producer = cast(Union[IPushProducer, IPullProducer], producer)
  184. self.producerStreaming = streaming
  185. def _produce() -> None:
  186. if self._producer:
  187. self._producer.resumeProducing()
  188. self._reactor.callLater(0.1, _produce)
  189. if not streaming:
  190. self._reactor.callLater(0.0, _produce)
  191. def unregisterProducer(self) -> None:
  192. if self._producer is None:
  193. return
  194. self._producer = None
  195. def stopProducing(self) -> None:
  196. if self._producer is not None:
  197. self._producer.stopProducing()
  198. def pauseProducing(self) -> None:
  199. raise NotImplementedError()
  200. def resumeProducing(self) -> None:
  201. raise NotImplementedError()
  202. def requestDone(self, _self: Request) -> None:
  203. self.result["done"] = True
  204. if isinstance(_self, SynapseRequest):
  205. assert _self.logcontext is not None
  206. self.resource_usage = _self.logcontext.get_resource_usage()
  207. def getPeer(self) -> IAddress:
  208. # We give an address so that getClientAddress/getClientIP returns a non null entry,
  209. # causing us to record the MAU
  210. return address.IPv4Address("TCP", self._ip, 3423)
  211. def getHost(self) -> IAddress:
  212. # this is called by Request.__init__ to configure Request.host.
  213. return address.IPv4Address("TCP", "127.0.0.1", 8888)
  214. def isSecure(self) -> bool:
  215. return False
  216. @property
  217. def transport(self) -> "FakeChannel":
  218. return self
  219. def await_result(self, timeout_ms: int = 1000) -> None:
  220. """
  221. Wait until the request is finished.
  222. """
  223. end_time = self._reactor.seconds() + timeout_ms / 1000.0
  224. self._reactor.run()
  225. while not self.is_finished():
  226. # If there's a producer, tell it to resume producing so we get content
  227. if self._producer:
  228. self._producer.resumeProducing()
  229. if self._reactor.seconds() > end_time:
  230. raise TimedOutException("Timed out waiting for request to finish.")
  231. self._reactor.advance(0.1)
  232. def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
  233. """Process the contents of any Set-Cookie headers in the response
  234. Any cookines found are added to the given dict
  235. """
  236. headers = self.headers.getRawHeaders("Set-Cookie")
  237. if not headers:
  238. return
  239. for h in headers:
  240. parts = h.split(";")
  241. k, v = parts[0].split("=", maxsplit=1)
  242. cookies[k] = v
  243. class FakeSite:
  244. """
  245. A fake Twisted Web Site, with mocks of the extra things that
  246. Synapse adds.
  247. """
  248. server_version_string = b"1"
  249. site_tag = "test"
  250. access_logger = logging.getLogger("synapse.access.http.fake")
  251. def __init__(
  252. self,
  253. resource: IResource,
  254. reactor: IReactorTime,
  255. experimental_cors_msc3886: bool = False,
  256. ):
  257. """
  258. Args:
  259. resource: the resource to be used for rendering all requests
  260. """
  261. self._resource = resource
  262. self.reactor = reactor
  263. self.experimental_cors_msc3886 = experimental_cors_msc3886
  264. def getResourceFor(self, request: Request) -> IResource:
  265. return self._resource
  266. def make_request(
  267. reactor: MemoryReactorClock,
  268. site: Union[Site, FakeSite],
  269. method: Union[bytes, str],
  270. path: Union[bytes, str],
  271. content: Union[bytes, str, JsonDict] = b"",
  272. access_token: Optional[str] = None,
  273. request: Type[Request] = SynapseRequest,
  274. shorthand: bool = True,
  275. federation_auth_origin: Optional[bytes] = None,
  276. content_is_form: bool = False,
  277. await_result: bool = True,
  278. custom_headers: Optional[Iterable[CustomHeaderType]] = None,
  279. client_ip: str = "127.0.0.1",
  280. ) -> FakeChannel:
  281. """
  282. Make a web request using the given method, path and content, and render it
  283. Returns the fake Channel object which records the response to the request.
  284. Args:
  285. reactor:
  286. site: The twisted Site to use to render the request
  287. method: The HTTP request method ("verb").
  288. path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
  289. content: The body of the request. JSON-encoded, if a str of bytes.
  290. access_token: The access token to add as authorization for the request.
  291. request: The request class to create.
  292. shorthand: Whether to try and be helpful and prefix the given URL
  293. with the usual REST API path, if it doesn't contain it.
  294. federation_auth_origin: if set to not-None, we will add a fake
  295. Authorization header pretenting to be the given server name.
  296. content_is_form: Whether the content is URL encoded form data. Adds the
  297. 'Content-Type': 'application/x-www-form-urlencoded' header.
  298. await_result: whether to wait for the request to complete rendering. If true,
  299. will pump the reactor until the the renderer tells the channel the request
  300. is finished.
  301. custom_headers: (name, value) pairs to add as request headers
  302. client_ip: The IP to use as the requesting IP. Useful for testing
  303. ratelimiting.
  304. Returns:
  305. channel
  306. """
  307. if not isinstance(method, bytes):
  308. method = method.encode("ascii")
  309. if not isinstance(path, bytes):
  310. path = path.encode("ascii")
  311. # Decorate it to be the full path, if we're using shorthand
  312. if (
  313. shorthand
  314. and not path.startswith(b"/_matrix")
  315. and not path.startswith(b"/_synapse")
  316. ):
  317. if path.startswith(b"/"):
  318. path = path[1:]
  319. path = b"/_matrix/client/r0/" + path
  320. if not path.startswith(b"/"):
  321. path = b"/" + path
  322. if isinstance(content, dict):
  323. content = json.dumps(content).encode("utf8")
  324. if isinstance(content, str):
  325. content = content.encode("utf8")
  326. channel = FakeChannel(site, reactor, ip=client_ip)
  327. req = request(channel, site)
  328. channel.request = req
  329. req.content = BytesIO(content)
  330. # Twisted expects to be at the end of the content when parsing the request.
  331. req.content.seek(0, SEEK_END)
  332. # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded
  333. # bodies if the Content-Length header is missing
  334. req.requestHeaders.addRawHeader(
  335. b"Content-Length", str(len(content)).encode("ascii")
  336. )
  337. if access_token:
  338. req.requestHeaders.addRawHeader(
  339. b"Authorization", b"Bearer " + access_token.encode("ascii")
  340. )
  341. if federation_auth_origin is not None:
  342. req.requestHeaders.addRawHeader(
  343. b"Authorization",
  344. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
  345. )
  346. if content:
  347. if content_is_form:
  348. req.requestHeaders.addRawHeader(
  349. b"Content-Type", b"application/x-www-form-urlencoded"
  350. )
  351. else:
  352. # Assume the body is JSON
  353. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  354. if custom_headers:
  355. for k, v in custom_headers:
  356. req.requestHeaders.addRawHeader(k, v)
  357. req.parseCookies()
  358. req.requestReceived(method, path, b"1.1")
  359. if await_result:
  360. channel.await_result()
  361. return channel
  362. # ISynapseReactor implies IReactorPluggableNameResolver, but explicitly
  363. # marking this as an implementer of the latter seems to keep mypy-zope happier.
  364. @implementer(IReactorPluggableNameResolver, ISynapseReactor)
  365. class ThreadedMemoryReactorClock(MemoryReactorClock):
  366. """
  367. A MemoryReactorClock that supports callFromThread.
  368. """
  369. def __init__(self) -> None:
  370. self.threadpool = ThreadPool(self)
  371. self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
  372. self._udp: List[udp.Port] = []
  373. self.lookups: Dict[str, str] = {}
  374. self._thread_callbacks: Deque[Callable[..., R]] = deque()
  375. lookups = self.lookups
  376. @implementer(IResolverSimple)
  377. class FakeResolver:
  378. def getHostByName(
  379. self, name: str, timeout: Optional[Sequence[int]] = None
  380. ) -> "Deferred[str]":
  381. if name not in lookups:
  382. return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
  383. return succeed(lookups[name])
  384. self.nameResolver = SimpleResolverComplexifier(FakeResolver())
  385. super().__init__()
  386. def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
  387. raise NotImplementedError()
  388. def listenUDP(
  389. self,
  390. port: int,
  391. protocol: DatagramProtocol,
  392. interface: str = "",
  393. maxPacketSize: int = 8196,
  394. ) -> udp.Port:
  395. p = udp.Port(port, protocol, interface, maxPacketSize, self)
  396. p.startListening()
  397. self._udp.append(p)
  398. return p
  399. def callFromThread(
  400. self, callable: Callable[..., Any], *args: object, **kwargs: object
  401. ) -> None:
  402. """
  403. Make the callback fire in the next reactor iteration.
  404. """
  405. cb = lambda: callable(*args, **kwargs)
  406. # it's not safe to call callLater() here, so we append the callback to a
  407. # separate queue.
  408. self._thread_callbacks.append(cb)
  409. def callInThread(
  410. self, callable: Callable[..., Any], *args: object, **kwargs: object
  411. ) -> None:
  412. raise NotImplementedError()
  413. def suggestThreadPoolSize(self, size: int) -> None:
  414. raise NotImplementedError()
  415. def getThreadPool(self) -> "threadpool.ThreadPool":
  416. # Cast to match super-class.
  417. return cast(threadpool.ThreadPool, self.threadpool)
  418. def add_tcp_client_callback(
  419. self, host: str, port: int, callback: Callable[[], None]
  420. ) -> None:
  421. """Add a callback that will be invoked when we receive a connection
  422. attempt to the given IP/port using `connectTCP`.
  423. Note that the callback gets run before we return the connection to the
  424. client, which means callbacks cannot block while waiting for writes.
  425. """
  426. self._tcp_callbacks[(host, port)] = callback
  427. def connectTCP(
  428. self,
  429. host: str,
  430. port: int,
  431. factory: ClientFactory,
  432. timeout: float = 30,
  433. bindAddress: Optional[Tuple[str, int]] = None,
  434. ) -> IConnector:
  435. """Fake L{IReactorTCP.connectTCP}."""
  436. conn = super().connectTCP(
  437. host, port, factory, timeout=timeout, bindAddress=None
  438. )
  439. callback = self._tcp_callbacks.get((host, port))
  440. if callback:
  441. callback()
  442. return conn
  443. def advance(self, amount: float) -> None:
  444. # first advance our reactor's time, and run any "callLater" callbacks that
  445. # makes ready
  446. super().advance(amount)
  447. # now run any "callFromThread" callbacks
  448. while True:
  449. try:
  450. callback = self._thread_callbacks.popleft()
  451. except IndexError:
  452. break
  453. callback()
  454. # check for more "callLater" callbacks added by the thread callback
  455. # This isn't required in a regular reactor, but it ends up meaning that
  456. # our database queries can complete in a single call to `advance` [1] which
  457. # simplifies tests.
  458. #
  459. # [1]: we replace the threadpool backing the db connection pool with a
  460. # mock ThreadPool which doesn't really use threads; but we still use
  461. # reactor.callFromThread to feed results back from the db functions to the
  462. # main thread.
  463. super().advance(0)
  464. class ThreadPool:
  465. """
  466. Threadless thread pool.
  467. See twisted.python.threadpool.ThreadPool
  468. """
  469. def __init__(self, reactor: IReactorTime):
  470. self._reactor = reactor
  471. def start(self) -> None:
  472. pass
  473. def stop(self) -> None:
  474. pass
  475. def callInThreadWithCallback(
  476. self,
  477. onResult: Callable[[bool, Union[Failure, R]], None],
  478. function: Callable[P, R],
  479. *args: P.args,
  480. **kwargs: P.kwargs,
  481. ) -> "Deferred[None]":
  482. def _(res: Any) -> None:
  483. if isinstance(res, Failure):
  484. onResult(False, res)
  485. else:
  486. onResult(True, res)
  487. d: "Deferred[None]" = Deferred()
  488. d.addCallback(lambda x: function(*args, **kwargs))
  489. d.addBoth(_)
  490. self._reactor.callLater(0, d.callback, True)
  491. return d
  492. def _make_test_homeserver_synchronous(server: HomeServer) -> None:
  493. """
  494. Make the given test homeserver's database interactions synchronous.
  495. """
  496. clock = server.get_clock()
  497. for database in server.get_datastores().databases:
  498. pool = database._db_pool
  499. def runWithConnection(
  500. func: Callable[..., R], *args: Any, **kwargs: Any
  501. ) -> Awaitable[R]:
  502. return threads.deferToThreadPool(
  503. pool._reactor,
  504. pool.threadpool,
  505. pool._runWithConnection,
  506. func,
  507. *args,
  508. **kwargs,
  509. )
  510. def runInteraction(
  511. desc: str, func: Callable[..., R], *args: Any, **kwargs: Any
  512. ) -> Awaitable[R]:
  513. return threads.deferToThreadPool(
  514. pool._reactor,
  515. pool.threadpool,
  516. pool._runInteraction,
  517. desc,
  518. func,
  519. *args,
  520. **kwargs,
  521. )
  522. pool.runWithConnection = runWithConnection # type: ignore[assignment]
  523. pool.runInteraction = runInteraction # type: ignore[assignment]
  524. # Replace the thread pool with a threadless 'thread' pool
  525. pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment]
  526. pool.running = True
  527. # We've just changed the Databases to run DB transactions on the same
  528. # thread, so we need to disable the dedicated thread behaviour.
  529. server.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
  530. def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
  531. clock = ThreadedMemoryReactorClock()
  532. hs_clock = Clock(clock)
  533. return clock, hs_clock
  534. @implementer(ITransport)
  535. @attr.s(cmp=False, auto_attribs=True)
  536. class FakeTransport:
  537. """
  538. A twisted.internet.interfaces.ITransport implementation which sends all its data
  539. straight into an IProtocol object: it exists to connect two IProtocols together.
  540. To use it, instantiate it with the receiving IProtocol, and then pass it to the
  541. sending IProtocol's makeConnection method:
  542. server = HTTPChannel()
  543. client.makeConnection(FakeTransport(server, self.reactor))
  544. If you want bidirectional communication, you'll need two instances.
  545. """
  546. other: IProtocol
  547. """The Protocol object which will receive any data written to this transport.
  548. """
  549. _reactor: IReactorTime
  550. """Test reactor
  551. """
  552. _protocol: Optional[IProtocol] = None
  553. """The Protocol which is producing data for this transport. Optional, but if set
  554. will get called back for connectionLost() notifications etc.
  555. """
  556. _peer_address: IAddress = attr.Factory(
  557. lambda: address.IPv4Address("TCP", "127.0.0.1", 5678)
  558. )
  559. """The value to be returned by getPeer"""
  560. _host_address: IAddress = attr.Factory(
  561. lambda: address.IPv4Address("TCP", "127.0.0.1", 1234)
  562. )
  563. """The value to be returned by getHost"""
  564. disconnecting = False
  565. disconnected = False
  566. connected = True
  567. buffer: bytes = b""
  568. producer: Optional[IPushProducer] = None
  569. autoflush: bool = True
  570. def getPeer(self) -> IAddress:
  571. return self._peer_address
  572. def getHost(self) -> IAddress:
  573. return self._host_address
  574. def loseConnection(self) -> None:
  575. if not self.disconnecting:
  576. logger.info("FakeTransport: loseConnection()")
  577. self.disconnecting = True
  578. if self._protocol:
  579. self._protocol.connectionLost(
  580. Failure(RuntimeError("FakeTransport.loseConnection()"))
  581. )
  582. # if we still have data to write, delay until that is done
  583. if self.buffer:
  584. logger.info(
  585. "FakeTransport: Delaying disconnect until buffer is flushed"
  586. )
  587. else:
  588. self.connected = False
  589. self.disconnected = True
  590. def abortConnection(self) -> None:
  591. logger.info("FakeTransport: abortConnection()")
  592. if not self.disconnecting:
  593. self.disconnecting = True
  594. if self._protocol:
  595. self._protocol.connectionLost(None) # type: ignore[arg-type]
  596. self.disconnected = True
  597. def pauseProducing(self) -> None:
  598. if not self.producer:
  599. return
  600. self.producer.pauseProducing()
  601. def resumeProducing(self) -> None:
  602. if not self.producer:
  603. return
  604. self.producer.resumeProducing()
  605. def unregisterProducer(self) -> None:
  606. if not self.producer:
  607. return
  608. self.producer = None
  609. def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
  610. self.producer = producer
  611. self.producerStreaming = streaming
  612. def _produce() -> None:
  613. if not self.producer:
  614. # we've been unregistered
  615. return
  616. # some implementations of IProducer (for example, FileSender)
  617. # don't return a deferred.
  618. d = maybeDeferred(self.producer.resumeProducing)
  619. d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
  620. if not streaming:
  621. self._reactor.callLater(0.0, _produce)
  622. def write(self, byt: bytes) -> None:
  623. if self.disconnecting:
  624. raise Exception("Writing to disconnecting FakeTransport")
  625. self.buffer = self.buffer + byt
  626. # always actually do the write asynchronously. Some protocols (notably the
  627. # TLSMemoryBIOProtocol) get very confused if a read comes back while they are
  628. # still doing a write. Doing a callLater here breaks the cycle.
  629. if self.autoflush:
  630. self._reactor.callLater(0.0, self.flush)
  631. def writeSequence(self, seq: Iterable[bytes]) -> None:
  632. for x in seq:
  633. self.write(x)
  634. def flush(self, maxbytes: Optional[int] = None) -> None:
  635. if not self.buffer:
  636. # nothing to do. Don't write empty buffers: it upsets the
  637. # TLSMemoryBIOProtocol
  638. return
  639. if self.disconnected:
  640. return
  641. if maxbytes is not None:
  642. to_write = self.buffer[:maxbytes]
  643. else:
  644. to_write = self.buffer
  645. logger.info("%s->%s: %s", self._protocol, self.other, to_write)
  646. try:
  647. self.other.dataReceived(to_write)
  648. except Exception as e:
  649. logger.exception("Exception writing to protocol: %s", e)
  650. return
  651. self.buffer = self.buffer[len(to_write) :]
  652. if self.buffer and self.autoflush:
  653. self._reactor.callLater(0.0, self.flush)
  654. if not self.buffer and self.disconnecting:
  655. logger.info("FakeTransport: Buffer now empty, completing disconnect")
  656. self.disconnected = True
  657. def connect_client(
  658. reactor: ThreadedMemoryReactorClock, client_id: int
  659. ) -> Tuple[IProtocol, AccumulatingProtocol]:
  660. """
  661. Connect a client to a fake TCP transport.
  662. Args:
  663. reactor
  664. factory: The connecting factory to build.
  665. """
  666. factory = reactor.tcpClients.pop(client_id)[2]
  667. client = factory.buildProtocol(None)
  668. server = AccumulatingProtocol()
  669. server.makeConnection(FakeTransport(client, reactor))
  670. client.makeConnection(FakeTransport(server, reactor))
  671. return client, server
  672. class TestHomeServer(HomeServer):
  673. DATASTORE_CLASS = DataStore # type: ignore[assignment]
  674. def setup_test_homeserver(
  675. cleanup_func: Callable[[Callable[[], None]], None],
  676. name: str = "test",
  677. config: Optional[HomeServerConfig] = None,
  678. reactor: Optional[ISynapseReactor] = None,
  679. homeserver_to_use: Type[HomeServer] = TestHomeServer,
  680. **kwargs: Any,
  681. ) -> HomeServer:
  682. """
  683. Setup a homeserver suitable for running tests against. Keyword arguments
  684. are passed to the Homeserver constructor.
  685. If no datastore is supplied, one is created and given to the homeserver.
  686. Args:
  687. cleanup_func : The function used to register a cleanup routine for
  688. after the test.
  689. Calling this method directly is deprecated: you should instead derive from
  690. HomeserverTestCase.
  691. """
  692. if reactor is None:
  693. from twisted.internet import reactor as _reactor
  694. reactor = cast(ISynapseReactor, _reactor)
  695. if config is None:
  696. config = default_config(name, parse=True)
  697. config.caches.resize_all_caches()
  698. if "clock" not in kwargs:
  699. kwargs["clock"] = MockClock()
  700. if USE_POSTGRES_FOR_TESTS:
  701. test_db = "synapse_test_%s" % uuid.uuid4().hex
  702. database_config = {
  703. "name": "psycopg2",
  704. "args": {
  705. "database": test_db,
  706. "host": POSTGRES_HOST,
  707. "password": POSTGRES_PASSWORD,
  708. "user": POSTGRES_USER,
  709. "port": POSTGRES_PORT,
  710. "cp_min": 1,
  711. "cp_max": 5,
  712. },
  713. }
  714. else:
  715. if SQLITE_PERSIST_DB:
  716. # The current working directory is in _trial_temp, so this gets created within that directory.
  717. test_db_location = os.path.abspath("test.db")
  718. logger.debug("Will persist db to %s", test_db_location)
  719. # Ensure each test gets a clean database.
  720. try:
  721. os.remove(test_db_location)
  722. except FileNotFoundError:
  723. pass
  724. else:
  725. logger.debug("Removed existing DB at %s", test_db_location)
  726. else:
  727. test_db_location = ":memory:"
  728. database_config = {
  729. "name": "sqlite3",
  730. "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
  731. }
  732. if "db_txn_limit" in kwargs:
  733. database_config["txn_limit"] = kwargs["db_txn_limit"]
  734. database = DatabaseConnectionConfig("master", database_config)
  735. config.database.databases = [database]
  736. db_engine = create_engine(database.config)
  737. # Create the database before we actually try and connect to it, based off
  738. # the template database we generate in setupdb()
  739. if isinstance(db_engine, PostgresEngine):
  740. import psycopg2.extensions
  741. db_conn = db_engine.module.connect(
  742. database=POSTGRES_BASE_DB,
  743. user=POSTGRES_USER,
  744. host=POSTGRES_HOST,
  745. port=POSTGRES_PORT,
  746. password=POSTGRES_PASSWORD,
  747. )
  748. assert isinstance(db_conn, psycopg2.extensions.connection)
  749. db_conn.autocommit = True
  750. cur = db_conn.cursor()
  751. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  752. cur.execute(
  753. "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
  754. )
  755. cur.close()
  756. db_conn.close()
  757. hs = homeserver_to_use(
  758. name,
  759. config=config,
  760. version_string="Synapse/tests",
  761. reactor=reactor,
  762. )
  763. # Install @cache_in_self attributes
  764. for key, val in kwargs.items():
  765. setattr(hs, "_" + key, val)
  766. # Mock TLS
  767. hs.tls_server_context_factory = Mock()
  768. hs.setup()
  769. if homeserver_to_use == TestHomeServer:
  770. hs.setup_background_tasks()
  771. if isinstance(db_engine, PostgresEngine):
  772. database_pool = hs.get_datastores().databases[0]
  773. # We need to do cleanup on PostgreSQL
  774. def cleanup() -> None:
  775. import psycopg2
  776. import psycopg2.extensions
  777. # Close all the db pools
  778. database_pool._db_pool.close()
  779. dropped = False
  780. # Drop the test database
  781. db_conn = db_engine.module.connect(
  782. database=POSTGRES_BASE_DB,
  783. user=POSTGRES_USER,
  784. host=POSTGRES_HOST,
  785. port=POSTGRES_PORT,
  786. password=POSTGRES_PASSWORD,
  787. )
  788. assert isinstance(db_conn, psycopg2.extensions.connection)
  789. db_conn.autocommit = True
  790. cur = db_conn.cursor()
  791. # Try a few times to drop the DB. Some things may hold on to the
  792. # database for a few more seconds due to flakiness, preventing
  793. # us from dropping it when the test is over. If we can't drop
  794. # it, warn and move on.
  795. for _ in range(5):
  796. try:
  797. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  798. db_conn.commit()
  799. dropped = True
  800. except psycopg2.OperationalError as e:
  801. warnings.warn(
  802. "Couldn't drop old db: " + str(e), category=UserWarning
  803. )
  804. time.sleep(0.5)
  805. cur.close()
  806. db_conn.close()
  807. if not dropped:
  808. warnings.warn("Failed to drop old DB.", category=UserWarning)
  809. if not LEAVE_DB:
  810. # Register the cleanup hook
  811. cleanup_func(cleanup)
  812. # bcrypt is far too slow to be doing in unit tests
  813. # Need to let the HS build an auth handler and then mess with it
  814. # because AuthHandler's constructor requires the HS, so we can't make one
  815. # beforehand and pass it in to the HS's constructor (chicken / egg)
  816. async def hash(p: str) -> str:
  817. return hashlib.md5(p.encode("utf8")).hexdigest()
  818. hs.get_auth_handler().hash = hash # type: ignore[assignment]
  819. async def validate_hash(p: str, h: str) -> bool:
  820. return hashlib.md5(p.encode("utf8")).hexdigest() == h
  821. hs.get_auth_handler().validate_hash = validate_hash # type: ignore[assignment]
  822. # Make the threadpool and database transactions synchronous for testing.
  823. _make_test_homeserver_synchronous(hs)
  824. # Load any configured modules into the homeserver
  825. module_api = hs.get_module_api()
  826. for module, module_config in hs.config.modules.loaded_modules:
  827. module(config=module_config, api=module_api)
  828. load_legacy_spam_checkers(hs)
  829. load_legacy_third_party_event_rules(hs)
  830. load_legacy_presence_router(hs)
  831. load_legacy_password_auth_providers(hs)
  832. return hs