Browse Source

Port the Password Auth Providers module interface to the new generic interface (#10548)

Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com>
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
Azrenbeth 2 years ago
parent
commit
cdd308845b

+ 1 - 0
changelog.d/10548.feature

@@ -0,0 +1 @@
+Port the Password Auth Providers module interface to the new generic interface.

+ 1 - 0
docs/SUMMARY.md

@@ -43,6 +43,7 @@
         - [Third-party rules callbacks](modules/third_party_rules_callbacks.md)
         - [Third-party rules callbacks](modules/third_party_rules_callbacks.md)
         - [Presence router callbacks](modules/presence_router_callbacks.md)
         - [Presence router callbacks](modules/presence_router_callbacks.md)
         - [Account validity callbacks](modules/account_validity_callbacks.md)
         - [Account validity callbacks](modules/account_validity_callbacks.md)
+        - [Password auth provider callbacks](modules/password_auth_provider_callbacks.md)
         - [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
         - [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
     - [Workers](workers.md)
     - [Workers](workers.md)
       - [Using `synctl` with Workers](synctl_workers.md)
       - [Using `synctl` with Workers](synctl_workers.md)

+ 153 - 0
docs/modules/password_auth_provider_callbacks.md

@@ -0,0 +1,153 @@
+# Password auth provider callbacks
+
+Password auth providers offer a way for server administrators to integrate
+their Synapse installation with an external authentication system. The callbacks can be
+registered by using the Module API's `register_password_auth_provider_callbacks` method.
+
+## Callbacks
+
+### `auth_checkers`
+
+```
+ auth_checkers: Dict[Tuple[str,Tuple], Callable]
+```
+
+A dict mapping from tuples of a login type identifier (such as `m.login.password`) and a
+tuple of field names (such as `("password", "secret_thing")`) to authentication checking
+callbacks, which should be of the following form:
+
+```python
+async def check_auth(
+    user: str,
+    login_type: str,
+    login_dict: "synapse.module_api.JsonDict",
+) -> Optional[
+    Tuple[
+        str, 
+        Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
+    ]
+]
+```
+
+The login type and field names should be provided by the user in the
+request to the `/login` API. [The Matrix specification](https://matrix.org/docs/spec/client_server/latest#authentication-types)
+defines some types, however user defined ones are also allowed.
+
+The callback is passed the `user` field provided by the client (which might not be in
+`@username:server` form), the login type, and a dictionary of login secrets passed by
+the client.
+
+If the authentication is successful, the module must return the user's Matrix ID (e.g. 
+`@alice:example.com`) and optionally a callback to be called with the response to the
+`/login` request. If the module doesn't wish to return a callback, it must return `None`
+instead.
+
+If the authentication is unsuccessful, the module must return `None`.
+
+### `check_3pid_auth`
+
+```python
+async def check_3pid_auth(
+    medium: str, 
+    address: str,
+    password: str,
+) -> Optional[
+    Tuple[
+        str, 
+        Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]]
+    ]
+]
+```
+
+Called when a user attempts to register or log in with a third party identifier,
+such as email. It is passed the medium (eg. `email`), an address (eg. `jdoe@example.com`)
+and the user's password.
+
+If the authentication is successful, the module must return the user's Matrix ID (e.g. 
+`@alice:example.com`) and optionally a callback to be called with the response to the `/login` request.
+If the module doesn't wish to return a callback, it must return None instead.
+
+If the authentication is unsuccessful, the module must return None.
+
+### `on_logged_out`
+
+```python
+async def on_logged_out(
+    user_id: str,
+    device_id: Optional[str],
+    access_token: str
+) -> None
+``` 
+Called during a logout request for a user. It is passed the qualified user ID, the ID of the
+deactivated device (if any: access tokens are occasionally created without an associated
+device ID), and the (now deactivated) access token.
+
+## Example
+
+The example module below implements authentication checkers for two different login types: 
+-  `my.login.type` 
+    - Expects a `my_field` field to be sent to `/login`
+    - Is checked by the method: `self.check_my_login`
+- `m.login.password` (defined in [the spec](https://matrix.org/docs/spec/client_server/latest#password-based))
+    - Expects a `password` field to be sent to `/login`
+    - Is checked by the method: `self.check_pass` 
+
+
+```python
+from typing import Awaitable, Callable, Optional, Tuple
+
+import synapse
+from synapse import module_api
+
+
+class MyAuthProvider:
+    def __init__(self, config: dict, api: module_api):
+
+        self.api = api
+
+        self.credentials = {
+            "bob": "building",
+            "@scoop:matrix.org": "digging",
+        }
+
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={
+                ("my.login_type", ("my_field",)): self.check_my_login,
+                ("m.login.password", ("password",)): self.check_pass,
+            },
+        )
+
+    async def check_my_login(
+        self,
+        username: str,
+        login_type: str,
+        login_dict: "synapse.module_api.JsonDict",
+    ) -> Optional[
+        Tuple[
+            str,
+            Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
+        ]
+    ]:
+        if login_type != "my.login_type":
+            return None
+
+        if self.credentials.get(username) == login_dict.get("my_field"):
+            return self.api.get_qualified_user_id(username)
+
+    async def check_pass(
+        self,
+        username: str,
+        login_type: str,
+        login_dict: "synapse.module_api.JsonDict",
+    ) -> Optional[
+        Tuple[
+            str,
+            Optional[Callable[["synapse.module_api.LoginResponse"], Awaitable[None]]],
+        ]
+    ]:
+        if login_type != "m.login.password":
+            return None
+
+        if self.credentials.get(username) == login_dict.get("password"):
+            return self.api.get_qualified_user_id(username)
+```

+ 3 - 0
docs/modules/porting_legacy_module.md

@@ -12,6 +12,9 @@ should register this resource in its `__init__` method using the `register_web_r
 method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for
 method from the `ModuleApi` class (see [this section](writing_a_module.html#registering-a-web-resource) for
 more info).
 more info).
 
 
+There is no longer a `get_db_schema_files` callback provided for password auth provider modules. Any
+changes to the database should now be made by the module using the module API class.
+
 The module's author should also update any example in the module's configuration to only
 The module's author should also update any example in the module's configuration to only
 use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules)
 use the new `modules` section in Synapse's configuration file (see [this section](index.html#using-modules)
 for more info).
 for more info).

+ 6 - 0
docs/password_auth_providers.md

@@ -1,3 +1,9 @@
+<h2 style="color:red">
+This page of the Synapse documentation is now deprecated. For up to date
+documentation on setting up or writing a password auth provider module, please see
+<a href="modules.md">this page</a>.
+</h2>
+
 # Password auth provider modules
 # Password auth provider modules
 
 
 Password auth providers offer a way for server administrators to
 Password auth providers offer a way for server administrators to

+ 0 - 28
docs/sample_config.yaml

@@ -2260,34 +2260,6 @@ email:
     #email_validation: "[%(server_name)s] Validate your email"
     #email_validation: "[%(server_name)s] Validate your email"
 
 
 
 
-# Password providers allow homeserver administrators to integrate
-# their Synapse installation with existing authentication methods
-# ex. LDAP, external tokens, etc.
-#
-# For more information and known implementations, please see
-# https://matrix-org.github.io/synapse/latest/password_auth_providers.html
-#
-# Note: instances wishing to use SAML or CAS authentication should
-# instead use the `saml2_config` or `cas_config` options,
-# respectively.
-#
-password_providers:
-#    # Example config for an LDAP auth provider
-#    - module: "ldap_auth_provider.LdapAuthProvider"
-#      config:
-#        enabled: true
-#        uri: "ldap://ldap.example.com:389"
-#        start_tls: true
-#        base: "ou=users,dc=example,dc=com"
-#        attributes:
-#           uid: "cn"
-#           mail: "email"
-#           name: "givenName"
-#        #bind_dn:
-#        #bind_password:
-#        #filter: "(objectClass=posixAccount)"
-
-
 
 
 ## Push ##
 ## Push ##
 
 

+ 2 - 0
synapse/app/_base.py

@@ -42,6 +42,7 @@ from synapse.crypto import context_factory
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.events.presence_router import load_legacy_presence_router
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.spamcheck import load_legacy_spam_checkers
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
+from synapse.handlers.auth import load_legacy_password_auth_providers
 from synapse.logging.context import PreserveLoggingContext
 from synapse.logging.context import PreserveLoggingContext
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.jemalloc import setup_jemalloc_stats
 from synapse.metrics.jemalloc import setup_jemalloc_stats
@@ -379,6 +380,7 @@ async def start(hs: "HomeServer"):
     load_legacy_spam_checkers(hs)
     load_legacy_spam_checkers(hs)
     load_legacy_third_party_event_rules(hs)
     load_legacy_third_party_event_rules(hs)
     load_legacy_presence_router(hs)
     load_legacy_presence_router(hs)
+    load_legacy_password_auth_providers(hs)
 
 
     # If we've configured an expiry time for caches, start the background job now.
     # If we've configured an expiry time for caches, start the background job now.
     setup_expire_lru_cache_entries(hs)
     setup_expire_lru_cache_entries(hs)

+ 23 - 30
synapse/config/password_auth_providers.py

@@ -25,6 +25,29 @@ class PasswordAuthProviderConfig(Config):
     section = "authproviders"
     section = "authproviders"
 
 
     def read_config(self, config, **kwargs):
     def read_config(self, config, **kwargs):
+        """Parses the old password auth providers config. The config format looks like this:
+
+        password_providers:
+           # Example config for an LDAP auth provider
+           - module: "ldap_auth_provider.LdapAuthProvider"
+             config:
+               enabled: true
+               uri: "ldap://ldap.example.com:389"
+               start_tls: true
+               base: "ou=users,dc=example,dc=com"
+               attributes:
+                  uid: "cn"
+                  mail: "email"
+                  name: "givenName"
+               #bind_dn:
+               #bind_password:
+               #filter: "(objectClass=posixAccount)"
+
+        We expect admins to use modules for this feature (which is why it doesn't appear
+        in the sample config file), but we want to keep support for it around for a bit
+        for backwards compatibility.
+        """
+
         self.password_providers: List[Tuple[Type, Any]] = []
         self.password_providers: List[Tuple[Type, Any]] = []
         providers = []
         providers = []
 
 
@@ -49,33 +72,3 @@ class PasswordAuthProviderConfig(Config):
             )
             )
 
 
             self.password_providers.append((provider_class, provider_config))
             self.password_providers.append((provider_class, provider_config))
-
-    def generate_config_section(self, **kwargs):
-        return """\
-        # Password providers allow homeserver administrators to integrate
-        # their Synapse installation with existing authentication methods
-        # ex. LDAP, external tokens, etc.
-        #
-        # For more information and known implementations, please see
-        # https://matrix-org.github.io/synapse/latest/password_auth_providers.html
-        #
-        # Note: instances wishing to use SAML or CAS authentication should
-        # instead use the `saml2_config` or `cas_config` options,
-        # respectively.
-        #
-        password_providers:
-        #    # Example config for an LDAP auth provider
-        #    - module: "ldap_auth_provider.LdapAuthProvider"
-        #      config:
-        #        enabled: true
-        #        uri: "ldap://ldap.example.com:389"
-        #        start_tls: true
-        #        base: "ou=users,dc=example,dc=com"
-        #        attributes:
-        #           uid: "cn"
-        #           mail: "email"
-        #           name: "givenName"
-        #        #bind_dn:
-        #        #bind_password:
-        #        #filter: "(objectClass=posixAccount)"
-        """

+ 388 - 140
synapse/handlers/auth.py

@@ -200,46 +200,13 @@ class AuthHandler:
 
 
         self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
         self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
 
 
-        # we can't use hs.get_module_api() here, because to do so will create an
-        # import loop.
-        #
-        # TODO: refactor this class to separate the lower-level stuff that
-        #   ModuleApi can use from the higher-level stuff that uses ModuleApi, as
-        #   better way to break the loop
-        account_handler = ModuleApi(hs, self)
-
-        self.password_providers = [
-            PasswordProvider.load(module, config, account_handler)
-            for module, config in hs.config.authproviders.password_providers
-        ]
-
-        logger.info("Extra password_providers: %s", self.password_providers)
+        self.password_auth_provider = hs.get_password_auth_provider()
 
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.auth.password_enabled
         self._password_enabled = hs.config.auth.password_enabled
         self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
         self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
 
 
-        # start out by assuming PASSWORD is enabled; we will remove it later if not.
-        login_types = set()
-        if self._password_localdb_enabled:
-            login_types.add(LoginType.PASSWORD)
-
-        for provider in self.password_providers:
-            login_types.update(provider.get_supported_login_types().keys())
-
-        if not self._password_enabled:
-            login_types.discard(LoginType.PASSWORD)
-
-        # Some clients just pick the first type in the list. In this case, we want
-        # them to use PASSWORD (rather than token or whatever), so we want to make sure
-        # that comes first, where it's present.
-        self._supported_login_types = []
-        if LoginType.PASSWORD in login_types:
-            self._supported_login_types.append(LoginType.PASSWORD)
-            login_types.remove(LoginType.PASSWORD)
-        self._supported_login_types.extend(login_types)
-
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
         # as per `rc_login.failed_attempts`.
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -427,11 +394,10 @@ class AuthHandler:
                     ui_auth_types.add(LoginType.PASSWORD)
                     ui_auth_types.add(LoginType.PASSWORD)
 
 
         # also allow auth from password providers
         # also allow auth from password providers
-        for provider in self.password_providers:
-            for t in provider.get_supported_login_types().keys():
-                if t == LoginType.PASSWORD and not self._password_enabled:
-                    continue
-                ui_auth_types.add(t)
+        for t in self.password_auth_provider.get_supported_login_types().keys():
+            if t == LoginType.PASSWORD and not self._password_enabled:
+                continue
+            ui_auth_types.add(t)
 
 
         # if sso is enabled, allow the user to log in via SSO iff they have a mapping
         # if sso is enabled, allow the user to log in via SSO iff they have a mapping
         # from sso to mxid.
         # from sso to mxid.
@@ -1038,7 +1004,25 @@ class AuthHandler:
         Returns:
         Returns:
             login types
             login types
         """
         """
-        return self._supported_login_types
+        # Load any login types registered by modules
+        # This is stored in the password_auth_provider so this doesn't trigger
+        # any callbacks
+        types = list(self.password_auth_provider.get_supported_login_types().keys())
+
+        # This list should include PASSWORD if (either _password_localdb_enabled is
+        # true or if one of the modules registered it) AND _password_enabled is true
+        # Also:
+        # Some clients just pick the first type in the list. In this case, we want
+        # them to use PASSWORD (rather than token or whatever), so we want to make sure
+        # that comes first, where it's present.
+        if LoginType.PASSWORD in types:
+            types.remove(LoginType.PASSWORD)
+            if self._password_enabled:
+                types.insert(0, LoginType.PASSWORD)
+        elif self._password_localdb_enabled and self._password_enabled:
+            types.insert(0, LoginType.PASSWORD)
+
+        return types
 
 
     async def validate_login(
     async def validate_login(
         self,
         self,
@@ -1217,15 +1201,20 @@ class AuthHandler:
 
 
         known_login_type = False
         known_login_type = False
 
 
-        for provider in self.password_providers:
-            supported_login_types = provider.get_supported_login_types()
-            if login_type not in supported_login_types:
-                # this password provider doesn't understand this login type
-                continue
-
+        # Check if login_type matches a type registered by one of the modules
+        # We don't need to remove LoginType.PASSWORD from the list if password login is
+        # disabled, since if that were the case then by this point we know that the
+        # login_type is not LoginType.PASSWORD
+        supported_login_types = self.password_auth_provider.get_supported_login_types()
+        # check if the login type being used is supported by a module
+        if login_type in supported_login_types:
+            # Make a note that this login type is supported by the server
             known_login_type = True
             known_login_type = True
+            # Get all the fields expected for this login types
             login_fields = supported_login_types[login_type]
             login_fields = supported_login_types[login_type]
 
 
+            # go through the login submission and keep track of which required fields are
+            # provided/not provided
             missing_fields = []
             missing_fields = []
             login_dict = {}
             login_dict = {}
             for f in login_fields:
             for f in login_fields:
@@ -1233,6 +1222,7 @@ class AuthHandler:
                     missing_fields.append(f)
                     missing_fields.append(f)
                 else:
                 else:
                     login_dict[f] = login_submission[f]
                     login_dict[f] = login_submission[f]
+            # raise an error if any of the expected fields for that login type weren't provided
             if missing_fields:
             if missing_fields:
                 raise SynapseError(
                 raise SynapseError(
                     400,
                     400,
@@ -1240,10 +1230,15 @@ class AuthHandler:
                     % (login_type, missing_fields),
                     % (login_type, missing_fields),
                 )
                 )
 
 
-            result = await provider.check_auth(username, login_type, login_dict)
+            # call all of the check_auth hooks for that login_type
+            # it will return a result once the first success is found (or None otherwise)
+            result = await self.password_auth_provider.check_auth(
+                username, login_type, login_dict
+            )
             if result:
             if result:
                 return result
                 return result
 
 
+        # if no module managed to authenticate the user, then fallback to built in password based auth
         if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
         if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
             known_login_type = True
             known_login_type = True
 
 
@@ -1282,11 +1277,16 @@ class AuthHandler:
             completed login/registration, or `None`. If authentication was
             completed login/registration, or `None`. If authentication was
             unsuccessful, `user_id` and `callback` are both `None`.
             unsuccessful, `user_id` and `callback` are both `None`.
         """
         """
-        for provider in self.password_providers:
-            result = await provider.check_3pid_auth(medium, address, password)
-            if result:
-                return result
+        # call all of the check_3pid_auth callbacks
+        # Result will be from the first callback that returns something other than None
+        # If all the callbacks return None, then result is also set to None
+        result = await self.password_auth_provider.check_3pid_auth(
+            medium, address, password
+        )
+        if result:
+            return result
 
 
+        # if result is None then return (None, None)
         return None, None
         return None, None
 
 
     async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
     async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
@@ -1365,13 +1365,12 @@ class AuthHandler:
         user_info = await self.auth.get_user_by_access_token(access_token)
         user_info = await self.auth.get_user_by_access_token(access_token)
         await self.store.delete_access_token(access_token)
         await self.store.delete_access_token(access_token)
 
 
-        # see if any of our auth providers want to know about this
-        for provider in self.password_providers:
-            await provider.on_logged_out(
-                user_id=user_info.user_id,
-                device_id=user_info.device_id,
-                access_token=access_token,
-            )
+        # see if any modules want to know about this
+        await self.password_auth_provider.on_logged_out(
+            user_id=user_info.user_id,
+            device_id=user_info.device_id,
+            access_token=access_token,
+        )
 
 
         # delete pushers associated with this access token
         # delete pushers associated with this access token
         if user_info.token_id is not None:
         if user_info.token_id is not None:
@@ -1398,12 +1397,11 @@ class AuthHandler:
             user_id, except_token_id=except_token_id, device_id=device_id
             user_id, except_token_id=except_token_id, device_id=device_id
         )
         )
 
 
-        # see if any of our auth providers want to know about this
-        for provider in self.password_providers:
-            for token, _, device_id in tokens_and_devices:
-                await provider.on_logged_out(
-                    user_id=user_id, device_id=device_id, access_token=token
-                )
+        # see if any modules want to know about this
+        for token, _, device_id in tokens_and_devices:
+            await self.password_auth_provider.on_logged_out(
+                user_id=user_id, device_id=device_id, access_token=token
+            )
 
 
         # delete pushers associated with the access tokens
         # delete pushers associated with the access tokens
         await self.hs.get_pusherpool().remove_pushers_by_access_token(
         await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1811,40 +1809,228 @@ class MacaroonGenerator:
         return macaroon
         return macaroon
 
 
 
 
-class PasswordProvider:
-    """Wrapper for a password auth provider module
+def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
+    module_api = hs.get_module_api()
+    for module, config in hs.config.authproviders.password_providers:
+        load_single_legacy_password_auth_provider(
+            module=module, config=config, api=module_api
+        )
 
 
-    This class abstracts out all of the backwards-compatibility hacks for
-    password providers, to provide a consistent interface.
-    """
 
 
-    @classmethod
-    def load(
-        cls, module: Type, config: JsonDict, module_api: ModuleApi
-    ) -> "PasswordProvider":
-        try:
-            pp = module(config=config, account_handler=module_api)
-        except Exception as e:
-            logger.error("Error while initializing %r: %s", module, e)
-            raise
-        return cls(pp, module_api)
+def load_single_legacy_password_auth_provider(
+    module: Type, config: JsonDict, api: ModuleApi
+) -> None:
+    try:
+        provider = module(config=config, account_handler=api)
+    except Exception as e:
+        logger.error("Error while initializing %r: %s", module, e)
+        raise
+
+    # The known hooks. If a module implements a method who's name appears in this set
+    # we'll want to register it
+    password_auth_provider_methods = {
+        "check_3pid_auth",
+        "on_logged_out",
+    }
+
+    # All methods that the module provides should be async, but this wasn't enforced
+    # in the old module system, so we wrap them if needed
+    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        # We need to wrap check_password because its old form would return a boolean
+        # but we now want it to behave just like check_auth() and return the matrix id of
+        # the user if authentication succeeded or None otherwise
+        if f.__name__ == "check_password":
+
+            async def wrapped_check_password(
+                username: str, login_type: str, login_dict: JsonDict
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                matrix_user_id = api.get_qualified_user_id(username)
+                password = login_dict["password"]
+
+                is_valid = await f(matrix_user_id, password)
+
+                if is_valid:
+                    return matrix_user_id, None
+
+                return None
 
 
-    def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
-        self._pp = pp
-        self._module_api = module_api
+            return wrapped_check_password
+
+        # We need to wrap check_auth as in the old form it could return
+        # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+        if f.__name__ == "check_auth":
+
+            async def wrapped_check_auth(
+                username: str, login_type: str, login_dict: JsonDict
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                result = await f(username, login_type, login_dict)
+
+                if isinstance(result, str):
+                    return result, None
+
+                return result
+
+            return wrapped_check_auth
+
+        # We need to wrap check_3pid_auth as in the old form it could return
+        # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
+        if f.__name__ == "check_3pid_auth":
+
+            async def wrapped_check_3pid_auth(
+                medium: str, address: str, password: str
+            ) -> Optional[Tuple[str, Optional[Callable]]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                result = await f(medium, address, password)
+
+                if isinstance(result, str):
+                    return result, None
+
+                return result
 
 
-        self._supported_login_types = {}
+            return wrapped_check_3pid_auth
 
 
-        # grandfather in check_password support
-        if hasattr(self._pp, "check_password"):
-            self._supported_login_types[LoginType.PASSWORD] = ("password",)
+        def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
+            # mypy doesn't do well across function boundaries so we need to tell it
+            # f is definitely not None.
+            assert f is not None
 
 
-        g = getattr(self._pp, "get_supported_login_types", None)
-        if g:
-            self._supported_login_types.update(g())
+            return maybe_awaitable(f(*args, **kwargs))
 
 
-    def __str__(self) -> str:
-        return str(self._pp)
+        return run
+
+    # populate hooks with the implemented methods, wrapped with async_wrapper
+    hooks = {
+        hook: async_wrapper(getattr(provider, hook, None))
+        for hook in password_auth_provider_methods
+    }
+
+    supported_login_types = {}
+    # call get_supported_login_types and add that to the dict
+    g = getattr(provider, "get_supported_login_types", None)
+    if g is not None:
+        # Note the old module style also called get_supported_login_types at loading time
+        # and it is synchronous
+        supported_login_types.update(g())
+
+    auth_checkers = {}
+    # Legacy modules have a check_auth method which expects to be called with one of
+    # the keys returned by get_supported_login_types. New style modules register a
+    # dictionary of login_type->check_auth_method mappings
+    check_auth = async_wrapper(getattr(provider, "check_auth", None))
+    if check_auth is not None:
+        for login_type, fields in supported_login_types.items():
+            # need tuple(fields) since fields can be any Iterable type (so may not be hashable)
+            auth_checkers[(login_type, tuple(fields))] = check_auth
+
+    # if it has a "check_password" method then it should handle all auth checks
+    # with login type of LoginType.PASSWORD
+    check_password = async_wrapper(getattr(provider, "check_password", None))
+    if check_password is not None:
+        # need to use a tuple here for ("password",) not a list since lists aren't hashable
+        auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
+
+    api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
+
+
+CHECK_3PID_AUTH_CALLBACK = Callable[
+    [str, str, str],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
+CHECK_AUTH_CALLBACK = Callable[
+    [str, str, JsonDict],
+    Awaitable[
+        Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
+    ],
+]
+
+
+class PasswordAuthProvider:
+    """
+    A class that the AuthHandler calls when authenticating users
+    It allows modules to provide alternative methods for authentication
+    """
+
+    def __init__(self) -> None:
+        # lists of callbacks
+        self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
+        self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
+
+        # Mapping from login type to login parameters
+        self._supported_login_types: Dict[str, Iterable[str]] = {}
+
+        # Mapping from login type to auth checker callbacks
+        self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
+
+    def register_password_auth_provider_callbacks(
+        self,
+        check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
+        on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
+        auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
+    ) -> None:
+        # Register check_3pid_auth callback
+        if check_3pid_auth is not None:
+            self.check_3pid_auth_callbacks.append(check_3pid_auth)
+
+        # register on_logged_out callback
+        if on_logged_out is not None:
+            self.on_logged_out_callbacks.append(on_logged_out)
+
+        if auth_checkers is not None:
+            # register a new supported login_type
+            # Iterate through all of the types being registered
+            for (login_type, fields), callback in auth_checkers.items():
+                # Note: fields may be empty here. This would allow a modules auth checker to
+                # be called with just 'login_type' and no password or other secrets
+
+                # Need to check that all the field names are strings or may get nasty errors later
+                for f in fields:
+                    if not isinstance(f, str):
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but all parameter names must be strings"
+                            % (login_type, fields)
+                        )
+
+                # 2 modules supporting the same login type must expect the same fields
+                # e.g. 1 can't expect "pass" if the other expects "password"
+                # so throw an exception if that happens
+                if login_type not in self._supported_login_types.get(login_type, []):
+                    self._supported_login_types[login_type] = fields
+                else:
+                    fields_currently_supported = self._supported_login_types.get(
+                        login_type
+                    )
+                    if fields_currently_supported != fields:
+                        raise RuntimeError(
+                            "A module tried to register support for login type: %s with parameters %s"
+                            " but another module had already registered support for that type with parameters %s"
+                            % (login_type, fields, fields_currently_supported)
+                        )
+
+                # Add the new method to the list of auth_checker_callbacks for this login type
+                self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
 
 
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
     def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
         """Get the login types supported by this password provider
         """Get the login types supported by this password provider
@@ -1852,20 +2038,15 @@ class PasswordProvider:
         Returns a map from a login type identifier (such as m.login.password) to an
         Returns a map from a login type identifier (such as m.login.password) to an
         iterable giving the fields which must be provided by the user in the submission
         iterable giving the fields which must be provided by the user in the submission
         to the /login API.
         to the /login API.
-
-        This wrapper adds m.login.password to the list if the underlying password
-        provider supports the check_password() api.
         """
         """
+
         return self._supported_login_types
         return self._supported_login_types
 
 
     async def check_auth(
     async def check_auth(
         self, username: str, login_type: str, login_dict: JsonDict
         self, username: str, login_type: str, login_dict: JsonDict
-    ) -> Optional[Tuple[str, Optional[Callable]]]:
+    ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
         """Check if the user has presented valid login credentials
         """Check if the user has presented valid login credentials
 
 
-        This wrapper also calls check_password() if the underlying password provider
-        supports the check_password() api and the login type is m.login.password.
-
         Args:
         Args:
             username: user id presented by the client. Either an MXID or an unqualified
             username: user id presented by the client. Either an MXID or an unqualified
                 username.
                 username.
@@ -1879,63 +2060,130 @@ class PasswordProvider:
             user, and `callback` is an optional callback which will be called with the
             user, and `callback` is an optional callback which will be called with the
             result from the /login call (including access_token, device_id, etc.)
             result from the /login call (including access_token, device_id, etc.)
         """
         """
-        # first grandfather in a call to check_password
-        if login_type == LoginType.PASSWORD:
-            check_password = getattr(self._pp, "check_password", None)
-            if check_password:
-                qualified_user_id = self._module_api.get_qualified_user_id(username)
-                is_valid = await check_password(
-                    qualified_user_id, login_dict["password"]
-                )
-                if is_valid:
-                    return qualified_user_id, None
 
 
-        check_auth = getattr(self._pp, "check_auth", None)
-        if not check_auth:
-            return None
-        result = await check_auth(username, login_type, login_dict)
+        # Go through all callbacks for the login type until one returns with a value
+        # other than None (i.e. until a callback returns a success)
+        for callback in self.auth_checker_callbacks[login_type]:
+            try:
+                result = await callback(username, login_type, login_dict)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
 
 
-        # Check if the return value is a str or a tuple
-        if isinstance(result, str):
-            # If it's a str, set callback function to None
-            return result, None
+            if result is not None:
+                # Check that the callback returned a Tuple[str, Optional[Callable]]
+                # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+                # result is always the right type, but as it is 3rd party code it might not be
+
+                if not isinstance(result, tuple) or len(result) != 2:
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
 
 
-        return result
+                # pull out the two parts of the tuple so we can do type checking
+                str_result, callback_result = result
+
+                # the 1st item in the tuple should be a str
+                if not isinstance(str_result, str):
+                    logger.warning(  # type: ignore[unreachable]
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # the second should be Optional[Callable]
+                if callback_result is not None:
+                    if not callable(callback_result):
+                        logger.warning(  # type: ignore[unreachable]
+                            "Wrong type returned by module API callback %s: %s, expected"
+                            " Optional[Tuple[str, Optional[Callable]]]",
+                            callback,
+                            result,
+                        )
+                        continue
+
+                # The result is a (str, Optional[callback]) tuple so return the successful result
+                return result
+
+        # If this point has been reached then none of the callbacks successfully authenticated
+        # the user so return None
+        return None
 
 
     async def check_3pid_auth(
     async def check_3pid_auth(
         self, medium: str, address: str, password: str
         self, medium: str, address: str, password: str
-    ) -> Optional[Tuple[str, Optional[Callable]]]:
-        g = getattr(self._pp, "check_3pid_auth", None)
-        if not g:
-            return None
-
+    ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
         # This function is able to return a deferred that either
         # This function is able to return a deferred that either
         # resolves None, meaning authentication failure, or upon
         # resolves None, meaning authentication failure, or upon
         # success, to a str (which is the user_id) or a tuple of
         # success, to a str (which is the user_id) or a tuple of
         # (user_id, callback_func), where callback_func should be run
         # (user_id, callback_func), where callback_func should be run
         # after we've finished everything else
         # after we've finished everything else
-        result = await g(medium, address, password)
 
 
-        # Check if the return value is a str or a tuple
-        if isinstance(result, str):
-            # If it's a str, set callback function to None
-            return result, None
+        for callback in self.check_3pid_auth_callbacks:
+            try:
+                result = await callback(medium, address, password)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
 
 
-        return result
+            if result is not None:
+                # Check that the callback returned a Tuple[str, Optional[Callable]]
+                # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
+                # result is always the right type, but as it is 3rd party code it might not be
+
+                if not isinstance(result, tuple) or len(result) != 2:
+                    logger.warning(
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # pull out the two parts of the tuple so we can do type checking
+                str_result, callback_result = result
+
+                # the 1st item in the tuple should be a str
+                if not isinstance(str_result, str):
+                    logger.warning(  # type: ignore[unreachable]
+                        "Wrong type returned by module API callback %s: %s, expected"
+                        " Optional[Tuple[str, Optional[Callable]]]",
+                        callback,
+                        result,
+                    )
+                    continue
+
+                # the second should be Optional[Callable]
+                if callback_result is not None:
+                    if not callable(callback_result):
+                        logger.warning(  # type: ignore[unreachable]
+                            "Wrong type returned by module API callback %s: %s, expected"
+                            " Optional[Tuple[str, Optional[Callable]]]",
+                            callback,
+                            result,
+                        )
+                        continue
+
+                # The result is a (str, Optional[callback]) tuple so return the successful result
+                return result
+
+        # If this point has been reached then none of the callbacks successfully authenticated
+        # the user so return None
+        return None
 
 
     async def on_logged_out(
     async def on_logged_out(
         self, user_id: str, device_id: Optional[str], access_token: str
         self, user_id: str, device_id: Optional[str], access_token: str
     ) -> None:
     ) -> None:
-        g = getattr(self._pp, "on_logged_out", None)
-        if not g:
-            return
 
 
-        # This might return an awaitable, if it does block the log out
-        # until it completes.
-        await maybe_awaitable(
-            g(
-                user_id=user_id,
-                device_id=device_id,
-                access_token=access_token,
-            )
-        )
+        # call all of the on_logged_out callbacks
+        for callback in self.on_logged_out_callbacks:
+            try:
+                callback(user_id, device_id, access_token)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue

+ 9 - 0
synapse/module_api/__init__.py

@@ -45,6 +45,7 @@ from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.rest.client.login import LoginResponse
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
 from synapse.storage.state import StateFilter
@@ -83,6 +84,8 @@ __all__ = [
     "DirectServeJsonResource",
     "DirectServeJsonResource",
     "ModuleApi",
     "ModuleApi",
     "PRESENCE_ALL_USERS",
     "PRESENCE_ALL_USERS",
+    "LoginResponse",
+    "JsonDict",
 ]
 ]
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -139,6 +142,7 @@ class ModuleApi:
         self._spam_checker = hs.get_spam_checker()
         self._spam_checker = hs.get_spam_checker()
         self._account_validity_handler = hs.get_account_validity_handler()
         self._account_validity_handler = hs.get_account_validity_handler()
         self._third_party_event_rules = hs.get_third_party_event_rules()
         self._third_party_event_rules = hs.get_third_party_event_rules()
+        self._password_auth_provider = hs.get_password_auth_provider()
         self._presence_router = hs.get_presence_router()
         self._presence_router = hs.get_presence_router()
 
 
     #################################################################################
     #################################################################################
@@ -164,6 +168,11 @@ class ModuleApi:
         """Registers callbacks for presence router capabilities."""
         """Registers callbacks for presence router capabilities."""
         return self._presence_router.register_presence_router_callbacks
         return self._presence_router.register_presence_router_callbacks
 
 
+    @property
+    def register_password_auth_provider_callbacks(self):
+        """Registers callbacks for password auth provider capabilities."""
+        return self._password_auth_provider.register_password_auth_provider_callbacks
+
     def register_web_resource(self, path: str, resource: IResource):
     def register_web_resource(self, path: str, resource: IResource):
         """Registers a web resource to be served at the given path.
         """Registers a web resource to be served at the given path.
 
 

+ 5 - 1
synapse/server.py

@@ -65,7 +65,7 @@ from synapse.handlers.account_data import AccountDataHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.account_validity import AccountValidityHandler
 from synapse.handlers.admin import AdminHandler
 from synapse.handlers.admin import AdminHandler
 from synapse.handlers.appservice import ApplicationServicesHandler
 from synapse.handlers.appservice import ApplicationServicesHandler
-from synapse.handlers.auth import AuthHandler, MacaroonGenerator
+from synapse.handlers.auth import AuthHandler, MacaroonGenerator, PasswordAuthProvider
 from synapse.handlers.cas import CasHandler
 from synapse.handlers.cas import CasHandler
 from synapse.handlers.deactivate_account import DeactivateAccountHandler
 from synapse.handlers.deactivate_account import DeactivateAccountHandler
 from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
 from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
@@ -687,6 +687,10 @@ class HomeServer(metaclass=abc.ABCMeta):
     def get_third_party_event_rules(self) -> ThirdPartyEventRules:
     def get_third_party_event_rules(self) -> ThirdPartyEventRules:
         return ThirdPartyEventRules(self)
         return ThirdPartyEventRules(self)
 
 
+    @cache_in_self
+    def get_password_auth_provider(self) -> PasswordAuthProvider:
+        return PasswordAuthProvider()
+
     @cache_in_self
     @cache_in_self
     def get_room_member_handler(self) -> RoomMemberHandler:
     def get_room_member_handler(self) -> RoomMemberHandler:
         if self.config.worker.worker_app:
         if self.config.worker.worker_app:

+ 2 - 0
synapse/storage/prepare_database.py

@@ -549,6 +549,8 @@ def _apply_module_schemas(
         database_engine:
         database_engine:
         config: application config
         config: application config
     """
     """
+    # This is the old way for password_auth_provider modules to make changes
+    # to the database. This should instead be done using the module API
     for (mod, _config) in config.authproviders.password_providers:
     for (mod, _config) in config.authproviders.password_providers:
         if not hasattr(mod, "get_db_schema_files"):
         if not hasattr(mod, "get_db_schema_files"):
             continue
             continue

+ 197 - 26
tests/handlers/test_password_providers.py

@@ -20,6 +20,8 @@ from unittest.mock import Mock
 from twisted.internet import defer
 from twisted.internet import defer
 
 
 import synapse
 import synapse
+from synapse.handlers.auth import load_legacy_password_auth_providers
+from synapse.module_api import ModuleApi
 from synapse.rest.client import devices, login
 from synapse.rest.client import devices, login
 from synapse.types import JsonDict
 from synapse.types import JsonDict
 
 
@@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi
 mock_password_provider = Mock()
 mock_password_provider = Mock()
 
 
 
 
-class PasswordOnlyAuthProvider:
-    """A password_provider which only implements `check_password`."""
+class LegacyPasswordOnlyAuthProvider:
+    """A legacy password_provider which only implements `check_password`."""
 
 
     @staticmethod
     @staticmethod
     def parse_config(self):
     def parse_config(self):
@@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider:
         return mock_password_provider.check_password(*args)
         return mock_password_provider.check_password(*args)
 
 
 
 
-class CustomAuthProvider:
-    """A password_provider which implements a custom login type."""
+class LegacyCustomAuthProvider:
+    """A legacy password_provider which implements a custom login type."""
 
 
     @staticmethod
     @staticmethod
     def parse_config(self):
     def parse_config(self):
@@ -67,7 +69,23 @@ class CustomAuthProvider:
         return mock_password_provider.check_auth(*args)
         return mock_password_provider.check_auth(*args)
 
 
 
 
-class PasswordCustomAuthProvider:
+class CustomAuthProvider:
+    """A module which registers password_auth_provider callbacks for a custom login type."""
+
+    @staticmethod
+    def parse_config(self):
+        pass
+
+    def __init__(self, config, api: ModuleApi):
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+        )
+
+    def check_auth(self, *args):
+        return mock_password_provider.check_auth(*args)
+
+
+class LegacyPasswordCustomAuthProvider:
     """A password_provider which implements password login via `check_auth`, as well
     """A password_provider which implements password login via `check_auth`, as well
     as a custom type."""
     as a custom type."""
 
 
@@ -85,8 +103,32 @@ class PasswordCustomAuthProvider:
         return mock_password_provider.check_auth(*args)
         return mock_password_provider.check_auth(*args)
 
 
 
 
-def providers_config(*providers: Type[Any]) -> dict:
-    """Returns a config dict that will enable the given password auth providers"""
+class PasswordCustomAuthProvider:
+    """A module which registers password_auth_provider callbacks for a custom login type.
+    as well as a password login"""
+
+    @staticmethod
+    def parse_config(self):
+        pass
+
+    def __init__(self, config, api: ModuleApi):
+        api.register_password_auth_provider_callbacks(
+            auth_checkers={
+                ("test.login_type", ("test_field",)): self.check_auth,
+                ("m.login.password", ("password",)): self.check_auth,
+            },
+        )
+        pass
+
+    def check_auth(self, *args):
+        return mock_password_provider.check_auth(*args)
+
+    def check_pass(self, *args):
+        return mock_password_provider.check_password(*args)
+
+
+def legacy_providers_config(*providers: Type[Any]) -> dict:
+    """Returns a config dict that will enable the given legacy password auth providers"""
     return {
     return {
         "password_providers": [
         "password_providers": [
             {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
             {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
@@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict:
     }
     }
 
 
 
 
+def providers_config(*providers: Type[Any]) -> dict:
+    """Returns a config dict that will enable the given modules"""
+    return {
+        "modules": [
+            {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+            for provider in providers
+        ]
+    }
+
+
 class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 class PasswordAuthProviderTests(unittest.HomeserverTestCase):
     servlets = [
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.admin.register_servlets,
@@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
         mock_password_provider.reset_mock()
         super().setUp()
         super().setUp()
 
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_password_only_auth_provider_login(self):
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+        # Load the modules into the homeserver
+        module_api = hs.get_module_api()
+        for module, config in hs.config.modules.loaded_modules:
+            module(config=config, api=module_api)
+        load_legacy_password_auth_providers(hs)
+
+        return hs
+
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_password_only_auth_progiver_login_legacy(self):
+        self.password_only_auth_provider_login_test_body()
+
+    def password_only_auth_provider_login_test_body(self):
         # login flows should only have m.login.password
         # login flows should only have m.login.password
         flows = self._get_login_flows()
         flows = self._get_login_flows()
         self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
         self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "@ USER🙂NAME :test", " pASS😢word "
             "@ USER🙂NAME :test", " pASS😢word "
         )
         )
 
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_password_only_auth_provider_ui_auth(self):
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_password_only_auth_provider_ui_auth_legacy(self):
+        self.password_only_auth_provider_ui_auth_test_body()
+
+    def password_only_auth_provider_ui_auth_test_body(self):
         """UI Auth should delegate correctly to the password provider"""
         """UI Auth should delegate correctly to the password provider"""
 
 
         # create the user, otherwise access doesn't work
         # create the user, otherwise access doesn't work
@@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.code, 200)
         mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
         mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
 
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_local_user_fallback_login(self):
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_local_user_fallback_login_legacy(self):
+        self.local_user_fallback_login_test_body()
+
+    def local_user_fallback_login_test_body(self):
         """rejected login should fall back to local db"""
         """rejected login should fall back to local db"""
         self.register_user("localuser", "localpass")
         self.register_user("localuser", "localpass")
 
 
@@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual("@localuser:test", channel.json_body["user_id"])
         self.assertEqual("@localuser:test", channel.json_body["user_id"])
 
 
-    @override_config(providers_config(PasswordOnlyAuthProvider))
-    def test_local_user_fallback_ui_auth(self):
+    @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+    def test_local_user_fallback_ui_auth_legacy(self):
+        self.local_user_fallback_ui_auth_test_body()
+
+    def local_user_fallback_ui_auth_test_body(self):
         """rejected login should fall back to local db"""
         """rejected login should fall back to local db"""
         self.register_user("localuser", "localpass")
         self.register_user("localuser", "localpass")
 
 
@@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
 
     @override_config(
     @override_config(
         {
         {
-            **providers_config(PasswordOnlyAuthProvider),
+            **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
             "password_config": {"localdb_enabled": False},
             "password_config": {"localdb_enabled": False},
         }
         }
     )
     )
-    def test_no_local_user_fallback_login(self):
+    def test_no_local_user_fallback_login_legacy(self):
+        self.no_local_user_fallback_login_test_body()
+
+    def no_local_user_fallback_login_test_body(self):
         """localdb_enabled can block login with the local password"""
         """localdb_enabled can block login with the local password"""
         self.register_user("localuser", "localpass")
         self.register_user("localuser", "localpass")
 
 
@@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
 
     @override_config(
     @override_config(
         {
         {
-            **providers_config(PasswordOnlyAuthProvider),
+            **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
             "password_config": {"localdb_enabled": False},
             "password_config": {"localdb_enabled": False},
         }
         }
     )
     )
-    def test_no_local_user_fallback_ui_auth(self):
+    def test_no_local_user_fallback_ui_auth_legacy(self):
+        self.no_local_user_fallback_ui_auth_test_body()
+
+    def no_local_user_fallback_ui_auth_test_body(self):
         """localdb_enabled can block ui auth with the local password"""
         """localdb_enabled can block ui auth with the local password"""
         self.register_user("localuser", "localpass")
         self.register_user("localuser", "localpass")
 
 
@@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
 
     @override_config(
     @override_config(
         {
         {
-            **providers_config(PasswordOnlyAuthProvider),
+            **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
             "password_config": {"enabled": False},
             "password_config": {"enabled": False},
         }
         }
     )
     )
-    def test_password_auth_disabled(self):
+    def test_password_auth_disabled_legacy(self):
+        self.password_auth_disabled_test_body()
+
+    def password_auth_disabled_test_body(self):
         """password auth doesn't work if it's disabled across the board"""
         """password auth doesn't work if it's disabled across the board"""
         # login flows should be empty
         # login flows should be empty
         flows = self._get_login_flows()
         flows = self._get_login_flows()
@@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_password.assert_not_called()
         mock_password_provider.check_password.assert_not_called()
 
 
+    @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+    def test_custom_auth_provider_login_legacy(self):
+        self.custom_auth_provider_login_test_body()
+
     @override_config(providers_config(CustomAuthProvider))
     @override_config(providers_config(CustomAuthProvider))
     def test_custom_auth_provider_login(self):
     def test_custom_auth_provider_login(self):
+        self.custom_auth_provider_login_test_body()
+
+    def custom_auth_provider_login_test_body(self):
         # login flows should have the custom flow and m.login.password, since we
         # login flows should have the custom flow and m.login.password, since we
         # haven't disabled local password lookup.
         # haven't disabled local password lookup.
         # (password must come first, because reasons)
         # (password must come first, because reasons)
@@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
         mock_password_provider.check_auth.assert_not_called()
 
 
-        mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+        mock_password_provider.check_auth.return_value = defer.succeed(
+            ("@user:bz", None)
+        )
         channel = self._send_login("test.login_type", "u", test_field="y")
         channel = self._send_login("test.login_type", "u", test_field="y")
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual("@user:bz", channel.json_body["user_id"])
         self.assertEqual("@user:bz", channel.json_body["user_id"])
@@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         # in these cases, but at least we can guard against the API changing
         # in these cases, but at least we can guard against the API changing
         # unexpectedly
         # unexpectedly
         mock_password_provider.check_auth.return_value = defer.succeed(
         mock_password_provider.check_auth.return_value = defer.succeed(
-            "@ MALFORMED! :bz"
+            ("@ MALFORMED! :bz", None)
         )
         )
         channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
         channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(channel.code, 200, channel.result)
@@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
             " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
         )
         )
 
 
+    @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+    def test_custom_auth_provider_ui_auth_legacy(self):
+        self.custom_auth_provider_ui_auth_test_body()
+
     @override_config(providers_config(CustomAuthProvider))
     @override_config(providers_config(CustomAuthProvider))
     def test_custom_auth_provider_ui_auth(self):
     def test_custom_auth_provider_ui_auth(self):
+        self.custom_auth_provider_ui_auth_test_body()
+
+    def custom_auth_provider_ui_auth_test_body(self):
         # register the user and log in twice, to get two devices
         # register the user and log in twice, to get two devices
         self.register_user("localuser", "localpass")
         self.register_user("localuser", "localpass")
         tok1 = self.login("localuser", "localpass")
         tok1 = self.login("localuser", "localpass")
@@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.reset_mock()
         mock_password_provider.reset_mock()
 
 
         # right params, but authing as the wrong user
         # right params, but authing as the wrong user
-        mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+        mock_password_provider.check_auth.return_value = defer.succeed(
+            ("@user:bz", None)
+        )
         body["auth"]["test_field"] = "foo"
         body["auth"]["test_field"] = "foo"
         channel = self._delete_device(tok1, "dev2", body)
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 403)
         self.assertEqual(channel.code, 403)
@@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
 
         # and finally, succeed
         # and finally, succeed
         mock_password_provider.check_auth.return_value = defer.succeed(
         mock_password_provider.check_auth.return_value = defer.succeed(
-            "@localuser:test"
+            ("@localuser:test", None)
         )
         )
         channel = self._delete_device(tok1, "dev2", body)
         channel = self._delete_device(tok1, "dev2", body)
         self.assertEqual(channel.code, 200)
         self.assertEqual(channel.code, 200)
@@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "localuser", "test.login_type", {"test_field": "foo"}
             "localuser", "test.login_type", {"test_field": "foo"}
         )
         )
 
 
+    @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+    def test_custom_auth_provider_callback_legacy(self):
+        self.custom_auth_provider_callback_test_body()
+
     @override_config(providers_config(CustomAuthProvider))
     @override_config(providers_config(CustomAuthProvider))
     def test_custom_auth_provider_callback(self):
     def test_custom_auth_provider_callback(self):
+        self.custom_auth_provider_callback_test_body()
+
+    def custom_auth_provider_callback_test_body(self):
         callback = Mock(return_value=defer.succeed(None))
         callback = Mock(return_value=defer.succeed(None))
 
 
         mock_password_provider.check_auth.return_value = defer.succeed(
         mock_password_provider.check_auth.return_value = defer.succeed(
@@ -410,10 +518,22 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         for p in ["user_id", "access_token", "device_id", "home_server"]:
         for p in ["user_id", "access_token", "device_id", "home_server"]:
             self.assertIn(p, call_args[0])
             self.assertIn(p, call_args[0])
 
 
+    @override_config(
+        {
+            **legacy_providers_config(LegacyCustomAuthProvider),
+            "password_config": {"enabled": False},
+        }
+    )
+    def test_custom_auth_password_disabled_legacy(self):
+        self.custom_auth_password_disabled_test_body()
+
     @override_config(
     @override_config(
         {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
         {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
     )
     )
     def test_custom_auth_password_disabled(self):
     def test_custom_auth_password_disabled(self):
+        self.custom_auth_password_disabled_test_body()
+
+    def custom_auth_password_disabled_test_body(self):
         """Test login with a custom auth provider where password login is disabled"""
         """Test login with a custom auth provider where password login is disabled"""
         self.register_user("localuser", "localpass")
         self.register_user("localuser", "localpass")
 
 
@@ -425,6 +545,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
         mock_password_provider.check_auth.assert_not_called()
 
 
+    @override_config(
+        {
+            **legacy_providers_config(LegacyCustomAuthProvider),
+            "password_config": {"enabled": False, "localdb_enabled": False},
+        }
+    )
+    def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
+        self.custom_auth_password_disabled_localdb_enabled_test_body()
+
     @override_config(
     @override_config(
         {
         {
             **providers_config(CustomAuthProvider),
             **providers_config(CustomAuthProvider),
@@ -432,6 +561,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
         }
     )
     )
     def test_custom_auth_password_disabled_localdb_enabled(self):
     def test_custom_auth_password_disabled_localdb_enabled(self):
+        self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+    def custom_auth_password_disabled_localdb_enabled_test_body(self):
         """Check the localdb_enabled == enabled == False
         """Check the localdb_enabled == enabled == False
 
 
         Regression test for https://github.com/matrix-org/synapse/issues/8914: check
         Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -448,6 +580,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
         mock_password_provider.check_auth.assert_not_called()
 
 
+    @override_config(
+        {
+            **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+            "password_config": {"enabled": False},
+        }
+    )
+    def test_password_custom_auth_password_disabled_login_legacy(self):
+        self.password_custom_auth_password_disabled_login_test_body()
+
     @override_config(
     @override_config(
         {
         {
             **providers_config(PasswordCustomAuthProvider),
             **providers_config(PasswordCustomAuthProvider),
@@ -455,6 +596,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
         }
     )
     )
     def test_password_custom_auth_password_disabled_login(self):
     def test_password_custom_auth_password_disabled_login(self):
+        self.password_custom_auth_password_disabled_login_test_body()
+
+    def password_custom_auth_password_disabled_login_test_body(self):
         """log in with a custom auth provider which implements password, but password
         """log in with a custom auth provider which implements password, but password
         login is disabled"""
         login is disabled"""
         self.register_user("localuser", "localpass")
         self.register_user("localuser", "localpass")
@@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         channel = self._send_password_login("localuser", "localpass")
         channel = self._send_password_login("localuser", "localpass")
         self.assertEqual(channel.code, 400, channel.result)
         self.assertEqual(channel.code, 400, channel.result)
         mock_password_provider.check_auth.assert_not_called()
         mock_password_provider.check_auth.assert_not_called()
+        mock_password_provider.check_password.assert_not_called()
+
+    @override_config(
+        {
+            **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+            "password_config": {"enabled": False},
+        }
+    )
+    def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
+        self.password_custom_auth_password_disabled_ui_auth_test_body()
 
 
     @override_config(
     @override_config(
         {
         {
@@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
         }
     )
     )
     def test_password_custom_auth_password_disabled_ui_auth(self):
     def test_password_custom_auth_password_disabled_ui_auth(self):
+        self.password_custom_auth_password_disabled_ui_auth_test_body()
+
+    def password_custom_auth_password_disabled_ui_auth_test_body(self):
         """UI Auth with a custom auth provider which implements password, but password
         """UI Auth with a custom auth provider which implements password, but password
         login is disabled"""
         login is disabled"""
         # register the user and log in twice via the test login type to get two devices,
         # register the user and log in twice via the test login type to get two devices,
         self.register_user("localuser", "localpass")
         self.register_user("localuser", "localpass")
         mock_password_provider.check_auth.return_value = defer.succeed(
         mock_password_provider.check_auth.return_value = defer.succeed(
-            "@localuser:test"
+            ("@localuser:test", None)
         )
         )
         channel = self._send_login("test.login_type", "localuser", test_field="")
         channel = self._send_login("test.login_type", "localuser", test_field="")
         self.assertEqual(channel.code, 200, channel.result)
         self.assertEqual(channel.code, 200, channel.result)
@@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
             "Password login has been disabled.", channel.json_body["error"]
             "Password login has been disabled.", channel.json_body["error"]
         )
         )
         mock_password_provider.check_auth.assert_not_called()
         mock_password_provider.check_auth.assert_not_called()
+        mock_password_provider.check_password.assert_not_called()
         mock_password_provider.reset_mock()
         mock_password_provider.reset_mock()
 
 
         # successful auth
         # successful auth
@@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         mock_password_provider.check_auth.assert_called_once_with(
         mock_password_provider.check_auth.assert_called_once_with(
             "localuser", "test.login_type", {"test_field": "x"}
             "localuser", "test.login_type", {"test_field": "x"}
         )
         )
+        mock_password_provider.check_password.assert_not_called()
+
+    @override_config(
+        {
+            **legacy_providers_config(LegacyCustomAuthProvider),
+            "password_config": {"localdb_enabled": False},
+        }
+    )
+    def test_custom_auth_no_local_user_fallback_legacy(self):
+        self.custom_auth_no_local_user_fallback_test_body()
 
 
     @override_config(
     @override_config(
         {
         {
@@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         }
         }
     )
     )
     def test_custom_auth_no_local_user_fallback(self):
     def test_custom_auth_no_local_user_fallback(self):
+        self.custom_auth_no_local_user_fallback_test_body()
+
+    def custom_auth_no_local_user_fallback_test_body(self):
         """Test login with a custom auth provider where the local db is disabled"""
         """Test login with a custom auth provider where the local db is disabled"""
         self.register_user("localuser", "localpass")
         self.register_user("localuser", "localpass")