ratelimiting.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. class Ratelimiter(object):
  16. """
  17. Ratelimit message sending by user.
  18. """
  19. def __init__(self):
  20. self.message_counts = collections.OrderedDict()
  21. def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
  22. """Can the user send a message?
  23. Args:
  24. user_id: The user sending a message.
  25. time_now_s: The time now.
  26. msg_rate_hz: The long term number of messages a user can send in a
  27. second.
  28. burst_count: How many messages the user can send before being
  29. limited.
  30. Returns:
  31. A pair of a bool indicating if they can send a message now and a
  32. time in seconds of when they can next send a message.
  33. """
  34. self.prune_message_counts(time_now_s)
  35. message_count, time_start, _ignored = self.message_counts.pop(
  36. user_id, (0., time_now_s, None),
  37. )
  38. time_delta = time_now_s - time_start
  39. sent_count = message_count - time_delta * msg_rate_hz
  40. if sent_count < 0:
  41. allowed = True
  42. time_start = time_now_s
  43. message_count = 1.
  44. elif sent_count > burst_count - 1.:
  45. allowed = False
  46. else:
  47. allowed = True
  48. message_count += 1
  49. self.message_counts[user_id] = (
  50. message_count, time_start, msg_rate_hz
  51. )
  52. if msg_rate_hz > 0:
  53. time_allowed = (
  54. time_start + (message_count - burst_count + 1) / msg_rate_hz
  55. )
  56. if time_allowed < time_now_s:
  57. time_allowed = time_now_s
  58. else:
  59. time_allowed = -1
  60. return allowed, time_allowed
  61. def prune_message_counts(self, time_now_s):
  62. for user_id in self.message_counts.keys():
  63. message_count, time_start, msg_rate_hz = (
  64. self.message_counts[user_id]
  65. )
  66. time_delta = time_now_s - time_start
  67. if message_count - time_delta * msg_rate_hz > 0:
  68. break
  69. else:
  70. del self.message_counts[user_id]