test_async_helpers.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright 2019 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from twisted.internet import defer
  15. from twisted.internet.defer import CancelledError, Deferred
  16. from twisted.internet.task import Clock
  17. from synapse.logging.context import (
  18. SENTINEL_CONTEXT,
  19. LoggingContext,
  20. PreserveLoggingContext,
  21. current_context,
  22. )
  23. from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
  24. from tests.unittest import TestCase
  25. class ObservableDeferredTest(TestCase):
  26. def test_succeed(self):
  27. origin_d = Deferred()
  28. observable = ObservableDeferred(origin_d)
  29. observer1 = observable.observe()
  30. observer2 = observable.observe()
  31. self.assertFalse(observer1.called)
  32. self.assertFalse(observer2.called)
  33. # check the first observer is called first
  34. def check_called_first(res):
  35. self.assertFalse(observer2.called)
  36. return res
  37. observer1.addBoth(check_called_first)
  38. # store the results
  39. results = [None, None]
  40. def check_val(res, idx):
  41. results[idx] = res
  42. return res
  43. observer1.addCallback(check_val, 0)
  44. observer2.addCallback(check_val, 1)
  45. origin_d.callback(123)
  46. self.assertEqual(results[0], 123, "observer 1 callback result")
  47. self.assertEqual(results[1], 123, "observer 2 callback result")
  48. def test_failure(self):
  49. origin_d = Deferred()
  50. observable = ObservableDeferred(origin_d, consumeErrors=True)
  51. observer1 = observable.observe()
  52. observer2 = observable.observe()
  53. self.assertFalse(observer1.called)
  54. self.assertFalse(observer2.called)
  55. # check the first observer is called first
  56. def check_called_first(res):
  57. self.assertFalse(observer2.called)
  58. return res
  59. observer1.addBoth(check_called_first)
  60. # store the results
  61. results = [None, None]
  62. def check_val(res, idx):
  63. results[idx] = res
  64. return None
  65. observer1.addErrback(check_val, 0)
  66. observer2.addErrback(check_val, 1)
  67. try:
  68. raise Exception("gah!")
  69. except Exception as e:
  70. origin_d.errback(e)
  71. self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result")
  72. self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
  73. class TimeoutDeferredTest(TestCase):
  74. def setUp(self):
  75. self.clock = Clock()
  76. def test_times_out(self):
  77. """Basic test case that checks that the original deferred is cancelled and that
  78. the timing-out deferred is errbacked
  79. """
  80. cancelled = [False]
  81. def canceller(_d):
  82. cancelled[0] = True
  83. non_completing_d = Deferred(canceller)
  84. timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
  85. self.assertNoResult(timing_out_d)
  86. self.assertFalse(cancelled[0], "deferred was cancelled prematurely")
  87. self.clock.pump((1.0,))
  88. self.assertTrue(cancelled[0], "deferred was not cancelled by timeout")
  89. self.failureResultOf(timing_out_d, defer.TimeoutError)
  90. def test_times_out_when_canceller_throws(self):
  91. """Test that we have successfully worked around
  92. https://twistedmatrix.com/trac/ticket/9534"""
  93. def canceller(_d):
  94. raise Exception("can't cancel this deferred")
  95. non_completing_d = Deferred(canceller)
  96. timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
  97. self.assertNoResult(timing_out_d)
  98. self.clock.pump((1.0,))
  99. self.failureResultOf(timing_out_d, defer.TimeoutError)
  100. def test_logcontext_is_preserved_on_cancellation(self):
  101. blocking_was_cancelled = [False]
  102. @defer.inlineCallbacks
  103. def blocking():
  104. non_completing_d = Deferred()
  105. with PreserveLoggingContext():
  106. try:
  107. yield non_completing_d
  108. except CancelledError:
  109. blocking_was_cancelled[0] = True
  110. raise
  111. with LoggingContext("one") as context_one:
  112. # the errbacks should be run in the test logcontext
  113. def errback(res, deferred_name):
  114. self.assertIs(
  115. current_context(),
  116. context_one,
  117. "errback %s run in unexpected logcontext %s"
  118. % (deferred_name, current_context()),
  119. )
  120. return res
  121. original_deferred = blocking()
  122. original_deferred.addErrback(errback, "orig")
  123. timing_out_d = timeout_deferred(original_deferred, 1.0, self.clock)
  124. self.assertNoResult(timing_out_d)
  125. self.assertIs(current_context(), SENTINEL_CONTEXT)
  126. timing_out_d.addErrback(errback, "timingout")
  127. self.clock.pump((1.0,))
  128. self.assertTrue(
  129. blocking_was_cancelled[0], "non-completing deferred was not cancelled"
  130. )
  131. self.failureResultOf(timing_out_d, defer.TimeoutError)
  132. self.assertIs(current_context(), context_one)