test_async_helpers.py 19 KB


  1. # Copyright 2019 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import traceback
  15. from parameterized import parameterized_class
  16. from twisted.internet import defer
  17. from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
  18. from twisted.internet.task import Clock
  19. from twisted.python.failure import Failure
  20. from synapse.logging.context import (
  21. SENTINEL_CONTEXT,
  22. LoggingContext,
  23. PreserveLoggingContext,
  24. current_context,
  25. make_deferred_yieldable,
  26. )
  27. from synapse.util.async_helpers import (
  28. AwakenableSleeper,
  29. ObservableDeferred,
  30. concurrently_execute,
  31. delay_cancellation,
  32. stop_cancellation,
  33. timeout_deferred,
  34. )
  35. from tests.server import get_clock
  36. from tests.unittest import TestCase
  37. class ObservableDeferredTest(TestCase):
  38. def test_succeed(self):
  39. origin_d = Deferred()
  40. observable = ObservableDeferred(origin_d)
  41. observer1 = observable.observe()
  42. observer2 = observable.observe()
  43. self.assertFalse(observer1.called)
  44. self.assertFalse(observer2.called)
  45. # check the first observer is called first
  46. def check_called_first(res):
  47. self.assertFalse(observer2.called)
  48. return res
  49. observer1.addBoth(check_called_first)
  50. # store the results
  51. results = [None, None]
  52. def check_val(res, idx):
  53. results[idx] = res
  54. return res
  55. observer1.addCallback(check_val, 0)
  56. observer2.addCallback(check_val, 1)
  57. origin_d.callback(123)
  58. self.assertEqual(results[0], 123, "observer 1 callback result")
  59. self.assertEqual(results[1], 123, "observer 2 callback result")
  60. def test_failure(self):
  61. origin_d = Deferred()
  62. observable = ObservableDeferred(origin_d, consumeErrors=True)
  63. observer1 = observable.observe()
  64. observer2 = observable.observe()
  65. self.assertFalse(observer1.called)
  66. self.assertFalse(observer2.called)
  67. # check the first observer is called first
  68. def check_called_first(res):
  69. self.assertFalse(observer2.called)
  70. return res
  71. observer1.addBoth(check_called_first)
  72. # store the results
  73. results = [None, None]
  74. def check_val(res, idx):
  75. results[idx] = res
  76. return None
  77. observer1.addErrback(check_val, 0)
  78. observer2.addErrback(check_val, 1)
  79. try:
  80. raise Exception("gah!")
  81. except Exception as e:
  82. origin_d.errback(e)
  83. self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
  84. self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
  85. def test_cancellation(self):
  86. """Test that cancelling an observer does not affect other observers."""
  87. origin_d: "Deferred[int]" = Deferred()
  88. observable = ObservableDeferred(origin_d, consumeErrors=True)
  89. observer1 = observable.observe()
  90. observer2 = observable.observe()
  91. observer3 = observable.observe()
  92. self.assertFalse(observer1.called)
  93. self.assertFalse(observer2.called)
  94. self.assertFalse(observer3.called)
  95. # cancel the second observer
  96. observer2.cancel()
  97. self.assertFalse(observer1.called)
  98. self.failureResultOf(observer2, CancelledError)
  99. self.assertFalse(observer3.called)
  100. # other observers resolve as normal
  101. origin_d.callback(123)
  102. self.assertEqual(observer1.result, 123, "observer 1 callback result")
  103. self.assertEqual(observer3.result, 123, "observer 3 callback result")
  104. # additional observers resolve as normal
  105. observer4 = observable.observe()
  106. self.assertEqual(observer4.result, 123, "observer 4 callback result")
  107. class TimeoutDeferredTest(TestCase):
  108. def setUp(self):
  109. self.clock = Clock()
  110. def test_times_out(self):
  111. """Basic test case that checks that the original deferred is cancelled and that
  112. the timing-out deferred is errbacked
  113. """
  114. cancelled = [False]
  115. def canceller(_d):
  116. cancelled[0] = True
  117. non_completing_d = Deferred(canceller)
  118. timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
  119. self.assertNoResult(timing_out_d)
  120. self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
  121. self.clock.pump((1.0,))
  122. self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
  123. self.failureResultOf(timing_out_d, defer.TimeoutError)
  124. def test_times_out_when_canceller_throws(self):
  125. """Test that we have successfully worked around
  126. https://twistedmatrix.com/trac/ticket/9534"""
  127. def canceller(_d):
  128. raise Exception("can't cancel this deferred")
  129. non_completing_d = Deferred(canceller)
  130. timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
  131. self.assertNoResult(timing_out_d)
  132. self.clock.pump((1.0,))
  133. self.failureResultOf(timing_out_d, defer.TimeoutError)
  134. def test_logcontext_is_preserved_on_cancellation(self):
  135. blocking_was_cancelled = [False]
  136. @defer.inlineCallbacks
  137. def blocking():
  138. non_completing_d = Deferred()
  139. with PreserveLoggingContext():
  140. try:
  141. yield non_completing_d
  142. except CancelledError:
  143. blocking_was_cancelled[0] = True
  144. raise
  145. with LoggingContext("one") as context_one:
  146. # the errbacks should be run in the test logcontext
  147. def errback(res, deferred_name):
  148. self.assertIs(
  149. current_context(),
  150. context_one,
  151. "errback %s run in unexpected logcontext %s"
  152. % (deferred_name, current_context()),
  153. )
  154. return res
  155. original_deferred = blocking()
  156. original_deferred.addErrback(errback, "orig")
  157. timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
  158. self.assertNoResult(timing_out_d)
  159. self.assertIs(current_context(), SENTINEL_CONTEXT)
  160. timing_out_d.addErrback(errback, "timingout")
  161. self.clock.pump((1.0,))
  162. self.assertTrue(
  163. blocking_was_cancelled[0], "non-completing deferred was not cancelled"
  164. )
  165. self.failureResultOf(timing_out_d, defer.TimeoutError)
  166. self.assertIs(current_context(), context_one)
  167. class _TestException(Exception):
  168. pass
  169. class ConcurrentlyExecuteTest(TestCase):
  170. def test_limits_runners(self):
  171. """If we have more tasks than runners, we should get the limit of runners"""
  172. started = 0
  173. waiters = []
  174. processed = []
  175. async def callback(v):
  176. # when we first enter, bump the start count
  177. nonlocal started
  178. started += 1
  179. # record the fact we got an item
  180. processed.append(v)
  181. # wait for the goahead before returning
  182. d2 = Deferred()
  183. waiters.append(d2)
  184. await d2
  185. # set it going
  186. d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
  187. # check we got exactly 3 processes
  188. self.assertEqual(started, 3)
  189. self.assertEqual(len(waiters), 3)
  190. # let one finish
  191. waiters.pop().callback(0)
  192. # ... which should start another
  193. self.assertEqual(started, 4)
  194. self.assertEqual(len(waiters), 3)
  195. # we still shouldn't be done
  196. self.assertNoResult(d2)
  197. # finish the job
  198. while waiters:
  199. waiters.pop().callback(0)
  200. # check everything got done
  201. self.assertEqual(started, 5)
  202. self.assertCountEqual(processed, [1, 2, 3, 4, 5])
  203. self.successResultOf(d2)
  204. def test_preserves_stacktraces(self):
  205. """Test that the stacktrace from an exception thrown in the callback is preserved"""
  206. d1 = Deferred()
  207. async def callback(v):
  208. # alas, this doesn't work at all without an await here
  209. await d1
  210. raise _TestException("bah")
  211. async def caller():
  212. try:
  213. await concurrently_execute(callback, [1], 2)
  214. except _TestException as e:
  215. tb = traceback.extract_tb(e.__traceback__)
  216. # we expect to see "caller", "concurrently_execute" and "callback".
  217. self.assertEqual(tb[0].name, "caller")
  218. self.assertEqual(tb[1].name, "concurrently_execute")
  219. self.assertEqual(tb[-1].name, "callback")
  220. else:
  221. self.fail("No exception thrown")
  222. d2 = ensureDeferred(caller())
  223. d1.callback(0)
  224. self.successResultOf(d2)
  225. def test_preserves_stacktraces_on_preformed_failure(self):
  226. """Test that the stacktrace on a Failure returned by the callback is preserved"""
  227. d1 = Deferred()
  228. f = Failure(_TestException("bah"))
  229. async def callback(v):
  230. # alas, this doesn't work at all without an await here
  231. await d1
  232. await defer.fail(f)
  233. async def caller():
  234. try:
  235. await concurrently_execute(callback, [1], 2)
  236. except _TestException as e:
  237. tb = traceback.extract_tb(e.__traceback__)
  238. # we expect to see "caller", "concurrently_execute", "callback",
  239. # and some magic from inside ensureDeferred that happens when .fail
  240. # is called.
  241. self.assertEqual(tb[0].name, "caller")
  242. self.assertEqual(tb[1].name, "concurrently_execute")
  243. self.assertEqual(tb[-2].name, "callback")
  244. else:
  245. self.fail("No exception thrown")
  246. d2 = ensureDeferred(caller())
  247. d1.callback(0)
  248. self.successResultOf(d2)
  249. @parameterized_class(
  250. ("wrapper",),
  251. [("stop_cancellation",), ("delay_cancellation",)],
  252. )
  253. class CancellationWrapperTests(TestCase):
  254. """Common tests for the `stop_cancellation` and `delay_cancellation` functions."""
  255. wrapper: str
  256. def wrap_deferred(self, deferred: "Deferred[str]") -> "Deferred[str]":
  257. if self.wrapper == "stop_cancellation":
  258. return stop_cancellation(deferred)
  259. elif self.wrapper == "delay_cancellation":
  260. return delay_cancellation(deferred)
  261. else:
  262. raise ValueError(f"Unsupported wrapper type: {self.wrapper}")
  263. def test_succeed(self):
  264. """Test that the new `Deferred` receives the result."""
  265. deferred: "Deferred[str]" = Deferred()
  266. wrapper_deferred = self.wrap_deferred(deferred)
  267. # Success should propagate through.
  268. deferred.callback("success")
  269. self.assertTrue(wrapper_deferred.called)
  270. self.assertEqual("success", self.successResultOf(wrapper_deferred))
  271. def test_failure(self):
  272. """Test that the new `Deferred` receives the `Failure`."""
  273. deferred: "Deferred[str]" = Deferred()
  274. wrapper_deferred = self.wrap_deferred(deferred)
  275. # Failure should propagate through.
  276. deferred.errback(ValueError("abc"))
  277. self.assertTrue(wrapper_deferred.called)
  278. self.failureResultOf(wrapper_deferred, ValueError)
  279. self.assertIsNone(deferred.result, "`Failure` was not consumed")
  280. class StopCancellationTests(TestCase):
  281. """Tests for the `stop_cancellation` function."""
  282. def test_cancellation(self):
  283. """Test that cancellation of the new `Deferred` leaves the original running."""
  284. deferred: "Deferred[str]" = Deferred()
  285. wrapper_deferred = stop_cancellation(deferred)
  286. # Cancel the new `Deferred`.
  287. wrapper_deferred.cancel()
  288. self.assertTrue(wrapper_deferred.called)
  289. self.failureResultOf(wrapper_deferred, CancelledError)
  290. self.assertFalse(
  291. deferred.called, "Original `Deferred` was unexpectedly cancelled"
  292. )
  293. # Now make the original `Deferred` fail.
  294. # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
  295. # in logs.
  296. deferred.errback(ValueError("abc"))
  297. self.assertIsNone(deferred.result, "`Failure` was not consumed")
  298. class DelayCancellationTests(TestCase):
  299. """Tests for the `delay_cancellation` function."""
  300. def test_deferred_cancellation(self):
  301. """Test that cancellation of the new `Deferred` waits for the original."""
  302. deferred: "Deferred[str]" = Deferred()
  303. wrapper_deferred = delay_cancellation(deferred)
  304. # Cancel the new `Deferred`.
  305. wrapper_deferred.cancel()
  306. self.assertNoResult(wrapper_deferred)
  307. self.assertFalse(
  308. deferred.called, "Original `Deferred` was unexpectedly cancelled"
  309. )
  310. # Now make the original `Deferred` fail.
  311. # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
  312. # in logs.
  313. deferred.errback(ValueError("abc"))
  314. self.assertIsNone(deferred.result, "`Failure` was not consumed")
  315. # Now that the original `Deferred` has failed, we should get a `CancelledError`.
  316. self.failureResultOf(wrapper_deferred, CancelledError)
  317. def test_coroutine_cancellation(self):
  318. """Test that cancellation of the new `Deferred` waits for the original."""
  319. blocking_deferred: "Deferred[None]" = Deferred()
  320. completion_deferred: "Deferred[None]" = Deferred()
  321. async def task():
  322. await blocking_deferred
  323. completion_deferred.callback(None)
  324. # Raise an exception. Twisted should consume it, otherwise unwanted
  325. # tracebacks will be printed in logs.
  326. raise ValueError("abc")
  327. wrapper_deferred = delay_cancellation(task())
  328. # Cancel the new `Deferred`.
  329. wrapper_deferred.cancel()
  330. self.assertNoResult(wrapper_deferred)
  331. self.assertFalse(
  332. blocking_deferred.called, "Cancellation was propagated too deep"
  333. )
  334. self.assertFalse(completion_deferred.called)
  335. # Unblock the task.
  336. blocking_deferred.callback(None)
  337. self.assertTrue(completion_deferred.called)
  338. # Now that the original coroutine has failed, we should get a `CancelledError`.
  339. self.failureResultOf(wrapper_deferred, CancelledError)
  340. def test_suppresses_second_cancellation(self):
  341. """Test that a second cancellation is suppressed.
  342. Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
  343. """
  344. deferred: "Deferred[str]" = Deferred()
  345. wrapper_deferred = delay_cancellation(deferred)
  346. # Cancel the new `Deferred`, twice.
  347. wrapper_deferred.cancel()
  348. wrapper_deferred.cancel()
  349. self.assertNoResult(wrapper_deferred)
  350. self.assertFalse(
  351. deferred.called, "Original `Deferred` was unexpectedly cancelled"
  352. )
  353. # Now make the original `Deferred` fail.
  354. # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
  355. # in logs.
  356. deferred.errback(ValueError("abc"))
  357. self.assertIsNone(deferred.result, "`Failure` was not consumed")
  358. # Now that the original `Deferred` has failed, we should get a `CancelledError`.
  359. self.failureResultOf(wrapper_deferred, CancelledError)
  360. def test_propagates_cancelled_error(self):
  361. """Test that a `CancelledError` from the original `Deferred` gets propagated."""
  362. deferred: "Deferred[str]" = Deferred()
  363. wrapper_deferred = delay_cancellation(deferred)
  364. # Fail the original `Deferred` with a `CancelledError`.
  365. cancelled_error = CancelledError()
  366. deferred.errback(cancelled_error)
  367. # The new `Deferred` should fail with exactly the same `CancelledError`.
  368. self.assertTrue(wrapper_deferred.called)
  369. self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)
  370. def test_preserves_logcontext(self):
  371. """Test that logging contexts are preserved."""
  372. blocking_d: "Deferred[None]" = Deferred()
  373. async def inner():
  374. await make_deferred_yieldable(blocking_d)
  375. async def outer():
  376. with LoggingContext("c") as c:
  377. try:
  378. await delay_cancellation(inner())
  379. self.fail("`CancelledError` was not raised")
  380. except CancelledError:
  381. self.assertEqual(c, current_context())
  382. # Succeed with no error, unless the logging context is wrong.
  383. # Run and block inside `inner()`.
  384. d = defer.ensureDeferred(outer())
  385. self.assertEqual(SENTINEL_CONTEXT, current_context())
  386. d.cancel()
  387. # Now unblock. `outer()` will consume the `CancelledError` and check the
  388. # logging context.
  389. blocking_d.callback(None)
  390. self.successResultOf(d)
  391. class AwakenableSleeperTests(TestCase):
  392. "Tests AwakenableSleeper"
  393. def test_sleep(self):
  394. reactor, _ = get_clock()
  395. sleeper = AwakenableSleeper(reactor)
  396. d = defer.ensureDeferred(sleeper.sleep("name", 1000))
  397. reactor.pump([0.0])
  398. self.assertFalse(d.called)
  399. reactor.advance(0.5)
  400. self.assertFalse(d.called)
  401. reactor.advance(0.6)
  402. self.assertTrue(d.called)
  403. def test_explicit_wake(self):
  404. reactor, _ = get_clock()
  405. sleeper = AwakenableSleeper(reactor)
  406. d = defer.ensureDeferred(sleeper.sleep("name", 1000))
  407. reactor.pump([0.0])
  408. self.assertFalse(d.called)
  409. reactor.advance(0.5)
  410. self.assertFalse(d.called)
  411. sleeper.wake("name")
  412. self.assertTrue(d.called)
  413. reactor.advance(0.6)
  414. def test_multiple_sleepers_timeout(self):
  415. reactor, _ = get_clock()
  416. sleeper = AwakenableSleeper(reactor)
  417. d1 = defer.ensureDeferred(sleeper.sleep("name", 1000))
  418. reactor.advance(0.6)
  419. self.assertFalse(d1.called)
  420. # Add another sleeper
  421. d2 = defer.ensureDeferred(sleeper.sleep("name", 1000))
  422. # Only the first sleep should time out now.
  423. reactor.advance(0.6)
  424. self.assertTrue(d1.called)
  425. self.assertFalse(d2.called)
  426. reactor.advance(0.6)
  427. self.assertTrue(d2.called)
  428. def test_multiple_sleepers_wake(self):
  429. reactor, _ = get_clock()
  430. sleeper = AwakenableSleeper(reactor)
  431. d1 = defer.ensureDeferred(sleeper.sleep("name", 1000))
  432. reactor.advance(0.5)
  433. self.assertFalse(d1.called)
  434. # Add another sleeper
  435. d2 = defer.ensureDeferred(sleeper.sleep("name", 1000))
  436. # Neither should fire yet
  437. reactor.advance(0.3)
  438. self.assertFalse(d1.called)
  439. self.assertFalse(d2.called)
  440. # Explicitly waking both up works
  441. sleeper.wake("name")
  442. self.assertTrue(d1.called)
  443. self.assertTrue(d2.called)