|
@@ -15,12 +15,30 @@
|
|
|
|
|
|
import threading
|
|
|
from functools import wraps
|
|
|
-from typing import Callable, Optional, Type, Union
|
|
|
+from typing import (
|
|
|
+ Any,
|
|
|
+ Callable,
|
|
|
+ Generic,
|
|
|
+ Iterable,
|
|
|
+ Optional,
|
|
|
+ Type,
|
|
|
+ TypeVar,
|
|
|
+ Union,
|
|
|
+ cast,
|
|
|
+ overload,
|
|
|
+)
|
|
|
+
|
|
|
+from typing_extensions import Literal
|
|
|
|
|
|
from synapse.config import cache as cache_config
|
|
|
from synapse.util.caches import CacheMetric, register_cache
|
|
|
from synapse.util.caches.treecache import TreeCache
|
|
|
|
|
|
+T = TypeVar("T")
|
|
|
+FT = TypeVar("FT", bound=Callable[..., Any])
|
|
|
+KT = TypeVar("KT")
|
|
|
+VT = TypeVar("VT")
|
|
|
+
|
|
|
|
|
|
def enumerate_leaves(node, depth):
|
|
|
if depth == 0:
|
|
@@ -42,7 +60,7 @@ class _Node:
|
|
|
self.callbacks = callbacks
|
|
|
|
|
|
|
|
|
-class LruCache:
|
|
|
+class LruCache(Generic[KT, VT]):
|
|
|
"""
|
|
|
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
|
|
|
|
|
@@ -128,13 +146,13 @@ class LruCache:
|
|
|
if metrics:
|
|
|
metrics.inc_evictions(evicted_len)
|
|
|
|
|
|
- def synchronized(f):
|
|
|
+ def synchronized(f: FT) -> FT:
|
|
|
@wraps(f)
|
|
|
def inner(*args, **kwargs):
|
|
|
with lock:
|
|
|
return f(*args, **kwargs)
|
|
|
|
|
|
- return inner
|
|
|
+ return cast(FT, inner)
|
|
|
|
|
|
cached_cache_len = [0]
|
|
|
if size_callback is not None:
|
|
@@ -188,8 +206,31 @@ class LruCache:
|
|
|
node.callbacks.clear()
|
|
|
return deleted_len
|
|
|
|
|
|
+ @overload
|
|
|
+ def cache_get(
|
|
|
+ key: KT,
|
|
|
+ default: Literal[None] = None,
|
|
|
+ callbacks: Iterable[Callable[[], None]] = ...,
|
|
|
+ update_metrics: bool = ...,
|
|
|
+ ) -> Optional[VT]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload
|
|
|
+ def cache_get(
|
|
|
+ key: KT,
|
|
|
+ default: T,
|
|
|
+ callbacks: Iterable[Callable[[], None]] = ...,
|
|
|
+ update_metrics: bool = ...,
|
|
|
+ ) -> Union[T, VT]:
|
|
|
+ ...
|
|
|
+
|
|
|
@synchronized
|
|
|
- def cache_get(key, default=None, callbacks=[], update_metrics=True):
|
|
|
+ def cache_get(
|
|
|
+ key: KT,
|
|
|
+ default=None,
|
|
|
+ callbacks: Iterable[Callable[[], None]] = [],
|
|
|
+ update_metrics: bool = True,
|
|
|
+ ):
|
|
|
node = cache.get(key, None)
|
|
|
if node is not None:
|
|
|
move_node_to_front(node)
|
|
@@ -203,7 +244,7 @@ class LruCache:
|
|
|
return default
|
|
|
|
|
|
@synchronized
|
|
|
- def cache_set(key, value, callbacks=[]):
|
|
|
+ def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
|
|
|
node = cache.get(key, None)
|
|
|
if node is not None:
|
|
|
# We sometimes store large objects, e.g. dicts, which cause
|
|
@@ -232,7 +273,7 @@ class LruCache:
|
|
|
evict()
|
|
|
|
|
|
@synchronized
|
|
|
- def cache_set_default(key, value):
|
|
|
+ def cache_set_default(key: KT, value: VT) -> VT:
|
|
|
node = cache.get(key, None)
|
|
|
if node is not None:
|
|
|
return node.value
|
|
@@ -241,8 +282,16 @@ class LruCache:
|
|
|
evict()
|
|
|
return value
|
|
|
|
|
|
+ @overload
|
|
|
+ def cache_pop(key: KT, default: Literal[None] = None) -> Union[None, VT]:
|
|
|
+ ...
|
|
|
+
|
|
|
+ @overload
|
|
|
+ def cache_pop(key: KT, default: T) -> Union[T, VT]:
|
|
|
+ ...
|
|
|
+
|
|
|
@synchronized
|
|
|
- def cache_pop(key, default=None):
|
|
|
+ def cache_pop(key: KT, default=None):
|
|
|
node = cache.get(key, None)
|
|
|
if node:
|
|
|
delete_node(node)
|
|
@@ -252,18 +301,18 @@ class LruCache:
|
|
|
return default
|
|
|
|
|
|
@synchronized
|
|
|
- def cache_del_multi(key):
|
|
|
+ def cache_del_multi(key: KT) -> None:
|
|
|
"""
|
|
|
This will only work if constructed with cache_type=TreeCache
|
|
|
"""
|
|
|
popped = cache.pop(key)
|
|
|
if popped is None:
|
|
|
return
|
|
|
- for leaf in enumerate_leaves(popped, keylen - len(key)):
|
|
|
+ for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
|
|
|
delete_node(leaf)
|
|
|
|
|
|
@synchronized
|
|
|
- def cache_clear():
|
|
|
+ def cache_clear() -> None:
|
|
|
list_root.next_node = list_root
|
|
|
list_root.prev_node = list_root
|
|
|
for node in cache.values():
|
|
@@ -274,7 +323,7 @@ class LruCache:
|
|
|
cached_cache_len[0] = 0
|
|
|
|
|
|
@synchronized
|
|
|
- def cache_contains(key):
|
|
|
+ def cache_contains(key: KT) -> bool:
|
|
|
return key in cache
|
|
|
|
|
|
self.sentinel = object()
|