Kaynağa Gözat

Merge branch 'babolivier/sso_module_api' into 'release-v1.11.1'

Factor out complete_sso_login and expose it to the Module API

See merge request new-vector/synapse!4
Brendan Abolivier 4 yıl önce
ebeveyn
işleme
6f67a8b570

+ 1 - 1
synapse/config/sso.py

@@ -16,7 +16,7 @@ from typing import Any, Dict
 
 import pkg_resources
 
-from ._base import Config, ConfigError
+from ._base import Config
 
 
 class SSOConfig(Config):

+ 74 - 0
synapse/handlers/auth.py

@@ -17,6 +17,8 @@
 import logging
 import time
 import unicodedata
+import urllib.parse
+from typing import Any
 
 import attr
 import bcrypt
@@ -38,8 +40,11 @@ from synapse.api.errors import (
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http.server import finish_request
+from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.module_api import ModuleApi
+from synapse.push.mailer import load_jinja2_templates
 from synapse.types import UserID
 from synapse.util.caches.expiringcache import ExpiringCache
 
@@ -108,6 +113,16 @@ class AuthHandler(BaseHandler):
 
         self._clock = self.hs.get_clock()
 
+        # Load the SSO redirect confirmation page HTML template
+        self._sso_redirect_confirm_template = load_jinja2_templates(
+            hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+        )[0]
+
+        self._server_name = hs.config.server_name
+
+        # cast to tuple for use with str.startswith
+        self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+
     @defer.inlineCallbacks
     def validate_user_via_ui_auth(self, requester, request_body, clientip):
         """
@@ -927,6 +942,65 @@ class AuthHandler(BaseHandler):
         else:
             return defer.succeed(False)
 
+    def complete_sso_login(
+        self,
+        registered_user_id: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ):
+        """Having figured out a mxid for this user, complete the HTTP request
+
+        Args:
+            registered_user_id: The registered user ID to complete SSO login for.
+            request: The request to complete.
+            client_redirect_url: The URL to which to redirect the user at the end of the
+                process.
+        """
+        # Create a login token
+        login_token = self.macaroon_gen.generate_short_term_login_token(
+            registered_user_id
+        )
+
+        # Append the login token to the original redirect URL (i.e. with its query
+        # parameters kept intact) to build the URL to which the template needs to
+        # redirect the users once they have clicked on the confirmation link.
+        redirect_url = self.add_query_param_to_url(
+            client_redirect_url, "loginToken", login_token
+        )
+
+        # if the client is whitelisted, we can redirect straight to it
+        if client_redirect_url.startswith(self._whitelisted_sso_clients):
+            request.redirect(redirect_url)
+            finish_request(request)
+            return
+
+        # Otherwise, serve the redirect confirmation page.
+
+        # Remove the query parameters from the redirect URL to get a shorter version of
+        # it. This is only to display a human-readable URL in the template, but not the
+        # URL we redirect users to.
+        redirect_url_no_params = client_redirect_url.split("?")[0]
+
+        html = self._sso_redirect_confirm_template.render(
+            display_url=redirect_url_no_params,
+            redirect_url=redirect_url,
+            server_name=self._server_name,
+        ).encode("utf-8")
+
+        request.setResponseCode(200)
+        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+        request.setHeader(b"Content-Length", b"%d" % (len(html),))
+        request.write(html)
+        finish_request(request)
+
+    @staticmethod
+    def add_query_param_to_url(url: str, param_name: str, param: Any):
+        url_parts = list(urllib.parse.urlparse(url))
+        query = dict(urllib.parse.parse_qsl(url_parts[4]))
+        query.update({param_name: param})
+        url_parts[4] = urllib.parse.urlencode(query)
+        return urllib.parse.urlunparse(url_parts)
+
 
 @attr.s
 class MacaroonGenerator(object):

+ 19 - 0
synapse/module_api/__init__.py

@@ -17,6 +17,7 @@ import logging
 
 from twisted.internet import defer
 
+from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.types import UserID
 
@@ -211,3 +212,21 @@ class ModuleApi(object):
             Deferred[object]: result of func
         """
         return self._store.db.runInteraction(desc, func, *args, **kwargs)
+
+    def complete_sso_login(
+        self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
+    ):
+        """Complete a SSO login by redirecting the user to a page to confirm whether they
+        want their access token sent to `client_redirect_url`, or redirect them to that
+        URL with a token directly if the URL matches with one of the whitelisted clients.
+
+        Args:
+            registered_user_id: The MXID that has been registered as a previous step of
+                of this SSO login.
+            request: The request to respond to.
+            client_redirect_url: The URL to which to offer to redirect the user (or to
+                redirect them directly if whitelisted).
+        """
+        self._auth_handler.complete_sso_login(
+            registered_user_id, request, client_redirect_url,
+        )

+ 2 - 56
synapse/rest/client/v1/login.py

@@ -28,7 +28,6 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
-from synapse.http.site import SynapseRequest
 from synapse.push.mailer import load_jinja2_templates
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
@@ -591,63 +590,10 @@ class SSOAuthHandler(object):
                 localpart=localpart, default_display_name=user_display_name
             )
 
-        self.complete_sso_login(registered_user_id, request, client_redirect_url)
-
-    def complete_sso_login(
-        self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
-    ):
-        """Having figured out a mxid for this user, complete the HTTP request
-
-        Args:
-            registered_user_id:
-            request:
-            client_redirect_url:
-        """
-        # Create a login token
-        login_token = self._macaroon_gen.generate_short_term_login_token(
-            registered_user_id
+        self._auth_handler.complete_sso_login(
+            registered_user_id, request, client_redirect_url
         )
 
-        # Append the login token to the original redirect URL (i.e. with its query
-        # parameters kept intact) to build the URL to which the template needs to
-        # redirect the users once they have clicked on the confirmation link.
-        redirect_url = self._add_query_param_to_url(
-            client_redirect_url, "loginToken", login_token
-        )
-
-        # if the client is whitelisted, we can redirect straight to it
-        if client_redirect_url.startswith(self._whitelisted_sso_clients):
-            request.redirect(redirect_url)
-            finish_request(request)
-            return
-
-        # Otherwise, serve the redirect confirmation page.
-
-        # Remove the query parameters from the redirect URL to get a shorter version of
-        # it. This is only to display a human-readable URL in the template, but not the
-        # URL we redirect users to.
-        redirect_url_no_params = client_redirect_url.split("?")[0]
-
-        html = self._template.render(
-            display_url=redirect_url_no_params,
-            redirect_url=redirect_url,
-            server_name=self._server_name,
-        ).encode("utf-8")
-
-        request.setResponseCode(200)
-        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
-        request.setHeader(b"Content-Length", b"%d" % (len(html),))
-        request.write(html)
-        finish_request(request)
-
-    @staticmethod
-    def _add_query_param_to_url(url, param_name, param):
-        url_parts = list(urllib.parse.urlparse(url))
-        query = dict(urllib.parse.parse_qsl(url_parts[4]))
-        query.update({param_name: param})
-        url_parts[4] = urllib.parse.urlencode(query)
-        return urllib.parse.urlunparse(url_parts)
-
 
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)