test_ratelimiting.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
  2. from synapse.appservice import ApplicationService
  3. from synapse.config.ratelimiting import RatelimitSettings
  4. from synapse.types import create_requester
  5. from tests import unittest
  6. class TestRatelimiter(unittest.HomeserverTestCase):
  7. def test_allowed_via_can_do_action(self) -> None:
  8. limiter = Ratelimiter(
  9. store=self.hs.get_datastores().main,
  10. clock=self.clock,
  11. cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
  12. )
  13. allowed, time_allowed = self.get_success_or_raise(
  14. limiter.can_do_action(None, key="test_id", _time_now_s=0)
  15. )
  16. self.assertTrue(allowed)
  17. self.assertEqual(10.0, time_allowed)
  18. allowed, time_allowed = self.get_success_or_raise(
  19. limiter.can_do_action(None, key="test_id", _time_now_s=5)
  20. )
  21. self.assertFalse(allowed)
  22. self.assertEqual(10.0, time_allowed)
  23. allowed, time_allowed = self.get_success_or_raise(
  24. limiter.can_do_action(None, key="test_id", _time_now_s=10)
  25. )
  26. self.assertTrue(allowed)
  27. self.assertEqual(20.0, time_allowed)
  28. def test_allowed_appservice_ratelimited_via_can_requester_do_action(self) -> None:
  29. appservice = ApplicationService(
  30. token="fake_token",
  31. id="foo",
  32. rate_limited=True,
  33. sender="@as:example.com",
  34. )
  35. as_requester = create_requester("@user:example.com", app_service=appservice)
  36. limiter = Ratelimiter(
  37. store=self.hs.get_datastores().main,
  38. clock=self.clock,
  39. cfg=RatelimitSettings(
  40. key="",
  41. per_second=0.1,
  42. burst_count=1,
  43. ),
  44. )
  45. allowed, time_allowed = self.get_success_or_raise(
  46. limiter.can_do_action(as_requester, _time_now_s=0)
  47. )
  48. self.assertTrue(allowed)
  49. self.assertEqual(10.0, time_allowed)
  50. allowed, time_allowed = self.get_success_or_raise(
  51. limiter.can_do_action(as_requester, _time_now_s=5)
  52. )
  53. self.assertFalse(allowed)
  54. self.assertEqual(10.0, time_allowed)
  55. allowed, time_allowed = self.get_success_or_raise(
  56. limiter.can_do_action(as_requester, _time_now_s=10)
  57. )
  58. self.assertTrue(allowed)
  59. self.assertEqual(20.0, time_allowed)
  60. def test_allowed_appservice_via_can_requester_do_action(self) -> None:
  61. appservice = ApplicationService(
  62. token="fake_token",
  63. id="foo",
  64. rate_limited=False,
  65. sender="@as:example.com",
  66. )
  67. as_requester = create_requester("@user:example.com", app_service=appservice)
  68. limiter = Ratelimiter(
  69. store=self.hs.get_datastores().main,
  70. clock=self.clock,
  71. cfg=RatelimitSettings(
  72. key="",
  73. per_second=0.1,
  74. burst_count=1,
  75. ),
  76. )
  77. allowed, time_allowed = self.get_success_or_raise(
  78. limiter.can_do_action(as_requester, _time_now_s=0)
  79. )
  80. self.assertTrue(allowed)
  81. self.assertEqual(-1, time_allowed)
  82. allowed, time_allowed = self.get_success_or_raise(
  83. limiter.can_do_action(as_requester, _time_now_s=5)
  84. )
  85. self.assertTrue(allowed)
  86. self.assertEqual(-1, time_allowed)
  87. allowed, time_allowed = self.get_success_or_raise(
  88. limiter.can_do_action(as_requester, _time_now_s=10)
  89. )
  90. self.assertTrue(allowed)
  91. self.assertEqual(-1, time_allowed)
  92. def test_allowed_via_ratelimit(self) -> None:
  93. limiter = Ratelimiter(
  94. store=self.hs.get_datastores().main,
  95. clock=self.clock,
  96. cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
  97. )
  98. # Shouldn't raise
  99. self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
  100. # Should raise
  101. with self.assertRaises(LimitExceededError) as context:
  102. self.get_success_or_raise(
  103. limiter.ratelimit(None, key="test_id", _time_now_s=5)
  104. )
  105. self.assertEqual(context.exception.retry_after_ms, 5000)
  106. # Shouldn't raise
  107. self.get_success_or_raise(
  108. limiter.ratelimit(None, key="test_id", _time_now_s=10)
  109. )
  110. def test_allowed_via_can_do_action_and_overriding_parameters(self) -> None:
  111. """Test that we can override options of can_do_action that would otherwise fail
  112. an action
  113. """
  114. # Create a Ratelimiter with a very low allowed rate_hz and burst_count
  115. limiter = Ratelimiter(
  116. store=self.hs.get_datastores().main,
  117. clock=self.clock,
  118. cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
  119. )
  120. # First attempt should be allowed
  121. allowed, time_allowed = self.get_success_or_raise(
  122. limiter.can_do_action(
  123. None,
  124. ("test_id",),
  125. _time_now_s=0,
  126. )
  127. )
  128. self.assertTrue(allowed)
  129. self.assertEqual(10.0, time_allowed)
  130. # Second attempt, 1s later, will fail
  131. allowed, time_allowed = self.get_success_or_raise(
  132. limiter.can_do_action(
  133. None,
  134. ("test_id",),
  135. _time_now_s=1,
  136. )
  137. )
  138. self.assertFalse(allowed)
  139. self.assertEqual(10.0, time_allowed)
  140. # But, if we allow 10 actions/sec for this request, we should be allowed
  141. # to continue.
  142. allowed, time_allowed = self.get_success_or_raise(
  143. limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
  144. )
  145. self.assertTrue(allowed)
  146. self.assertEqual(1.1, time_allowed)
  147. # Similarly if we allow a burst of 10 actions
  148. allowed, time_allowed = self.get_success_or_raise(
  149. limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
  150. )
  151. self.assertTrue(allowed)
  152. self.assertEqual(1.0, time_allowed)
  153. def test_allowed_via_ratelimit_and_overriding_parameters(self) -> None:
  154. """Test that we can override options of the ratelimit method that would otherwise
  155. fail an action
  156. """
  157. # Create a Ratelimiter with a very low allowed rate_hz and burst_count
  158. limiter = Ratelimiter(
  159. store=self.hs.get_datastores().main,
  160. clock=self.clock,
  161. cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
  162. )
  163. # First attempt should be allowed
  164. self.get_success_or_raise(
  165. limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
  166. )
  167. # Second attempt, 1s later, will fail
  168. with self.assertRaises(LimitExceededError) as context:
  169. self.get_success_or_raise(
  170. limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
  171. )
  172. self.assertEqual(context.exception.retry_after_ms, 9000)
  173. # But, if we allow 10 actions/sec for this request, we should be allowed
  174. # to continue.
  175. self.get_success_or_raise(
  176. limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
  177. )
  178. # Similarly if we allow a burst of 10 actions
  179. self.get_success_or_raise(
  180. limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
  181. )
  182. def test_pruning(self) -> None:
  183. limiter = Ratelimiter(
  184. store=self.hs.get_datastores().main,
  185. clock=self.clock,
  186. cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
  187. )
  188. self.get_success_or_raise(
  189. limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
  190. )
  191. self.assertIn("test_id_1", limiter.actions)
  192. self.get_success_or_raise(
  193. limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
  194. )
  195. self.assertNotIn("test_id_1", limiter.actions)
  196. def test_db_user_override(self) -> None:
  197. """Test that users that have ratelimiting disabled in the DB aren't
  198. ratelimited.
  199. """
  200. store = self.hs.get_datastores().main
  201. user_id = "@user:test"
  202. requester = create_requester(user_id)
  203. self.get_success(
  204. store.db_pool.simple_insert(
  205. table="ratelimit_override",
  206. values={
  207. "user_id": user_id,
  208. "messages_per_second": None,
  209. "burst_count": None,
  210. },
  211. desc="test_db_user_override",
  212. )
  213. )
  214. limiter = Ratelimiter(
  215. store=store,
  216. clock=self.clock,
  217. cfg=RatelimitSettings("", per_second=0.1, burst_count=1),
  218. )
  219. # Shouldn't raise
  220. for _ in range(20):
  221. self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))
  222. def test_multiple_actions(self) -> None:
  223. limiter = Ratelimiter(
  224. store=self.hs.get_datastores().main,
  225. clock=self.clock,
  226. cfg=RatelimitSettings(
  227. key="",
  228. per_second=0.1,
  229. burst_count=3,
  230. ),
  231. )
  232. # Test that 4 actions aren't allowed with a maximum burst of 3.
  233. allowed, time_allowed = self.get_success_or_raise(
  234. limiter.can_do_action(None, key="test_id", n_actions=4, _time_now_s=0)
  235. )
  236. self.assertFalse(allowed)
  237. # Test that 3 actions are allowed with a maximum burst of 3.
  238. allowed, time_allowed = self.get_success_or_raise(
  239. limiter.can_do_action(None, key="test_id", n_actions=3, _time_now_s=0)
  240. )
  241. self.assertTrue(allowed)
  242. self.assertEqual(10.0, time_allowed)
  243. # Test that, after doing these 3 actions, we can't do any more actions without
  244. # waiting.
  245. allowed, time_allowed = self.get_success_or_raise(
  246. limiter.can_do_action(None, key="test_id", n_actions=1, _time_now_s=0)
  247. )
  248. self.assertFalse(allowed)
  249. self.assertEqual(10.0, time_allowed)
  250. # Test that after waiting we would be able to do only 1 action.
  251. # Note that we don't actually do it (update=False) here.
  252. allowed, time_allowed = self.get_success_or_raise(
  253. limiter.can_do_action(
  254. None,
  255. key="test_id",
  256. update=False,
  257. n_actions=1,
  258. _time_now_s=10,
  259. )
  260. )
  261. self.assertTrue(allowed)
  262. # We would be able to do the 5th action at t=20.
  263. self.assertEqual(20.0, time_allowed)
  264. # Attempt (but fail) to perform TWO actions at t=10.
  265. # Those would be the 4th and 5th actions.
  266. allowed, time_allowed = self.get_success_or_raise(
  267. limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=10)
  268. )
  269. self.assertFalse(allowed)
  270. # The returned time allowed for the next action is now even though we weren't
  271. # allowed to perform the action because whilst we don't allow 2 actions,
  272. # we could still do 1.
  273. self.assertEqual(10.0, time_allowed)
  274. # Test that after waiting until t=20, we can do perform 2 actions.
  275. # These are the 4th and 5th actions.
  276. allowed, time_allowed = self.get_success_or_raise(
  277. limiter.can_do_action(None, key="test_id", n_actions=2, _time_now_s=20)
  278. )
  279. self.assertTrue(allowed)
  280. # We would be able to do the 6th action at t=30.
  281. self.assertEqual(30.0, time_allowed)
  282. def test_rate_limit_burst_only_given_once(self) -> None:
  283. """
  284. Regression test against a bug that meant that you could build up
  285. extra tokens by timing requests.
  286. """
  287. limiter = Ratelimiter(
  288. store=self.hs.get_datastores().main,
  289. clock=self.clock,
  290. cfg=RatelimitSettings("", per_second=0.1, burst_count=3),
  291. )
  292. def consume_at(time: float) -> bool:
  293. success, _ = self.get_success_or_raise(
  294. limiter.can_do_action(requester=None, key="a", _time_now_s=time)
  295. )
  296. return success
  297. # Use all our 3 burst tokens
  298. self.assertTrue(consume_at(0.0))
  299. self.assertTrue(consume_at(0.1))
  300. self.assertTrue(consume_at(0.2))
  301. # Wait to recover 1 token (10 seconds at 0.1 Hz).
  302. self.assertTrue(consume_at(10.1))
  303. # Check that we get rate limited after using that token.
  304. self.assertFalse(consume_at(11.1))
  305. def test_record_action_which_doesnt_fill_bucket(self) -> None:
  306. limiter = Ratelimiter(
  307. store=self.hs.get_datastores().main,
  308. clock=self.clock,
  309. cfg=RatelimitSettings(
  310. "",
  311. per_second=0.1,
  312. burst_count=3,
  313. ),
  314. )
  315. # Observe two actions, leaving room in the bucket for one more.
  316. limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
  317. # We should be able to take a new action now.
  318. success, _ = self.get_success_or_raise(
  319. limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
  320. )
  321. self.assertTrue(success)
  322. # ... but not two.
  323. success, _ = self.get_success_or_raise(
  324. limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
  325. )
  326. self.assertFalse(success)
  327. def test_record_action_which_fills_bucket(self) -> None:
  328. limiter = Ratelimiter(
  329. store=self.hs.get_datastores().main,
  330. clock=self.clock,
  331. cfg=RatelimitSettings(
  332. "",
  333. per_second=0.1,
  334. burst_count=3,
  335. ),
  336. )
  337. # Observe three actions, filling up the bucket.
  338. limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
  339. # We should be unable to take a new action now.
  340. success, _ = self.get_success_or_raise(
  341. limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
  342. )
  343. self.assertFalse(success)
  344. # If we wait 10 seconds to leak a token, we should be able to take one action...
  345. success, _ = self.get_success_or_raise(
  346. limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
  347. )
  348. self.assertTrue(success)
  349. # ... but not two.
  350. success, _ = self.get_success_or_raise(
  351. limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
  352. )
  353. self.assertFalse(success)
  354. def test_record_action_which_overfills_bucket(self) -> None:
  355. limiter = Ratelimiter(
  356. store=self.hs.get_datastores().main,
  357. clock=self.clock,
  358. cfg=RatelimitSettings(
  359. "",
  360. per_second=0.1,
  361. burst_count=3,
  362. ),
  363. )
  364. # Observe four actions, exceeding the bucket.
  365. limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
  366. # We should be prevented from taking a new action now.
  367. success, _ = self.get_success_or_raise(
  368. limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
  369. )
  370. self.assertFalse(success)
  371. # If we wait 10 seconds to leak a token, we should be unable to take an action
  372. # because the bucket is still full.
  373. success, _ = self.get_success_or_raise(
  374. limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
  375. )
  376. self.assertFalse(success)
  377. # But after another 10 seconds we leak a second token, giving us room for
  378. # action.
  379. success, _ = self.get_success_or_raise(
  380. limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
  381. )
  382. self.assertTrue(success)