فهرست منبع

Move callback-related code from the PasswordAuthProvider to its own class

Andrew Morgan 1 سال پیش
والد
کامیت
46c0ab559b

+ 20 - 120
synapse/handlers/auth.py

@@ -65,6 +65,10 @@ from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.module_api.callbacks.password_auth_provider_callbacks import (
+    CHECK_3PID_AUTH_CALLBACK,
+    ON_LOGGED_OUT_CALLBACK,
+)
 from synapse.storage.databases.main.registration import (
     LoginTokenExpired,
     LoginTokenLookupResult,
@@ -1096,7 +1100,7 @@ class AuthHandler:
         return self._password_enabled_for_login and self._password_localdb_enabled
 
     def get_supported_login_types(self) -> Iterable[str]:
-        """Get a the login types supported for the /login API
+        """Get the login types supported for the /login API
 
         By default this is just 'm.login.password' (unless password_enabled is
         False in the config file), but password auth providers can provide
@@ -1999,124 +2003,16 @@ def load_single_legacy_password_auth_provider(
     )
 
 
-CHECK_3PID_AUTH_CALLBACK = Callable[
-    [str, str, str],
-    Awaitable[
-        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
-    ],
-]
-ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
-CHECK_AUTH_CALLBACK = Callable[
-    [str, str, JsonDict],
-    Awaitable[
-        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
-    ],
-]
-GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
-    [JsonDict, JsonDict],
-    Awaitable[Optional[str]],
-]
-GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
-    [JsonDict, JsonDict],
-    Awaitable[Optional[str]],
-]
-IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
-
-
 class PasswordAuthProvider:
     """
     A class that the AuthHandler calls when authenticating users
     It allows modules to provide alternative methods for authentication
     """
 
-    def __init__(self) -> None:
-        # lists of callbacks
-        self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
-        self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
-        self.get_username_for_registration_callbacks: List[
-            GET_USERNAME_FOR_REGISTRATION_CALLBACK
-        ] = []
-        self.get_displayname_for_registration_callbacks: List[
-            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
-        ] = []
-        self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
-
-        # Mapping from login type to login parameters
-        self._supported_login_types: Dict[str, Tuple[str, ...]] = {}
-
-        # Mapping from login type to auth checker callbacks
-        self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
-
-    def register_password_auth_provider_callbacks(
-        self,
-        check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
-        on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
-        is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
-        auth_checkers: Optional[
-            Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
-        ] = None,
-        get_username_for_registration: Optional[
-            GET_USERNAME_FOR_REGISTRATION_CALLBACK
-        ] = None,
-        get_displayname_for_registration: Optional[
-            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
-        ] = None,
-    ) -> None:
-        # Register check_3pid_auth callback
-        if check_3pid_auth is not None:
-            self.check_3pid_auth_callbacks.append(check_3pid_auth)
-
-        # register on_logged_out callback
-        if on_logged_out is not None:
-            self.on_logged_out_callbacks.append(on_logged_out)
-
-        if auth_checkers is not None:
-            # register a new supported login_type
-            # Iterate through all of the types being registered
-            for (login_type, fields), callback in auth_checkers.items():
-                # Note: fields may be empty here. This would allow a modules auth checker to
-                # be called with just 'login_type' and no password or other secrets
-
-                # Need to check that all the field names are strings or may get nasty errors later
-                for f in fields:
-                    if not isinstance(f, str):
-                        raise RuntimeError(
-                            "A module tried to register support for login type: %s with parameters %s"
-                            " but all parameter names must be strings"
-                            % (login_type, fields)
-                        )
-
-                # 2 modules supporting the same login type must expect the same fields
-                # e.g. 1 can't expect "pass" if the other expects "password"
-                # so throw an exception if that happens
-                if login_type not in self._supported_login_types.get(login_type, []):
-                    self._supported_login_types[login_type] = fields
-                else:
-                    fields_currently_supported = self._supported_login_types.get(
-                        login_type
-                    )
-                    if fields_currently_supported != fields:
-                        raise RuntimeError(
-                            "A module tried to register support for login type: %s with parameters %s"
-                            " but another module had already registered support for that type with parameters %s"
-                            % (login_type, fields, fields_currently_supported)
-                        )
-
-                # Add the new method to the list of auth_checker_callbacks for this login type
-                self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
-
-        if get_username_for_registration is not None:
-            self.get_username_for_registration_callbacks.append(
-                get_username_for_registration,
-            )
-
-        if get_displayname_for_registration is not None:
-            self.get_displayname_for_registration_callbacks.append(
-                get_displayname_for_registration,
-            )
-
-        if is_3pid_allowed is not None:
-            self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
+    def __init__(self, hs: "HomeServer") -> None:
+        self._module_api_callbacks = (
+            hs.get_module_api_callbacks().password_auth_provider
+        )
 
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
@@ -2126,7 +2022,7 @@ class PasswordAuthProvider:
         to the /login API.
         """
 
-        return self._supported_login_types
+        return self._module_api_callbacks.supported_login_types
 
     async def check_auth(
         self, username: str, login_type: str, login_dict: JsonDict
@@ -2149,7 +2045,7 @@ class PasswordAuthProvider:
 
         # Go through all callbacks for the login type until one returns with a value
         # other than None (i.e. until a callback returns a success)
-        for callback in self.auth_checker_callbacks[login_type]:
+        for callback in self._module_api_callbacks.auth_checker_callbacks[login_type]:
             try:
                 result = await delay_cancellation(
                     callback(username, login_type, login_dict)
@@ -2214,7 +2110,7 @@ class PasswordAuthProvider:
         # (user_id, callback_func), where callback_func should be run
         # after we've finished everything else
 
-        for callback in self.check_3pid_auth_callbacks:
+        for callback in self._module_api_callbacks.check_3pid_auth_callbacks:
             try:
                 result = await delay_cancellation(callback(medium, address, password))
             except CancelledError:
@@ -2272,7 +2168,7 @@ class PasswordAuthProvider:
         self, user_id: str, device_id: Optional[str], access_token: str
     ) -> None:
         # call all of the on_logged_out callbacks
-        for callback in self.on_logged_out_callbacks:
+        for callback in self._module_api_callbacks.on_logged_out_callbacks:
             try:
                 await callback(user_id, device_id, access_token)
             except Exception as e:
@@ -2297,7 +2193,9 @@ class PasswordAuthProvider:
             The localpart to use when registering this user, or None if no module
             returned a localpart.
         """
-        for callback in self.get_username_for_registration_callbacks:
+        for (
+            callback
+        ) in self._module_api_callbacks.get_username_for_registration_callbacks:
             try:
                 res = await delay_cancellation(callback(uia_results, params))
 
@@ -2342,7 +2240,9 @@ class PasswordAuthProvider:
             A tuple which first element is the display name, and the second is an MXC URL
             to the user's avatar.
         """
-        for callback in self.get_displayname_for_registration_callbacks:
+        for (
+            callback
+        ) in self._module_api_callbacks.get_displayname_for_registration_callbacks:
             try:
                 res = await delay_cancellation(callback(uia_results, params))
 
@@ -2385,7 +2285,7 @@ class PasswordAuthProvider:
         Returns:
             Whether the 3PID is allowed to be bound on this homeserver
         """
-        for callback in self.is_3pid_allowed_callbacks:
+        for callback in self._module_api_callbacks.is_3pid_allowed_callbacks:
             try:
                 res = await delay_cancellation(callback(medium, address, registration))
 

+ 10 - 11
synapse/module_api/__init__.py

@@ -42,15 +42,7 @@ from synapse.events import EventBase
 from synapse.events.presence_router import PresenceRouter
 from synapse.events.spamcheck import SpamChecker
 from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
-from synapse.handlers.auth import (
-    CHECK_3PID_AUTH_CALLBACK,
-    CHECK_AUTH_CALLBACK,
-    GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
-    GET_USERNAME_FOR_REGISTRATION_CALLBACK,
-    IS_3PID_ALLOWED_CALLBACK,
-    ON_LOGGED_OUT_CALLBACK,
-    AuthHandler,
-)
+from synapse.handlers.auth import AuthHandler
 from synapse.handlers.device import DeviceHandler
 from synapse.handlers.push_rules import RuleSpec, check_actions
 from synapse.http.client import SimpleHttpClient
@@ -79,6 +71,14 @@ from synapse.module_api.callbacks.background_updater_callbacks import (
     MIN_BATCH_SIZE_CALLBACK,
     ON_UPDATE_CALLBACK,
 )
+from synapse.module_api.callbacks.password_auth_provider_callbacks import (
+    CHECK_3PID_AUTH_CALLBACK,
+    CHECK_AUTH_CALLBACK,
+    GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
+    GET_USERNAME_FOR_REGISTRATION_CALLBACK,
+    IS_3PID_ALLOWED_CALLBACK,
+    ON_LOGGED_OUT_CALLBACK,
+)
 from synapse.module_api.callbacks.presence_router_callbacks import (
     GET_INTERESTED_USERS_CALLBACK,
     GET_USERS_FOR_STATES_CALLBACK,
@@ -271,7 +271,6 @@ class ModuleApi:
         self._public_room_list_manager = PublicRoomListManager(hs)
         self._account_data_manager = AccountDataManager(hs)
 
-        self._password_auth_provider = hs.get_password_auth_provider()
         self._account_data_handler = hs.get_account_data_handler()
 
     #################################################################################
@@ -417,7 +416,7 @@ class ModuleApi:
 
         Added in Synapse v1.46.0.
         """
-        return self._password_auth_provider.register_password_auth_provider_callbacks(
+        return self._callbacks.password_auth_provider.register_callbacks(
             check_3pid_auth=check_3pid_auth,
             on_logged_out=on_logged_out,
             is_3pid_allowed=is_3pid_allowed,

+ 2 - 0
synapse/module_api/callbacks/__init__.py

@@ -14,6 +14,7 @@
 
 from .account_validity_callbacks import AccountValidityModuleApiCallbacks
 from .background_updater_callbacks import BackgroundUpdaterModuleApiCallbacks
+from .password_auth_provider_callbacks import PasswordAuthProviderModuleApiCallbacks
 from .presence_router_callbacks import PresenceRouterModuleApiCallbacks
 from .spam_checker_callbacks import SpamCheckerModuleApiCallbacks
 from .third_party_event_rules_callbacks import ThirdPartyEventRulesModuleApiCallbacks
@@ -27,6 +28,7 @@ class ModuleApiCallbacks:
     def __init__(self) -> None:
         self.account_validity = AccountValidityModuleApiCallbacks()
         self.background_updater = BackgroundUpdaterModuleApiCallbacks()
+        self.password_auth_provider = PasswordAuthProviderModuleApiCallbacks()
         self.presence_router = PresenceRouterModuleApiCallbacks()
         self.spam_checker = SpamCheckerModuleApiCallbacks()
         self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks()

+ 138 - 0
synapse/module_api/callbacks/password_auth_provider_callbacks.py

@@ -0,0 +1,138 @@
+# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020, 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional, Tuple
+
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+    from synapse.module_api import LoginResponse
+
+logger = logging.getLogger(__name__)
+
+
+CHECK_3PID_AUTH_CALLBACK = Callable[
+    [str, str, str],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
+CHECK_AUTH_CALLBACK = Callable[
+    [str, str, JsonDict],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
+GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
+IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
+
+
+class PasswordAuthProviderModuleApiCallbacks:
+    def __init__(self) -> None:
+        # Mapping from login type to login parameters
+        self.supported_login_types: Dict[str, Tuple[str, ...]] = {}
+
+        self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
+        self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+        self.get_username_for_registration_callbacks: List[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
+        self.get_displayname_for_registration_callbacks: List[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
+        self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
+
+        # Mapping from login type to auth checker callbacks
+        self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
+
+    def register_callbacks(
+        self,
+        check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
+        on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+        is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
+        auth_checkers: Optional[
+            Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
+        ] = None,
+        get_username_for_registration: Optional[
+            GET_USERNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
+        get_displayname_for_registration: Optional[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
+    ) -> None:
+        # Register check_3pid_auth callback
+        if check_3pid_auth is not None:
+            self.check_3pid_auth_callbacks.append(check_3pid_auth)
+
+        # register on_logged_out callback
+        if on_logged_out is not None:
+            self.on_logged_out_callbacks.append(on_logged_out)
+
+        if auth_checkers is not None:
+            # register a new supported login_type
+            # Iterate through all of the types being registered
+            for (login_type, fields), callback in auth_checkers.items():
+                # Note: fields may be empty here. This would allow a modules auth checker to
+                # be called with just 'login_type' and no password or other secrets
+
+                # Need to check that all the field names are strings or may get nasty errors later
+                for f in fields:
+                    if not isinstance(f, str):
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but all parameter names must be strings"
+                            % (login_type, fields)
+                        )
+
+                # 2 modules supporting the same login type must expect the same fields
+                # e.g. 1 can't expect "pass" if the other expects "password"
+                # so throw an exception if that happens
+                if login_type not in self.supported_login_types.get(login_type, []):
+                    self.supported_login_types[login_type] = fields
+                else:
+                    fields_currently_supported = self.supported_login_types.get(
+                        login_type
+                    )
+                    if fields_currently_supported != fields:
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but another module had already registered support for that type with parameters %s"
+                            % (login_type, fields, fields_currently_supported)
+                        )
+
+                # Add the new method to the list of auth_checker_callbacks for this login type
+                self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
+
+        if get_username_for_registration is not None:
+            self.get_username_for_registration_callbacks.append(
+                get_username_for_registration,
+            )
+
+        if get_displayname_for_registration is not None:
+            self.get_displayname_for_registration_callbacks.append(
+                get_displayname_for_registration,
+            )
+
+        if is_3pid_allowed is not None:
+            self.is_3pid_allowed_callbacks.append(is_3pid_allowed)

+ 1 - 1
synapse/server.py

@@ -674,7 +674,7 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_password_auth_provider(self) -> PasswordAuthProvider:
-        return PasswordAuthProvider()
+        return PasswordAuthProvider(self)
 
     @cache_in_self
     def get_room_member_handler(self) -> RoomMemberHandler:

+ 7 - 3
tests/handlers/test_password_providers.py

@@ -727,7 +727,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             self.called = True
 
         on_logged_out = Mock(side_effect=on_logged_out)
-        self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
+        self.hs.get_module_api_callbacks().password_auth_provider.on_logged_out_callbacks.append(
             on_logged_out
         )
 
@@ -857,7 +857,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         )
 
         m = Mock(return_value=make_awaitable(False))
-        self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+        self.hs.get_module_api_callbacks().password_auth_provider.is_3pid_allowed_callbacks = [
+            m
+        ]
 
         self.register_user(username, "password")
         tok = self.login(username, "password")
@@ -887,7 +889,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         m.assert_called_once_with("email", "foo@test.com", registration)
 
         m = Mock(return_value=make_awaitable(True))
-        self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
+        self.hs.get_module_api_callbacks().password_auth_provider.is_3pid_allowed_callbacks = [
+            m
+        ]
 
         channel = self.make_request(
             "POST",