ratelimitutils.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import contextlib
  16. import logging
  17. import typing
  18. from typing import Any, DefaultDict, Iterator, List, Set
  19. from twisted.internet import defer
  20. from synapse.api.errors import LimitExceededError
  21. from synapse.config.ratelimiting import FederationRateLimitConfig
  22. from synapse.logging.context import (
  23. PreserveLoggingContext,
  24. make_deferred_yieldable,
  25. run_in_background,
  26. )
  27. from synapse.util import Clock
  28. if typing.TYPE_CHECKING:
  29. from contextlib import _GeneratorContextManager
  30. logger = logging.getLogger(__name__)
  31. class FederationRateLimiter:
  32. def __init__(self, clock: Clock, config: FederationRateLimitConfig):
  33. def new_limiter() -> "_PerHostRatelimiter":
  34. return _PerHostRatelimiter(clock=clock, config=config)
  35. self.ratelimiters: DefaultDict[
  36. str, "_PerHostRatelimiter"
  37. ] = collections.defaultdict(new_limiter)
  38. def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
  39. """Used to ratelimit an incoming request from a given host
  40. Example usage:
  41. with rate_limiter.ratelimit(origin) as wait_deferred:
  42. yield wait_deferred
  43. # Handle request ...
  44. Args:
  45. host (str): Origin of incoming request.
  46. Returns:
  47. context manager which returns a deferred.
  48. """
  49. return self.ratelimiters[host].ratelimit()
  50. class _PerHostRatelimiter:
  51. def __init__(self, clock: Clock, config: FederationRateLimitConfig):
  52. """
  53. Args:
  54. clock
  55. config
  56. """
  57. self.clock = clock
  58. self.window_size = config.window_size
  59. self.sleep_limit = config.sleep_limit
  60. self.sleep_sec = config.sleep_delay / 1000.0
  61. self.reject_limit = config.reject_limit
  62. self.concurrent_requests = config.concurrent
  63. # request_id objects for requests which have been slept
  64. self.sleeping_requests: Set[object] = set()
  65. # map from request_id object to Deferred for requests which are ready
  66. # for processing but have been queued
  67. self.ready_request_queue: collections.OrderedDict[
  68. object, defer.Deferred[None]
  69. ] = collections.OrderedDict()
  70. # request id objects for requests which are in progress
  71. self.current_processing: Set[object] = set()
  72. # times at which we have recently (within the last window_size ms)
  73. # received requests.
  74. self.request_times: List[int] = []
  75. @contextlib.contextmanager
  76. def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
  77. # `contextlib.contextmanager` takes a generator and turns it into a
  78. # context manager. The generator should only yield once with a value
  79. # to be returned by manager.
  80. # Exceptions will be reraised at the yield.
  81. request_id = object()
  82. ret = self._on_enter(request_id)
  83. try:
  84. yield ret
  85. finally:
  86. self._on_exit(request_id)
  87. def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
  88. time_now = self.clock.time_msec()
  89. # remove any entries from request_times which aren't within the window
  90. self.request_times[:] = [
  91. r for r in self.request_times if time_now - r < self.window_size
  92. ]
  93. # reject the request if we already have too many queued up (either
  94. # sleeping or in the ready queue).
  95. queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
  96. if queue_size > self.reject_limit:
  97. raise LimitExceededError(
  98. retry_after_ms=int(self.window_size / self.sleep_limit)
  99. )
  100. self.request_times.append(time_now)
  101. def queue_request() -> "defer.Deferred[None]":
  102. if len(self.current_processing) >= self.concurrent_requests:
  103. queue_defer: defer.Deferred[None] = defer.Deferred()
  104. self.ready_request_queue[request_id] = queue_defer
  105. logger.info(
  106. "Ratelimiter: queueing request (queue now %i items)",
  107. len(self.ready_request_queue),
  108. )
  109. return queue_defer
  110. else:
  111. return defer.succeed(None)
  112. logger.debug(
  113. "Ratelimit [%s]: len(self.request_times)=%d",
  114. id(request_id),
  115. len(self.request_times),
  116. )
  117. if len(self.request_times) > self.sleep_limit:
  118. logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
  119. ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
  120. self.sleeping_requests.add(request_id)
  121. def on_wait_finished(_: Any) -> "defer.Deferred[None]":
  122. logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
  123. self.sleeping_requests.discard(request_id)
  124. queue_defer = queue_request()
  125. return queue_defer
  126. ret_defer.addBoth(on_wait_finished)
  127. else:
  128. ret_defer = queue_request()
  129. def on_start(r: object) -> object:
  130. logger.debug("Ratelimit [%s]: Processing req", id(request_id))
  131. self.current_processing.add(request_id)
  132. return r
  133. def on_err(r: object) -> object:
  134. # XXX: why is this necessary? this is called before we start
  135. # processing the request so why would the request be in
  136. # current_processing?
  137. self.current_processing.discard(request_id)
  138. return r
  139. def on_both(r: object) -> object:
  140. # Ensure that we've properly cleaned up.
  141. self.sleeping_requests.discard(request_id)
  142. self.ready_request_queue.pop(request_id, None)
  143. return r
  144. ret_defer.addCallbacks(on_start, on_err)
  145. ret_defer.addBoth(on_both)
  146. return make_deferred_yieldable(ret_defer)
  147. def _on_exit(self, request_id: object) -> None:
  148. logger.debug("Ratelimit [%s]: Processed req", id(request_id))
  149. self.current_processing.discard(request_id)
  150. try:
  151. # start processing the next item on the queue.
  152. _, deferred = self.ready_request_queue.popitem(last=False)
  153. with PreserveLoggingContext():
  154. deferred.callback(None)
  155. except KeyError:
  156. pass