Browse Source

ModuleAPI SSO auth callbacks (#15207)

Signed-off-by: Andrii Yasynyshyn yasinishyn.a.n@gmail.com
Andrew Yasinishyn 4 months ago
parent
commit
63d96bfc61

+ 1 - 0
changelog.d/15207.feature

@@ -0,0 +1 @@
+Adds on_user_login ModuleAPI callback allowing to execute custom code after (on) Auth.

+ 13 - 0
docs/modules/account_validity_callbacks.md

@@ -42,3 +42,16 @@ operations to keep track of them. (e.g. add them to a database table). The user
 represented by their Matrix user ID.
 
 If multiple modules implement this callback, Synapse runs them all in order.
+
+### `on_user_login`
+
+_First introduced in Synapse v1.98.0_
+
+```python
+async def on_user_login(user_id: str, auth_provider_type: str, auth_provider_id: str) -> None
+```
+
+Called after successfully login or registration of a user for cases when module needs to perform extra operations after auth.
+represented by their Matrix user ID.
+
+If multiple modules implement this callback, Synapse runs them all in order.

+ 1 - 2
rust/src/push/mod.rs

@@ -296,8 +296,7 @@ impl<'source> FromPyObject<'source> for JsonValue {
             match l.iter().map(SimpleJsonValue::extract).collect() {
                 Ok(a) => Ok(JsonValue::Array(a)),
                 Err(e) => Err(PyTypeError::new_err(format!(
-                    "Can't convert to JsonValue::Array: {}",
-                    e
+                    "Can't convert to JsonValue::Array: {e}"
                 ))),
             }
         } else if let Ok(v) = SimpleJsonValue::extract(ob) {

+ 16 - 0
synapse/handlers/account_validity.py

@@ -98,6 +98,22 @@ class AccountValidityHandler:
         for callback in self._module_api_callbacks.on_user_registration_callbacks:
             await callback(user_id)
 
+    async def on_user_login(
+        self,
+        user_id: str,
+        auth_provider_type: Optional[str],
+        auth_provider_id: Optional[str],
+    ) -> None:
+        """Tell third-party modules about a user logins.
+
+        Args:
+            user_id: The mxID of the user.
+            auth_provider_type: The type of login.
+            auth_provider_id: The ID of the auth provider.
+        """
+        for callback in self._module_api_callbacks.on_user_login_callbacks:
+            await callback(user_id, auth_provider_type, auth_provider_id)
+
     @wrap_as_background_process("send_renewals")
     async def _send_renewal_emails(self) -> None:
         """Gets the list of users whose account is expiring in the amount of time

+ 8 - 0
synapse/handlers/auth.py

@@ -212,6 +212,7 @@ class AuthHandler:
         self._password_enabled_for_reauth = hs.config.auth.password_enabled_for_reauth
         self._password_localdb_enabled = hs.config.auth.password_localdb_enabled
         self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
+        self._account_validity_handler = hs.get_account_validity_handler()
 
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
@@ -1783,6 +1784,13 @@ class AuthHandler:
             client_redirect_url, "loginToken", login_token
         )
 
+        # Run post-login module callback handlers
+        await self._account_validity_handler.on_user_login(
+            user_id=registered_user_id,
+            auth_provider_type=LoginType.SSO,
+            auth_provider_id=auth_provider_id,
+        )
+
         # if the client is whitelisted, we can redirect straight to it
         if client_redirect_url.startswith(self._whitelisted_sso_clients):
             request.redirect(redirect_url)

+ 3 - 0
synapse/module_api/__init__.py

@@ -80,6 +80,7 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
     ON_LEGACY_ADMIN_REQUEST,
     ON_LEGACY_RENEW_CALLBACK,
     ON_LEGACY_SEND_MAIL_CALLBACK,
+    ON_USER_LOGIN_CALLBACK,
     ON_USER_REGISTRATION_CALLBACK,
 )
 from synapse.module_api.callbacks.spamchecker_callbacks import (
@@ -334,6 +335,7 @@ class ModuleApi:
         *,
         is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
         on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
+        on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None,
         on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
         on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
         on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
@@ -345,6 +347,7 @@ class ModuleApi:
         return self._callbacks.account_validity.register_callbacks(
             is_user_expired=is_user_expired,
             on_user_registration=on_user_registration,
+            on_user_login=on_user_login,
             on_legacy_send_mail=on_legacy_send_mail,
             on_legacy_renew=on_legacy_renew,
             on_legacy_admin_request=on_legacy_admin_request,

+ 6 - 0
synapse/module_api/callbacks/account_validity_callbacks.py

@@ -22,6 +22,7 @@ logger = logging.getLogger(__name__)
 # Types for callbacks to be registered via the module api
 IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
 ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
+ON_USER_LOGIN_CALLBACK = Callable[[str, Optional[str], Optional[str]], Awaitable]
 # Temporary hooks to allow for a transition from `/_matrix/client` endpoints
 # to `/_synapse/client/account_validity`. See `register_callbacks` below.
 ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
@@ -33,6 +34,7 @@ class AccountValidityModuleApiCallbacks:
     def __init__(self) -> None:
         self.is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
         self.on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
+        self.on_user_login_callbacks: List[ON_USER_LOGIN_CALLBACK] = []
         self.on_legacy_send_mail_callback: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None
         self.on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
 
@@ -44,6 +46,7 @@ class AccountValidityModuleApiCallbacks:
         self,
         is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
         on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
+        on_user_login: Optional[ON_USER_LOGIN_CALLBACK] = None,
         on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
         on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
         on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
@@ -55,6 +58,9 @@ class AccountValidityModuleApiCallbacks:
         if on_user_registration is not None:
             self.on_user_registration_callbacks.append(on_user_registration)
 
+        if on_user_login is not None:
+            self.on_user_login_callbacks.append(on_user_login)
+
         # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
         # an admin one). As part of moving the feature into a module, we need to change
         # the path from /_matrix/client/unstable/account_validity/... to

+ 8 - 0
synapse/rest/client/login.py

@@ -115,6 +115,7 @@ class LoginRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
         self._sso_handler = hs.get_sso_handler()
         self._spam_checker = hs.get_module_api_callbacks().spam_checker
+        self._account_validity_handler = hs.get_account_validity_handler()
 
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
@@ -470,6 +471,13 @@ class LoginRestServlet(RestServlet):
             device_id=device_id,
         )
 
+        # execute the callback
+        await self._account_validity_handler.on_user_login(
+            user_id,
+            auth_provider_type=login_submission.get("type"),
+            auth_provider_id=auth_provider_id,
+        )
+
         if valid_until_ms is not None:
             expires_in_ms = valid_until_ms - self.clock.time_msec()
             result["expires_in_ms"] = expires_in_ms