|
@@ -75,6 +75,29 @@ class Ratelimiter:
|
|
|
# * The rate_hz (leak rate) of this particular bucket.
|
|
|
self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
|
|
|
|
|
|
+ def _get_key(
|
|
|
+ self, requester: Optional[Requester], key: Optional[Hashable]
|
|
|
+ ) -> Hashable:
|
|
|
+ """Use the requester's MXID as a fallback key if no key is provided.
|
|
|
+
|
|
|
+ Pulled out so that `can_do_action` and `record_action` are consistent.
|
|
|
+ """
|
|
|
+ if key is None:
|
|
|
+ if not requester:
|
|
|
+ raise ValueError("Must supply at least one of `requester` or `key`")
|
|
|
+
|
|
|
+ key = requester.user.to_string()
|
|
|
+ return key
|
|
|
+
|
|
|
+ def _get_action_counts(
|
|
|
+ self, key: Hashable, time_now_s: float
|
|
|
+ ) -> Tuple[float, float, float]:
|
|
|
+ """Retrieve the action counts, with a fallback representing an empty bucket.
|
|
|
+
|
|
|
+ Pulled out so that `can_do_action` and `record_action` are consistent.
|
|
|
+ """
|
|
|
+ return self.actions.get(key, (0.0, time_now_s, 0.0))
|
|
|
+
|
|
|
async def can_do_action(
|
|
|
self,
|
|
|
requester: Optional[Requester],
|
|
@@ -114,11 +137,7 @@ class Ratelimiter:
|
|
|
* The reactor timestamp for when the action can be performed next.
|
|
|
-1 if rate_hz is less than or equal to zero
|
|
|
"""
|
|
|
- if key is None:
|
|
|
- if not requester:
|
|
|
- raise ValueError("Must supply at least one of `requester` or `key`")
|
|
|
-
|
|
|
- key = requester.user.to_string()
|
|
|
+ key = self._get_key(requester, key)
|
|
|
|
|
|
if requester:
|
|
|
# Disable rate limiting of users belonging to any AS that is configured
|
|
@@ -147,7 +166,7 @@ class Ratelimiter:
|
|
|
self._prune_message_counts(time_now_s)
|
|
|
|
|
|
# Check if there is an existing count entry for this key
|
|
|
- action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
|
|
|
+ action_count, time_start, _ = self._get_action_counts(key, time_now_s)
|
|
|
|
|
|
# Check whether performing another action is allowed
|
|
|
time_delta = time_now_s - time_start
|