|
@@ -65,6 +65,10 @@ from synapse.http.server import finish_request, respond_with_html
|
|
|
from synapse.http.site import SynapseRequest
|
|
|
from synapse.logging.context import defer_to_thread
|
|
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
|
|
+from synapse.module_api.callbacks.password_auth_provider_callbacks import (
|
|
|
+ CHECK_3PID_AUTH_CALLBACK,
|
|
|
+ ON_LOGGED_OUT_CALLBACK,
|
|
|
+)
|
|
|
from synapse.storage.databases.main.registration import (
|
|
|
LoginTokenExpired,
|
|
|
LoginTokenLookupResult,
|
|
@@ -1096,7 +1100,7 @@ class AuthHandler:
|
|
|
return self._password_enabled_for_login and self._password_localdb_enabled
|
|
|
|
|
|
def get_supported_login_types(self) -> Iterable[str]:
|
|
|
- """Get a the login types supported for the /login API
|
|
|
+ """Get the login types supported for the /login API
|
|
|
|
|
|
By default this is just 'm.login.password' (unless password_enabled is
|
|
|
False in the config file), but password auth providers can provide
|
|
@@ -1999,124 +2003,16 @@ def load_single_legacy_password_auth_provider(
|
|
|
)
|
|
|
|
|
|
|
|
|
-CHECK_3PID_AUTH_CALLBACK = Callable[
|
|
|
- [str, str, str],
|
|
|
- Awaitable[
|
|
|
- Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
|
|
- ],
|
|
|
-]
|
|
|
-ON_LOGGED_OUT_CALLBACK = Callable[[str, Optional[str], str], Awaitable]
|
|
|
-CHECK_AUTH_CALLBACK = Callable[
|
|
|
- [str, str, JsonDict],
|
|
|
- Awaitable[
|
|
|
- Optional[Tuple[str, Optional[Callable[["LoginResponse"], Awaitable[None]]]]]
|
|
|
- ],
|
|
|
-]
|
|
|
-GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
|
|
|
- [JsonDict, JsonDict],
|
|
|
- Awaitable[Optional[str]],
|
|
|
-]
|
|
|
-GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
|
|
|
- [JsonDict, JsonDict],
|
|
|
- Awaitable[Optional[str]],
|
|
|
-]
|
|
|
-IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
|
|
|
-
|
|
|
-
|
|
|
class PasswordAuthProvider:
|
|
|
"""
|
|
|
A class that the AuthHandler calls when authenticating users
|
|
|
It allows modules to provide alternative methods for authentication
|
|
|
"""
|
|
|
|
|
|
- def __init__(self) -> None:
|
|
|
- # lists of callbacks
|
|
|
- self.check_3pid_auth_callbacks: List[CHECK_3PID_AUTH_CALLBACK] = []
|
|
|
- self.on_logged_out_callbacks: List[ON_LOGGED_OUT_CALLBACK] = []
|
|
|
- self.get_username_for_registration_callbacks: List[
|
|
|
- GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
|
|
- ] = []
|
|
|
- self.get_displayname_for_registration_callbacks: List[
|
|
|
- GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
|
|
|
- ] = []
|
|
|
- self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
|
|
|
-
|
|
|
- # Mapping from login type to login parameters
|
|
|
- self._supported_login_types: Dict[str, Tuple[str, ...]] = {}
|
|
|
-
|
|
|
- # Mapping from login type to auth checker callbacks
|
|
|
- self.auth_checker_callbacks: Dict[str, List[CHECK_AUTH_CALLBACK]] = {}
|
|
|
-
|
|
|
- def register_password_auth_provider_callbacks(
|
|
|
- self,
|
|
|
- check_3pid_auth: Optional[CHECK_3PID_AUTH_CALLBACK] = None,
|
|
|
- on_logged_out: Optional[ON_LOGGED_OUT_CALLBACK] = None,
|
|
|
- is_3pid_allowed: Optional[IS_3PID_ALLOWED_CALLBACK] = None,
|
|
|
- auth_checkers: Optional[
|
|
|
- Dict[Tuple[str, Tuple[str, ...]], CHECK_AUTH_CALLBACK]
|
|
|
- ] = None,
|
|
|
- get_username_for_registration: Optional[
|
|
|
- GET_USERNAME_FOR_REGISTRATION_CALLBACK
|
|
|
- ] = None,
|
|
|
- get_displayname_for_registration: Optional[
|
|
|
- GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
|
|
|
- ] = None,
|
|
|
- ) -> None:
|
|
|
- # Register check_3pid_auth callback
|
|
|
- if check_3pid_auth is not None:
|
|
|
- self.check_3pid_auth_callbacks.append(check_3pid_auth)
|
|
|
-
|
|
|
- # register on_logged_out callback
|
|
|
- if on_logged_out is not None:
|
|
|
- self.on_logged_out_callbacks.append(on_logged_out)
|
|
|
-
|
|
|
- if auth_checkers is not None:
|
|
|
- # register a new supported login_type
|
|
|
- # Iterate through all of the types being registered
|
|
|
- for (login_type, fields), callback in auth_checkers.items():
|
|
|
- # Note: fields may be empty here. This would allow a modules auth checker to
|
|
|
- # be called with just 'login_type' and no password or other secrets
|
|
|
-
|
|
|
- # Need to check that all the field names are strings or may get nasty errors later
|
|
|
- for f in fields:
|
|
|
- if not isinstance(f, str):
|
|
|
- raise RuntimeError(
|
|
|
- "A module tried to register support for login type: %s with parameters %s"
|
|
|
- " but all parameter names must be strings"
|
|
|
- % (login_type, fields)
|
|
|
- )
|
|
|
-
|
|
|
- # 2 modules supporting the same login type must expect the same fields
|
|
|
- # e.g. 1 can't expect "pass" if the other expects "password"
|
|
|
- # so throw an exception if that happens
|
|
|
- if login_type not in self._supported_login_types.get(login_type, []):
|
|
|
- self._supported_login_types[login_type] = fields
|
|
|
- else:
|
|
|
- fields_currently_supported = self._supported_login_types.get(
|
|
|
- login_type
|
|
|
- )
|
|
|
- if fields_currently_supported != fields:
|
|
|
- raise RuntimeError(
|
|
|
- "A module tried to register support for login type: %s with parameters %s"
|
|
|
- " but another module had already registered support for that type with parameters %s"
|
|
|
- % (login_type, fields, fields_currently_supported)
|
|
|
- )
|
|
|
-
|
|
|
- # Add the new method to the list of auth_checker_callbacks for this login type
|
|
|
- self.auth_checker_callbacks.setdefault(login_type, []).append(callback)
|
|
|
-
|
|
|
- if get_username_for_registration is not None:
|
|
|
- self.get_username_for_registration_callbacks.append(
|
|
|
- get_username_for_registration,
|
|
|
- )
|
|
|
-
|
|
|
- if get_displayname_for_registration is not None:
|
|
|
- self.get_displayname_for_registration_callbacks.append(
|
|
|
- get_displayname_for_registration,
|
|
|
- )
|
|
|
-
|
|
|
- if is_3pid_allowed is not None:
|
|
|
- self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
|
|
|
+ def __init__(self, hs: "HomeServer") -> None:
|
|
|
+ self._module_api_callbacks = (
|
|
|
+ hs.get_module_api_callbacks().password_auth_provider
|
|
|
+ )
|
|
|
|
|
|
def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
|
|
|
"""Get the login types supported by this password provider
|
|
@@ -2126,7 +2022,7 @@ class PasswordAuthProvider:
|
|
|
to the /login API.
|
|
|
"""
|
|
|
|
|
|
- return self._supported_login_types
|
|
|
+ return self._module_api_callbacks.supported_login_types
|
|
|
|
|
|
async def check_auth(
|
|
|
self, username: str, login_type: str, login_dict: JsonDict
|
|
@@ -2149,7 +2045,7 @@ class PasswordAuthProvider:
|
|
|
|
|
|
# Go through all callbacks for the login type until one returns with a value
|
|
|
# other than None (i.e. until a callback returns a success)
|
|
|
- for callback in self.auth_checker_callbacks[login_type]:
|
|
|
+ for callback in self._module_api_callbacks.auth_checker_callbacks[login_type]:
|
|
|
try:
|
|
|
result = await delay_cancellation(
|
|
|
callback(username, login_type, login_dict)
|
|
@@ -2214,7 +2110,7 @@ class PasswordAuthProvider:
|
|
|
# (user_id, callback_func), where callback_func should be run
|
|
|
# after we've finished everything else
|
|
|
|
|
|
- for callback in self.check_3pid_auth_callbacks:
|
|
|
+ for callback in self._module_api_callbacks.check_3pid_auth_callbacks:
|
|
|
try:
|
|
|
result = await delay_cancellation(callback(medium, address, password))
|
|
|
except CancelledError:
|
|
@@ -2272,7 +2168,7 @@ class PasswordAuthProvider:
|
|
|
self, user_id: str, device_id: Optional[str], access_token: str
|
|
|
) -> None:
|
|
|
# call all of the on_logged_out callbacks
|
|
|
- for callback in self.on_logged_out_callbacks:
|
|
|
+ for callback in self._module_api_callbacks.on_logged_out_callbacks:
|
|
|
try:
|
|
|
await callback(user_id, device_id, access_token)
|
|
|
except Exception as e:
|
|
@@ -2297,7 +2193,9 @@ class PasswordAuthProvider:
|
|
|
The localpart to use when registering this user, or None if no module
|
|
|
returned a localpart.
|
|
|
"""
|
|
|
- for callback in self.get_username_for_registration_callbacks:
|
|
|
+ for (
|
|
|
+ callback
|
|
|
+ ) in self._module_api_callbacks.get_username_for_registration_callbacks:
|
|
|
try:
|
|
|
res = await delay_cancellation(callback(uia_results, params))
|
|
|
|
|
@@ -2342,7 +2240,9 @@ class PasswordAuthProvider:
|
|
|
A tuple which first element is the display name, and the second is an MXC URL
|
|
|
to the user's avatar.
|
|
|
"""
|
|
|
- for callback in self.get_displayname_for_registration_callbacks:
|
|
|
+ for (
|
|
|
+ callback
|
|
|
+ ) in self._module_api_callbacks.get_displayname_for_registration_callbacks:
|
|
|
try:
|
|
|
res = await delay_cancellation(callback(uia_results, params))
|
|
|
|
|
@@ -2385,7 +2285,7 @@ class PasswordAuthProvider:
|
|
|
Returns:
|
|
|
Whether the 3PID is allowed to be bound on this homeserver
|
|
|
"""
|
|
|
- for callback in self.is_3pid_allowed_callbacks:
|
|
|
+ for callback in self._module_api_callbacks.is_3pid_allowed_callbacks:
|
|
|
try:
|
|
|
res = await delay_cancellation(callback(medium, address, registration))
|
|
|
|