cached_call.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import enum
  15. from typing import Awaitable, Callable, Generic, Optional, TypeVar, Union
  16. from twisted.internet.defer import Deferred
  17. from twisted.python.failure import Failure
  18. from synapse.logging.context import make_deferred_yieldable, run_in_background
  19. TV = TypeVar("TV")
  20. class _Sentinel(enum.Enum):
  21. sentinel = object()
  22. class CachedCall(Generic[TV]):
  23. """A wrapper for asynchronous calls whose results should be shared
  24. This is useful for wrapping asynchronous functions, where there might be multiple
  25. callers, but we only want to call the underlying function once (and have the result
  26. returned to all callers).
  27. Similar results can be achieved via a lock of some form, but that typically requires
  28. more boilerplate (and ends up being less efficient).
  29. Correctly handles Synapse logcontexts (logs and resource usage for the underlying
  30. function are logged against the logcontext which is active when get() is first
  31. called).
  32. Example usage:
  33. _cached_val = CachedCall(_load_prop)
  34. async def handle_request() -> X:
  35. # We can call this multiple times, but it will result in a single call to
  36. # _load_prop().
  37. return await _cached_val.get()
  38. async def _load_prop() -> X:
  39. await difficult_operation()
  40. The implementation is deliberately single-shot (ie, once the call is initiated,
  41. there is no way to ask for it to be run). This keeps the implementation and
  42. semantics simple. If you want to make a new call, simply replace the whole
  43. CachedCall object.
  44. """
  45. __slots__ = ["_callable", "_deferred", "_result"]
  46. def __init__(self, f: Callable[[], Awaitable[TV]]):
  47. """
  48. Args:
  49. f: The underlying function. Only one call to this function will be alive
  50. at once (per instance of CachedCall)
  51. """
  52. self._callable: Optional[Callable[[], Awaitable[TV]]] = f
  53. self._deferred: Optional[Deferred] = None
  54. self._result: Union[_Sentinel, TV, Failure] = _Sentinel.sentinel
  55. async def get(self) -> TV:
  56. """Kick off the call if necessary, and return the result"""
  57. # Fire off the callable now if this is our first time
  58. if not self._deferred:
  59. assert self._callable is not None
  60. self._deferred = run_in_background(self._callable)
  61. # we will never need the callable again, so make sure it can be GCed
  62. self._callable = None
  63. # once the deferred completes, store the result. We cannot simply leave the
  64. # result in the deferred, since `awaiting` a deferred destroys its result.
  65. # (Also, if it's a Failure, GCing the deferred would log a critical error
  66. # about unhandled Failures)
  67. def got_result(r: Union[TV, Failure]) -> None:
  68. self._result = r
  69. self._deferred.addBoth(got_result)
  70. # TODO: consider cancellation semantics. Currently, if the call to get()
  71. # is cancelled, the underlying call will continue (and any future calls
  72. # will get the result/exception), which I think is *probably* ok, modulo
  73. # the fact the underlying call may be logged to a cancelled logcontext,
  74. # and any eventual exception may not be reported.
  75. # we can now await the deferred, and once it completes, return the result.
  76. if isinstance(self._result, _Sentinel):
  77. await make_deferred_yieldable(self._deferred)
  78. assert not isinstance(self._result, _Sentinel)
  79. if isinstance(self._result, Failure):
  80. self._result.raiseException()
  81. raise AssertionError("unexpected return from Failure.raiseException")
  82. return self._result
  83. class RetryOnExceptionCachedCall(Generic[TV]):
  84. """A wrapper around CachedCall which will retry the call if an exception is thrown
  85. This is used in much the same way as CachedCall, but adds some extra functionality
  86. so that if the underlying function throws an exception, then the next call to get()
  87. will initiate another call to the underlying function. (Any calls to get() which
  88. are already pending will raise the exception.)
  89. """
  90. slots = ["_cachedcall"]
  91. def __init__(self, f: Callable[[], Awaitable[TV]]):
  92. async def _wrapper() -> TV:
  93. try:
  94. return await f()
  95. except Exception:
  96. # the call raised an exception: replace the underlying CachedCall to
  97. # trigger another call next time get() is called
  98. self._cachedcall = CachedCall(_wrapper)
  99. raise
  100. self._cachedcall = CachedCall(_wrapper)
  101. async def get(self) -> TV:
  102. return await self._cachedcall.get()