Browse Source

Port the Password Auth Providers module interface to the new generic interface (#10548)

Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com>
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
Azrenbeth 2 years ago
parent
commit
cdd308845b

+ 1 - 0
changelog.d/10548.feature

@@ -0,0 +1 @@
+Port the Password Auth Providers module interface to the new generic interface.

+ 1 - 0
docs/SUMMARY.md

@@ -43,6 +43,7 @@
         - [Third-party rules callbacks](modules/third_party_rules_callbacks.md)
         - [Presence router callbacks](modules/presence_router_callbacks.md)
         - [Account validity callbacks](modules/account_validity_callbacks.md)
+        - [Password auth provider callbacks](modules/password_auth_provider_callbacks.md)
         - [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
     - [Workers](workers.md)
       - [Using `synctl` with Workers](synctl_workers.md)

+ 153 - 0
docs/modules/password_auth_provider_callbacks.md

@@ -0,0 +1,153 @@
+# Password auth provider callbacks
+
+Password auth providers offer a way for server administrators to integrate
+their Synapse installation with an external authentication system. The callbacks can be
+registered by using the Module API's `register_password_auth_provider_callbacks` method.
+
+## Callbacks
+
+### `auth_checkers`
+
+```
+ auth_checkers: Dict[Tuple[str,Tuple], Callable]
+```
+
+A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a
+tuple of field names (such as `("password", "secret_thing")`) to authentication checking
+callbacks, which should be of the following form:
+
+```python
+async def check_auth(
+    user: str,
+    login_type: str,
+    login_dict: "synapse.module_api.JsonDict",
+) -> Optional[
+    Tuple[
+        str, 
+        Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
+    ]
+]
+```
+
+The login type and field names should be provided by the user in the
+request to the `/login` API. [The Matrix specification](https://matrix.org/docs/spec/client_server/latest#authentication-types)
+defines some types, however user defined ones are also allowed.
+
+The callback is passed the `user` field provided by the client (which might not be in
+`@username:server` form), the login type, and a dictionary of login secrets passed by
+the client.
+
+If the authentication is successful, the module must return the user's Matrix ID (e.g. 
+`@alice:example.com`) and optionally a callback to be called with the response to the
+`/login` request. If the module doesn't wish to return a callback, it must return `None`
+instead.
+
+If the authentication is unsuccessful, the module must return `None`.
+
+### `check_3pid_auth`
+
+```python
+async def check_3pid_auth(
+    medium: str, 
+    address: str,
+    password: str,
+) -> Optional[
+    Tuple[
+        str, 
+        Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
+    ]
+]
+```
+
+Called when a user attempts to register or log in with a third party identifier,
+such as email. It is passed the medium (eg. `email`), an address (eg. `jdoe@example.com`)
+and the user's password.
+
+If the authentication is successful, the module must return the user's Matrix ID (e.g. 
+`@alice:example.com`) and optionally a callback to be called with the response to the `/login` request.
+If the module doesn't wish to return a callback, it must return None instead.
+
+If the authentication is unsuccessful, the module must return None.
+
+### `on_logged_out`
+
+```python
+async def on_logged_out(
+    user_id: str,
+    device_id: Optional[str],
+    access_token: str
+) -> None
+``` 
+Called during a logout request for a user. It is passed the qualified user ID, the ID of the
+deactivated device (if any: access tokens are occasionally created without an associated
+device ID), and the (now deactivated) access token.
+
+## Example
+
+The example module below implements authentication checkers for two different login types: 
+-  `my.login.type` 
+    - Expects a `my_field` field to be sent to `/login`
+    - Is checked by the method: `self.check_my_login`
+- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
+    - Expects a `password` field to be sent to `/login`
+    - Is checked by the method: `self.check_pass` 
+
+
+```python
+from typing import Awaitable, Callable, Optional, Tuple
+
+import synapse
+from synapse import module_api
+
+
+class MyAuthProvider:
+    def __init__(self, config: dict, api: module_api):
+
+        self.api = api
+
+        self.credentials = {
+            "bob": "building",
+            "@scoop:matrix.org": "digging",
+        }
+
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={
+                ("my.login_type", ("my_field",)): self.check_my_login,
+                ("m.login.password", ("password",)): self.check_pass,
+            },
+        )
+
+    async def check_my_login(
+        self,
+        username: str,
+        login_type: str,
+        login_dict: "synapse.module_api.JsonDict",
+    ) -> Optional[
+        Tuple[
+            str,
+            Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
+        ]
+    ]:
+        if login_type != "my.login_type":
+            return None
+
+        if self.credentials.get(username) == login_dict.get("my_field"):
+            return self.api.get_qualified_user_id(username)
+
+    async def check_pass(
+        self,
+        username: str,
+        login_type: str,
+        login_dict: "synapse.module_api.JsonDict",
+    ) -> Optional[
+        Tuple[
+            str,
+            Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
+        ]
+    ]:
+        if login_type != "m.login.password":
+            return None
+
+        if self.credentials.get(username) == login_dict.get("password"):
+            return self.api.get_qualified_user_id(username)
+```

+ 3 - 0
docs/modules/porting_legacy_module.md

@@ -12,6 +12,9 @@ should register this resource in its `__init__` method using the `register_web_r
 method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for
 more info).
 
+There is no longer a `get_db_schema_files` callback provided for password auth provider modules. Any
+changes to the database should now be made by the module using the module API class.
+
 The module's author should also update any example in the module's configuration to only
 use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules)
 for more info).

+ 6 - 0
docs/password_auth_providers.md

@@ -1,3 +1,9 @@
+<h2 style="color:red">
+This page of the Synapse documentation is now deprecated. For up to date
+documentation on setting up or writing a password auth provider module, please see
+<a href="modules.md">this page</a>.
+</h2>
+
 # Password auth provider modules
 
 Password auth providers offer a way for server administrators to

+ 0 - 28
docs/sample_config.yaml

@@ -2260,34 +2260,6 @@ email:
     #email_validation: "[%(server_name)s] Validate your email"
 
 
-# Password providers allow homeserver administrators to integrate
-# their Synapse installation with existing authentication methods
-# ex. LDAP, external tokens, etc.
-#
-# For more information and known implementations, please see
-# https://matrix-org.github.io/synapse/latest/password_auth_providers.html
-#
-# Note: instances wishing to use SAML or CAS authentication should
-# instead use the `saml2_config` or `cas_config` options,
-# respectively.
-#
-password_providers:
-#    # Example config for an LDAP auth provider
-#    - module: "ldap_auth_provider.LdapAuthProvider"
-#      config:
-#        enabled: true
-#        uri: "ldap://ldap.example.com:389"
-#        start_tls: true
-#        base: "ou=users,dc=example,dc=com"
-#        attributes:
-#           uid: "cn"
-#           mail: "email"
-#           name: "givenName"
-#        #bind_dn:
-#        #bind_password:
-#        #filter: "(objectClass=posixAccount)"
-
-
 
 ## Push ##
 

+ 2 - 0
synapse/app/_base.py

@@ -42,6 +42,7 @@ from synapse.crypto import context_factory
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.logging.context import PreserveLoggingContext
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -379,6 +380,7 @@ async def start(hs: "HomeServer"):
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
     load_legacy_presence_router(hs)
+    load_legacy_password_auth_providers(hs)
 
     # If we've configured an expiry time for caches, start the background job now.
     setup_expire_lru_cache_entries(hs)

+ 23 - 30
synapse/config/password_auth_providers.py

@@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config):
     section = "authproviders"
 
     def read_config(self, config, **kwargs):
+        """Parses the old password auth providers config. The config format looks like this:
+
+        password_providers:
+           # Example config for an LDAP auth provider
+           - module: "ldap_auth_provider.LdapAuthProvider"
+             config:
+               enabled: true
+               uri: "ldap://ldap.example.com:389"
+               start_tls: true
+               base: "ou=users,dc=example,dc=com"
+               attributes:
+                  uid: "cn"
+                  mail: "email"
+                  name: "givenName"
+               #bind_dn:
+               #bind_password:
+               #filter: "(objectClass=posixAccount)"
+
+        We expect admins to use modules for this feature (which is why it doesn't appear
+        in the sample config file), but we want to keep support for it around for a bit
+        for backwards compatibility.
+        """
+
         self.password_providers: List[Tuple[Type, Any]] = []
         providers = []
 
@@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config):
             )
 
             self.password_providers.append((provider_class, provider_config))
-
-    def generate_config_section(self, **kwargs):
-        return """\
-        # Password providers allow homeserver administrators to integrate
-        # their Synapse installation with existing authentication methods
-        # ex. LDAP, external tokens, etc.
-        #
-        # For more information and known implementations, please see
-        # https://matrix-org.github.io/synapse/latest/password_auth_providers.html
-        #
-        # Note: instances wishing to use SAML or CAS authentication should
-        # instead use the `saml2_config` or `cas_config` options,
-        # respectively.
-        #
-        password_providers:
-        #    # Example config for an LDAP auth provider
-        #    - module: "ldap_auth_provider.LdapAuthProvider"
-        #      config:
-        #        enabled: true
-        #        uri: "ldap://ldap.example.com:389"
-        #        start_tls: true
-        #        base: "ou=users,dc=example,dc=com"
-        #        attributes:
-        #           uid: "cn"
-        #           mail: "email"
-        #           name: "givenName"
-        #        #bind_dn:
-        #        #bind_password:
-        #        #filter: "(objectClass=posixAccount)"
-        """

+ 388 - 140
synapse/handlers/auth.py

@@ -200,46 +200,13 @@ class AuthHandler:
 
         self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
 
-        # we can't use hs.get_module_api() here, because to do so will create an
-        # import loop.
-        #
-        # TODO: refactor this class to separate the lower-level stuff that
-        #   ModuleApi can use from the higher-level stuff that uses ModuleApi, as
-        #   better way to break the loop
-        account_handler = ModuleApi(hs, self)
-
-        self.password_providers = [
-            PasswordProvider.load(module, config, account_handler)
-            for module, config in hs.config.authproviders.password_providers
-        ]
-
-        logger.info("Extra password_providers: %s", self.password_providers)
+        self.password_auth_provider = hs.get_password_auth_provider()
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.auth.password_enabled
         self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
 
-        # start out by assuming PASSWORD is enabled; we will remove it later if not.
-        login_types = set()
-        if self._password_localdb_enabled:
-            login_types.add(LoginType.PASSWORD)
-
-        for provider in self.password_providers:
-            login_types.update(provider.get_supported_login_types().keys())
-
-        if not self._password_enabled:
-            login_types.discard(LoginType.PASSWORD)
-
-        # Some clients just pick the first type in the list. In this case, we want
-        # them to use PASSWORD (rather than token or whatever), so we want to make sure
-        # that comes first, where it's present.
-        self._supported_login_types = []
-        if LoginType.PASSWORD in login_types:
-            self._supported_login_types.append(LoginType.PASSWORD)
-            login_types.remove(LoginType.PASSWORD)
-        self._supported_login_types.extend(login_types)
-
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -427,11 +394,10 @@ class AuthHandler:
                     ui_auth_types.add(LoginType.PASSWORD)
 
         # also allow auth from password providers
-        for provider in self.password_providers:
-            for t in provider.get_supported_login_types().keys():
-                if t == LoginType.PASSWORD and not self._password_enabled:
-                    continue
-                ui_auth_types.add(t)
+        for t in self.password_auth_provider.get_supported_login_types().keys():
+            if t == LoginType.PASSWORD and not self._password_enabled:
+                continue
+            ui_auth_types.add(t)
 
         # if sso is enabled, allow the user to log in via SSO iff they have a mapping
         # from sso to mxid.
@@ -1038,7 +1004,25 @@ class AuthHandler:
         Returns:
             login types
         """
-        return self._supported_login_types
+        # Load any login types registered by modules
+        # This is stored in the password_auth_provider so this doesn't trigger
+        # any callbacks
+        types = list(self.password_auth_provider.get_supported_login_types().keys())
+
+        # This list should include PASSWORD if (either _password_localdb_enabled is
+        # true or if one of the modules registered it) AND _password_enabled is true
+        # Also:
+        # Some clients just pick the first type in the list. In this case, we want
+        # them to use PASSWORD (rather than token or whatever), so we want to make sure
+        # that comes first, where it's present.
+        if LoginType.PASSWORD in types:
+            types.remove(LoginType.PASSWORD)
+            if self._password_enabled:
+                types.insert(0, LoginType.PASSWORD)
+        elif self._password_localdb_enabled and self._password_enabled:
+            types.insert(0, LoginType.PASSWORD)
+
+        return types
 
     async def validate_login(
         self,
@@ -1217,15 +1201,20 @@ class AuthHandler:
 
         known_login_type = False
 
-        for provider in self.password_providers:
-            supported_login_types = provider.get_supported_login_types()
-            if login_type not in supported_login_types:
-                # this password provider doesn't understand this login type
-                continue
-
+        # Check if login_type matches a type registered by one of the modules
+        # We don't need to remove LoginType.PASSWORD from the list if password login is
+        # disabled, since if that were the case then by this point we know that the
+        # login_type is not LoginType.PASSWORD
+        supported_login_types = self.password_auth_provider.get_supported_login_types()
+        # check if the login type being used is supported by a module
+        if login_type in supported_login_types:
+            # Make a note that this login type is supported by the server
             known_login_type = True
+            # Get all the fields expected for this login types
             login_fields = supported_login_types[login_type]
 
+            # go through the login submission and keep track of which required fields are
+            # provided/not provided
             missing_fields = []
             login_dict = {}
             for f in login_fields:
@@ -1233,6 +1222,7 @@ class AuthHandler:
                     missing_fields.append(f)
                 else:
                     login_dict[f] = login_submission[f]
+            # raise an error if any of the expected fields for that login type weren't provided
             if missing_fields:
                 raise SynapseError(
                     400,
@@ -1240,10 +1230,15 @@ class AuthHandler:
                     % (login_type, missing_fields),
                 )
 
-            result = await provider.check_auth(username, login_type, login_dict)
+            # call all of the check_auth hooks for that login_type
+            # it will return a result once the first success is found (or None otherwise)
+            result = await self.password_auth_provider.check_auth(
+                username, login_type, login_dict
+            )
             if result:
                 return result
 
+        # if no module managed to authenticate the user, then fallback to built in password based auth
         if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
             known_login_type = True
 
@@ -1282,11 +1277,16 @@ class AuthHandler:
             completed login/registration, or `None`. If authentication was
             unsuccessful, `user_id` and `callback` are both `None`.
         """
-        for provider in self.password_providers:
-            result = await provider.check_3pid_auth(medium, address, password)
-            if result:
-                return result
+        # call all of the check_3pid_auth callbacks
+        # Result will be from the first callback that returns something other than None
+        # If all the callbacks return None, then result is also set to None
+        result = await self.password_auth_provider.check_3pid_auth(
+            medium, address, password
+        )
+        if result:
+            return result
 
+        # if result is None then return (None, None)
         return None, None
 
     async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
@@ -1365,13 +1365,12 @@ class AuthHandler:
         user_info = await self.auth.get_user_by_access_token(access_token)
         await self.store.delete_access_token(access_token)
 
-        # see if any of our auth providers want to know about this
-        for provider in self.password_providers:
-            await provider.on_logged_out(
-                user_id=user_info.user_id,
-                device_id=user_info.device_id,
-                access_token=access_token,
-            )
+        # see if any modules want to know about this
+        await self.password_auth_provider.on_logged_out(
+            user_id=user_info.user_id,
+            device_id=user_info.device_id,
+            access_token=access_token,
+        )
 
         # delete pushers associated with this access token
         if user_info.token_id is not None:
@@ -1398,12 +1397,11 @@ class AuthHandler:
             user_id, except_token_id=except_token_id, device_id=device_id
         )
 
-        # see if any of our auth providers want to know about this
-        for provider in self.password_providers:
-            for token, _, device_id in tokens_and_devices:
-                await provider.on_logged_out(
-                    user_id=user_id, device_id=device_id, access_token=token
-                )
+        # see if any modules want to know about this
+        for token, _, device_id in tokens_and_devices:
+            await self.password_auth_provider.on_logged_out(
+                user_id=user_id, device_id=device_id, access_token=token
+            )
 
         # delete pushers associated with the access tokens
         await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1811,40 +1809,228 @@ class MacaroonGenerator:
         return macaroon
 
 
-class PasswordProvider:
-    """Wrapper for a password auth provider module
+def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
+    module_api = hs.get_module_api()
+    for module, config in hs.config.authproviders.password_providers:
+        load_single_legacy_password_auth_provider(
+            module=module, config=config, api=module_api
+        )
 
-    This class abstracts out all of the backwards-compatibility hacks for
-    password providers, to provide a consistent interface.
-    """
 
-    @classmethod
-    def load(
-        cls, module: Type, config: JsonDict, module_api: ModuleApi
-    ) -> "PasswordProvider":
-        try:
-            pp = module(config=config, account_handler=module_api)
-        except Exception as e:
-            logger.error("Error while initializing %r: %s", module, e)
-            raise
-        return cls(pp, module_api)
+def load_single_legacy_password_auth_provider(
+    module: Type, config: JsonDict, api: ModuleApi
+) -> None:
+    try:
+        provider = module(config=config, account_handler=api)
+    except Exception as e:
+        logger.error("Error while initializing %r: %s", module, e)
+        raise
+
+    # The known hooks. If a module implements a method who's name appears in this set
+    # we'll want to register it
+    password_auth_provider_methods = {
+        "check_3pid_auth",
+        "on_logged_out",
+    }
+
+    # All methods that the module provides should be async, but this wasn't enforced
+    # in the old module system, so we wrap them if needed
+    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        # We need to wrap check_password because its old form would return a boolean
+        # but we now want it to behave just like check_auth() and return the matrix id of
+        # the user if authentication succeeded or None otherwise
+        if f.__name__ == "check_password":
+
+            async def wrapped_check_password(
+                username: str, login_type: str, login_dict: JsonDict
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                matrix_user_id = api.get_qualified_user_id(username)
+                password = login_dict["password"]
+
+                is_valid = await f(matrix_user_id, password)
+
+                if is_valid:
+                    return matrix_user_id, None
+
+                return None
 
-    def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
-        self._pp = pp
-        self._module_api = module_api
+            return wrapped_check_password
+
+        # We need to wrap check_auth as in the old form it could return
+        # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+        if f.__name__ == "check_auth":
+
+            async def wrapped_check_auth(
+                username: str, login_type: str, login_dict: JsonDict
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                result = await f(username, login_type, login_dict)
+
+                if isinstance(result, str):
+                    return result, None
+
+                return result
+
+            return wrapped_check_auth
+
+        # We need to wrap check_3pid_auth as in the old form it could return
+        # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+        if f.__name__ == "check_3pid_auth":
+
+            async def wrapped_check_3pid_auth(
+                medium: str, address: str, password: str
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                result = await f(medium, address, password)
+
+                if isinstance(result, str):
+                    return result, None
+
+                return result
 
-        self._supported_login_types = {}
+            return wrapped_check_3pid_auth
 
-        # grandfather in check_password support
-        if hasattr(self._pp, "check_password"):
-            self._supported_login_types[LoginType.PASSWORD] = ("password",)
+        def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
+            # mypy doesn't do well across function boundaries so we need to tell it
+            # f is definitely not None.
+            assert f is not None
 
-        g = getattr(self._pp, "get_supported_login_types", None)
-        if g:
-            self._supported_login_types.update(g())
+            return maybe_awaitable(f(*args, **kwargs))
 
-    def __str__(self) -> str:
-        return str(self._pp)
+        return run
+
+    # populate hooks with the implemented methods, wrapped with async_wrapper
+    hooks = {
+        hook: async_wrapper(getattr(provider, hook, None))
+        for hook in password_auth_provider_methods
+    }
+
+    supported_login_types = {}
+    # call get_supported_login_types and add that to the dict
+    g = getattr(provider, "get_supported_login_types", None)
+    if g is not None:
+        # Note the old module style also called get_supported_login_types at loading time
+        # and it is synchronous
+        supported_login_types.update(g())
+
+    auth_checkers = {}
+    # Legacy modules have a check_auth method which expects to be called with one of
+    # the keys returned by get_supported_login_types. New style modules register a
+    # dictionary of login_type->check_auth_method mappings
+    check_auth = async_wrapper(getattr(provider, "check_auth", None))
+    if check_auth is not None:
+        for login_type, fields in supported_login_types.items():
+            # need tuple(fields) since fields can be any Iterable type (so may not be hashable)
+            auth_checkers[(login_type, tuple(fields))] = check_auth
+
+    # if it has a "check_password" method then it should handle all auth checks
+    # with login type of LoginType.PASSWORD
+    check_password = async_wrapper(getattr(provider, "check_password", None))
+    if check_password is not None:
+        # need to use a tuple here for ("password",) not a list since lists aren't hashable
+        auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
+
+    api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
+
+
+CHECK_3PID_AUTH_CALLBACK = Callable[
+    [str, str, str],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
+CHECK_AUTH_CALLBACK = Callable[
+    [str, str, JsonDict],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+
+
+class PasswordAuthProvider:
+    """
+    A class that the AuthHandler calls when authenticating users
+    It allows modules to provide alternative methods for authentication
+    """
+
+    def __init__(self) -> None:
+        # lists of callbacks
+        self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
+        self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+
+        # Mapping from login type to login parameters
+        self._supported_login_types: Dict[str, Iterable[str]] = {}
+
+        # Mapping from login type to auth checker callbacks
+        self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
+
+    def register_password_auth_provider_callbacks(
+        self,
+        check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
+        on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+        auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
+    ) -> None:
+        # Register check_3pid_auth callback
+        if check_3pid_auth is not None:
+            self.check_3pid_auth_callbacks.append(check_3pid_auth)
+
+        # register on_logged_out callback
+        if on_logged_out is not None:
+            self.on_logged_out_callbacks.append(on_logged_out)
+
+        if auth_checkers is not None:
+            # register a new supported login_type
+            # Iterate through all of the types being registered
+            for (login_type, fields), callback in auth_checkers.items():
+                # Note: fields may be empty here. This would allow a modules auth checker to
+                # be called with just 'login_type' and no password or other secrets
+
+                # Need to check that all the field names are strings or may get nasty errors later
+                for f in fields:
+                    if not isinstance(f, str):
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but all parameter names must be strings"
+                            % (login_type, fields)
+                        )
+
+                # 2 modules supporting the same login type must expect the same fields
+                # e.g. 1 can't expect "pass" if the other expects "password"
+                # so throw an exception if that happens
+                if login_type not in self._supported_login_types.get(login_type, []):
+                    self._supported_login_types[login_type] = fields
+                else:
+                    fields_currently_supported = self._supported_login_types.get(
+                        login_type
+                    )
+                    if fields_currently_supported != fields:
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but another module had already registered support for that type with parameters %s"
+                            % (login_type, fields, fields_currently_supported)
+                        )
+
+                # Add the new method to the list of auth_checker_callbacks for this login type
+                self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
 
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
@@ -1852,20 +2038,15 @@ class PasswordProvider:
         Returns a map from a login type identifier (such as m.login.password) to an
         iterable giving the fields which must be provided by the user in the submission
         to the /login API.
-
-        This wrapper adds m.login.password to the list if the underlying password
-        provider supports the check_password() api.
         """
+
         return self._supported_login_types
 
     async def check_auth(
         self, username: str, login_type: str, login_dict: JsonDict
-    ) -> Optional[Tuple[str, Optional[Callable]]]:
+    ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
         """Check if the user has presented valid login credentials
 
-        This wrapper also calls check_password() if the underlying password provider
-        supports the check_password() api and the login type is m.login.password.
-
         Args:
             username: user id presented by the client. Either an MXID or an unqualified
                 username.
@@ -1879,63 +2060,130 @@ class PasswordProvider:
             user, and `callback` is an optional callback which will be called with the
             result from the /login call (including access_token, device_id, etc.)
         """
-        # first grandfather in a call to check_password
-        if login_type == LoginType.PASSWORD:
-            check_password = getattr(self._pp, "check_password", None)
-            if check_password:
-                qualified_user_id = self._module_api.get_qualified_user_id(username)
-                is_valid = await check_password(
-                    qualified_user_id, login_dict["password"]
-                )
-                if is_valid:
-                    return qualified_user_id, None
 
-        check_auth = getattr(self._pp, "check_auth", None)
-        if not check_auth:
-            return None
-        result = await check_auth(username, login_type, login_dict)
+        # Go through all callbacks for the login type until one returns with a value
+        # other than None (i.e. until a callback returns a success)
+        for callback in self.auth_checker_callbacks[login_type]:
+            try:
+                result = await callback(username, login_type, login_dict)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
 
-        # Check if the return value is a str or a tuple
-        if isinstance(result, str):
-            # If it's a str, set callback function to None
-            return result, None
+            if result is not None:
+                # Check that the callback returned a Tuple[str, Optional[Callable]]
+                # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+                # result is always the right type, but as it is 3rd party code it might not be
+
+                if not isinstance(result, tuple) or len(result) != 2:
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
 
-        return result
+                # pull out the two parts of the tuple so we can do type checking
+                str_result, callback_result = result
+
+                # the 1st item in the tuple should be a str
+                if not isinstance(str_result, str):
+                    logger.warning(  # type: ignore[unreachable]
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # the second should be Optional[Callable]
+                if callback_result is not None:
+                    if not callable(callback_result):
+                        logger.warning(  # type: ignore[unreachable]
+                            "Wrong type returned by module API callback %s: %s, expected"
+                            " Optional[Tuple[str, Optional[Callable]]]",
+                            callback,
+                            result,
+                        )
+                        continue
+
+                # The result is a (str, Optional[callback]) tuple so return the successful result
+                return result
+
+        # If this point has been reached then none of the callbacks successfully authenticated
+        # the user so return None
+        return None
 
     async def check_3pid_auth(
         self, medium: str, address: str, password: str
-    ) -> Optional[Tuple[str, Optional[Callable]]]:
-        g = getattr(self._pp, "check_3pid_auth", None)
-        if not g:
-            return None
-
+    ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
         # This function is able to return a deferred that either
         # resolves None, meaning authentication failure, or upon
         # success, to a str (which is the user_id) or a tuple of
         # (user_id, callback_func), where callback_func should be run
         # after we've finished everything else
-        result = await g(medium, address, password)
 
-        # Check if the return value is a str or a tuple
-        if isinstance(result, str):
-            # If it's a str, set callback function to None
-            return result, None
+        for callback in self.check_3pid_auth_callbacks:
+            try:
+                result = await callback(medium, address, password)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
 
-        return result
+            if result is not None:
+                # Check that the callback returned a Tuple[str, Optional[Callable]]
+                # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+                # result is always the right type, but as it is 3rd party code it might not be
+
+                if not isinstance(result, tuple) or len(result) != 2:
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # pull out the two parts of the tuple so we can do type checking
+                str_result, callback_result = result
+
+                # the 1st item in the tuple should be a str
+                if not isinstance(str_result, str):
+                    logger.warning(  # type: ignore[unreachable]
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # the second should be Optional[Callable]
+                if callback_result is not None:
+                    if not callable(callback_result):
+                        logger.warning(  # type: ignore[unreachable]
+                            "Wrong type returned by module API callback %s: %s, expected"
+                            " Optional[Tuple[str, Optional[Callable]]]",
+                            callback,
+                            result,
+                        )
+                        continue
+
+                # The result is a (str, Optional[callback]) tuple so return the successful result
+                return result
+
+        # If this point has been reached then none of the callbacks successfully authenticated
+        # the user so return None
+        return None
 
     async def on_logged_out(
         self, user_id: str, device_id: Optional[str], access_token: str
     ) -> None:
-        g = getattr(self._pp, "on_logged_out", None)
-        if not g:
-            return
 
-        # This might return an awaitable, if it does block the log out
-        # until it completes.
-        await maybe_awaitable(
-            g(
-                user_id=user_id,
-                device_id=device_id,
-                access_token=access_token,
-            )
-        )
+        # call all of the on_logged_out callbacks
+        for callback in self.on_logged_out_callbacks:
+            try:
+                callback(user_id, device_id, access_token)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue

+ 9 - 0
synapse/module_api/__init__.py

@@ -45,6 +45,7 @@ from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.rest.client.login import LoginResponse
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
@@ -83,6 +84,8 @@ __all__ = [
     "DirectServeJsonResource",
     "ModuleApi",
     "PRESENCE_ALL_USERS",
+    "LoginResponse",
+    "JsonDict",
 ]
 
 logger = logging.getLogger(__name__)
@@ -139,6 +142,7 @@ class ModuleApi:
         self._spam_checker = hs.get_spam_checker()
         self._account_validity_handler = hs.get_account_validity_handler()
         self._third_party_event_rules = hs.get_third_party_event_rules()
+        self._password_auth_provider = hs.get_password_auth_provider()
         self._presence_router = hs.get_presence_router()
 
     #################################################################################
@@ -164,6 +168,11 @@ class ModuleApi:
         """Registers callbacks for presence router capabilities."""
         return self._presence_router.register_presence_router_callbacks
 
+    @property
+    def register_password_auth_provider_callbacks(self):
+        """Registers callbacks for password auth provider capabilities."""
+        return self._password_auth_provider.register_password_auth_provider_callbacks
+
     def register_web_resource(self, path: str, resource: IResource):
         """Registers a web resource to be served at the given path.
 

+ 5 - 1
synapse/server.py

@@ -65,7 +65,7 @@ from synapse.handlers.account_data import AccountDataHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.admin import AdminHandler
 from synapse.handlers.appservice import ApplicationServicesHandler
-from synapse.handlers.auth import AuthHandler, MacaroonGenerator
+from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider
 from synapse.handlers.cas import CasHandler
 from synapse.handlers.deactivate_account import DeactivateAccountHandler
 from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
@@ -687,6 +687,10 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_third_party_event_rules(self) -> ThirdPartyEventRules:
         return ThirdPartyEventRules(self)
 
+    @cache_in_self
+    def get_password_auth_provider(self) -> PasswordAuthProvider:
+        return PasswordAuthProvider()
+
     @cache_in_self
     def get_room_member_handler(self) -> RoomMemberHandler:
         if self.config.worker.worker_app:

+ 2 - 0
synapse/storage/prepare_database.py

@@ -549,6 +549,8 @@ def _apply_module_schemas(
         database_engine:
         config: application config
     """
+    # This is the old way for password_auth_provider modules to make changes
+    # to the database. This should instead be done using the module API
     for (mod, _config) in config.authproviders.password_providers:
         if not hasattr(mod, "get_db_schema_files"):
             continue

+ 197 - 26
tests/handlers/test_password_providers.py

@@ -20,6 +20,8 @@ from unittest.mock import Mock
 from twisted.internet import defer
 
 import synapse
+from synapse.handlers.auth import load_legacy_password_auth_providers
+from synapse.module_api import ModuleApi
 from synapse.rest.client import devices, login
 from synapse.types import JsonDict
 
@@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi
 mock_password_provider = Mock()
 
 
-class PasswordOnlyAuthProvider:
-    """A password_provider which only implements `check_password`."""
+class LegacyPasswordOnlyAuthProvider:
+    """A legacy password_provider which only implements `check_password`."""
 
     @staticmethod
     def parse_config(self):
@@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider:
         return mock_password_provider.check_password(*args)
 
 
-class CustomAuthProvider:
-    """A password_provider which implements a custom login type."""
+class LegacyCustomAuthProvider:
+    """A legacy password_provider which implements a custom login type."""
 
     @staticmethod
     def parse_config(self):
@@ -67,7 +69,23 @@ class CustomAuthProvider:
         return mock_password_provider.check_auth(*args)
 
 
-class PasswordCustomAuthProvider:
+class CustomAuthProvider:
+    """A module which registers password_auth_provider callbacks for a custom login type."""
+
+    @staticmethod
+    def parse_config(self):
+        pass
+
+    def __init__(self, config, api: ModuleApi):
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+        )
+
+    def check_auth(self, *args):
+        return mock_password_provider.check_auth(*args)
+
+
+class LegacyPasswordCustomAuthProvider:
     """A password_provider which implements password login via `check_auth`, as well
     as a custom type."""
 
@@ -85,8 +103,32 @@ class PasswordCustomAuthProvider:
         return mock_password_provider.check_auth(*args)
 
 
-def providers_config(*providers: Type[Any]) -> dict:
-    """Returns a config dict that will enable the given password auth providers"""
+class PasswordCustomAuthProvider:
+    """A module which registers password_auth_provider callbacks for a custom login type.
+    as well as a password login"""
+
+    @staticmethod
+    def parse_config(self):
+        pass
+
+    def __init__(self, config, api: ModuleApi):
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={
+                ("test.login_type", ("test_field",)): self.check_auth,
+                ("m.login.password", ("password",)): self.check_auth,
+            },
+        )
+        pass
+
+    def check_auth(self, *args):
+        return mock_password_provider.check_auth(*args)
+
+    def check_pass(self, *args):
+        return mock_password_provider.check_password(*args)
+
+
+def legacy_providers_config(*providers: Type[Any]) -> dict:
+    """Returns a config dict that will enable the given legacy password auth providers"""
     return {
         "password_providers": [
             {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
@@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict:
     }
 
 
+def providers_config(*providers: Type[Any]) -> dict:
+    """Returns a config dict that will enable the given modules"""
+    return {
+        "modules": [
+            {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+            for provider in providers
+        ]
+    }
+
+
 class PasswordAuthProviderTests(unittest.HomeserverTestCase):
     servlets = [
         synapse.rest.admin.register_servlets,
@@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
         super().setUp()
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_password_only_auth_provider_login(self):
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+        # Load the modules into the homeserver
+        module_api = hs.get_module_api()
+        for module, config in hs.config.modules.loaded_modules:
+            module(config=config, api=module_api)
+        load_legacy_password_auth_providers(hs)
+
+        return hs
+
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_password_only_auth_progiver_login_legacy(self):
+        self.password_only_auth_provider_login_test_body()
+
+    def password_only_auth_provider_login_test_body(self):
         # login flows should only have m.login.password
         flows = self._get_login_flows()
         self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "@ USER🙂NAME :test", " pASS😢word "
         )
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_password_only_auth_provider_ui_auth(self):
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_password_only_auth_provider_ui_auth_legacy(self):
+        self.password_only_auth_provider_ui_auth_test_body()
+
+    def password_only_auth_provider_ui_auth_test_body(self):
         """UI Auth should delegate correctly to the password provider"""
 
         # create the user, otherwise access doesn't work
@@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_local_user_fallback_login(self):
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_local_user_fallback_login_legacy(self):
+        self.local_user_fallback_login_test_body()
+
+    def local_user_fallback_login_test_body(self):
         """rejected login should fall back to local db"""
         self.register_user("localuser", "localpass")
 
@@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual("@localuser:test", channel.json_body["user_id"])
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_local_user_fallback_ui_auth(self):
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_local_user_fallback_ui_auth_legacy(self):
+        self.local_user_fallback_ui_auth_test_body()
+
+    def local_user_fallback_ui_auth_test_body(self):
         """rejected login should fall back to local db"""
         self.register_user("localuser", "localpass")
 
@@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
-            **providers_config(PasswordOnlyAuthProvider),
+            **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
             "password_config": {"localdb_enabled": False},
         }
     )
-    def test_no_local_user_fallback_login(self):
+    def test_no_local_user_fallback_login_legacy(self):
+        self.no_local_user_fallback_login_test_body()
+
+    def no_local_user_fallback_login_test_body(self):
         """localdb_enabled can block login with the local password"""
         self.register_user("localuser", "localpass")
 
@@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
-            **providers_config(PasswordOnlyAuthProvider),
+            **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
             "password_config": {"localdb_enabled": False},
         }
     )
-    def test_no_local_user_fallback_ui_auth(self):
+    def test_no_local_user_fallback_ui_auth_legacy(self):
+        self.no_local_user_fallback_ui_auth_test_body()
+
+    def no_local_user_fallback_ui_auth_test_body(self):
         """localdb_enabled can block ui auth with the local password"""
         self.register_user("localuser", "localpass")
 
@@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
-            **providers_config(PasswordOnlyAuthProvider),
+            **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
             "password_config": {"enabled": False},
         }
     )
-    def test_password_auth_disabled(self):
+    def test_password_auth_disabled_legacy(self):
+        self.password_auth_disabled_test_body()
+
+    def password_auth_disabled_test_body(self):
         """password auth doesn't work if it's disabled across the board"""
         # login flows should be empty
         flows = self._get_login_flows()
@@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_password.assert_not_called()
 
+    @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+    def test_custom_auth_provider_login_legacy(self):
+        self.custom_auth_provider_login_test_body()
+
     @override_config(providers_config(CustomAuthProvider))
     def test_custom_auth_provider_login(self):
+        self.custom_auth_provider_login_test_body()
+
+    def custom_auth_provider_login_test_body(self):
         # login flows should have the custom flow and m.login.password, since we
         # haven't disabled local password lookup.
         # (password must come first, because reasons)
@@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
 
-        mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+        mock_password_provider.check_auth.return_value = defer.succeed(
+            ("@user:bz", None)
+        )
         channel = self._send_login("test.login_type", "u", test_field="y")
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual("@user:bz", channel.json_body["user_id"])
@@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         # in these cases, but at least we can guard against the API changing
         # unexpectedly
         mock_password_provider.check_auth.return_value = defer.succeed(
-            "@ MALFORMED! :bz"
+            ("@ MALFORMED! :bz", None)
         )
         channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
         self.assertEqual(channel.code, 200, channel.result)
@@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
         )
 
+    @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+    def test_custom_auth_provider_ui_auth_legacy(self):
+        self.custom_auth_provider_ui_auth_test_body()
+
     @override_config(providers_config(CustomAuthProvider))
     def test_custom_auth_provider_ui_auth(self):
+        self.custom_auth_provider_ui_auth_test_body()
+
+    def custom_auth_provider_ui_auth_test_body(self):
         # register the user and log in twice, to get two devices
         self.register_user("localuser", "localpass")
         tok1 = self.login("localuser", "localpass")
@@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
 
         # right params, but authing as the wrong user
-        mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+        mock_password_provider.check_auth.return_value = defer.succeed(
+            ("@user:bz", None)
+        )
         body["auth"]["test_field"] = "foo"
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 403)
@@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
         # and finally, succeed
         mock_password_provider.check_auth.return_value = defer.succeed(
-            "@localuser:test"
+            ("@localuser:test", None)
         )
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 200)
@@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "localuser", "test.login_type", {"test_field": "foo"}
         )
 
+    @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+    def test_custom_auth_provider_callback_legacy(self):
+        self.custom_auth_provider_callback_test_body()
+
     @override_config(providers_config(CustomAuthProvider))
     def test_custom_auth_provider_callback(self):
+        self.custom_auth_provider_callback_test_body()
+
+    def custom_auth_provider_callback_test_body(self):
         callback = Mock(return_value=defer.succeed(None))
 
         mock_password_provider.check_auth.return_value = defer.succeed(
@@ -410,10 +518,22 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         for p in ["user_id", "access_token", "device_id", "home_server"]:
             self.assertIn(p, call_args[0])
 
+    @override_config(
+        {
+            **legacy_providers_config(LegacyCustomAuthProvider),
+            "password_config": {"enabled": False},
+        }
+    )
+    def test_custom_auth_password_disabled_legacy(self):
+        self.custom_auth_password_disabled_test_body()
+
     @override_config(
         {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
     )
     def test_custom_auth_password_disabled(self):
+        self.custom_auth_password_disabled_test_body()
+
+    def custom_auth_password_disabled_test_body(self):
         """Test login with a custom auth provider where password login is disabled"""
         self.register_user("localuser", "localpass")
 
@@ -425,6 +545,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
 
+    @override_config(
+        {
+            **legacy_providers_config(LegacyCustomAuthProvider),
+            "password_config": {"enabled": False, "localdb_enabled": False},
+        }
+    )
+    def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
+        self.custom_auth_password_disabled_localdb_enabled_test_body()
+
     @override_config(
         {
             **providers_config(CustomAuthProvider),
@@ -432,6 +561,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_custom_auth_password_disabled_localdb_enabled(self):
+        self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+    def custom_auth_password_disabled_localdb_enabled_test_body(self):
         """Check the localdb_enabled == enabled == False
 
         Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -448,6 +580,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
 
+    @override_config(
+        {
+            **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+            "password_config": {"enabled": False},
+        }
+    )
+    def test_password_custom_auth_password_disabled_login_legacy(self):
+        self.password_custom_auth_password_disabled_login_test_body()
+
     @override_config(
         {
             **providers_config(PasswordCustomAuthProvider),
@@ -455,6 +596,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_password_custom_auth_password_disabled_login(self):
+        self.password_custom_auth_password_disabled_login_test_body()
+
+    def password_custom_auth_password_disabled_login_test_body(self):
         """log in with a custom auth provider which implements password, but password
         login is disabled"""
         self.register_user("localuser", "localpass")
@@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         channel = self._send_password_login("localuser", "localpass")
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
+        mock_password_provider.check_password.assert_not_called()
+
+    @override_config(
+        {
+            **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+            "password_config": {"enabled": False},
+        }
+    )
+    def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
+        self.password_custom_auth_password_disabled_ui_auth_test_body()
 
     @override_config(
         {
@@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_password_custom_auth_password_disabled_ui_auth(self):
+        self.password_custom_auth_password_disabled_ui_auth_test_body()
+
+    def password_custom_auth_password_disabled_ui_auth_test_body(self):
         """UI Auth with a custom auth provider which implements password, but password
         login is disabled"""
         # register the user and log in twice via the test login type to get two devices,
         self.register_user("localuser", "localpass")
         mock_password_provider.check_auth.return_value = defer.succeed(
-            "@localuser:test"
+            ("@localuser:test", None)
         )
         channel = self._send_login("test.login_type", "localuser", test_field="")
         self.assertEqual(channel.code, 200, channel.result)
@@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "Password login has been disabled.", channel.json_body["error"]
         )
         mock_password_provider.check_auth.assert_not_called()
+        mock_password_provider.check_password.assert_not_called()
         mock_password_provider.reset_mock()
 
         # successful auth
@@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.check_auth.assert_called_once_with(
             "localuser", "test.login_type", {"test_field": "x"}
         )
+        mock_password_provider.check_password.assert_not_called()
+
+    @override_config(
+        {
+            **legacy_providers_config(LegacyCustomAuthProvider),
+            "password_config": {"localdb_enabled": False},
+        }
+    )
+    def test_custom_auth_no_local_user_fallback_legacy(self):
+        self.custom_auth_no_local_user_fallback_test_body()
 
     @override_config(
         {
@@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
     )
     def test_custom_auth_no_local_user_fallback(self):
+        self.custom_auth_no_local_user_fallback_test_body()
+
+    def custom_auth_no_local_user_fallback_test_body(self):
         """Test login with a custom auth provider where the local db is disabled"""
         self.register_user("localuser", "localpass")