Browse Source

Update `delay_cancellation` to accept any awaitable (#12468)

This will mainly be useful when dealing with module callbacks, which are
all typed as returning `Awaitable`s instead of coroutines or
`Deferred`s.

Signed-off-by: Sean Quah <seanq@element.io>
Sean Quah 2 years ago
parent
commit
a50fb411b3

+ 1 - 0
changelog.d/12468.misc

@@ -0,0 +1 @@
+Update `delay_cancellation` to accept any awaitable, rather than just `Deferred`s.

+ 1 - 2
synapse/storage/database.py

@@ -41,7 +41,6 @@ from prometheus_client import Histogram
 from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
-from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -794,7 +793,7 @@ class DatabasePool:
         # We also wait until everything above is done before releasing the
         # `CancelledError`, so that logging contexts won't get used after they have been
         # finished.
-        return await delay_cancellation(defer.ensureDeferred(_runInteraction()))
+        return await delay_cancellation(_runInteraction())
 
     async def runWithConnection(
         self,

+ 42 - 10
synapse/util/async_helpers.py

@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import abc
+import asyncio
 import collections
 import inspect
 import itertools
@@ -25,6 +26,7 @@ from typing import (
     Awaitable,
     Callable,
     Collection,
+    Coroutine,
     Dict,
     Generic,
     Hashable,
@@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
     return new_deferred
 
 
-def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
-    """Delay cancellation of a `Deferred` until it resolves.
+@overload
+def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
+    ...
+
+
+@overload
+def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
+    ...
+
+
+@overload
+def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
+    ...
+
+
+def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
+    """Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
 
     Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
-    resolve with a `CancelledError` until the original `Deferred` resolves.
+    resolve with a `CancelledError` until the original awaitable resolves.
 
     Args:
-        deferred: The `Deferred` to protect against cancellation. May optionally follow
-            the Synapse logcontext rules.
+        deferred: The coroutine or `Deferred` to protect against cancellation. May
+            optionally follow the Synapse logcontext rules.
 
     Returns:
-        A new `Deferred`, which will contain the result of the original `Deferred`.
-        The new `Deferred` will not propagate cancellation through to the original.
-        When cancelled, the new `Deferred` will wait until the original `Deferred`
-        resolves before failing with a `CancelledError`.
+        A new `Deferred`, which will contain the result of the original coroutine or
+        `Deferred`. The new `Deferred` will not propagate cancellation through to the
+        original coroutine or `Deferred`.
 
-        The new `Deferred` will follow the Synapse logcontext rules if `deferred`
+        When cancelled, the new `Deferred` will wait until the original coroutine or
+        `Deferred` resolves before failing with a `CancelledError`.
+
+        The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
         follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
         wrapped with `make_deferred_yieldable`.
     """
 
+    # First, convert the awaitable into a `Deferred`.
+    if isinstance(awaitable, defer.Deferred):
+        deferred = awaitable
+    elif asyncio.iscoroutine(awaitable):
+        # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
+        # type-checking, but we'd need Twisted >= 21.2.
+        deferred = defer.ensureDeferred(awaitable)
+    else:
+        # We have no idea what to do with this awaitable.
+        # We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
+        # `make_awaitable`, and let the caller `await` it normally.
+        return awaitable
+
     def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
         # before the new deferred is cancelled, we `pause` it to stop the cancellation
         # propagating. we then `unpause` it once the wrapped deferred completes, to

+ 31 - 2
tests/util/test_async_helpers.py

@@ -382,7 +382,7 @@ class StopCancellationTests(TestCase):
 class DelayCancellationTests(TestCase):
     """Tests for the `delay_cancellation` function."""
 
-    def test_cancellation(self):
+    def test_deferred_cancellation(self):
         """Test that cancellation of the new `Deferred` waits for the original."""
         deferred: "Deferred[str]" = Deferred()
         wrapper_deferred = delay_cancellation(deferred)
@@ -403,6 +403,35 @@ class DelayCancellationTests(TestCase):
         # Now that the original `Deferred` has failed, we should get a `CancelledError`.
         self.failureResultOf(wrapper_deferred, CancelledError)
 
+    def test_coroutine_cancellation(self):
+        """Test that cancellation of the new `Deferred` waits for the original."""
+        blocking_deferred: "Deferred[None]" = Deferred()
+        completion_deferred: "Deferred[None]" = Deferred()
+
+        async def task():
+            await blocking_deferred
+            completion_deferred.callback(None)
+            # Raise an exception. Twisted should consume it, otherwise unwanted
+            # tracebacks will be printed in logs.
+            raise ValueError("abc")
+
+        wrapper_deferred = delay_cancellation(task())
+
+        # Cancel the new `Deferred`.
+        wrapper_deferred.cancel()
+        self.assertNoResult(wrapper_deferred)
+        self.assertFalse(
+            blocking_deferred.called, "Cancellation was propagated too deep"
+        )
+        self.assertFalse(completion_deferred.called)
+
+        # Unblock the task.
+        blocking_deferred.callback(None)
+        self.assertTrue(completion_deferred.called)
+
+        # Now that the original coroutine has failed, we should get a `CancelledError`.
+        self.failureResultOf(wrapper_deferred, CancelledError)
+
     def test_suppresses_second_cancellation(self):
         """Test that a second cancellation is suppressed.
 
@@ -451,7 +480,7 @@ class DelayCancellationTests(TestCase):
         async def outer():
             with LoggingContext("c") as c:
                 try:
-                    await delay_cancellation(defer.ensureDeferred(inner()))
+                    await delay_cancellation(inner())
                     self.fail("`CancelledError` was not raised")
                 except CancelledError:
                     self.assertEqual(c, current_context())