|
@@ -12,52 +12,60 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
|
|
|
+import enum
|
|
|
import logging
|
|
|
import time
|
|
|
-from typing import Any, Tuple
|
|
|
+from typing import Callable, Dict, Generic, Tuple, TypeVar, Union
|
|
|
|
|
|
import attr
|
|
|
-from sortedcontainers import SortedList # type: ignore
|
|
|
+from sortedcontainers import SortedList
|
|
|
+from typing_extensions import Literal
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
-SENTINEL = object()
|
|
|
|
|
|
+class Sentinel(enum.Enum):
|
|
|
+ token = enum.auto()
|
|
|
|
|
|
-class TTLCache:
|
|
|
+
|
|
|
+K = TypeVar("K")
|
|
|
+V = TypeVar("V")
|
|
|
+
|
|
|
+
|
|
|
+class TTLCache(Generic[K, V]):
|
|
|
"""A key/value cache implementation where each entry has its own TTL"""
|
|
|
|
|
|
- def __init__(self, cache_name, timer=time.time):
|
|
|
- # map from key to _CacheEntry
|
|
|
- self._data = {}
|
|
|
+ def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
|
|
|
+ self._data: Dict[K, _CacheEntry[K, V]] = {}
|
|
|
|
|
|
# the _CacheEntries, sorted by expiry time
|
|
|
- self._expiry_list = SortedList()
|
|
|
+ self._expiry_list: SortedList[_CacheEntry] = SortedList()
|
|
|
|
|
|
self._timer = timer
|
|
|
|
|
|
- def set(self, key, value, ttl: float) -> None:
|
|
|
+ def set(self, key: K, value: V, ttl: float) -> None:
|
|
|
"""Add/update an entry in the cache
|
|
|
|
|
|
:param key: Key for this entry.
|
|
|
|
|
|
:param value: Value for this entry.
|
|
|
|
|
|
- :param paramttl: TTL for this entry, in seconds.
|
|
|
- :type paramttl: float
|
|
|
+ :param ttl: TTL for this entry, in seconds.
|
|
|
"""
|
|
|
expiry = self._timer() + ttl
|
|
|
|
|
|
self.expire()
|
|
|
- e = self._data.pop(key, SENTINEL)
|
|
|
- if e != SENTINEL:
|
|
|
+ e = self._data.pop(key, Sentinel.token)
|
|
|
+ if e != Sentinel.token:
|
|
|
self._expiry_list.remove(e)
|
|
|
|
|
|
entry = _CacheEntry(expiry_time=expiry, key=key, value=value)
|
|
|
self._data[key] = entry
|
|
|
self._expiry_list.add(entry)
|
|
|
|
|
|
- def get(self, key, default=SENTINEL):
|
|
|
+ def get(
|
|
|
+ self, key: K, default: Union[V, Literal[Sentinel.token]] = Sentinel.token
|
|
|
+ ) -> V:
|
|
|
"""Get a value from the cache
|
|
|
|
|
|
:param key: The key to look up.
|
|
@@ -67,14 +75,14 @@ class TTLCache:
|
|
|
:returns a value from the cache, or the default.
|
|
|
"""
|
|
|
self.expire()
|
|
|
- e = self._data.get(key, SENTINEL)
|
|
|
- if e == SENTINEL:
|
|
|
- if default == SENTINEL:
|
|
|
+ e = self._data.get(key, Sentinel.token)
|
|
|
+ if e is Sentinel.token:
|
|
|
+ if default is Sentinel.token:
|
|
|
raise KeyError(key)
|
|
|
return default
|
|
|
return e.value
|
|
|
|
|
|
- def get_with_expiry(self, key) -> Tuple[Any, float]:
|
|
|
+ def get_with_expiry(self, key: K) -> Tuple[V, float]:
|
|
|
"""Get a value, and its expiry time, from the cache
|
|
|
|
|
|
:param key: key to look up
|
|
@@ -92,7 +100,9 @@ class TTLCache:
|
|
|
raise
|
|
|
return e.value, e.expiry_time
|
|
|
|
|
|
- def pop(self, key, default=SENTINEL):
|
|
|
+ def pop(
|
|
|
+ self, key: K, default: Union[V, Literal[Sentinel.token]] = Sentinel.token
|
|
|
+ ) -> V:
|
|
|
"""Remove a value from the cache
|
|
|
|
|
|
If key is in the cache, remove it and return its value, else return default.
|
|
@@ -105,28 +115,28 @@ class TTLCache:
|
|
|
:returns a value from the cache, or the default
|
|
|
"""
|
|
|
self.expire()
|
|
|
- e = self._data.pop(key, SENTINEL)
|
|
|
- if e == SENTINEL:
|
|
|
- if default == SENTINEL:
|
|
|
+ e = self._data.pop(key, Sentinel.token)
|
|
|
+ if e is Sentinel.token:
|
|
|
+ if default == Sentinel.token:
|
|
|
raise KeyError(key)
|
|
|
return default
|
|
|
self._expiry_list.remove(e)
|
|
|
return e.value
|
|
|
|
|
|
- def __getitem__(self, key):
|
|
|
+ def __getitem__(self, key: K) -> V:
|
|
|
return self.get(key)
|
|
|
|
|
|
- def __delitem__(self, key):
|
|
|
+ def __delitem__(self, key: K) -> None:
|
|
|
self.pop(key)
|
|
|
|
|
|
- def __contains__(self, key):
|
|
|
+ def __contains__(self, key: K) -> bool:
|
|
|
return key in self._data
|
|
|
|
|
|
- def __len__(self):
|
|
|
+ def __len__(self) -> int:
|
|
|
self.expire()
|
|
|
return len(self._data)
|
|
|
|
|
|
- def expire(self):
|
|
|
+ def expire(self) -> None:
|
|
|
"""Run the expiry on the cache. Any entries whose expiry times are due will
|
|
|
be removed
|
|
|
"""
|
|
@@ -139,11 +149,11 @@ class TTLCache:
|
|
|
del self._expiry_list[0]
|
|
|
|
|
|
|
|
|
-@attr.s(frozen=True, slots=True)
|
|
|
-class _CacheEntry:
|
|
|
+@attr.s(frozen=True)
|
|
|
+class _CacheEntry(Generic[K, V]):
|
|
|
"""TTLCache entry"""
|
|
|
|
|
|
# expiry_time is the first attribute, so that entries are sorted by expiry.
|
|
|
- expiry_time = attr.ib()
|
|
|
- key = attr.ib()
|
|
|
- value = attr.ib()
|
|
|
+ expiry_time: float = attr.ib()
|
|
|
+ key: K = attr.ib()
|
|
|
+ value: V = attr.ib()
|