test_ratelimiting.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
  2. from synapse.appservice import ApplicationService
  3. from synapse.types import create_requester
  4. from tests import unittest
  5. class TestRatelimiter(unittest.HomeserverTestCase):
  6. def test_allowed_via_can_do_action(self):
  7. limiter = Ratelimiter(
  8. store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
  9. )
  10. allowed, time_allowed = self.get_success_or_raise(
  11. limiter.can_do_action(None, key="test_id", _time_now_s=0)
  12. )
  13. self.assertTrue(allowed)
  14. self.assertEquals(10.0, time_allowed)
  15. allowed, time_allowed = self.get_success_or_raise(
  16. limiter.can_do_action(None, key="test_id", _time_now_s=5)
  17. )
  18. self.assertFalse(allowed)
  19. self.assertEquals(10.0, time_allowed)
  20. allowed, time_allowed = self.get_success_or_raise(
  21. limiter.can_do_action(None, key="test_id", _time_now_s=10)
  22. )
  23. self.assertTrue(allowed)
  24. self.assertEquals(20.0, time_allowed)
  25. def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
  26. appservice = ApplicationService(
  27. None,
  28. "example.com",
  29. id="foo",
  30. rate_limited=True,
  31. sender="@as:example.com",
  32. )
  33. as_requester = create_requester("@user:example.com", app_service=appservice)
  34. limiter = Ratelimiter(
  35. store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
  36. )
  37. allowed, time_allowed = self.get_success_or_raise(
  38. limiter.can_do_action(as_requester, _time_now_s=0)
  39. )
  40. self.assertTrue(allowed)
  41. self.assertEquals(10.0, time_allowed)
  42. allowed, time_allowed = self.get_success_or_raise(
  43. limiter.can_do_action(as_requester, _time_now_s=5)
  44. )
  45. self.assertFalse(allowed)
  46. self.assertEquals(10.0, time_allowed)
  47. allowed, time_allowed = self.get_success_or_raise(
  48. limiter.can_do_action(as_requester, _time_now_s=10)
  49. )
  50. self.assertTrue(allowed)
  51. self.assertEquals(20.0, time_allowed)
  52. def test_allowed_appservice_via_can_requester_do_action(self):
  53. appservice = ApplicationService(
  54. None,
  55. "example.com",
  56. id="foo",
  57. rate_limited=False,
  58. sender="@as:example.com",
  59. )
  60. as_requester = create_requester("@user:example.com", app_service=appservice)
  61. limiter = Ratelimiter(
  62. store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
  63. )
  64. allowed, time_allowed = self.get_success_or_raise(
  65. limiter.can_do_action(as_requester, _time_now_s=0)
  66. )
  67. self.assertTrue(allowed)
  68. self.assertEquals(-1, time_allowed)
  69. allowed, time_allowed = self.get_success_or_raise(
  70. limiter.can_do_action(as_requester, _time_now_s=5)
  71. )
  72. self.assertTrue(allowed)
  73. self.assertEquals(-1, time_allowed)
  74. allowed, time_allowed = self.get_success_or_raise(
  75. limiter.can_do_action(as_requester, _time_now_s=10)
  76. )
  77. self.assertTrue(allowed)
  78. self.assertEquals(-1, time_allowed)
  79. def test_allowed_via_ratelimit(self):
  80. limiter = Ratelimiter(
  81. store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
  82. )
  83. # Shouldn't raise
  84. self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
  85. # Should raise
  86. with self.assertRaises(LimitExceededError) as context:
  87. self.get_success_or_raise(
  88. limiter.ratelimit(None, key="test_id", _time_now_s=5)
  89. )
  90. self.assertEqual(context.exception.retry_after_ms, 5000)
  91. # Shouldn't raise
  92. self.get_success_or_raise(
  93. limiter.ratelimit(None, key="test_id", _time_now_s=10)
  94. )
  95. def test_allowed_via_can_do_action_and_overriding_parameters(self):
  96. """Test that we can override options of can_do_action that would otherwise fail
  97. an action
  98. """
  99. # Create a Ratelimiter with a very low allowed rate_hz and burst_count
  100. limiter = Ratelimiter(
  101. store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
  102. )
  103. # First attempt should be allowed
  104. allowed, time_allowed = self.get_success_or_raise(
  105. limiter.can_do_action(
  106. None,
  107. ("test_id",),
  108. _time_now_s=0,
  109. )
  110. )
  111. self.assertTrue(allowed)
  112. self.assertEqual(10.0, time_allowed)
  113. # Second attempt, 1s later, will fail
  114. allowed, time_allowed = self.get_success_or_raise(
  115. limiter.can_do_action(
  116. None,
  117. ("test_id",),
  118. _time_now_s=1,
  119. )
  120. )
  121. self.assertFalse(allowed)
  122. self.assertEqual(10.0, time_allowed)
  123. # But, if we allow 10 actions/sec for this request, we should be allowed
  124. # to continue.
  125. allowed, time_allowed = self.get_success_or_raise(
  126. limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
  127. )
  128. self.assertTrue(allowed)
  129. self.assertEqual(1.1, time_allowed)
  130. # Similarly if we allow a burst of 10 actions
  131. allowed, time_allowed = self.get_success_or_raise(
  132. limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
  133. )
  134. self.assertTrue(allowed)
  135. self.assertEqual(1.0, time_allowed)
  136. def test_allowed_via_ratelimit_and_overriding_parameters(self):
  137. """Test that we can override options of the ratelimit method that would otherwise
  138. fail an action
  139. """
  140. # Create a Ratelimiter with a very low allowed rate_hz and burst_count
  141. limiter = Ratelimiter(
  142. store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
  143. )
  144. # First attempt should be allowed
  145. self.get_success_or_raise(
  146. limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
  147. )
  148. # Second attempt, 1s later, will fail
  149. with self.assertRaises(LimitExceededError) as context:
  150. self.get_success_or_raise(
  151. limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
  152. )
  153. self.assertEqual(context.exception.retry_after_ms, 9000)
  154. # But, if we allow 10 actions/sec for this request, we should be allowed
  155. # to continue.
  156. self.get_success_or_raise(
  157. limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
  158. )
  159. # Similarly if we allow a burst of 10 actions
  160. self.get_success_or_raise(
  161. limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
  162. )
  163. def test_pruning(self):
  164. limiter = Ratelimiter(
  165. store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
  166. )
  167. self.get_success_or_raise(
  168. limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
  169. )
  170. self.assertIn("test_id_1", limiter.actions)
  171. self.get_success_or_raise(
  172. limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
  173. )
  174. self.assertNotIn("test_id_1", limiter.actions)
  175. def test_db_user_override(self):
  176. """Test that users that have ratelimiting disabled in the DB aren't
  177. ratelimited.
  178. """
  179. store = self.hs.get_datastore()
  180. user_id = "@user:test"
  181. requester = create_requester(user_id)
  182. self.get_success(
  183. store.db_pool.simple_insert(
  184. table="ratelimit_override",
  185. values={
  186. "user_id": user_id,
  187. "messages_per_second": None,
  188. "burst_count": None,
  189. },
  190. desc="test_db_user_override",
  191. )
  192. )
  193. limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
  194. # Shouldn't raise
  195. for _ in range(20):
  196. self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))