Browse Source

Rate limiter: Pull out some small methods

David Robertson 1 year ago
parent
commit
c2e3025b33
1 changed files with 25 additions and 6 deletions
  1. 25 6
      synapse/api/ratelimiting.py

+ 25 - 6
synapse/api/ratelimiting.py

@@ -75,6 +75,29 @@ class Ratelimiter:
         #   * The rate_hz (leak rate) of this particular bucket.
         self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
 
+    def _get_key(
+        self, requester: Optional[Requester], key: Optional[Hashable]
+    ) -> Hashable:
+        """Use the requester's MXID as a fallback key if no key is provided.
+
+        Pulled out so that `can_do_action` and `record_action` are consistent.
+        """
+        if key is None:
+            if not requester:
+                raise ValueError("Must supply at least one of `requester` or `key`")
+
+            key = requester.user.to_string()
+        return key
+
+    def _get_action_counts(
+        self, key: Hashable, time_now_s: float
+    ) -> Tuple[float, float, float]:
+        """Retrieve the action counts, with a fallback representing an empty bucket.
+
+        Pulled out so that `can_do_action` and `record_action` are consistent.
+        """
+        return self.actions.get(key, (0.0, time_now_s, 0.0))
+
     async def can_do_action(
         self,
         requester: Optional[Requester],
@@ -114,11 +137,7 @@ class Ratelimiter:
                 * The reactor timestamp for when the action can be performed next.
                   -1 if rate_hz is less than or equal to zero
         """
-        if key is None:
-            if not requester:
-                raise ValueError("Must supply at least one of `requester` or `key`")
-
-            key = requester.user.to_string()
+        key = self._get_key(requester, key)
 
         if requester:
             # Disable rate limiting of users belonging to any AS that is configured
@@ -147,7 +166,7 @@ class Ratelimiter:
         self._prune_message_counts(time_now_s)
 
         # Check if there is an existing count entry for this key
-        action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, 0.0))
+        action_count, time_start, _ = self._get_action_counts(key, time_now_s)
 
         # Check whether performing another action is allowed
         time_delta = time_now_s - time_start