test_ratelimiting.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
  2. from synapse.appservice import ApplicationService
  3. from synapse.types import create_requester
  4. from tests import unittest
  5. class TestRatelimiter(unittest.HomeserverTestCase):
  6. def test_allowed_via_can_do_action(self) -> None:
  7. limiter = Ratelimiter(
  8. store=self.hs.get_datastores().main,
  9. clock=self.clock,
  10. rate_hz=0.1,
  11. burst_count=1,
  12. )
  13. allowed, time_allowed = self.get_success_or_raise(
  14. limiter.can_do_action(None, key="test_id", _time_now_s=0)
  15. )
  16. self.assertTrue(allowed)
  17. self.assertEqual(10.0, time_allowed)
  18. allowed, time_allowed = self.get_success_or_raise(
  19. limiter.can_do_action(None, key="test_id", _time_now_s=5)
  20. )
  21. self.assertFalse(allowed)
  22. self.assertEqual(10.0, time_allowed)
  23. allowed, time_allowed = self.get_success_or_raise(
  24. limiter.can_do_action(None, key="test_id", _time_now_s=10)
  25. )
  26. self.assertTrue(allowed)
  27. self.assertEqual(20.0, time_allowed)
  28. def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> None:
  29. appservice = ApplicationService(
  30. token="fake_token",
  31. id="foo",
  32. rate_limited=True,
  33. sender="@as:example.com",
  34. )
  35. as_requester = create_requester("@user:example.com", app_service=appservice)
  36. limiter = Ratelimiter(
  37. store=self.hs.get_datastores().main,
  38. clock=self.clock,
  39. rate_hz=0.1,
  40. burst_count=1,
  41. )
  42. allowed, time_allowed = self.get_success_or_raise(
  43. limiter.can_do_action(as_requester, _time_now_s=0)
  44. )
  45. self.assertTrue(allowed)
  46. self.assertEqual(10.0, time_allowed)
  47. allowed, time_allowed = self.get_success_or_raise(
  48. limiter.can_do_action(as_requester, _time_now_s=5)
  49. )
  50. self.assertFalse(allowed)
  51. self.assertEqual(10.0, time_allowed)
  52. allowed, time_allowed = self.get_success_or_raise(
  53. limiter.can_do_action(as_requester, _time_now_s=10)
  54. )
  55. self.assertTrue(allowed)
  56. self.assertEqual(20.0, time_allowed)
  57. def test_allowed_appservice_via_can_requester_do_action(self) -> None:
  58. appservice = ApplicationService(
  59. token="fake_token",
  60. id="foo",
  61. rate_limited=False,
  62. sender="@as:example.com",
  63. )
  64. as_requester = create_requester("@user:example.com", app_service=appservice)
  65. limiter = Ratelimiter(
  66. store=self.hs.get_datastores().main,
  67. clock=self.clock,
  68. rate_hz=0.1,
  69. burst_count=1,
  70. )
  71. allowed, time_allowed = self.get_success_or_raise(
  72. limiter.can_do_action(as_requester, _time_now_s=0)
  73. )
  74. self.assertTrue(allowed)
  75. self.assertEqual(-1, time_allowed)
  76. allowed, time_allowed = self.get_success_or_raise(
  77. limiter.can_do_action(as_requester, _time_now_s=5)
  78. )
  79. self.assertTrue(allowed)
  80. self.assertEqual(-1, time_allowed)
  81. allowed, time_allowed = self.get_success_or_raise(
  82. limiter.can_do_action(as_requester, _time_now_s=10)
  83. )
  84. self.assertTrue(allowed)
  85. self.assertEqual(-1, time_allowed)
  86. def test_allowed_via_ratelimit(self) -> None:
  87. limiter = Ratelimiter(
  88. store=self.hs.get_datastores().main,
  89. clock=self.clock,
  90. rate_hz=0.1,
  91. burst_count=1,
  92. )
  93. # Shouldn't raise
  94. self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
  95. # Should raise
  96. with self.assertRaises(LimitExceededError) as context:
  97. self.get_success_or_raise(
  98. limiter.ratelimit(None, key="test_id", _time_now_s=5)
  99. )
  100. self.assertEqual(context.exception.retry_after_ms, 5000)
  101. # Shouldn't raise
  102. self.get_success_or_raise(
  103. limiter.ratelimit(None, key="test_id", _time_now_s=10)
  104. )
  105. def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None:
  106. """Test that we can override options of can_do_action that would otherwise fail
  107. an action
  108. """
  109. # Create a Ratelimiter with a very low allowed rate_hz and burst_count
  110. limiter = Ratelimiter(
  111. store=self.hs.get_datastores().main,
  112. clock=self.clock,
  113. rate_hz=0.1,
  114. burst_count=1,
  115. )
  116. # First attempt should be allowed
  117. allowed, time_allowed = self.get_success_or_raise(
  118. limiter.can_do_action(
  119. None,
  120. ("test_id",),
  121. _time_now_s=0,
  122. )
  123. )
  124. self.assertTrue(allowed)
  125. self.assertEqual(10.0, time_allowed)
  126. # Second attempt, 1s later, will fail
  127. allowed, time_allowed = self.get_success_or_raise(
  128. limiter.can_do_action(
  129. None,
  130. ("test_id",),
  131. _time_now_s=1,
  132. )
  133. )
  134. self.assertFalse(allowed)
  135. self.assertEqual(10.0, time_allowed)
  136. # But, if we allow 10 actions/sec for this request, we should be allowed
  137. # to continue.
  138. allowed, time_allowed = self.get_success_or_raise(
  139. limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
  140. )
  141. self.assertTrue(allowed)
  142. self.assertEqual(1.1, time_allowed)
  143. # Similarly if we allow a burst of 10 actions
  144. allowed, time_allowed = self.get_success_or_raise(
  145. limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
  146. )
  147. self.assertTrue(allowed)
  148. self.assertEqual(1.0, time_allowed)
  149. def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None:
  150. """Test that we can override options of the ratelimit method that would otherwise
  151. fail an action
  152. """
  153. # Create a Ratelimiter with a very low allowed rate_hz and burst_count
  154. limiter = Ratelimiter(
  155. store=self.hs.get_datastores().main,
  156. clock=self.clock,
  157. rate_hz=0.1,
  158. burst_count=1,
  159. )
  160. # First attempt should be allowed
  161. self.get_success_or_raise(
  162. limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
  163. )
  164. # Second attempt, 1s later, will fail
  165. with self.assertRaises(LimitExceededError) as context:
  166. self.get_success_or_raise(
  167. limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
  168. )
  169. self.assertEqual(context.exception.retry_after_ms, 9000)
  170. # But, if we allow 10 actions/sec for this request, we should be allowed
  171. # to continue.
  172. self.get_success_or_raise(
  173. limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
  174. )
  175. # Similarly if we allow a burst of 10 actions
  176. self.get_success_or_raise(
  177. limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
  178. )
  179. def test_pruning(self) -> None:
  180. limiter = Ratelimiter(
  181. store=self.hs.get_datastores().main,
  182. clock=self.clock,
  183. rate_hz=0.1,
  184. burst_count=1,
  185. )
  186. self.get_success_or_raise(
  187. limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
  188. )
  189. self.assertIn("test_id_1", limiter.actions)
  190. self.get_success_or_raise(
  191. limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
  192. )
  193. self.assertNotIn("test_id_1", limiter.actions)
  194. def test_db_user_override(self) -> None:
  195. """Test that users that have ratelimiting disabled in the DB aren't
  196. ratelimited.
  197. """
  198. store = self.hs.get_datastores().main
  199. user_id = "@user:test"
  200. requester = create_requester(user_id)
  201. self.get_success(
  202. store.db_pool.simple_insert(
  203. table="ratelimit_override",
  204. values={
  205. "user_id": user_id,
  206. "messages_per_second": None,
  207. "burst_count": None,
  208. },
  209. desc="test_db_user_override",
  210. )
  211. )
  212. limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1)
  213. # Shouldn't raise
  214. for _ in range(20):
  215. self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
  216. def test_multiple_actions(self) -> None:
  217. limiter = Ratelimiter(
  218. store=self.hs.get_datastores().main,
  219. clock=self.clock,
  220. rate_hz=0.1,
  221. burst_count=3,
  222. )
  223. # Test that 4 actions aren't allowed with a maximum burst of 3.
  224. allowed, time_allowed = self.get_success_or_raise(
  225. limiter.can_do_action(None, key="test_id", n_actions=4, _time_now_s=0)
  226. )
  227. self.assertFalse(allowed)
  228. # Test that 3 actions are allowed with a maximum burst of 3.
  229. allowed, time_allowed = self.get_success_or_raise(
  230. limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0)
  231. )
  232. self.assertTrue(allowed)
  233. self.assertEqual(10.0, time_allowed)
  234. # Test that, after doing these 3 actions, we can't do any more actions without
  235. # waiting.
  236. allowed, time_allowed = self.get_success_or_raise(
  237. limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0)
  238. )
  239. self.assertFalse(allowed)
  240. self.assertEqual(10.0, time_allowed)
  241. # Test that after waiting we would be able to do only 1 action.
  242. # Note that we don't actually do it (update=False) here.
  243. allowed, time_allowed = self.get_success_or_raise(
  244. limiter.can_do_action(
  245. None,
  246. key="test_id",
  247. update=False,
  248. n_actions=1,
  249. _time_now_s=10,
  250. )
  251. )
  252. self.assertTrue(allowed)
  253. # We would be able to do the 5th action at t=20.
  254. self.assertEqual(20.0, time_allowed)
  255. # Attempt (but fail) to perform TWO actions at t=10.
  256. # Those would be the 4th and 5th actions.
  257. allowed, time_allowed = self.get_success_or_raise(
  258. limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10)
  259. )
  260. self.assertFalse(allowed)
  261. # The returned time allowed for the next action is now even though we weren't
  262. # allowed to perform the action because whilst we don't allow 2 actions,
  263. # we could still do 1.
  264. self.assertEqual(10.0, time_allowed)
  265. # Test that after waiting until t=20, we can do perform 2 actions.
  266. # These are the 4th and 5th actions.
  267. allowed, time_allowed = self.get_success_or_raise(
  268. limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=20)
  269. )
  270. self.assertTrue(allowed)
  271. # We would be able to do the 6th action at t=30.
  272. self.assertEqual(30.0, time_allowed)
  273. def test_rate_limit_burst_only_given_once(self) -> None:
  274. """
  275. Regression test against a bug that meant that you could build up
  276. extra tokens by timing requests.
  277. """
  278. limiter = Ratelimiter(
  279. store=self.hs.get_datastores().main,
  280. clock=self.clock,
  281. rate_hz=0.1,
  282. burst_count=3,
  283. )
  284. def consume_at(time: float) -> bool:
  285. success, _ = self.get_success_or_raise(
  286. limiter.can_do_action(requester=None, key="a", _time_now_s=time)
  287. )
  288. return success
  289. # Use all our 3 burst tokens
  290. self.assertTrue(consume_at(0.0))
  291. self.assertTrue(consume_at(0.1))
  292. self.assertTrue(consume_at(0.2))
  293. # Wait to recover 1 token (10 seconds at 0.1 Hz).
  294. self.assertTrue(consume_at(10.1))
  295. # Check that we get rate limited after using that token.
  296. self.assertFalse(consume_at(11.1))
  297. def test_record_action_which_doesnt_fill_bucket(self) -> None:
  298. limiter = Ratelimiter(
  299. store=self.hs.get_datastores().main,
  300. clock=self.clock,
  301. rate_hz=0.1,
  302. burst_count=3,
  303. )
  304. # Observe two actions, leaving room in the bucket for one more.
  305. limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
  306. # We should be able to take a new action now.
  307. success, _ = self.get_success_or_raise(
  308. limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
  309. )
  310. self.assertTrue(success)
  311. # ... but not two.
  312. success, _ = self.get_success_or_raise(
  313. limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
  314. )
  315. self.assertFalse(success)
  316. def test_record_action_which_fills_bucket(self) -> None:
  317. limiter = Ratelimiter(
  318. store=self.hs.get_datastores().main,
  319. clock=self.clock,
  320. rate_hz=0.1,
  321. burst_count=3,
  322. )
  323. # Observe three actions, filling up the bucket.
  324. limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
  325. # We should be unable to take a new action now.
  326. success, _ = self.get_success_or_raise(
  327. limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
  328. )
  329. self.assertFalse(success)
  330. # If we wait 10 seconds to leak a token, we should be able to take one action...
  331. success, _ = self.get_success_or_raise(
  332. limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
  333. )
  334. self.assertTrue(success)
  335. # ... but not two.
  336. success, _ = self.get_success_or_raise(
  337. limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
  338. )
  339. self.assertFalse(success)
  340. def test_record_action_which_overfills_bucket(self) -> None:
  341. limiter = Ratelimiter(
  342. store=self.hs.get_datastores().main,
  343. clock=self.clock,
  344. rate_hz=0.1,
  345. burst_count=3,
  346. )
  347. # Observe four actions, exceeding the bucket.
  348. limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
  349. # We should be prevented from taking a new action now.
  350. success, _ = self.get_success_or_raise(
  351. limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
  352. )
  353. self.assertFalse(success)
  354. # If we wait 10 seconds to leak a token, we should be unable to take an action
  355. # because the bucket is still full.
  356. success, _ = self.get_success_or_raise(
  357. limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
  358. )
  359. self.assertFalse(success)
  360. # But after another 10 seconds we leak a second token, giving us room for
  361. # action.
  362. success, _ = self.get_success_or_raise(
  363. limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
  364. )
  365. self.assertTrue(success)