Browse Source

Add ratelimiting on login (#4821)

Add two ratelimiters on login (per-IP address and per-userID).
Brendan Abolivier 5 years ago
parent
commit
899e523d6d

+ 1 - 0
changelog.d/4821.feature

@@ -0,0 +1 @@
+Add configurable rate limiting to the /login endpoint.

+ 28 - 11
docs/sample_config.yaml

@@ -379,6 +379,34 @@ rc_messages_per_second: 0.2
 #
 rc_message_burst_count: 10.0
 
+# Ratelimiting settings for registration and login.
+#
+# Each ratelimiting configuration is made of two parameters:
+#   - per_second: number of requests a client can send per second.
+#   - burst_count: number of requests a client can send before being throttled.
+#
+# Synapse currently uses the following configurations:
+#   - one for registration that ratelimits registration requests based on the
+#     client's IP address.
+#   - one for login that ratelimits login requests based on the client's IP
+#     address.
+#   - one for login that ratelimits login requests based on the account the
+#     client is attempting to log into.
+#
+# The defaults are as shown below.
+#
+#rc_registration:
+#  per_second: 0.17
+#  burst_count: 3
+#
+#rc_login:
+#  address:
+#    per_second: 0.17
+#    burst_count: 3
+#  account:
+#    per_second: 0.17
+#    burst_count: 3
+
 # The federation window size in milliseconds
 #
 federation_rc_window_size: 1000
@@ -403,17 +431,6 @@ federation_rc_reject_limit: 50
 #
 federation_rc_concurrent: 3
 
-# Number of registration requests a client can send per second.
-# Defaults to 1/minute (0.17).
-#
-#rc_registration_requests_per_second: 0.17
-
-# Number of registration requests a client can send before being
-# throttled.
-# Defaults to 3.
-#
-#rc_registration_request_burst_count: 3.0
-
 
 
 # Directory where uploaded images and attachments are stored.

+ 12 - 0
synapse/api/ratelimiting.py

@@ -14,6 +14,8 @@
 
 import collections
 
+from synapse.api.errors import LimitExceededError
+
 
 class Ratelimiter(object):
     """
@@ -82,3 +84,13 @@ class Ratelimiter(object):
                 break
             else:
                 del self.message_counts[key]
+
+    def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True):
+        allowed, time_allowed = self.can_do_action(
+            key, time_now_s, rate_hz, burst_count, update
+        )
+
+        if not allowed:
+            raise LimitExceededError(
+                retry_after_ms=int(1000 * (time_allowed - time_now_s)),
+            )

+ 40 - 18
synapse/config/ratelimiting.py

@@ -15,25 +15,30 @@
 from ._base import Config
 
 
+class RateLimitConfig(object):
+    def __init__(self, config):
+        self.per_second = config.get("per_second", 0.17)
+        self.burst_count = config.get("burst_count", 3.0)
+
+
 class RatelimitConfig(Config):
 
     def read_config(self, config):
         self.rc_messages_per_second = config["rc_messages_per_second"]
         self.rc_message_burst_count = config["rc_message_burst_count"]
 
+        self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+
+        rc_login_config = config.get("rc_login", {})
+        self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
+        self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
+
         self.federation_rc_window_size = config["federation_rc_window_size"]
         self.federation_rc_sleep_limit = config["federation_rc_sleep_limit"]
         self.federation_rc_sleep_delay = config["federation_rc_sleep_delay"]
         self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
         self.federation_rc_concurrent = config["federation_rc_concurrent"]
 
-        self.rc_registration_requests_per_second = config.get(
-            "rc_registration_requests_per_second", 0.17,
-        )
-        self.rc_registration_request_burst_count = config.get(
-            "rc_registration_request_burst_count", 3,
-        )
-
     def default_config(self, **kwargs):
         return """\
         ## Ratelimiting ##
@@ -46,6 +51,34 @@ class RatelimitConfig(Config):
         #
         rc_message_burst_count: 10.0
 
+        # Ratelimiting settings for registration and login.
+        #
+        # Each ratelimiting configuration is made of two parameters:
+        #   - per_second: number of requests a client can send per second.
+        #   - burst_count: number of requests a client can send before being throttled.
+        #
+        # Synapse currently uses the following configurations:
+        #   - one for registration that ratelimits registration requests based on the
+        #     client's IP address.
+        #   - one for login that ratelimits login requests based on the client's IP
+        #     address.
+        #   - one for login that ratelimits login requests based on the account the
+        #     client is attempting to log into.
+        #
+        # The defaults are as shown below.
+        #
+        #rc_registration:
+        #  per_second: 0.17
+        #  burst_count: 3
+        #
+        #rc_login:
+        #  address:
+        #    per_second: 0.17
+        #    burst_count: 3
+        #  account:
+        #    per_second: 0.17
+        #    burst_count: 3
+
         # The federation window size in milliseconds
         #
         federation_rc_window_size: 1000
@@ -69,15 +102,4 @@ class RatelimitConfig(Config):
         # single server
         #
         federation_rc_concurrent: 3
-
-        # Number of registration requests a client can send per second.
-        # Defaults to 1/minute (0.17).
-        #
-        #rc_registration_requests_per_second: 0.17
-
-        # Number of registration requests a client can send before being
-        # throttled.
-        # Defaults to 3.
-        #
-        #rc_registration_request_burst_count: 3.0
         """

+ 36 - 0
synapse/handlers/auth.py

@@ -35,6 +35,7 @@ from synapse.api.errors import (
     StoreError,
     SynapseError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.module_api import ModuleApi
 from synapse.types import UserID
 from synapse.util import logcontext
@@ -99,6 +100,10 @@ class AuthHandler(BaseHandler):
                         login_types.append(t)
         self._supported_login_types = login_types
 
+        self._account_ratelimiter = Ratelimiter()
+
+        self._clock = self.hs.get_clock()
+
     @defer.inlineCallbacks
     def validate_user_via_ui_auth(self, requester, request_body, clientip):
         """
@@ -568,7 +573,12 @@ class AuthHandler(BaseHandler):
         Returns:
             defer.Deferred: (unicode) canonical_user_id, or None if zero or
             multiple matches
+
+        Raises:
+            LimitExceededError if the ratelimiter's login requests count for this
+                user is too high too proceed.
         """
+        self.ratelimit_login_per_account(user_id)
         res = yield self._find_user_id_and_pwd_hash(user_id)
         if res is not None:
             defer.returnValue(res[0])
@@ -634,6 +644,8 @@ class AuthHandler(BaseHandler):
             StoreError if there was a problem accessing the database
             SynapseError if there was a problem with the request
             LoginError if there was an authentication problem.
+            LimitExceededError if the ratelimiter's login requests count for this
+                user is too high too proceed.
         """
 
         if username.startswith('@'):
@@ -643,6 +655,8 @@ class AuthHandler(BaseHandler):
                 username, self.hs.hostname
             ).to_string()
 
+        self.ratelimit_login_per_account(qualified_user_id)
+
         login_type = login_submission.get("type")
         known_login_type = False
 
@@ -735,6 +749,10 @@ class AuthHandler(BaseHandler):
             password (unicode): the provided password
         Returns:
             (unicode) the canonical_user_id, or None if unknown user / bad password
+
+        Raises:
+            LimitExceededError if the ratelimiter's login requests count for this
+                user is too high too proceed.
         """
         lookupres = yield self._find_user_id_and_pwd_hash(user_id)
         if not lookupres:
@@ -763,6 +781,7 @@ class AuthHandler(BaseHandler):
             auth_api.validate_macaroon(macaroon, "login", True, user_id)
         except Exception:
             raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
+        self.ratelimit_login_per_account(user_id)
         yield self.auth.check_auth_blocking(user_id)
         defer.returnValue(user_id)
 
@@ -934,6 +953,23 @@ class AuthHandler(BaseHandler):
         else:
             return defer.succeed(False)
 
+    def ratelimit_login_per_account(self, user_id):
+        """Checks whether the process must be stopped because of ratelimiting.
+
+        Args:
+            user_id (unicode): complete @user:id
+
+        Raises:
+            LimitExceededError if the ratelimiter's login requests count for this
+                user is too high too proceed.
+        """
+        self._account_ratelimiter.ratelimit(
+            user_id.lower(), time_now_s=self._clock.time(),
+            rate_hz=self.hs.config.rc_login_account.per_second,
+            burst_count=self.hs.config.rc_login_account.burst_count,
+            update=True,
+        )
+
 
 @attr.s
 class MacaroonGenerator(object):

+ 2 - 2
synapse/handlers/register.py

@@ -629,8 +629,8 @@ class RegistrationHandler(BaseHandler):
 
             allowed, time_allowed = self.ratelimiter.can_do_action(
                 address, time_now_s=time_now,
-                rate_hz=self.hs.config.rc_registration_requests_per_second,
-                burst_count=self.hs.config.rc_registration_request_burst_count,
+                rate_hz=self.hs.config.rc_registration.per_second,
+                burst_count=self.hs.config.rc_registration.burst_count,
             )
 
             if not allowed:

+ 10 - 0
synapse/rest/client/v1/login.py

@@ -22,6 +22,7 @@ from twisted.internet import defer
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.http.server import finish_request
 from synapse.http.servlet import (
     RestServlet,
@@ -97,6 +98,7 @@ class LoginRestServlet(ClientV1RestServlet):
         self.registration_handler = hs.get_registration_handler()
         self.handlers = hs.get_handlers()
         self._well_known_builder = WellKnownBuilder(hs)
+        self._address_ratelimiter = Ratelimiter()
 
     def on_GET(self, request):
         flows = []
@@ -129,6 +131,13 @@ class LoginRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
+        self._address_ratelimiter.ratelimit(
+            request.getClientIP(), time_now_s=self.hs.clock.time(),
+            rate_hz=self.hs.config.rc_login_address.per_second,
+            burst_count=self.hs.config.rc_login_address.burst_count,
+            update=True,
+        )
+
         login_submission = parse_json_object_from_request(request)
         try:
             if self.jwt_enabled and (login_submission["type"] ==
@@ -285,6 +294,7 @@ class LoginRestServlet(ClientV1RestServlet):
             raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
 
         user_id = UserID(user, self.hs.hostname).to_string()
+
         auth_handler = self.auth_handler
         registered_user_id = yield auth_handler.check_user_exists(user_id)
         if registered_user_id:

+ 2 - 2
synapse/rest/client/v2_alpha/register.py

@@ -210,8 +210,8 @@ class RegisterRestServlet(RestServlet):
 
         allowed, time_allowed = self.ratelimiter.can_do_action(
             client_addr, time_now_s=time_now,
-            rate_hz=self.hs.config.rc_registration_requests_per_second,
-            burst_count=self.hs.config.rc_registration_request_burst_count,
+            rate_hz=self.hs.config.rc_registration.per_second,
+            burst_count=self.hs.config.rc_registration.burst_count,
             update=False,
         )
 

+ 118 - 0
tests/rest/client/v1/test_login.py

@@ -0,0 +1,118 @@
+import json
+
+from synapse.rest.client.v1 import admin, login
+
+from tests import unittest
+
+LOGIN_URL = b"/_matrix/client/r0/login"
+
+
+class LoginRestServletTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+
+        self.hs = self.setup_test_homeserver()
+        self.hs.config.enable_registration = True
+        self.hs.config.registrations_require_3pid = []
+        self.hs.config.auto_join_rooms = []
+        self.hs.config.enable_registration_captcha = False
+
+        return self.hs
+
+    def test_POST_ratelimiting_per_address(self):
+        self.hs.config.rc_login_address.burst_count = 5
+        self.hs.config.rc_login_address.per_second = 0.17
+
+        # Create different users so we're sure not to be bothered by the per-user
+        # ratelimiter.
+        for i in range(0, 6):
+            self.register_user("kermit" + str(i), "monkey")
+
+        for i in range(0, 6):
+            params = {
+                "type": "m.login.password",
+                "identifier": {
+                    "type": "m.id.user",
+                    "user": "kermit" + str(i),
+                },
+                "password": "monkey",
+            }
+            request_data = json.dumps(params)
+            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
+        # than 1min.
+        self.assertTrue(retry_after_ms < 6000)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        params = {
+            "type": "m.login.password",
+            "identifier": {
+                "type": "m.id.user",
+                "user": "kermit" + str(i),
+            },
+            "password": "monkey",
+        }
+        request_data = json.dumps(params)
+        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_POST_ratelimiting_per_account(self):
+        self.hs.config.rc_login_account.burst_count = 5
+        self.hs.config.rc_login_account.per_second = 0.17
+
+        self.register_user("kermit", "monkey")
+
+        for i in range(0, 6):
+            params = {
+                "type": "m.login.password",
+                "identifier": {
+                    "type": "m.id.user",
+                    "user": "kermit",
+                },
+                "password": "monkey",
+            }
+            request_data = json.dumps(params)
+            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
+        # than 1min.
+        self.assertTrue(retry_after_ms < 6000)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        params = {
+            "type": "m.login.password",
+            "identifier": {
+                "type": "m.id.user",
+                "user": "kermit",
+            },
+            "password": "monkey",
+        }
+        request_data = json.dumps(params)
+        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)

+ 4 - 2
tests/rest/client/v2_alpha/test_register.py

@@ -132,7 +132,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
 
     def test_POST_ratelimiting_guest(self):
-        self.hs.config.rc_registration_request_burst_count = 5
+        self.hs.config.rc_registration.burst_count = 5
+        self.hs.config.rc_registration.per_second = 0.17
 
         for i in range(0, 6):
             url = self.url + b"?kind=guest"
@@ -153,7 +154,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_POST_ratelimiting(self):
-        self.hs.config.rc_registration_request_burst_count = 5
+        self.hs.config.rc_registration.burst_count = 5
+        self.hs.config.rc_registration.per_second = 0.17
 
         for i in range(0, 6):
             params = {

+ 6 - 2
tests/utils.py

@@ -151,8 +151,12 @@ def default_config(name):
     config.admin_contact = None
     config.rc_messages_per_second = 10000
     config.rc_message_burst_count = 10000
-    config.rc_registration_request_burst_count = 3.0
-    config.rc_registration_requests_per_second = 0.17
+    config.rc_registration.per_second = 10000
+    config.rc_registration.burst_count = 10000
+    config.rc_login_address.per_second = 10000
+    config.rc_login_address.burst_count = 10000
+    config.rc_login_account.per_second = 10000
+    config.rc_login_account.burst_count = 10000
     config.saml2_enabled = False
     config.public_baseurl = None
     config.default_identity_server = None