|
@@ -14,6 +14,7 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
import abc
|
|
|
+import asyncio
|
|
|
import collections
|
|
|
import inspect
|
|
|
import itertools
|
|
@@ -25,6 +26,7 @@ from typing import (
|
|
|
Awaitable,
|
|
|
Callable,
|
|
|
Collection,
|
|
|
+ Coroutine,
|
|
|
Dict,
|
|
|
Generic,
|
|
|
Hashable,
|
|
@@ -701,27 +703,57 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
|
|
return new_deferred
|
|
|
|
|
|
|
|
|
-def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
|
|
- """Delay cancellation of a `Deferred` until it resolves.
|
|
|
+@overload
|
|
|
+def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
|
|
+ ...
|
|
|
+
|
|
|
+
|
|
|
+@overload
|
|
|
+def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
|
|
|
+ ...
|
|
|
+
|
|
|
+
|
|
|
+@overload
|
|
|
+def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
|
|
|
+ ...
|
|
|
+
|
|
|
+
|
|
|
+def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
|
|
|
+ """Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
|
|
|
|
|
|
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
|
|
|
- resolve with a `CancelledError` until the original `Deferred` resolves.
|
|
|
+ resolve with a `CancelledError` until the original awaitable resolves.
|
|
|
|
|
|
Args:
|
|
|
- deferred: The `Deferred` to protect against cancellation. May optionally follow
|
|
|
- the Synapse logcontext rules.
|
|
|
+ deferred: The coroutine or `Deferred` to protect against cancellation. May
|
|
|
+ optionally follow the Synapse logcontext rules.
|
|
|
|
|
|
Returns:
|
|
|
- A new `Deferred`, which will contain the result of the original `Deferred`.
|
|
|
- The new `Deferred` will not propagate cancellation through to the original.
|
|
|
- When cancelled, the new `Deferred` will wait until the original `Deferred`
|
|
|
- resolves before failing with a `CancelledError`.
|
|
|
+ A new `Deferred`, which will contain the result of the original coroutine or
|
|
|
+ `Deferred`. The new `Deferred` will not propagate cancellation through to the
|
|
|
+ original coroutine or `Deferred`.
|
|
|
|
|
|
- The new `Deferred` will follow the Synapse logcontext rules if `deferred`
|
|
|
+ When cancelled, the new `Deferred` will wait until the original coroutine or
|
|
|
+ `Deferred` resolves before failing with a `CancelledError`.
|
|
|
+
|
|
|
+ The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
|
|
|
follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
|
|
|
wrapped with `make_deferred_yieldable`.
|
|
|
"""
|
|
|
|
|
|
+ # First, convert the awaitable into a `Deferred`.
|
|
|
+ if isinstance(awaitable, defer.Deferred):
|
|
|
+ deferred = awaitable
|
|
|
+ elif asyncio.iscoroutine(awaitable):
|
|
|
+ # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
|
|
|
+ # type-checking, but we'd need Twisted >= 21.2.
|
|
|
+ deferred = defer.ensureDeferred(awaitable)
|
|
|
+ else:
|
|
|
+ # We have no idea what to do with this awaitable.
|
|
|
+ # We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
|
|
|
+ # `make_awaitable`, and let the caller `await` it normally.
|
|
|
+ return awaitable
|
|
|
+
|
|
|
def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
|
|
|
# before the new deferred is cancelled, we `pause` it to stop the cancellation
|
|
|
# propagating. we then `unpause` it once the wrapped deferred completes, to
|