|
@@ -169,7 +169,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
# which sets these values to 10000, but as we're overriding the entire
|
|
|
# rc_login dict here, we need to set this manually as well
|
|
|
"account": {"per_second": 10000, "burst_count": 10000},
|
|
|
- }
|
|
|
+ },
|
|
|
+ "experimental_features": {"msc4041_enabled": True},
|
|
|
}
|
|
|
)
|
|
|
def test_POST_ratelimiting_per_address(self) -> None:
|
|
@@ -189,12 +190,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
if i == 5:
|
|
|
self.assertEqual(channel.code, 429, msg=channel.result)
|
|
|
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
|
|
+ retry_header = channel.headers.getRawHeaders("Retry-After")
|
|
|
else:
|
|
|
self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
|
|
|
# than 1min.
|
|
|
- self.assertTrue(retry_after_ms < 6000)
|
|
|
+ self.assertLess(retry_after_ms, 6000)
|
|
|
+ assert retry_header
|
|
|
+ self.assertLessEqual(int(retry_header[0]), 6)
|
|
|
|
|
|
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|
|
|
|
|
@@ -217,7 +221,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
# which sets these values to 10000, but as we're overriding the entire
|
|
|
# rc_login dict here, we need to set this manually as well
|
|
|
"address": {"per_second": 10000, "burst_count": 10000},
|
|
|
- }
|
|
|
+ },
|
|
|
+ "experimental_features": {"msc4041_enabled": True},
|
|
|
}
|
|
|
)
|
|
|
def test_POST_ratelimiting_per_account(self) -> None:
|
|
@@ -234,12 +239,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
if i == 5:
|
|
|
self.assertEqual(channel.code, 429, msg=channel.result)
|
|
|
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
|
|
+ retry_header = channel.headers.getRawHeaders("Retry-After")
|
|
|
else:
|
|
|
self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
|
|
|
# than 1min.
|
|
|
- self.assertTrue(retry_after_ms < 6000)
|
|
|
+ self.assertLess(retry_after_ms, 6000)
|
|
|
+ assert retry_header
|
|
|
+ self.assertLessEqual(int(retry_header[0]), 6)
|
|
|
|
|
|
self.reactor.advance(retry_after_ms / 1000.0)
|
|
|
|
|
@@ -262,7 +270,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
# rc_login dict here, we need to set this manually as well
|
|
|
"address": {"per_second": 10000, "burst_count": 10000},
|
|
|
"failed_attempts": {"per_second": 0.17, "burst_count": 5},
|
|
|
- }
|
|
|
+ },
|
|
|
+ "experimental_features": {"msc4041_enabled": True},
|
|
|
}
|
|
|
)
|
|
|
def test_POST_ratelimiting_per_account_failed_attempts(self) -> None:
|
|
@@ -279,12 +288,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
if i == 5:
|
|
|
self.assertEqual(channel.code, 429, msg=channel.result)
|
|
|
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
|
|
+ retry_header = channel.headers.getRawHeaders("Retry-After")
|
|
|
else:
|
|
|
self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
|
|
|
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
|
|
|
# than 1min.
|
|
|
- self.assertTrue(retry_after_ms < 6000)
|
|
|
+ self.assertLess(retry_after_ms, 6000)
|
|
|
+ assert retry_header
|
|
|
+ self.assertLessEqual(int(retry_header[0]), 6)
|
|
|
|
|
|
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|
|
|
|