ratelimitutils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import collections
  16. import contextlib
  17. import logging
  18. from twisted.internet import defer
  19. from synapse.api.errors import LimitExceededError
  20. from synapse.util.logcontext import (
  21. PreserveLoggingContext,
  22. make_deferred_yieldable,
  23. run_in_background,
  24. )
  25. logger = logging.getLogger(__name__)
  26. class FederationRateLimiter(object):
  27. def __init__(self, clock, window_size, sleep_limit, sleep_msec,
  28. reject_limit, concurrent_requests):
  29. """
  30. Args:
  31. clock (Clock)
  32. window_size (int): The window size in milliseconds.
  33. sleep_limit (int): The number of requests received in the last
  34. `window_size` milliseconds before we artificially start
  35. delaying processing of requests.
  36. sleep_msec (int): The number of milliseconds to delay processing
  37. of incoming requests by.
  38. reject_limit (int): The maximum number of requests that are can be
  39. queued for processing before we start rejecting requests with
  40. a 429 Too Many Requests response.
  41. concurrent_requests (int): The number of concurrent requests to
  42. process.
  43. """
  44. self.clock = clock
  45. self.window_size = window_size
  46. self.sleep_limit = sleep_limit
  47. self.sleep_msec = sleep_msec
  48. self.reject_limit = reject_limit
  49. self.concurrent_requests = concurrent_requests
  50. self.ratelimiters = {}
  51. def ratelimit(self, host):
  52. """Used to ratelimit an incoming request from given host
  53. Example usage:
  54. with rate_limiter.ratelimit(origin) as wait_deferred:
  55. yield wait_deferred
  56. # Handle request ...
  57. Args:
  58. host (str): Origin of incoming request.
  59. Returns:
  60. _PerHostRatelimiter
  61. """
  62. return self.ratelimiters.setdefault(
  63. host,
  64. _PerHostRatelimiter(
  65. clock=self.clock,
  66. window_size=self.window_size,
  67. sleep_limit=self.sleep_limit,
  68. sleep_msec=self.sleep_msec,
  69. reject_limit=self.reject_limit,
  70. concurrent_requests=self.concurrent_requests,
  71. )
  72. ).ratelimit()
  73. class _PerHostRatelimiter(object):
  74. def __init__(self, clock, window_size, sleep_limit, sleep_msec,
  75. reject_limit, concurrent_requests):
  76. self.clock = clock
  77. self.window_size = window_size
  78. self.sleep_limit = sleep_limit
  79. self.sleep_msec = sleep_msec
  80. self.reject_limit = reject_limit
  81. self.concurrent_requests = concurrent_requests
  82. self.sleeping_requests = set()
  83. self.ready_request_queue = collections.OrderedDict()
  84. self.current_processing = set()
  85. self.request_times = []
  86. @contextlib.contextmanager
  87. def ratelimit(self):
  88. # `contextlib.contextmanager` takes a generator and turns it into a
  89. # context manager. The generator should only yield once with a value
  90. # to be returned by manager.
  91. # Exceptions will be reraised at the yield.
  92. request_id = object()
  93. ret = self._on_enter(request_id)
  94. try:
  95. yield ret
  96. finally:
  97. self._on_exit(request_id)
  98. def _on_enter(self, request_id):
  99. time_now = self.clock.time_msec()
  100. self.request_times[:] = [
  101. r for r in self.request_times
  102. if time_now - r < self.window_size
  103. ]
  104. queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
  105. if queue_size > self.reject_limit:
  106. raise LimitExceededError(
  107. retry_after_ms=int(
  108. self.window_size / self.sleep_limit
  109. ),
  110. )
  111. self.request_times.append(time_now)
  112. def queue_request():
  113. if len(self.current_processing) > self.concurrent_requests:
  114. logger.debug("Ratelimit [%s]: Queue req", id(request_id))
  115. queue_defer = defer.Deferred()
  116. self.ready_request_queue[request_id] = queue_defer
  117. return queue_defer
  118. else:
  119. return defer.succeed(None)
  120. logger.debug(
  121. "Ratelimit [%s]: len(self.request_times)=%d",
  122. id(request_id), len(self.request_times),
  123. )
  124. if len(self.request_times) > self.sleep_limit:
  125. logger.debug(
  126. "Ratelimit [%s]: sleeping req",
  127. id(request_id),
  128. )
  129. ret_defer = run_in_background(self.clock.sleep, self.sleep_msec / 1000.0)
  130. self.sleeping_requests.add(request_id)
  131. def on_wait_finished(_):
  132. logger.debug(
  133. "Ratelimit [%s]: Finished sleeping",
  134. id(request_id),
  135. )
  136. self.sleeping_requests.discard(request_id)
  137. queue_defer = queue_request()
  138. return queue_defer
  139. ret_defer.addBoth(on_wait_finished)
  140. else:
  141. ret_defer = queue_request()
  142. def on_start(r):
  143. logger.debug(
  144. "Ratelimit [%s]: Processing req",
  145. id(request_id),
  146. )
  147. self.current_processing.add(request_id)
  148. return r
  149. def on_err(r):
  150. # XXX: why is this necessary? this is called before we start
  151. # processing the request so why would the request be in
  152. # current_processing?
  153. self.current_processing.discard(request_id)
  154. return r
  155. def on_both(r):
  156. # Ensure that we've properly cleaned up.
  157. self.sleeping_requests.discard(request_id)
  158. self.ready_request_queue.pop(request_id, None)
  159. return r
  160. ret_defer.addCallbacks(on_start, on_err)
  161. ret_defer.addBoth(on_both)
  162. return make_deferred_yieldable(ret_defer)
  163. def _on_exit(self, request_id):
  164. logger.debug(
  165. "Ratelimit [%s]: Processed req",
  166. id(request_id),
  167. )
  168. self.current_processing.discard(request_id)
  169. try:
  170. request_id, deferred = self.ready_request_queue.popitem()
  171. # XXX: why do we do the following? the on_start callback above will
  172. # do it for us.
  173. self.current_processing.add(request_id)
  174. with PreserveLoggingContext():
  175. deferred.callback(None)
  176. except KeyError:
  177. pass