ratelimitutils.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet import defer
  16. from synapse.api.errors import LimitExceededError
  17. from synapse.util.async import sleep
  18. from synapse.util.logcontext import preserve_fn
  19. import collections
  20. import contextlib
  21. import logging
  22. logger = logging.getLogger(__name__)
  23. class FederationRateLimiter(object):
  24. def __init__(self, clock, window_size, sleep_limit, sleep_msec,
  25. reject_limit, concurrent_requests):
  26. """
  27. Args:
  28. clock (Clock)
  29. window_size (int): The window size in milliseconds.
  30. sleep_limit (int): The number of requests received in the last
  31. `window_size` milliseconds before we artificially start
  32. delaying processing of requests.
  33. sleep_msec (int): The number of milliseconds to delay processing
  34. of incoming requests by.
  35. reject_limit (int): The maximum number of requests that are can be
  36. queued for processing before we start rejecting requests with
  37. a 429 Too Many Requests response.
  38. concurrent_requests (int): The number of concurrent requests to
  39. process.
  40. """
  41. self.clock = clock
  42. self.window_size = window_size
  43. self.sleep_limit = sleep_limit
  44. self.sleep_msec = sleep_msec
  45. self.reject_limit = reject_limit
  46. self.concurrent_requests = concurrent_requests
  47. self.ratelimiters = {}
  48. def ratelimit(self, host):
  49. """Used to ratelimit an incoming request from given host
  50. Example usage:
  51. with rate_limiter.ratelimit(origin) as wait_deferred:
  52. yield wait_deferred
  53. # Handle request ...
  54. Args:
  55. host (str): Origin of incoming request.
  56. Returns:
  57. _PerHostRatelimiter
  58. """
  59. return self.ratelimiters.setdefault(
  60. host,
  61. _PerHostRatelimiter(
  62. clock=self.clock,
  63. window_size=self.window_size,
  64. sleep_limit=self.sleep_limit,
  65. sleep_msec=self.sleep_msec,
  66. reject_limit=self.reject_limit,
  67. concurrent_requests=self.concurrent_requests,
  68. )
  69. ).ratelimit()
  70. class _PerHostRatelimiter(object):
  71. def __init__(self, clock, window_size, sleep_limit, sleep_msec,
  72. reject_limit, concurrent_requests):
  73. self.clock = clock
  74. self.window_size = window_size
  75. self.sleep_limit = sleep_limit
  76. self.sleep_msec = sleep_msec
  77. self.reject_limit = reject_limit
  78. self.concurrent_requests = concurrent_requests
  79. self.sleeping_requests = set()
  80. self.ready_request_queue = collections.OrderedDict()
  81. self.current_processing = set()
  82. self.request_times = []
  83. @contextlib.contextmanager
  84. def ratelimit(self):
  85. # `contextlib.contextmanager` takes a generator and turns it into a
  86. # context manager. The generator should only yield once with a value
  87. # to be returned by manager.
  88. # Exceptions will be reraised at the yield.
  89. request_id = object()
  90. ret = self._on_enter(request_id)
  91. try:
  92. yield ret
  93. finally:
  94. self._on_exit(request_id)
  95. def _on_enter(self, request_id):
  96. time_now = self.clock.time_msec()
  97. self.request_times[:] = [
  98. r for r in self.request_times
  99. if time_now - r < self.window_size
  100. ]
  101. queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
  102. if queue_size > self.reject_limit:
  103. raise LimitExceededError(
  104. retry_after_ms=int(
  105. self.window_size / self.sleep_limit
  106. ),
  107. )
  108. self.request_times.append(time_now)
  109. def queue_request():
  110. if len(self.current_processing) > self.concurrent_requests:
  111. logger.debug("Ratelimit [%s]: Queue req", id(request_id))
  112. queue_defer = defer.Deferred()
  113. self.ready_request_queue[request_id] = queue_defer
  114. return queue_defer
  115. else:
  116. return defer.succeed(None)
  117. logger.debug(
  118. "Ratelimit [%s]: len(self.request_times)=%d",
  119. id(request_id), len(self.request_times),
  120. )
  121. if len(self.request_times) > self.sleep_limit:
  122. logger.debug(
  123. "Ratelimit [%s]: sleeping req",
  124. id(request_id),
  125. )
  126. ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
  127. self.sleeping_requests.add(request_id)
  128. def on_wait_finished(_):
  129. logger.debug(
  130. "Ratelimit [%s]: Finished sleeping",
  131. id(request_id),
  132. )
  133. self.sleeping_requests.discard(request_id)
  134. queue_defer = queue_request()
  135. return queue_defer
  136. ret_defer.addBoth(on_wait_finished)
  137. else:
  138. ret_defer = queue_request()
  139. def on_start(r):
  140. logger.debug(
  141. "Ratelimit [%s]: Processing req",
  142. id(request_id),
  143. )
  144. self.current_processing.add(request_id)
  145. return r
  146. def on_err(r):
  147. self.current_processing.discard(request_id)
  148. return r
  149. def on_both(r):
  150. # Ensure that we've properly cleaned up.
  151. self.sleeping_requests.discard(request_id)
  152. self.ready_request_queue.pop(request_id, None)
  153. return r
  154. ret_defer.addCallbacks(on_start, on_err)
  155. ret_defer.addBoth(on_both)
  156. return ret_defer
  157. def _on_exit(self, request_id):
  158. logger.debug(
  159. "Ratelimit [%s]: Processed req",
  160. id(request_id),
  161. )
  162. self.current_processing.discard(request_id)
  163. try:
  164. request_id, deferred = self.ready_request_queue.popitem()
  165. self.current_processing.add(request_id)
  166. deferred.callback(None)
  167. except KeyError:
  168. pass