|
@@ -19,12 +19,15 @@ import logging
|
|
|
from typing import (
|
|
|
Any,
|
|
|
Callable,
|
|
|
+ Dict,
|
|
|
Generic,
|
|
|
+ Hashable,
|
|
|
Iterable,
|
|
|
Mapping,
|
|
|
Optional,
|
|
|
Sequence,
|
|
|
Tuple,
|
|
|
+ Type,
|
|
|
TypeVar,
|
|
|
Union,
|
|
|
cast,
|
|
@@ -32,6 +35,7 @@ from typing import (
|
|
|
from weakref import WeakValueDictionary
|
|
|
|
|
|
from twisted.internet import defer
|
|
|
+from twisted.python.failure import Failure
|
|
|
|
|
|
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
|
|
from synapse.util import unwrapFirstError
|
|
@@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
|
|
|
|
|
|
|
|
|
class _CacheDescriptorBase:
|
|
|
- def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ orig: Callable[..., Any],
|
|
|
+ num_args: Optional[int],
|
|
|
+ cache_context: bool = False,
|
|
|
+ ):
|
|
|
self.orig = orig
|
|
|
|
|
|
arg_spec = inspect.getfullargspec(orig)
|
|
@@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- orig,
|
|
|
+ orig: Callable[..., Any],
|
|
|
max_entries: int = 1000,
|
|
|
cache_context: bool = False,
|
|
|
):
|
|
|
super().__init__(orig, num_args=None, cache_context=cache_context)
|
|
|
self.max_entries = max_entries
|
|
|
|
|
|
- def __get__(self, obj, owner):
|
|
|
+ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
|
|
cache: LruCache[CacheKey, Any] = LruCache(
|
|
|
cache_name=self.orig.__name__,
|
|
|
max_size=self.max_entries,
|
|
@@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
|
|
sentinel = LruCacheDescriptor._Sentinel.sentinel
|
|
|
|
|
|
@functools.wraps(self.orig)
|
|
|
- def _wrapped(*args, **kwargs):
|
|
|
+ def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
|
|
callbacks = (invalidate_callback,) if invalidate_callback else ()
|
|
|
|
|
@@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|
|
return r1 + r2
|
|
|
|
|
|
Args:
|
|
|
- num_args (int): number of positional arguments (excluding ``self`` and
|
|
|
+ num_args: number of positional arguments (excluding ``self`` and
|
|
|
``cache_context``) to use as cache keys. Defaults to all named
|
|
|
args of the function.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
- orig,
|
|
|
- max_entries=1000,
|
|
|
- num_args=None,
|
|
|
- tree=False,
|
|
|
- cache_context=False,
|
|
|
- iterable=False,
|
|
|
+ orig: Callable[..., Any],
|
|
|
+ max_entries: int = 1000,
|
|
|
+ num_args: Optional[int] = None,
|
|
|
+ tree: bool = False,
|
|
|
+ cache_context: bool = False,
|
|
|
+ iterable: bool = False,
|
|
|
prune_unread_entries: bool = True,
|
|
|
):
|
|
|
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
|
@@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|
|
self.iterable = iterable
|
|
|
self.prune_unread_entries = prune_unread_entries
|
|
|
|
|
|
- def __get__(self, obj, owner):
|
|
|
+ def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
|
|
|
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
|
|
name=self.orig.__name__,
|
|
|
max_entries=self.max_entries,
|
|
@@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
|
|
get_cache_key = self.cache_key_builder
|
|
|
|
|
|
@functools.wraps(self.orig)
|
|
|
- def _wrapped(*args, **kwargs):
|
|
|
+ def _wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
|
# If we're passed a cache_context then we'll want to call its invalidate()
|
|
|
# whenever we are invalidated
|
|
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
|
@@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|
|
of results.
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, orig, cached_method_name, list_name, num_args=None):
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ orig: Callable[..., Any],
|
|
|
+ cached_method_name: str,
|
|
|
+ list_name: str,
|
|
|
+ num_args: Optional[int] = None,
|
|
|
+ ):
|
|
|
"""
|
|
|
Args:
|
|
|
- orig (function)
|
|
|
- cached_method_name (str): The name of the cached method.
|
|
|
- list_name (str): Name of the argument which is the bulk lookup list
|
|
|
- num_args (int): number of positional arguments (excluding ``self``,
|
|
|
+ orig
|
|
|
+ cached_method_name: The name of the cached method.
|
|
|
+ list_name: Name of the argument which is the bulk lookup list
|
|
|
+ num_args: number of positional arguments (excluding ``self``,
|
|
|
but including list_name) to use as cache keys. Defaults to all
|
|
|
named args of the function.
|
|
|
"""
|
|
@@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|
|
% (self.list_name, cached_method_name)
|
|
|
)
|
|
|
|
|
|
- def __get__(self, obj, objtype=None):
|
|
|
+ def __get__(
|
|
|
+ self, obj: Optional[Any], objtype: Optional[Type] = None
|
|
|
+ ) -> Callable[..., Any]:
|
|
|
cached_method = getattr(obj, self.cached_method_name)
|
|
|
cache: DeferredCache[CacheKey, Any] = cached_method.cache
|
|
|
num_args = cached_method.num_args
|
|
|
|
|
|
@functools.wraps(self.orig)
|
|
|
- def wrapped(*args, **kwargs):
|
|
|
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
|
|
|
# If we're passed a cache_context then we'll want to call its
|
|
|
# invalidate() whenever we are invalidated
|
|
|
invalidate_callback = kwargs.pop("on_invalidate", None)
|
|
@@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|
|
|
|
|
results = {}
|
|
|
|
|
|
- def update_results_dict(res, arg):
|
|
|
+ def update_results_dict(res: Any, arg: Hashable) -> None:
|
|
|
results[arg] = res
|
|
|
|
|
|
# list of deferreds to wait for
|
|
@@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|
|
# otherwise a tuple is used.
|
|
|
if num_args == 1:
|
|
|
|
|
|
- def arg_to_cache_key(arg):
|
|
|
+ def arg_to_cache_key(arg: Hashable) -> Hashable:
|
|
|
return arg
|
|
|
|
|
|
else:
|
|
|
keylist = list(keyargs)
|
|
|
|
|
|
- def arg_to_cache_key(arg):
|
|
|
+ def arg_to_cache_key(arg: Hashable) -> Hashable:
|
|
|
keylist[self.list_pos] = arg
|
|
|
return tuple(keylist)
|
|
|
|
|
@@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|
|
key = arg_to_cache_key(arg)
|
|
|
cache.set(key, deferred, callback=invalidate_callback)
|
|
|
|
|
|
- def complete_all(res):
|
|
|
+ def complete_all(res: Dict[Hashable, Any]) -> None:
|
|
|
# the wrapped function has completed. It returns a
|
|
|
# a dict. We can now resolve the observable deferreds in
|
|
|
# the cache and update our own result map.
|
|
@@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
|
|
deferreds_map[e].callback(val)
|
|
|
results[e] = val
|
|
|
|
|
|
- def errback(f):
|
|
|
+ def errback(f: Failure) -> Failure:
|
|
|
# the wrapped function has failed. Invalidate any cache
|
|
|
# entries we're supposed to be populating, and fail
|
|
|
# their deferreds.
|