123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885 |
- # Copyright 2014-2016 OpenMarket Ltd
- # Copyright 2018 New Vector Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import abc
- import asyncio
- import collections
- import inspect
- import itertools
- import logging
- from contextlib import asynccontextmanager
- from typing import (
- Any,
- AsyncIterator,
- Awaitable,
- Callable,
- Collection,
- Coroutine,
- Dict,
- Generic,
- Hashable,
- Iterable,
- List,
- Optional,
- Set,
- Tuple,
- TypeVar,
- Union,
- cast,
- overload,
- )
- import attr
- from typing_extensions import AsyncContextManager, Concatenate, Literal, ParamSpec
- from twisted.internet import defer
- from twisted.internet.defer import CancelledError
- from twisted.internet.interfaces import IReactorTime
- from twisted.python.failure import Failure
- from synapse.logging.context import (
- PreserveLoggingContext,
- make_deferred_yieldable,
- run_in_background,
- )
- from synapse.util import Clock
- logger = logging.getLogger(__name__)
- _T = TypeVar("_T")
- class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta):
- """Abstract base class defining the consumer interface of ObservableDeferred"""
- __slots__ = ()
- @abc.abstractmethod
- def observe(self) -> "defer.Deferred[_T]":
- """Add a new observer for this ObservableDeferred
- This returns a brand new deferred that is resolved when the underlying
- deferred is resolved. Interacting with the returned deferred does not
- effect the underlying deferred.
- Note that the returned Deferred doesn't follow the Synapse logcontext rules -
- you will probably want to `make_deferred_yieldable` it.
- """
- ...
- class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
- """Wraps a deferred object so that we can add observer deferreds. These
- observer deferreds do not affect the callback chain of the original
- deferred.
- If consumeErrors is true errors will be captured from the origin deferred.
- Cancelling or otherwise resolving an observer will not affect the original
- ObservableDeferred.
- NB that it does not attempt to do anything with logcontexts; in general
- you should probably make_deferred_yieldable the deferreds
- returned by `observe`, and ensure that the original deferred runs its
- callbacks in the sentinel logcontext.
- """
- __slots__ = ["_deferred", "_observers", "_result"]
- _deferred: "defer.Deferred[_T]"
- _observers: Union[List["defer.Deferred[_T]"], Tuple[()]]
- _result: Union[None, Tuple[Literal[True], _T], Tuple[Literal[False], Failure]]
- def __init__(self, deferred: "defer.Deferred[_T]", consumeErrors: bool = False):
- object.__setattr__(self, "_deferred", deferred)
- object.__setattr__(self, "_result", None)
- object.__setattr__(self, "_observers", [])
- def callback(r: _T) -> _T:
- object.__setattr__(self, "_result", (True, r))
- # once we have set _result, no more entries will be added to _observers,
- # so it's safe to replace it with the empty tuple.
- observers = self._observers
- object.__setattr__(self, "_observers", ())
- for observer in observers:
- try:
- observer.callback(r)
- except Exception as e:
- logger.exception(
- "%r threw an exception on .callback(%r), ignoring...",
- observer,
- r,
- exc_info=e,
- )
- return r
- def errback(f: Failure) -> Optional[Failure]:
- object.__setattr__(self, "_result", (False, f))
- # once we have set _result, no more entries will be added to _observers,
- # so it's safe to replace it with the empty tuple.
- observers = self._observers
- object.__setattr__(self, "_observers", ())
- for observer in observers:
- # This is a little bit of magic to correctly propagate stack
- # traces when we `await` on one of the observer deferreds.
- f.value.__failure__ = f # type: ignore[union-attr]
- try:
- observer.errback(f)
- except Exception as e:
- logger.exception(
- "%r threw an exception on .errback(%r), ignoring...",
- observer,
- f,
- exc_info=e,
- )
- if consumeErrors:
- return None
- else:
- return f
- deferred.addCallbacks(callback, errback)
- def observe(self) -> "defer.Deferred[_T]":
- """Observe the underlying deferred.
- This returns a brand new deferred that is resolved when the underlying
- deferred is resolved. Interacting with the returned deferred does not
- effect the underlying deferred.
- """
- if not self._result:
- assert isinstance(self._observers, list)
- d: "defer.Deferred[_T]" = defer.Deferred()
- self._observers.append(d)
- return d
- elif self._result[0]:
- return defer.succeed(self._result[1])
- else:
- return defer.fail(self._result[1])
- def observers(self) -> "Collection[defer.Deferred[_T]]":
- return self._observers
- def has_called(self) -> bool:
- return self._result is not None
- def has_succeeded(self) -> bool:
- return self._result is not None and self._result[0] is True
- def get_result(self) -> Union[_T, Failure]:
- if self._result is None:
- raise ValueError(f"{self!r} has no result yet")
- return self._result[1]
- def __getattr__(self, name: str) -> Any:
- return getattr(self._deferred, name)
- def __setattr__(self, name: str, value: Any) -> None:
- setattr(self._deferred, name, value)
- def __repr__(self) -> str:
- return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
- id(self),
- self._result,
- self._deferred,
- )
- T = TypeVar("T")
- async def concurrently_execute(
- func: Callable[[T], Any],
- args: Iterable[T],
- limit: int,
- delay_cancellation: bool = False,
- ) -> None:
- """Executes the function with each argument concurrently while limiting
- the number of concurrent executions.
- Args:
- func: Function to execute, should return a deferred or coroutine.
- args: List of arguments to pass to func, each invocation of func
- gets a single argument.
- limit: Maximum number of conccurent executions.
- delay_cancellation: Whether to delay cancellation until after the invocations
- have finished.
- Returns:
- None, when all function invocations have finished. The return values
- from those functions are discarded.
- """
- it = iter(args)
- async def _concurrently_execute_inner(value: T) -> None:
- try:
- while True:
- await maybe_awaitable(func(value))
- value = next(it)
- except StopIteration:
- pass
- # We use `itertools.islice` to handle the case where the number of args is
- # less than the limit, avoiding needlessly spawning unnecessary background
- # tasks.
- if delay_cancellation:
- await yieldable_gather_results_delaying_cancellation(
- _concurrently_execute_inner,
- (value for value in itertools.islice(it, limit)),
- )
- else:
- await yieldable_gather_results(
- _concurrently_execute_inner,
- (value for value in itertools.islice(it, limit)),
- )
- P = ParamSpec("P")
- R = TypeVar("R")
- async def yieldable_gather_results(
- func: Callable[Concatenate[T, P], Awaitable[R]],
- iter: Iterable[T],
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> List[R]:
- """Executes the function with each argument concurrently.
- Args:
- func: Function to execute that returns a Deferred
- iter: An iterable that yields items that get passed as the first
- argument to the function
- *args: Arguments to be passed to each call to func
- **kwargs: Keyword arguments to be passed to each call to func
- Returns
- A list containing the results of the function
- """
- try:
- return await make_deferred_yieldable(
- defer.gatherResults(
- # type-ignore: mypy reports two errors:
- # error: Argument 1 to "run_in_background" has incompatible type
- # "Callable[[T, **P], Awaitable[R]]"; expected
- # "Callable[[T, **P], Awaitable[R]]" [arg-type]
- # error: Argument 2 to "run_in_background" has incompatible type
- # "T"; expected "[T, **P.args]" [arg-type]
- # The former looks like a mypy bug, and the latter looks like a
- # false positive.
- [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
- consumeErrors=True,
- )
- )
- except defer.FirstError as dfe:
- # unwrap the error from defer.gatherResults.
- # The raised exception's traceback only includes func() etc if
- # the 'await' happens before the exception is thrown - ie if the failure
- # happens *asynchronously* - otherwise Twisted throws away the traceback as it
- # could be large.
- #
- # We could maybe reconstruct a fake traceback from Failure.frames. Or maybe
- # we could throw Twisted into the fires of Mordor.
- # suppress exception chaining, because the FirstError doesn't tell us anything
- # very interesting.
- assert isinstance(dfe.subFailure.value, BaseException)
- raise dfe.subFailure.value from None
- async def yieldable_gather_results_delaying_cancellation(
- func: Callable[Concatenate[T, P], Awaitable[R]],
- iter: Iterable[T],
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> List[R]:
- """Executes the function with each argument concurrently.
- Cancellation is delayed until after all the results have been gathered.
- See `yieldable_gather_results`.
- Args:
- func: Function to execute that returns a Deferred
- iter: An iterable that yields items that get passed as the first
- argument to the function
- *args: Arguments to be passed to each call to func
- **kwargs: Keyword arguments to be passed to each call to func
- Returns
- A list containing the results of the function
- """
- try:
- return await make_deferred_yieldable(
- delay_cancellation(
- defer.gatherResults(
- [run_in_background(func, item, *args, **kwargs) for item in iter], # type: ignore[arg-type]
- consumeErrors=True,
- )
- )
- )
- except defer.FirstError as dfe:
- assert isinstance(dfe.subFailure.value, BaseException)
- raise dfe.subFailure.value from None
- T1 = TypeVar("T1")
- T2 = TypeVar("T2")
- T3 = TypeVar("T3")
- @overload
- def gather_results(
- deferredList: Tuple[()], consumeErrors: bool = ...
- ) -> "defer.Deferred[Tuple[()]]":
- ...
- @overload
- def gather_results(
- deferredList: Tuple["defer.Deferred[T1]"],
- consumeErrors: bool = ...,
- ) -> "defer.Deferred[Tuple[T1]]":
- ...
- @overload
- def gather_results(
- deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
- consumeErrors: bool = ...,
- ) -> "defer.Deferred[Tuple[T1, T2]]":
- ...
- @overload
- def gather_results(
- deferredList: Tuple[
- "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
- ],
- consumeErrors: bool = ...,
- ) -> "defer.Deferred[Tuple[T1, T2, T3]]":
- ...
- def gather_results( # type: ignore[misc]
- deferredList: Tuple["defer.Deferred[T1]", ...],
- consumeErrors: bool = False,
- ) -> "defer.Deferred[Tuple[T1, ...]]":
- """Combines a tuple of `Deferred`s into a single `Deferred`.
- Wraps `defer.gatherResults` to provide type annotations that support heterogenous
- lists of `Deferred`s.
- """
- # The `type: ignore[misc]` above suppresses
- # "Overloaded function implementation cannot produce return type of signature 1/2/3"
- deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
- return deferred.addCallback(tuple)
- @attr.s(slots=True, auto_attribs=True)
- class _LinearizerEntry:
- # The number of things executing.
- count: int
- # Deferreds for the things blocked from executing.
- deferreds: collections.OrderedDict
- class Linearizer:
- """Limits concurrent access to resources based on a key. Useful to ensure
- only a few things happen at a time on a given resource.
- Example:
- async with limiter.queue("test_key"):
- # do some work.
- """
- def __init__(
- self,
- name: Optional[str] = None,
- max_count: int = 1,
- clock: Optional[Clock] = None,
- ):
- """
- Args:
- max_count: The maximum number of concurrent accesses
- """
- if name is None:
- self.name: Union[str, int] = id(self)
- else:
- self.name = name
- if not clock:
- from twisted.internet import reactor
- clock = Clock(cast(IReactorTime, reactor))
- self._clock = clock
- self.max_count = max_count
- # key_to_defer is a map from the key to a _LinearizerEntry.
- self.key_to_defer: Dict[Hashable, _LinearizerEntry] = {}
- def is_queued(self, key: Hashable) -> bool:
- """Checks whether there is a process queued up waiting"""
- entry = self.key_to_defer.get(key)
- if not entry:
- # No entry so nothing is waiting.
- return False
- # There are waiting deferreds only in the OrderedDict of deferreds is
- # non-empty.
- return bool(entry.deferreds)
- def queue(self, key: Hashable) -> AsyncContextManager[None]:
- @asynccontextmanager
- async def _ctx_manager() -> AsyncIterator[None]:
- entry = await self._acquire_lock(key)
- try:
- yield
- finally:
- self._release_lock(key, entry)
- return _ctx_manager()
- async def _acquire_lock(self, key: Hashable) -> _LinearizerEntry:
- """Acquires a linearizer lock, waiting if necessary.
- Returns once we have secured the lock.
- """
- entry = self.key_to_defer.setdefault(
- key, _LinearizerEntry(0, collections.OrderedDict())
- )
- if entry.count < self.max_count:
- # The number of things executing is less than the maximum.
- logger.debug(
- "Acquired uncontended linearizer lock %r for key %r", self.name, key
- )
- entry.count += 1
- return entry
- # Otherwise, the number of things executing is at the maximum and we have to
- # add a deferred to the list of blocked items.
- # When one of the things currently executing finishes it will callback
- # this item so that it can continue executing.
- logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
- new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
- entry.deferreds[new_defer] = 1
- try:
- await new_defer
- except Exception as e:
- logger.info("defer %r got err %r", new_defer, e)
- if isinstance(e, CancelledError):
- logger.debug(
- "Cancelling wait for linearizer lock %r for key %r",
- self.name,
- key,
- )
- else:
- logger.warning(
- "Unexpected exception waiting for linearizer lock %r for key %r",
- self.name,
- key,
- )
- # we just have to take ourselves back out of the queue.
- del entry.deferreds[new_defer]
- raise
- logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
- entry.count += 1
- # if the code holding the lock completes synchronously, then it
- # will recursively run the next claimant on the list. That can
- # relatively rapidly lead to stack exhaustion. This is essentially
- # the same problem as http://twistedmatrix.com/trac/ticket/9304.
- #
- # In order to break the cycle, we add a cheeky sleep(0) here to
- # ensure that we fall back to the reactor between each iteration.
- #
- # This needs to happen while we hold the lock. We could put it on the
- # exit path, but that would slow down the uncontended case.
- try:
- await self._clock.sleep(0)
- except CancelledError:
- self._release_lock(key, entry)
- raise
- return entry
- def _release_lock(self, key: Hashable, entry: _LinearizerEntry) -> None:
- """Releases a held linearizer lock."""
- logger.debug("Releasing linearizer lock %r for key %r", self.name, key)
- # We've finished executing so check if there are any things
- # blocked waiting to execute and start one of them
- entry.count -= 1
- if entry.deferreds:
- (next_def, _) = entry.deferreds.popitem(last=False)
- # we need to run the next thing in the sentinel context.
- with PreserveLoggingContext():
- next_def.callback(None)
- elif entry.count == 0:
- # We were the last thing for this key: remove it from the
- # map.
- del self.key_to_defer[key]
- class ReadWriteLock:
- """An async read write lock.
- Example:
- async with read_write_lock.read("test_key"):
- # do some work
- """
- # IMPLEMENTATION NOTES
- #
- # We track the most recent queued reader and writer deferreds (which get
- # resolved when they release the lock).
- #
- # Read: We know its safe to acquire a read lock when the latest writer has
- # been resolved. The new reader is appended to the list of latest readers.
- #
- # Write: We know its safe to acquire the write lock when both the latest
- # writers and readers have been resolved. The new writer replaces the latest
- # writer.
- def __init__(self) -> None:
- # Latest readers queued
- self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
- # Latest writer queued
- self.key_to_current_writer: Dict[str, defer.Deferred] = {}
- def read(self, key: str) -> AsyncContextManager:
- @asynccontextmanager
- async def _ctx_manager() -> AsyncIterator[None]:
- new_defer: "defer.Deferred[None]" = defer.Deferred()
- curr_readers = self.key_to_current_readers.setdefault(key, set())
- curr_writer = self.key_to_current_writer.get(key, None)
- curr_readers.add(new_defer)
- try:
- # We wait for the latest writer to finish writing. We can safely ignore
- # any existing readers... as they're readers.
- # May raise a `CancelledError` if the `Deferred` wrapping us is
- # cancelled. The `Deferred` we are waiting on must not be cancelled,
- # since we do not own it.
- if curr_writer:
- await make_deferred_yieldable(stop_cancellation(curr_writer))
- yield
- finally:
- with PreserveLoggingContext():
- new_defer.callback(None)
- self.key_to_current_readers.get(key, set()).discard(new_defer)
- return _ctx_manager()
- def write(self, key: str) -> AsyncContextManager:
- @asynccontextmanager
- async def _ctx_manager() -> AsyncIterator[None]:
- new_defer: "defer.Deferred[None]" = defer.Deferred()
- curr_readers = self.key_to_current_readers.get(key, set())
- curr_writer = self.key_to_current_writer.get(key, None)
- # We wait on all latest readers and writer.
- to_wait_on = list(curr_readers)
- if curr_writer:
- to_wait_on.append(curr_writer)
- # We can clear the list of current readers since `new_defer` waits
- # for them to finish.
- curr_readers.clear()
- self.key_to_current_writer[key] = new_defer
- to_wait_on_defer = defer.gatherResults(to_wait_on)
- try:
- # Wait for all current readers and the latest writer to finish.
- # May raise a `CancelledError` immediately after the wait if the
- # `Deferred` wrapping us is cancelled. We must only release the lock
- # once we have acquired it, hence the use of `delay_cancellation`
- # rather than `stop_cancellation`.
- await make_deferred_yieldable(delay_cancellation(to_wait_on_defer))
- yield
- finally:
- # Release the lock.
- with PreserveLoggingContext():
- new_defer.callback(None)
- # `self.key_to_current_writer[key]` may be missing if there was another
- # writer waiting for us and it completed entirely within the
- # `new_defer.callback()` call above.
- if self.key_to_current_writer.get(key) == new_defer:
- self.key_to_current_writer.pop(key)
- return _ctx_manager()
- def timeout_deferred(
- deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime
- ) -> "defer.Deferred[_T]":
- """The in built twisted `Deferred.addTimeout` fails to time out deferreds
- that have a canceller that throws exceptions. This method creates a new
- deferred that wraps and times out the given deferred, correctly handling
- the case where the given deferred's canceller throws.
- (See https://twistedmatrix.com/trac/ticket/9534)
- NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred.
- NOTE: the TimeoutError raised by the resultant deferred is
- twisted.internet.defer.TimeoutError, which is *different* to the built-in
- TimeoutError, as well as various other TimeoutErrors you might have imported.
- Args:
- deferred: The Deferred to potentially timeout.
- timeout: Timeout in seconds
- reactor: The twisted reactor to use
- Returns:
- A new Deferred, which will errback with defer.TimeoutError on timeout.
- """
- new_d: "defer.Deferred[_T]" = defer.Deferred()
- timed_out = [False]
- def time_it_out() -> None:
- timed_out[0] = True
- try:
- deferred.cancel()
- except Exception: # if we throw any exception it'll break time outs
- logger.exception("Canceller failed during timeout")
- # the cancel() call should have set off a chain of errbacks which
- # will have errbacked new_d, but in case it hasn't, errback it now.
- if not new_d.called:
- new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,)))
- delayed_call = reactor.callLater(timeout, time_it_out)
- def convert_cancelled(value: Failure) -> Failure:
- # if the original deferred was cancelled, and our timeout has fired, then
- # the reason it was cancelled was due to our timeout. Turn the CancelledError
- # into a TimeoutError.
- if timed_out[0] and value.check(CancelledError):
- raise defer.TimeoutError("Timed out after %gs" % (timeout,))
- return value
- deferred.addErrback(convert_cancelled)
- def cancel_timeout(result: _T) -> _T:
- # stop the pending call to cancel the deferred if it's been fired
- if delayed_call.active():
- delayed_call.cancel()
- return result
- deferred.addBoth(cancel_timeout)
- def success_cb(val: _T) -> None:
- if not new_d.called:
- new_d.callback(val)
- def failure_cb(val: Failure) -> None:
- if not new_d.called:
- new_d.errback(val)
- deferred.addCallbacks(success_cb, failure_cb)
- return new_d
- # This class can't be generic because it uses slots with attrs.
- # See: https://github.com/python-attrs/attrs/issues/313
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class DoneAwaitable: # should be: Generic[R]
- """Simple awaitable that returns the provided value."""
- value: Any # should be: R
- def __await__(self) -> Any:
- return self
- def __iter__(self) -> "DoneAwaitable":
- return self
- def __next__(self) -> None:
- raise StopIteration(self.value)
- def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
- """Convert a value to an awaitable if not already an awaitable."""
- if inspect.isawaitable(value):
- assert isinstance(value, Awaitable)
- return value
- return DoneAwaitable(value)
- def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
- """Prevent a `Deferred` from being cancelled by wrapping it in another `Deferred`.
- Args:
- deferred: The `Deferred` to protect against cancellation. Must not follow the
- Synapse logcontext rules.
- Returns:
- A new `Deferred`, which will contain the result of the original `Deferred`.
- The new `Deferred` will not propagate cancellation through to the original.
- When cancelled, the new `Deferred` will fail with a `CancelledError`.
- The new `Deferred` will not follow the Synapse logcontext rules and should be
- wrapped with `make_deferred_yieldable`.
- """
- new_deferred: "defer.Deferred[T]" = defer.Deferred()
- deferred.chainDeferred(new_deferred)
- return new_deferred
- @overload
- def delay_cancellation(awaitable: "defer.Deferred[T]") -> "defer.Deferred[T]":
- ...
- @overload
- def delay_cancellation(awaitable: Coroutine[Any, Any, T]) -> "defer.Deferred[T]":
- ...
- @overload
- def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
- ...
- def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
- """Delay cancellation of a coroutine or `Deferred` awaitable until it resolves.
- Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
- resolve with a `CancelledError` until the original awaitable resolves.
- Args:
- deferred: The coroutine or `Deferred` to protect against cancellation. May
- optionally follow the Synapse logcontext rules.
- Returns:
- A new `Deferred`, which will contain the result of the original coroutine or
- `Deferred`. The new `Deferred` will not propagate cancellation through to the
- original coroutine or `Deferred`.
- When cancelled, the new `Deferred` will wait until the original coroutine or
- `Deferred` resolves before failing with a `CancelledError`.
- The new `Deferred` will follow the Synapse logcontext rules if `awaitable`
- follows the Synapse logcontext rules. Otherwise the new `Deferred` should be
- wrapped with `make_deferred_yieldable`.
- """
- # First, convert the awaitable into a `Deferred`.
- if isinstance(awaitable, defer.Deferred):
- deferred = awaitable
- elif asyncio.iscoroutine(awaitable):
- # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
- # type-checking, but we'd need Twisted >= 21.2.
- deferred = defer.ensureDeferred(awaitable)
- else:
- # We have no idea what to do with this awaitable.
- # We assume it's already resolved, such as `DoneAwaitable`s or `Future`s from
- # `make_awaitable`, and let the caller `await` it normally.
- return awaitable
- def handle_cancel(new_deferred: "defer.Deferred[T]") -> None:
- # before the new deferred is cancelled, we `pause` it to stop the cancellation
- # propagating. we then `unpause` it once the wrapped deferred completes, to
- # propagate the exception.
- new_deferred.pause()
- new_deferred.errback(Failure(CancelledError()))
- deferred.addBoth(lambda _: new_deferred.unpause())
- new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel)
- deferred.chainDeferred(new_deferred)
- return new_deferred
- class AwakenableSleeper:
- """Allows explicitly waking up deferreds related to an entity that are
- currently sleeping.
- """
- def __init__(self, reactor: IReactorTime) -> None:
- self._streams: Dict[str, Set[defer.Deferred[None]]] = {}
- self._reactor = reactor
- def wake(self, name: str) -> None:
- """Wake everything related to `name` that is currently sleeping."""
- stream_set = self._streams.pop(name, set())
- for deferred in stream_set:
- try:
- with PreserveLoggingContext():
- deferred.callback(None)
- except Exception:
- pass
- async def sleep(self, name: str, delay_ms: int) -> None:
- """Sleep for the given number of milliseconds, or return if the given
- `name` is explicitly woken up.
- """
- # Create a deferred that gets called in N seconds
- sleep_deferred: "defer.Deferred[None]" = defer.Deferred()
- call = self._reactor.callLater(delay_ms / 1000, sleep_deferred.callback, None)
- # Create a deferred that will get called if `wake` is called with
- # the same `name`.
- stream_set = self._streams.setdefault(name, set())
- notify_deferred: "defer.Deferred[None]" = defer.Deferred()
- stream_set.add(notify_deferred)
- try:
- # Wait for either the delay or for `wake` to be called.
- await make_deferred_yieldable(
- defer.DeferredList(
- [sleep_deferred, notify_deferred],
- fireOnOneCallback=True,
- fireOnOneErrback=True,
- consumeErrors=True,
- )
- )
- finally:
- # Clean up the state
- curr_stream_set = self._streams.get(name)
- if curr_stream_set is not None:
- curr_stream_set.discard(notify_deferred)
- if len(curr_stream_set) == 0:
- self._streams.pop(name)
- # Cancel the sleep if we were woken up
- if call.active():
- call.cancel()
|