test_async_helpers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # Copyright 2019 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import traceback
  15. from twisted.internet import defer
  16. from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
  17. from twisted.internet.task import Clock
  18. from twisted.python.failure import Failure
  19. from synapse.logging.context import (
  20. SENTINEL_CONTEXT,
  21. LoggingContext,
  22. PreserveLoggingContext,
  23. current_context,
  24. )
  25. from synapse.util.async_helpers import (
  26. ObservableDeferred,
  27. concurrently_execute,
  28. stop_cancellation,
  29. timeout_deferred,
  30. )
  31. from tests.unittest import TestCase
  32. class ObservableDeferredTest(TestCase):
  33. def test_succeed(self):
  34. origin_d = Deferred()
  35. observable = ObservableDeferred(origin_d)
  36. observer1 = observable.observe()
  37. observer2 = observable.observe()
  38. self.assertFalse(observer1.called)
  39. self.assertFalse(observer2.called)
  40. # check the first observer is called first
  41. def check_called_first(res):
  42. self.assertFalse(observer2.called)
  43. return res
  44. observer1.addBoth(check_called_first)
  45. # store the results
  46. results = [None, None]
  47. def check_val(res, idx):
  48. results[idx] = res
  49. return res
  50. observer1.addCallback(check_val, 0)
  51. observer2.addCallback(check_val, 1)
  52. origin_d.callback(123)
  53. self.assertEqual(results[0], 123, "observer 1 callback result")
  54. self.assertEqual(results[1], 123, "observer 2 callback result")
  55. def test_failure(self):
  56. origin_d = Deferred()
  57. observable = ObservableDeferred(origin_d, consumeErrors=True)
  58. observer1 = observable.observe()
  59. observer2 = observable.observe()
  60. self.assertFalse(observer1.called)
  61. self.assertFalse(observer2.called)
  62. # check the first observer is called first
  63. def check_called_first(res):
  64. self.assertFalse(observer2.called)
  65. return res
  66. observer1.addBoth(check_called_first)
  67. # store the results
  68. results = [None, None]
  69. def check_val(res, idx):
  70. results[idx] = res
  71. return None
  72. observer1.addErrback(check_val, 0)
  73. observer2.addErrback(check_val, 1)
  74. try:
  75. raise Exception("gah!")
  76. except Exception as e:
  77. origin_d.errback(e)
  78. self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
  79. self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
  80. def test_cancellation(self):
  81. """Test that cancelling an observer does not affect other observers."""
  82. origin_d: "Deferred[int]" = Deferred()
  83. observable = ObservableDeferred(origin_d, consumeErrors=True)
  84. observer1 = observable.observe()
  85. observer2 = observable.observe()
  86. observer3 = observable.observe()
  87. self.assertFalse(observer1.called)
  88. self.assertFalse(observer2.called)
  89. self.assertFalse(observer3.called)
  90. # cancel the second observer
  91. observer2.cancel()
  92. self.assertFalse(observer1.called)
  93. self.failureResultOf(observer2, CancelledError)
  94. self.assertFalse(observer3.called)
  95. # other observers resolve as normal
  96. origin_d.callback(123)
  97. self.assertEqual(observer1.result, 123, "observer 1 callback result")
  98. self.assertEqual(observer3.result, 123, "observer 3 callback result")
  99. # additional observers resolve as normal
  100. observer4 = observable.observe()
  101. self.assertEqual(observer4.result, 123, "observer 4 callback result")
  102. class TimeoutDeferredTest(TestCase):
  103. def setUp(self):
  104. self.clock = Clock()
  105. def test_times_out(self):
  106. """Basic test case that checks that the original deferred is cancelled and that
  107. the timing-out deferred is errbacked
  108. """
  109. cancelled = [False]
  110. def canceller(_d):
  111. cancelled[0] = True
  112. non_completing_d = Deferred(canceller)
  113. timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
  114. self.assertNoResult(timing_out_d)
  115. self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
  116. self.clock.pump((1.0,))
  117. self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
  118. self.failureResultOf(timing_out_d, defer.TimeoutError)
  119. def test_times_out_when_canceller_throws(self):
  120. """Test that we have successfully worked around
  121. https://twistedmatrix.com/trac/ticket/9534"""
  122. def canceller(_d):
  123. raise Exception("can't cancel this deferred")
  124. non_completing_d = Deferred(canceller)
  125. timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
  126. self.assertNoResult(timing_out_d)
  127. self.clock.pump((1.0,))
  128. self.failureResultOf(timing_out_d, defer.TimeoutError)
  129. def test_logcontext_is_preserved_on_cancellation(self):
  130. blocking_was_cancelled = [False]
  131. @defer.inlineCallbacks
  132. def blocking():
  133. non_completing_d = Deferred()
  134. with PreserveLoggingContext():
  135. try:
  136. yield non_completing_d
  137. except CancelledError:
  138. blocking_was_cancelled[0] = True
  139. raise
  140. with LoggingContext("one") as context_one:
  141. # the errbacks should be run in the test logcontext
  142. def errback(res, deferred_name):
  143. self.assertIs(
  144. current_context(),
  145. context_one,
  146. "errback %s run in unexpected logcontext %s"
  147. % (deferred_name, current_context()),
  148. )
  149. return res
  150. original_deferred = blocking()
  151. original_deferred.addErrback(errback, "orig")
  152. timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
  153. self.assertNoResult(timing_out_d)
  154. self.assertIs(current_context(), SENTINEL_CONTEXT)
  155. timing_out_d.addErrback(errback, "timingout")
  156. self.clock.pump((1.0,))
  157. self.assertTrue(
  158. blocking_was_cancelled[0], "non-completing deferred was not cancelled"
  159. )
  160. self.failureResultOf(timing_out_d, defer.TimeoutError)
  161. self.assertIs(current_context(), context_one)
  162. class _TestException(Exception):
  163. pass
  164. class ConcurrentlyExecuteTest(TestCase):
  165. def test_limits_runners(self):
  166. """If we have more tasks than runners, we should get the limit of runners"""
  167. started = 0
  168. waiters = []
  169. processed = []
  170. async def callback(v):
  171. # when we first enter, bump the start count
  172. nonlocal started
  173. started += 1
  174. # record the fact we got an item
  175. processed.append(v)
  176. # wait for the goahead before returning
  177. d2 = Deferred()
  178. waiters.append(d2)
  179. await d2
  180. # set it going
  181. d2 = ensureDeferred(concurrently_execute(callback, [1, 2, 3, 4, 5], 3))
  182. # check we got exactly 3 processes
  183. self.assertEqual(started, 3)
  184. self.assertEqual(len(waiters), 3)
  185. # let one finish
  186. waiters.pop().callback(0)
  187. # ... which should start another
  188. self.assertEqual(started, 4)
  189. self.assertEqual(len(waiters), 3)
  190. # we still shouldn't be done
  191. self.assertNoResult(d2)
  192. # finish the job
  193. while waiters:
  194. waiters.pop().callback(0)
  195. # check everything got done
  196. self.assertEqual(started, 5)
  197. self.assertCountEqual(processed, [1, 2, 3, 4, 5])
  198. self.successResultOf(d2)
  199. def test_preserves_stacktraces(self):
  200. """Test that the stacktrace from an exception thrown in the callback is preserved"""
  201. d1 = Deferred()
  202. async def callback(v):
  203. # alas, this doesn't work at all without an await here
  204. await d1
  205. raise _TestException("bah")
  206. async def caller():
  207. try:
  208. await concurrently_execute(callback, [1], 2)
  209. except _TestException as e:
  210. tb = traceback.extract_tb(e.__traceback__)
  211. # we expect to see "caller", "concurrently_execute" and "callback".
  212. self.assertEqual(tb[0].name, "caller")
  213. self.assertEqual(tb[1].name, "concurrently_execute")
  214. self.assertEqual(tb[-1].name, "callback")
  215. else:
  216. self.fail("No exception thrown")
  217. d2 = ensureDeferred(caller())
  218. d1.callback(0)
  219. self.successResultOf(d2)
  220. def test_preserves_stacktraces_on_preformed_failure(self):
  221. """Test that the stacktrace on a Failure returned by the callback is preserved"""
  222. d1 = Deferred()
  223. f = Failure(_TestException("bah"))
  224. async def callback(v):
  225. # alas, this doesn't work at all without an await here
  226. await d1
  227. await defer.fail(f)
  228. async def caller():
  229. try:
  230. await concurrently_execute(callback, [1], 2)
  231. except _TestException as e:
  232. tb = traceback.extract_tb(e.__traceback__)
  233. # we expect to see "caller", "concurrently_execute", "callback",
  234. # and some magic from inside ensureDeferred that happens when .fail
  235. # is called.
  236. self.assertEqual(tb[0].name, "caller")
  237. self.assertEqual(tb[1].name, "concurrently_execute")
  238. self.assertEqual(tb[-2].name, "callback")
  239. else:
  240. self.fail("No exception thrown")
  241. d2 = ensureDeferred(caller())
  242. d1.callback(0)
  243. self.successResultOf(d2)
  244. class StopCancellationTests(TestCase):
  245. """Tests for the `stop_cancellation` function."""
  246. def test_succeed(self):
  247. """Test that the new `Deferred` receives the result."""
  248. deferred: "Deferred[str]" = Deferred()
  249. wrapper_deferred = stop_cancellation(deferred)
  250. # Success should propagate through.
  251. deferred.callback("success")
  252. self.assertTrue(wrapper_deferred.called)
  253. self.assertEqual("success", self.successResultOf(wrapper_deferred))
  254. def test_failure(self):
  255. """Test that the new `Deferred` receives the `Failure`."""
  256. deferred: "Deferred[str]" = Deferred()
  257. wrapper_deferred = stop_cancellation(deferred)
  258. # Failure should propagate through.
  259. deferred.errback(ValueError("abc"))
  260. self.assertTrue(wrapper_deferred.called)
  261. self.failureResultOf(wrapper_deferred, ValueError)
  262. self.assertIsNone(deferred.result, "`Failure` was not consumed")
  263. def test_cancellation(self):
  264. """Test that cancellation of the new `Deferred` leaves the original running."""
  265. deferred: "Deferred[str]" = Deferred()
  266. wrapper_deferred = stop_cancellation(deferred)
  267. # Cancel the new `Deferred`.
  268. wrapper_deferred.cancel()
  269. self.assertTrue(wrapper_deferred.called)
  270. self.failureResultOf(wrapper_deferred, CancelledError)
  271. self.assertFalse(
  272. deferred.called, "Original `Deferred` was unexpectedly cancelled."
  273. )
  274. # Now make the inner `Deferred` fail.
  275. # The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
  276. # in logs.
  277. deferred.errback(ValueError("abc"))
  278. self.assertIsNone(deferred.result, "`Failure` was not consumed")