ratelimiting.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. from synapse.api.errors import LimitExceededError
  16. class Ratelimiter(object):
  17. """
  18. Ratelimit message sending by user.
  19. """
  20. def __init__(self):
  21. self.message_counts = collections.OrderedDict()
  22. def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
  23. """Can the entity (e.g. user or IP address) perform the action?
  24. Args:
  25. key: The key we should use when rate limiting. Can be a user ID
  26. (when sending events), an IP address, etc.
  27. time_now_s: The time now.
  28. rate_hz: The long term number of messages a user can send in a
  29. second.
  30. burst_count: How many messages the user can send before being
  31. limited.
  32. update (bool): Whether to update the message rates or not. This is
  33. useful to check if a message would be allowed to be sent before
  34. its ready to be actually sent.
  35. Returns:
  36. A pair of a bool indicating if they can send a message now and a
  37. time in seconds of when they can next send a message.
  38. """
  39. self.prune_message_counts(time_now_s)
  40. message_count, time_start, _ignored = self.message_counts.get(
  41. key, (0.0, time_now_s, None)
  42. )
  43. time_delta = time_now_s - time_start
  44. sent_count = message_count - time_delta * rate_hz
  45. if sent_count < 0:
  46. allowed = True
  47. time_start = time_now_s
  48. message_count = 1.0
  49. elif sent_count > burst_count - 1.0:
  50. allowed = False
  51. else:
  52. allowed = True
  53. message_count += 1
  54. if update:
  55. self.message_counts[key] = (message_count, time_start, rate_hz)
  56. if rate_hz > 0:
  57. time_allowed = time_start + (message_count - burst_count + 1) / rate_hz
  58. if time_allowed < time_now_s:
  59. time_allowed = time_now_s
  60. else:
  61. time_allowed = -1
  62. return allowed, time_allowed
  63. def prune_message_counts(self, time_now_s):
  64. for key in list(self.message_counts.keys()):
  65. message_count, time_start, rate_hz = self.message_counts[key]
  66. time_delta = time_now_s - time_start
  67. if message_count - time_delta * rate_hz > 0:
  68. break
  69. else:
  70. del self.message_counts[key]
  71. def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True):
  72. allowed, time_allowed = self.can_do_action(
  73. key, time_now_s, rate_hz, burst_count, update
  74. )
  75. if not allowed:
  76. raise LimitExceededError(
  77. retry_after_ms=int(1000 * (time_allowed - time_now_s))
  78. )