_base.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import Any, Callable, List, Optional, Tuple
  17. import attr
  18. from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
  19. from twisted.internet.task import LoopingCall
  20. from twisted.web.http import HTTPChannel
  21. from synapse.app.generic_worker import (
  22. GenericWorkerReplicationHandler,
  23. GenericWorkerServer,
  24. )
  25. from synapse.http.server import JsonResource
  26. from synapse.http.site import SynapseRequest
  27. from synapse.replication.http import ReplicationRestResource, streams
  28. from synapse.replication.tcp.handler import ReplicationCommandHandler
  29. from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
  30. from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
  31. from synapse.server import HomeServer
  32. from synapse.util import Clock
  33. from tests import unittest
  34. from tests.server import FakeTransport, render
  35. logger = logging.getLogger(__name__)
  36. class BaseStreamTestCase(unittest.HomeserverTestCase):
  37. """Base class for tests of the replication streams"""
  38. servlets = [
  39. streams.register_servlets,
  40. ]
  41. def prepare(self, reactor, clock, hs):
  42. # build a replication server
  43. server_factory = ReplicationStreamProtocolFactory(hs)
  44. self.streamer = hs.get_replication_streamer()
  45. self.server = server_factory.buildProtocol(None)
  46. # Make a new HomeServer object for the worker
  47. self.reactor.lookups["testserv"] = "1.2.3.4"
  48. self.worker_hs = self.setup_test_homeserver(
  49. http_client=None,
  50. homeserverToUse=GenericWorkerServer,
  51. config=self._get_worker_hs_config(),
  52. reactor=self.reactor,
  53. )
  54. # Since we use sqlite in memory databases we need to make sure the
  55. # databases objects are the same.
  56. self.worker_hs.get_datastore().db = hs.get_datastore().db
  57. self.test_handler = self._build_replication_data_handler()
  58. self.worker_hs.replication_data_handler = self.test_handler
  59. repl_handler = ReplicationCommandHandler(self.worker_hs)
  60. self.client = ClientReplicationStreamProtocol(
  61. self.worker_hs, "client", "test", clock, repl_handler,
  62. )
  63. self._client_transport = None
  64. self._server_transport = None
  65. def _get_worker_hs_config(self) -> dict:
  66. config = self.default_config()
  67. config["worker_app"] = "synapse.app.generic_worker"
  68. config["worker_replication_host"] = "testserv"
  69. config["worker_replication_http_port"] = "8765"
  70. return config
  71. def _build_replication_data_handler(self):
  72. return TestReplicationDataHandler(self.worker_hs)
  73. def reconnect(self):
  74. if self._client_transport:
  75. self.client.close()
  76. if self._server_transport:
  77. self.server.close()
  78. self._client_transport = FakeTransport(self.server, self.reactor)
  79. self.client.makeConnection(self._client_transport)
  80. self._server_transport = FakeTransport(self.client, self.reactor)
  81. self.server.makeConnection(self._server_transport)
  82. def disconnect(self):
  83. if self._client_transport:
  84. self._client_transport = None
  85. self.client.close()
  86. if self._server_transport:
  87. self._server_transport = None
  88. self.server.close()
  89. def replicate(self):
  90. """Tell the master side of replication that something has happened, and then
  91. wait for the replication to occur.
  92. """
  93. self.streamer.on_notifier_poke()
  94. self.pump(0.1)
  95. def handle_http_replication_attempt(self) -> SynapseRequest:
  96. """Asserts that a connection attempt was made to the master HS on the
  97. HTTP replication port, then proxies it to the master HS object to be
  98. handled.
  99. Returns:
  100. The request object received by master HS.
  101. """
  102. # We should have an outbound connection attempt.
  103. clients = self.reactor.tcpClients
  104. self.assertEqual(len(clients), 1)
  105. (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
  106. self.assertEqual(host, "1.2.3.4")
  107. self.assertEqual(port, 8765)
  108. # Set up client side protocol
  109. client_protocol = client_factory.buildProtocol(None)
  110. request_factory = OneShotRequestFactory()
  111. # Set up the server side protocol
  112. channel = _PushHTTPChannel(self.reactor)
  113. channel.requestFactory = request_factory
  114. channel.site = self.site
  115. # Connect client to server and vice versa.
  116. client_to_server_transport = FakeTransport(
  117. channel, self.reactor, client_protocol
  118. )
  119. client_protocol.makeConnection(client_to_server_transport)
  120. server_to_client_transport = FakeTransport(
  121. client_protocol, self.reactor, channel
  122. )
  123. channel.makeConnection(server_to_client_transport)
  124. # The request will now be processed by `self.site` and the response
  125. # streamed back.
  126. self.reactor.advance(0)
  127. # We tear down the connection so it doesn't get reused without our
  128. # knowledge.
  129. server_to_client_transport.loseConnection()
  130. client_to_server_transport.loseConnection()
  131. return request_factory.request
  132. def assert_request_is_get_repl_stream_updates(
  133. self, request: SynapseRequest, stream_name: str
  134. ):
  135. """Asserts that the given request is a HTTP replication request for
  136. fetching updates for given stream.
  137. """
  138. self.assertRegex(
  139. request.path,
  140. br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
  141. % (stream_name.encode("ascii"),),
  142. )
  143. self.assertEqual(request.method, b"GET")
  144. class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
  145. """Base class for tests running multiple workers.
  146. Automatically handle HTTP replication requests from workers to master,
  147. unlike `BaseStreamTestCase`.
  148. """
  149. servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
  150. def setUp(self):
  151. super().setUp()
  152. # build a replication server
  153. self.server_factory = ReplicationStreamProtocolFactory(self.hs)
  154. self.streamer = self.hs.get_replication_streamer()
  155. store = self.hs.get_datastore()
  156. self.database = store.db
  157. self.reactor.lookups["testserv"] = "1.2.3.4"
  158. self._worker_hs_to_resource = {}
  159. # When we see a connection attempt to the master replication listener we
  160. # automatically set up the connection. This is so that tests don't
  161. # manually have to go and explicitly set it up each time (plus sometimes
  162. # it is impossible to write the handling explicitly in the tests).
  163. self.reactor.add_tcp_client_callback(
  164. "1.2.3.4", 8765, self._handle_http_replication_attempt
  165. )
  166. def create_test_json_resource(self):
  167. """Overrides `HomeserverTestCase.create_test_json_resource`.
  168. """
  169. # We override this so that it automatically registers all the HTTP
  170. # replication servlets, without having to explicitly do that in all
  171. # subclassses.
  172. resource = ReplicationRestResource(self.hs)
  173. for servlet in self.servlets:
  174. servlet(self.hs, resource)
  175. return resource
  176. def make_worker_hs(
  177. self, worker_app: str, extra_config: dict = {}, **kwargs
  178. ) -> HomeServer:
  179. """Make a new worker HS instance, correctly connecting replcation
  180. stream to the master HS.
  181. Args:
  182. worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
  183. extra_config: Any extra config to use for this instances.
  184. **kwargs: Options that get passed to `self.setup_test_homeserver`,
  185. useful to e.g. pass some mocks for things like `http_client`
  186. Returns:
  187. The new worker HomeServer instance.
  188. """
  189. config = self._get_worker_hs_config()
  190. config["worker_app"] = worker_app
  191. config.update(extra_config)
  192. worker_hs = self.setup_test_homeserver(
  193. homeserverToUse=GenericWorkerServer,
  194. config=config,
  195. reactor=self.reactor,
  196. **kwargs
  197. )
  198. store = worker_hs.get_datastore()
  199. store.db._db_pool = self.database._db_pool
  200. repl_handler = ReplicationCommandHandler(worker_hs)
  201. client = ClientReplicationStreamProtocol(
  202. worker_hs, "client", "test", self.clock, repl_handler,
  203. )
  204. server = self.server_factory.buildProtocol(None)
  205. client_transport = FakeTransport(server, self.reactor)
  206. client.makeConnection(client_transport)
  207. server_transport = FakeTransport(client, self.reactor)
  208. server.makeConnection(server_transport)
  209. # Set up a resource for the worker
  210. resource = ReplicationRestResource(self.hs)
  211. for servlet in self.servlets:
  212. servlet(worker_hs, resource)
  213. self._worker_hs_to_resource[worker_hs] = resource
  214. return worker_hs
  215. def _get_worker_hs_config(self) -> dict:
  216. config = self.default_config()
  217. config["worker_replication_host"] = "testserv"
  218. config["worker_replication_http_port"] = "8765"
  219. return config
  220. def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
  221. render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
  222. def replicate(self):
  223. """Tell the master side of replication that something has happened, and then
  224. wait for the replication to occur.
  225. """
  226. self.streamer.on_notifier_poke()
  227. self.pump()
  228. def _handle_http_replication_attempt(self):
  229. """Handles a connection attempt to the master replication HTTP
  230. listener.
  231. """
  232. # We should have at least one outbound connection attempt, where the
  233. # last is one to the HTTP repication IP/port.
  234. clients = self.reactor.tcpClients
  235. self.assertGreaterEqual(len(clients), 1)
  236. (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
  237. self.assertEqual(host, "1.2.3.4")
  238. self.assertEqual(port, 8765)
  239. # Set up client side protocol
  240. client_protocol = client_factory.buildProtocol(None)
  241. request_factory = OneShotRequestFactory()
  242. # Set up the server side protocol
  243. channel = _PushHTTPChannel(self.reactor)
  244. channel.requestFactory = request_factory
  245. channel.site = self.site
  246. # Connect client to server and vice versa.
  247. client_to_server_transport = FakeTransport(
  248. channel, self.reactor, client_protocol
  249. )
  250. client_protocol.makeConnection(client_to_server_transport)
  251. server_to_client_transport = FakeTransport(
  252. client_protocol, self.reactor, channel
  253. )
  254. channel.makeConnection(server_to_client_transport)
  255. # Note: at this point we've wired everything up, but we need to return
  256. # before the data starts flowing over the connections as this is called
  257. # inside `connecTCP` before the connection has been passed back to the
  258. # code that requested the TCP connection.
  259. class TestReplicationDataHandler(GenericWorkerReplicationHandler):
  260. """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
  261. def __init__(self, hs: HomeServer):
  262. super().__init__(hs)
  263. # list of received (stream_name, token, row) tuples
  264. self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
  265. async def on_rdata(self, stream_name, instance_name, token, rows):
  266. await super().on_rdata(stream_name, instance_name, token, rows)
  267. for r in rows:
  268. self.received_rdata_rows.append((stream_name, token, r))
  269. @attr.s()
  270. class OneShotRequestFactory:
  271. """A simple request factory that generates a single `SynapseRequest` and
  272. stores it for future use. Can only be used once.
  273. """
  274. request = attr.ib(default=None)
  275. def __call__(self, *args, **kwargs):
  276. assert self.request is None
  277. self.request = SynapseRequest(*args, **kwargs)
  278. return self.request
  279. class _PushHTTPChannel(HTTPChannel):
  280. """A HTTPChannel that wraps pull producers to push producers.
  281. This is a hack to get around the fact that HTTPChannel transparently wraps a
  282. pull producer (which is what Synapse uses to reply to requests) with
  283. `_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
  284. uses the standard reactor rather than letting us use our test reactor, which
  285. makes it very hard to test.
  286. """
  287. def __init__(self, reactor: IReactorTime):
  288. super().__init__()
  289. self.reactor = reactor
  290. self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
  291. def registerProducer(self, producer, streaming):
  292. # Convert pull producers to push producer.
  293. if not streaming:
  294. self._pull_to_push_producer = _PullToPushProducer(
  295. self.reactor, producer, self
  296. )
  297. producer = self._pull_to_push_producer
  298. super().registerProducer(producer, True)
  299. def unregisterProducer(self):
  300. if self._pull_to_push_producer:
  301. # We need to manually stop the _PullToPushProducer.
  302. self._pull_to_push_producer.stop()
  303. def checkPersistence(self, request, version):
  304. """Check whether the connection can be re-used
  305. """
  306. # We hijack this to always say no for ease of wiring stuff up in
  307. # `handle_http_replication_attempt`.
  308. request.responseHeaders.setRawHeaders(b"connection", [b"close"])
  309. return False
  310. class _PullToPushProducer:
  311. """A push producer that wraps a pull producer.
  312. """
  313. def __init__(
  314. self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
  315. ):
  316. self._clock = Clock(reactor)
  317. self._producer = producer
  318. self._consumer = consumer
  319. # While running we use a looping call with a zero delay to call
  320. # resumeProducing on given producer.
  321. self._looping_call = None # type: Optional[LoopingCall]
  322. # We start writing next reactor tick.
  323. self._start_loop()
  324. def _start_loop(self):
  325. """Start the looping call to
  326. """
  327. if not self._looping_call:
  328. # Start a looping call which runs every tick.
  329. self._looping_call = self._clock.looping_call(self._run_once, 0)
  330. def stop(self):
  331. """Stops calling resumeProducing.
  332. """
  333. if self._looping_call:
  334. self._looping_call.stop()
  335. self._looping_call = None
  336. def pauseProducing(self):
  337. """Implements IPushProducer
  338. """
  339. self.stop()
  340. def resumeProducing(self):
  341. """Implements IPushProducer
  342. """
  343. self._start_loop()
  344. def stopProducing(self):
  345. """Implements IPushProducer
  346. """
  347. self.stop()
  348. self._producer.stopProducing()
  349. def _run_once(self):
  350. """Calls resumeProducing on producer once.
  351. """
  352. try:
  353. self._producer.resumeProducing()
  354. except Exception:
  355. logger.exception("Failed to call resumeProducing")
  356. try:
  357. self._consumer.unregisterProducer()
  358. except Exception:
  359. pass
  360. self.stopProducing()