|
@@ -200,46 +200,13 @@ class AuthHandler:
|
|
|
|
|
|
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
|
|
|
|
|
|
- # we can't use hs.get_module_api() here, because to do so will create an
|
|
|
- # import loop.
|
|
|
- #
|
|
|
- # TODO: refactor this class to separate the lower-level stuff that
|
|
|
- # ModuleApi can use from the higher-level stuff that uses ModuleApi, as
|
|
|
- # better way to break the loop
|
|
|
- account_handler = ModuleApi(hs, self)
|
|
|
-
|
|
|
- self.password_providers = [
|
|
|
- PasswordProvider.load(module, config, account_handler)
|
|
|
- for module, config in hs.config.authproviders.password_providers
|
|
|
- ]
|
|
|
-
|
|
|
- logger.info("Extra password_providers: %s", self.password_providers)
|
|
|
+ self.password_auth_provider = hs.get_password_auth_provider()
|
|
|
|
|
|
self.hs = hs # FIXME better possibility to access registrationHandler later?
|
|
|
self.macaroon_gen = hs.get_macaroon_generator()
|
|
|
self._password_enabled = hs.config.auth.password_enabled
|
|
|
self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
|
|
|
|
|
|
- # start out by assuming PASSWORD is enabled; we will remove it later if not.
|
|
|
- login_types = set()
|
|
|
- if self._password_localdb_enabled:
|
|
|
- login_types.add(LoginType.PASSWORD)
|
|
|
-
|
|
|
- for provider in self.password_providers:
|
|
|
- login_types.update(provider.get_supported_login_types().keys())
|
|
|
-
|
|
|
- if not self._password_enabled:
|
|
|
- login_types.discard(LoginType.PASSWORD)
|
|
|
-
|
|
|
- # Some clients just pick the first type in the list. In this case, we want
|
|
|
- # them to use PASSWORD (rather than token or whatever), so we want to make sure
|
|
|
- # that comes first, where it's present.
|
|
|
- self._supported_login_types = []
|
|
|
- if LoginType.PASSWORD in login_types:
|
|
|
- self._supported_login_types.append(LoginType.PASSWORD)
|
|
|
- login_types.remove(LoginType.PASSWORD)
|
|
|
- self._supported_login_types.extend(login_types)
|
|
|
-
|
|
|
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
|
|
# as per `rc_login.failed_attempts`.
|
|
|
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
|
@@ -427,11 +394,10 @@ class AuthHandler:
|
|
|
ui_auth_types.add(LoginType.PASSWORD)
|
|
|
|
|
|
# also allow auth from password providers
|
|
|
- for provider in self.password_providers:
|
|
|
- for t in provider.get_supported_login_types().keys():
|
|
|
- if t == LoginType.PASSWORD and not self._password_enabled:
|
|
|
- continue
|
|
|
- ui_auth_types.add(t)
|
|
|
+ for t in self.password_auth_provider.get_supported_login_types().keys():
|
|
|
+ if t == LoginType.PASSWORD and not self._password_enabled:
|
|
|
+ continue
|
|
|
+ ui_auth_types.add(t)
|
|
|
|
|
|
# if sso is enabled, allow the user to log in via SSO iff they have a mapping
|
|
|
# from sso to mxid.
|
|
@@ -1038,7 +1004,25 @@ class AuthHandler:
|
|
|
Returns:
|
|
|
login types
|
|
|
"""
|
|
|
- return self._supported_login_types
|
|
|
+ # Load any login types registered by modules
|
|
|
+ # This is stored in the password_auth_provider so this doesn't trigger
|
|
|
+ # any callbacks
|
|
|
+ types = list(self.password_auth_provider.get_supported_login_types().keys())
|
|
|
+
|
|
|
+ # This list should include PASSWORD if (either _password_localdb_enabled is
|
|
|
+ # true or if one of the modules registered it) AND _password_enabled is true
|
|
|
+ # Also:
|
|
|
+ # Some clients just pick the first type in the list. In this case, we want
|
|
|
+ # them to use PASSWORD (rather than token or whatever), so we want to make sure
|
|
|
+ # that comes first, where it's present.
|
|
|
+ if LoginType.PASSWORD in types:
|
|
|
+ types.remove(LoginType.PASSWORD)
|
|
|
+ if self._password_enabled:
|
|
|
+ types.insert(0, LoginType.PASSWORD)
|
|
|
+ elif self._password_localdb_enabled and self._password_enabled:
|
|
|
+ types.insert(0, LoginType.PASSWORD)
|
|
|
+
|
|
|
+ return types
|
|
|
|
|
|
async def validate_login(
|
|
|
self,
|
|
@@ -1217,15 +1201,20 @@ class AuthHandler:
|
|
|
|
|
|
known_login_type = False
|
|
|
|
|
|
- for provider in self.password_providers:
|
|
|
- supported_login_types = provider.get_supported_login_types()
|
|
|
- if login_type not in supported_login_types:
|
|
|
- # this password provider doesn't understand this login type
|
|
|
- continue
|
|
|
-
|
|
|
+ # Check if login_type matches a type registered by one of the modules
|
|
|
+ # We don't need to remove LoginType.PASSWORD from the list if password login is
|
|
|
+ # disabled, since if that were the case then by this point we know that the
|
|
|
+ # login_type is not LoginType.PASSWORD
|
|
|
+ supported_login_types = self.password_auth_provider.get_supported_login_types()
|
|
|
+ # check if the login type being used is supported by a module
|
|
|
+ if login_type in supported_login_types:
|
|
|
+ # Make a note that this login type is supported by the server
|
|
|
known_login_type = True
|
|
|
+ # Get all the fields expected for this login types
|
|
|
login_fields = supported_login_types[login_type]
|
|
|
|
|
|
+ # go through the login submission and keep track of which required fields are
|
|
|
+ # provided/not provided
|
|
|
missing_fields = []
|
|
|
login_dict = {}
|
|
|
for f in login_fields:
|
|
@@ -1233,6 +1222,7 @@ class AuthHandler:
|
|
|
missing_fields.append(f)
|
|
|
else:
|
|
|
login_dict[f] = login_submission[f]
|
|
|
+ # raise an error if any of the expected fields for that login type weren't provided
|
|
|
if missing_fields:
|
|
|
raise SynapseError(
|
|
|
400,
|
|
@@ -1240,10 +1230,15 @@ class AuthHandler:
|
|
|
% (login_type, missing_fields),
|
|
|
)
|
|
|
|
|
|
- result = await provider.check_auth(username, login_type, login_dict)
|
|
|
+ # call all of the check_auth hooks for that login_type
|
|
|
+ # it will return a result once the first success is found (or None otherwise)
|
|
|
+ result = await self.password_auth_provider.check_auth(
|
|
|
+ username, login_type, login_dict
|
|
|
+ )
|
|
|
if result:
|
|
|
return result
|
|
|
|
|
|
+ # if no module managed to authenticate the user, then fallback to built in password based auth
|
|
|
if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
|
|
|
known_login_type = True
|
|
|
|
|
@@ -1282,11 +1277,16 @@ class AuthHandler:
|
|
|
completed login/registration, or `None`. If authentication was
|
|
|
unsuccessful, `user_id` and `callback` are both `None`.
|
|
|
"""
|
|
|
- for provider in self.password_providers:
|
|
|
- result = await provider.check_3pid_auth(medium, address, password)
|
|
|
- if result:
|
|
|
- return result
|
|
|
+ # call all of the check_3pid_auth callbacks
|
|
|
+ # Result will be from the first callback that returns something other than None
|
|
|
+ # If all the callbacks return None, then result is also set to None
|
|
|
+ result = await self.password_auth_provider.check_3pid_auth(
|
|
|
+ medium, address, password
|
|
|
+ )
|
|
|
+ if result:
|
|
|
+ return result
|
|
|
|
|
|
+ # if result is None then return (None, None)
|
|
|
return None, None
|
|
|
|
|
|
async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
|
|
@@ -1365,13 +1365,12 @@ class AuthHandler:
|
|
|
user_info = await self.auth.get_user_by_access_token(access_token)
|
|
|
await self.store.delete_access_token(access_token)
|
|
|
|
|
|
- # see if any of our auth providers want to know about this
|
|
|
- for provider in self.password_providers:
|
|
|
- await provider.on_logged_out(
|
|
|
- user_id=user_info.user_id,
|
|
|
- device_id=user_info.device_id,
|
|
|
- access_token=access_token,
|
|
|
- )
|
|
|
+ # see if any modules want to know about this
|
|
|
+ await self.password_auth_provider.on_logged_out(
|
|
|
+ user_id=user_info.user_id,
|
|
|
+ device_id=user_info.device_id,
|
|
|
+ access_token=access_token,
|
|
|
+ )
|
|
|
|
|
|
# delete pushers associated with this access token
|
|
|
if user_info.token_id is not None:
|
|
@@ -1398,12 +1397,11 @@ class AuthHandler:
|
|
|
user_id, except_token_id=except_token_id, device_id=device_id
|
|
|
)
|
|
|
|
|
|
- # see if any of our auth providers want to know about this
|
|
|
- for provider in self.password_providers:
|
|
|
- for token, _, device_id in tokens_and_devices:
|
|
|
- await provider.on_logged_out(
|
|
|
- user_id=user_id, device_id=device_id, access_token=token
|
|
|
- )
|
|
|
+ # see if any modules want to know about this
|
|
|
+ for token, _, device_id in tokens_and_devices:
|
|
|
+ await self.password_auth_provider.on_logged_out(
|
|
|
+ user_id=user_id, device_id=device_id, access_token=token
|
|
|
+ )
|
|
|
|
|
|
# delete pushers associated with the access tokens
|
|
|
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
|
@@ -1811,40 +1809,228 @@ class MacaroonGenerator:
|
|
|
return macaroon
|
|
|
|
|
|
|
|
|
-class PasswordProvider:
|
|
|
- """Wrapper for a password auth provider module
|
|
|
+def load_legacy_password_auth_providers(hs: "HomeServer") -> None:
|
|
|
+ module_api = hs.get_module_api()
|
|
|
+ for module, config in hs.config.authproviders.password_providers:
|
|
|
+ load_single_legacy_password_auth_provider(
|
|
|
+ module=module, config=config, api=module_api
|
|
|
+ )
|
|
|
|
|
|
- This class abstracts out all of the backwards-compatibility hacks for
|
|
|
- password providers, to provide a consistent interface.
|
|
|
- """
|
|
|
|
|
|
- @classmethod
|
|
|
- def load(
|
|
|
- cls, module: Type, config: JsonDict, module_api: ModuleApi
|
|
|
- ) -> "PasswordProvider":
|
|
|
- try:
|
|
|
- pp = module(config=config, account_handler=module_api)
|
|
|
- except Exception as e:
|
|
|
- logger.error("Error while initializing %r: %s", module, e)
|
|
|
- raise
|
|
|
- return cls(pp, module_api)
|
|
|
+def load_single_legacy_password_auth_provider(
|
|
|
+ module: Type, config: JsonDict, api: ModuleApi
|
|
|
+) -> None:
|
|
|
+ try:
|
|
|
+ provider = module(config=config, account_handler=api)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error("Error while initializing %r: %s", module, e)
|
|
|
+ raise
|
|
|
+
|
|
|
+ # The known hooks. If a module implements a method who's name appears in this set
|
|
|
+ # we'll want to register it
|
|
|
+ password_auth_provider_methods = {
|
|
|
+ "check_3pid_auth",
|
|
|
+ "on_logged_out",
|
|
|
+ }
|
|
|
+
|
|
|
+ # All methods that the module provides should be async, but this wasn't enforced
|
|
|
+ # in the old module system, so we wrap them if needed
|
|
|
+ def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
|
|
|
+ # f might be None if the callback isn't implemented by the module. In this
|
|
|
+ # case we don't want to register a callback at all so we return None.
|
|
|
+ if f is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # We need to wrap check_password because its old form would return a boolean
|
|
|
+ # but we now want it to behave just like check_auth() and return the matrix id of
|
|
|
+ # the user if authentication succeeded or None otherwise
|
|
|
+ if f.__name__ == "check_password":
|
|
|
+
|
|
|
+ async def wrapped_check_password(
|
|
|
+ username: str, login_type: str, login_dict: JsonDict
|
|
|
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
|
|
|
+ # We've already made sure f is not None above, but mypy doesn't do well
|
|
|
+ # across function boundaries so we need to tell it f is definitely not
|
|
|
+ # None.
|
|
|
+ assert f is not None
|
|
|
+
|
|
|
+ matrix_user_id = api.get_qualified_user_id(username)
|
|
|
+ password = login_dict["password"]
|
|
|
+
|
|
|
+ is_valid = await f(matrix_user_id, password)
|
|
|
+
|
|
|
+ if is_valid:
|
|
|
+ return matrix_user_id, None
|
|
|
+
|
|
|
+ return None
|
|
|
|
|
|
- def __init__(self, pp: "PasswordProvider", module_api: ModuleApi):
|
|
|
- self._pp = pp
|
|
|
- self._module_api = module_api
|
|
|
+ return wrapped_check_password
|
|
|
+
|
|
|
+ # We need to wrap check_auth as in the old form it could return
|
|
|
+ # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
|
|
|
+ if f.__name__ == "check_auth":
|
|
|
+
|
|
|
+ async def wrapped_check_auth(
|
|
|
+ username: str, login_type: str, login_dict: JsonDict
|
|
|
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
|
|
|
+ # We've already made sure f is not None above, but mypy doesn't do well
|
|
|
+ # across function boundaries so we need to tell it f is definitely not
|
|
|
+ # None.
|
|
|
+ assert f is not None
|
|
|
+
|
|
|
+ result = await f(username, login_type, login_dict)
|
|
|
+
|
|
|
+ if isinstance(result, str):
|
|
|
+ return result, None
|
|
|
+
|
|
|
+ return result
|
|
|
+
|
|
|
+ return wrapped_check_auth
|
|
|
+
|
|
|
+ # We need to wrap check_3pid_auth as in the old form it could return
|
|
|
+ # just a str, but now it must return Optional[Tuple[str, Optional[Callable]]
|
|
|
+ if f.__name__ == "check_3pid_auth":
|
|
|
+
|
|
|
+ async def wrapped_check_3pid_auth(
|
|
|
+ medium: str, address: str, password: str
|
|
|
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
|
|
|
+ # We've already made sure f is not None above, but mypy doesn't do well
|
|
|
+ # across function boundaries so we need to tell it f is definitely not
|
|
|
+ # None.
|
|
|
+ assert f is not None
|
|
|
+
|
|
|
+ result = await f(medium, address, password)
|
|
|
+
|
|
|
+ if isinstance(result, str):
|
|
|
+ return result, None
|
|
|
+
|
|
|
+ return result
|
|
|
|
|
|
- self._supported_login_types = {}
|
|
|
+ return wrapped_check_3pid_auth
|
|
|
|
|
|
- # grandfather in check_password support
|
|
|
- if hasattr(self._pp, "check_password"):
|
|
|
- self._supported_login_types[LoginType.PASSWORD] = ("password",)
|
|
|
+ def run(*args: Tuple, **kwargs: Dict) -> Awaitable:
|
|
|
+ # mypy doesn't do well across function boundaries so we need to tell it
|
|
|
+ # f is definitely not None.
|
|
|
+ assert f is not None
|
|
|
|
|
|
- g = getattr(self._pp, "get_supported_login_types", None)
|
|
|
- if g:
|
|
|
- self._supported_login_types.update(g())
|
|
|
+ return maybe_awaitable(f(*args, **kwargs))
|
|
|
|
|
|
- def __str__(self) -> str:
|
|
|
- return str(self._pp)
|
|
|
+ return run
|
|
|
+
|
|
|
+ # populate hooks with the implemented methods, wrapped with async_wrapper
|
|
|
+ hooks = {
|
|
|
+ hook: async_wrapper(getattr(provider, hook, None))
|
|
|
+ for hook in password_auth_provider_methods
|
|
|
+ }
|
|
|
+
|
|
|
+ supported_login_types = {}
|
|
|
+ # call get_supported_login_types and add that to the dict
|
|
|
+ g = getattr(provider, "get_supported_login_types", None)
|
|
|
+ if g is not None:
|
|
|
+ # Note the old module style also called get_supported_login_types at loading time
|
|
|
+ # and it is synchronous
|
|
|
+ supported_login_types.update(g())
|
|
|
+
|
|
|
+ auth_checkers = {}
|
|
|
+ # Legacy modules have a check_auth method which expects to be called with one of
|
|
|
+ # the keys returned by get_supported_login_types. New style modules register a
|
|
|
+ # dictionary of login_type->check_auth_method mappings
|
|
|
+ check_auth = async_wrapper(getattr(provider, "check_auth", None))
|
|
|
+ if check_auth is not None:
|
|
|
+ for login_type, fields in supported_login_types.items():
|
|
|
+ # need tuple(fields) since fields can be any Iterable type (so may not be hashable)
|
|
|
+ auth_checkers[(login_type, tuple(fields))] = check_auth
|
|
|
+
|
|
|
+ # if it has a "check_password" method then it should handle all auth checks
|
|
|
+ # with login type of LoginType.PASSWORD
|
|
|
+ check_password = async_wrapper(getattr(provider, "check_password", None))
|
|
|
+ if check_password is not None:
|
|
|
+ # need to use a tuple here for ("password",) not a list since lists aren't hashable
|
|
|
+ auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
|
|
|
+
|
|
|
+ api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
|
|
|
+
|
|
|
+
|
|
|
+CHECK_3PID_AUTH_CALLBACK = Callable[
|
|
|
+ [str, str, str],
|
|
|
+ Awaitable[
|
|
|
+ Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
|
|
+ ],
|
|
|
+]
|
|
|
+ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
|
|
|
+CHECK_AUTH_CALLBACK = Callable[
|
|
|
+ [str, str, JsonDict],
|
|
|
+ Awaitable[
|
|
|
+ Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
|
|
+ ],
|
|
|
+]
|
|
|
+
|
|
|
+
|
|
|
+class PasswordAuthProvider:
|
|
|
+ """
|
|
|
+ A class that the AuthHandler calls when authenticating users
|
|
|
+ It allows modules to provide alternative methods for authentication
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self) -> None:
|
|
|
+ # lists of callbacks
|
|
|
+ self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
|
|
|
+ self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
|
|
|
+
|
|
|
+ # Mapping from login type to login parameters
|
|
|
+ self._supported_login_types: Dict[str, Iterable[str]] = {}
|
|
|
+
|
|
|
+ # Mapping from login type to auth checker callbacks
|
|
|
+ self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
|
|
|
+
|
|
|
+ def register_password_auth_provider_callbacks(
|
|
|
+ self,
|
|
|
+ check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
|
|
|
+ on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
|
|
|
+ auth_checkers: Optional[Dict[Tuple[str, Tuple], CHECK_AUTH_CALLBACK]] = None,
|
|
|
+ ) -> None:
|
|
|
+ # Register check_3pid_auth callback
|
|
|
+ if check_3pid_auth is not None:
|
|
|
+ self.check_3pid_auth_callbacks.append(check_3pid_auth)
|
|
|
+
|
|
|
+ # register on_logged_out callback
|
|
|
+ if on_logged_out is not None:
|
|
|
+ self.on_logged_out_callbacks.append(on_logged_out)
|
|
|
+
|
|
|
+ if auth_checkers is not None:
|
|
|
+ # register a new supported login_type
|
|
|
+ # Iterate through all of the types being registered
|
|
|
+ for (login_type, fields), callback in auth_checkers.items():
|
|
|
+ # Note: fields may be empty here. This would allow a modules auth checker to
|
|
|
+ # be called with just 'login_type' and no password or other secrets
|
|
|
+
|
|
|
+ # Need to check that all the field names are strings or may get nasty errors later
|
|
|
+ for f in fields:
|
|
|
+ if not isinstance(f, str):
|
|
|
+ raise RuntimeError(
|
|
|
+ "A module tried to register support for login type: %s with parameters %s"
|
|
|
+ " but all parameter names must be strings"
|
|
|
+ % (login_type, fields)
|
|
|
+ )
|
|
|
+
|
|
|
+ # 2 modules supporting the same login type must expect the same fields
|
|
|
+ # e.g. 1 can't expect "pass" if the other expects "password"
|
|
|
+ # so throw an exception if that happens
|
|
|
+ if login_type not in self._supported_login_types.get(login_type, []):
|
|
|
+ self._supported_login_types[login_type] = fields
|
|
|
+ else:
|
|
|
+ fields_currently_supported = self._supported_login_types.get(
|
|
|
+ login_type
|
|
|
+ )
|
|
|
+ if fields_currently_supported != fields:
|
|
|
+ raise RuntimeError(
|
|
|
+ "A module tried to register support for login type: %s with parameters %s"
|
|
|
+ " but another module had already registered support for that type with parameters %s"
|
|
|
+ % (login_type, fields, fields_currently_supported)
|
|
|
+ )
|
|
|
+
|
|
|
+ # Add the new method to the list of auth_checker_callbacks for this login type
|
|
|
+ self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
|
|
|
|
|
|
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
|
|
"""Get the login types supported by this password provider
|
|
@@ -1852,20 +2038,15 @@ class PasswordProvider:
|
|
|
Returns a map from a login type identifier (such as m.login.password) to an
|
|
|
iterable giving the fields which must be provided by the user in the submission
|
|
|
to the /login API.
|
|
|
-
|
|
|
- This wrapper adds m.login.password to the list if the underlying password
|
|
|
- provider supports the check_password() api.
|
|
|
"""
|
|
|
+
|
|
|
return self._supported_login_types
|
|
|
|
|
|
async def check_auth(
|
|
|
self, username: str, login_type: str, login_dict: JsonDict
|
|
|
- ) -> Optional[Tuple[str, Optional[Callable]]]:
|
|
|
+ ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
|
|
|
"""Check if the user has presented valid login credentials
|
|
|
|
|
|
- This wrapper also calls check_password() if the underlying password provider
|
|
|
- supports the check_password() api and the login type is m.login.password.
|
|
|
-
|
|
|
Args:
|
|
|
username: user id presented by the client. Either an MXID or an unqualified
|
|
|
username.
|
|
@@ -1879,63 +2060,130 @@ class PasswordProvider:
|
|
|
user, and `callback` is an optional callback which will be called with the
|
|
|
result from the /login call (including access_token, device_id, etc.)
|
|
|
"""
|
|
|
- # first grandfather in a call to check_password
|
|
|
- if login_type == LoginType.PASSWORD:
|
|
|
- check_password = getattr(self._pp, "check_password", None)
|
|
|
- if check_password:
|
|
|
- qualified_user_id = self._module_api.get_qualified_user_id(username)
|
|
|
- is_valid = await check_password(
|
|
|
- qualified_user_id, login_dict["password"]
|
|
|
- )
|
|
|
- if is_valid:
|
|
|
- return qualified_user_id, None
|
|
|
|
|
|
- check_auth = getattr(self._pp, "check_auth", None)
|
|
|
- if not check_auth:
|
|
|
- return None
|
|
|
- result = await check_auth(username, login_type, login_dict)
|
|
|
+ # Go through all callbacks for the login type until one returns with a value
|
|
|
+ # other than None (i.e. until a callback returns a success)
|
|
|
+ for callback in self.auth_checker_callbacks[login_type]:
|
|
|
+ try:
|
|
|
+ result = await callback(username, login_type, login_dict)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
|
|
|
+ continue
|
|
|
|
|
|
- # Check if the return value is a str or a tuple
|
|
|
- if isinstance(result, str):
|
|
|
- # If it's a str, set callback function to None
|
|
|
- return result, None
|
|
|
+ if result is not None:
|
|
|
+ # Check that the callback returned a Tuple[str, Optional[Callable]]
|
|
|
+ # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
|
|
|
+ # result is always the right type, but as it is 3rd party code it might not be
|
|
|
+
|
|
|
+ if not isinstance(result, tuple) or len(result) != 2:
|
|
|
+ logger.warning(
|
|
|
+ "Wrong type returned by module API callback %s: %s, expected"
|
|
|
+ " Optional[Tuple[str, Optional[Callable]]]",
|
|
|
+ callback,
|
|
|
+ result,
|
|
|
+ )
|
|
|
+ continue
|
|
|
|
|
|
- return result
|
|
|
+ # pull out the two parts of the tuple so we can do type checking
|
|
|
+ str_result, callback_result = result
|
|
|
+
|
|
|
+ # the 1st item in the tuple should be a str
|
|
|
+ if not isinstance(str_result, str):
|
|
|
+ logger.warning( # type: ignore[unreachable]
|
|
|
+ "Wrong type returned by module API callback %s: %s, expected"
|
|
|
+ " Optional[Tuple[str, Optional[Callable]]]",
|
|
|
+ callback,
|
|
|
+ result,
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ # the second should be Optional[Callable]
|
|
|
+ if callback_result is not None:
|
|
|
+ if not callable(callback_result):
|
|
|
+ logger.warning( # type: ignore[unreachable]
|
|
|
+ "Wrong type returned by module API callback %s: %s, expected"
|
|
|
+ " Optional[Tuple[str, Optional[Callable]]]",
|
|
|
+ callback,
|
|
|
+ result,
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ # The result is a (str, Optional[callback]) tuple so return the successful result
|
|
|
+ return result
|
|
|
+
|
|
|
+ # If this point has been reached then none of the callbacks successfully authenticated
|
|
|
+ # the user so return None
|
|
|
+ return None
|
|
|
|
|
|
async def check_3pid_auth(
|
|
|
self, medium: str, address: str, password: str
|
|
|
- ) -> Optional[Tuple[str, Optional[Callable]]]:
|
|
|
- g = getattr(self._pp, "check_3pid_auth", None)
|
|
|
- if not g:
|
|
|
- return None
|
|
|
-
|
|
|
+ ) -> Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]:
|
|
|
# This function is able to return a deferred that either
|
|
|
# resolves None, meaning authentication failure, or upon
|
|
|
# success, to a str (which is the user_id) or a tuple of
|
|
|
# (user_id, callback_func), where callback_func should be run
|
|
|
# after we've finished everything else
|
|
|
- result = await g(medium, address, password)
|
|
|
|
|
|
- # Check if the return value is a str or a tuple
|
|
|
- if isinstance(result, str):
|
|
|
- # If it's a str, set callback function to None
|
|
|
- return result, None
|
|
|
+ for callback in self.check_3pid_auth_callbacks:
|
|
|
+ try:
|
|
|
+ result = await callback(medium, address, password)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
|
|
|
+ continue
|
|
|
|
|
|
- return result
|
|
|
+ if result is not None:
|
|
|
+ # Check that the callback returned a Tuple[str, Optional[Callable]]
|
|
|
+ # "type: ignore[unreachable]" is used after some isinstance checks because mypy thinks
|
|
|
+ # result is always the right type, but as it is 3rd party code it might not be
|
|
|
+
|
|
|
+ if not isinstance(result, tuple) or len(result) != 2:
|
|
|
+ logger.warning(
|
|
|
+ "Wrong type returned by module API callback %s: %s, expected"
|
|
|
+ " Optional[Tuple[str, Optional[Callable]]]",
|
|
|
+ callback,
|
|
|
+ result,
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ # pull out the two parts of the tuple so we can do type checking
|
|
|
+ str_result, callback_result = result
|
|
|
+
|
|
|
+ # the 1st item in the tuple should be a str
|
|
|
+ if not isinstance(str_result, str):
|
|
|
+ logger.warning( # type: ignore[unreachable]
|
|
|
+ "Wrong type returned by module API callback %s: %s, expected"
|
|
|
+ " Optional[Tuple[str, Optional[Callable]]]",
|
|
|
+ callback,
|
|
|
+ result,
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ # the second should be Optional[Callable]
|
|
|
+ if callback_result is not None:
|
|
|
+ if not callable(callback_result):
|
|
|
+ logger.warning( # type: ignore[unreachable]
|
|
|
+ "Wrong type returned by module API callback %s: %s, expected"
|
|
|
+ " Optional[Tuple[str, Optional[Callable]]]",
|
|
|
+ callback,
|
|
|
+ result,
|
|
|
+ )
|
|
|
+ continue
|
|
|
+
|
|
|
+ # The result is a (str, Optional[callback]) tuple so return the successful result
|
|
|
+ return result
|
|
|
+
|
|
|
+ # If this point has been reached then none of the callbacks successfully authenticated
|
|
|
+ # the user so return None
|
|
|
+ return None
|
|
|
|
|
|
async def on_logged_out(
|
|
|
self, user_id: str, device_id: Optional[str], access_token: str
|
|
|
) -> None:
|
|
|
- g = getattr(self._pp, "on_logged_out", None)
|
|
|
- if not g:
|
|
|
- return
|
|
|
|
|
|
- # This might return an awaitable, if it does block the log out
|
|
|
- # until it completes.
|
|
|
- await maybe_awaitable(
|
|
|
- g(
|
|
|
- user_id=user_id,
|
|
|
- device_id=device_id,
|
|
|
- access_token=access_token,
|
|
|
- )
|
|
|
- )
|
|
|
+ # call all of the on_logged_out callbacks
|
|
|
+ for callback in self.on_logged_out_callbacks:
|
|
|
+ try:
|
|
|
+ callback(user_id, device_id, access_token)
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning("Failed to run module API callback %s: %s", callback, e)
|
|
|
+ continue
|