1
0
Эх сурвалжийг харах

Split OidcProvider out of OidcHandler (#9107)

The idea here is that we will have an instance of OidcProvider for each
configured IdP, with OidcHandler just doing the marshalling of them.

For now it's still hardcoded with a single provider.
Richard van der Hoff 3 жил өмнө
parent
commit
21a296cd5a

+ 1 - 0
changelog.d/9107.feature

@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.

+ 0 - 1
synapse/app/homeserver.py

@@ -429,7 +429,6 @@ def setup(config_options):
             oidc = hs.get_oidc_handler()
             # Loading the provider metadata also ensures the provider config is valid.
             await oidc.load_metadata()
-            await oidc.load_jwks()
 
         await _base.start(hs, config.listeners)
 

+ 148 - 98
synapse/handlers/oidc_handler.py

@@ -35,6 +35,7 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
+from synapse.config.oidc_config import OidcProviderConfig
 from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
@@ -70,6 +71,131 @@ JWK = Dict[str, str]
 JWKS = TypedDict("JWKS", {"keys": List[JWK]})
 
 
+class OidcHandler:
+    """Handles requests related to the OpenID Connect login flow.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        self._sso_handler = hs.get_sso_handler()
+
+        provider_conf = hs.config.oidc.oidc_provider
+        # we should not have been instantiated if there is no configured provider.
+        assert provider_conf is not None
+
+        self._token_generator = OidcSessionTokenGenerator(hs)
+
+        self._provider = OidcProvider(hs, self._token_generator, provider_conf)
+
+    async def load_metadata(self) -> None:
+        """Validate the config and load the metadata from the remote endpoint.
+
+        Called at startup to ensure we have everything we need.
+        """
+        await self._provider.load_metadata()
+        await self._provider.load_jwks()
+
+    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/oidc/callback
+
+        Since we might want to display OIDC-related errors in a user-friendly
+        way, we don't raise SynapseError from here. Instead, we call
+        ``self._sso_handler.render_error`` which displays an HTML page for the error.
+
+        Most of the OpenID Connect logic happens here:
+
+          - first, we check if there was any error returned by the provider and
+            display it
+          - then we fetch the session cookie, decode and verify it
+          - the ``state`` query parameter should match with the one stored in the
+            session cookie
+
+        Once we know the session is legit, we then delegate to the OIDC Provider
+        implementation, which will exchange the code with the provider and complete the
+        login/authentication.
+
+        Args:
+            request: the incoming request from the browser.
+        """
+
+        # The provider might redirect with an error.
+        # In that case, just display it as-is.
+        if b"error" in request.args:
+            # error response from the auth server. see:
+            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
+            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
+            error = request.args[b"error"][0].decode()
+            description = request.args.get(b"error_description", [b""])[0].decode()
+
+            # Most of the errors returned by the provider could be due by
+            # either the provider misbehaving or Synapse being misconfigured.
+            # The only exception of that is "access_denied", where the user
+            # probably cancelled the login flow. In other cases, log those errors.
+            if error != "access_denied":
+                logger.error("Error from the OIDC provider: %s %s", error, description)
+
+            self._sso_handler.render_error(request, error, description)
+            return
+
+        # otherwise, it is presumably a successful response. see:
+        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
+
+        # Fetch the session cookie
+        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
+        if session is None:
+            logger.info("No session cookie found")
+            self._sso_handler.render_error(
+                request, "missing_session", "No session cookie found"
+            )
+            return
+
+        # Remove the cookie. There is a good chance that if the callback failed
+        # once, it will fail next time and the code will already be exchanged.
+        # Removing it early avoids spamming the provider with token requests.
+        request.addCookie(
+            SESSION_COOKIE_NAME,
+            b"",
+            path="/_synapse/oidc",
+            expires="Thu, Jan 01 1970 00:00:00 UTC",
+            httpOnly=True,
+            sameSite="lax",
+        )
+
+        # Check for the state query parameter
+        if b"state" not in request.args:
+            logger.info("State parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "State parameter is missing"
+            )
+            return
+
+        state = request.args[b"state"][0].decode()
+
+        # Deserialize the session token and verify it.
+        try:
+            session_data = self._token_generator.verify_oidc_session_token(
+                session, state
+            )
+        except MacaroonDeserializationException as e:
+            logger.exception("Invalid session")
+            self._sso_handler.render_error(request, "invalid_session", str(e))
+            return
+        except MacaroonInvalidSignatureException as e:
+            logger.exception("Could not verify session")
+            self._sso_handler.render_error(request, "mismatching_session", str(e))
+            return
+
+        if b"code" not in request.args:
+            logger.info("Code parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "Code parameter is missing"
+            )
+            return
+
+        code = request.args[b"code"][0].decode()
+
+        await self._provider.handle_oidc_callback(request, session_data, code)
+
+
 class OidcError(Exception):
     """Used to catch errors when calling the token_endpoint
     """
@@ -84,21 +210,25 @@ class OidcError(Exception):
         return self.error
 
 
-class OidcHandler:
-    """Handles requests related to the OpenID Connect login flow.
+class OidcProvider:
+    """Wraps the config for a single OIDC IdentityProvider
+
+    Provides methods for handling redirect requests and callbacks via that particular
+    IdP.
     """
 
-    def __init__(self, hs: "HomeServer"):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        token_generator: "OidcSessionTokenGenerator",
+        provider: OidcProviderConfig,
+    ):
         self._store = hs.get_datastore()
 
-        self._token_generator = OidcSessionTokenGenerator(hs)
+        self._token_generator = token_generator
 
         self._callback_url = hs.config.oidc_callback_url  # type: str
 
-        provider = hs.config.oidc.oidc_provider
-        # we should not have been instantiated if there is no configured provider.
-        assert provider is not None
-
         self._scopes = provider.scopes
         self._user_profile_method = provider.user_profile_method
         self._client_auth = ClientAuth(
@@ -552,22 +682,16 @@ class OidcHandler:
             nonce=nonce,
         )
 
-    async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+    async def handle_oidc_callback(
+        self, request: SynapseRequest, session_data: "OidcSessionData", code: str
+    ) -> None:
         """Handle an incoming request to /_synapse/oidc/callback
 
-        Since we might want to display OIDC-related errors in a user-friendly
-        way, we don't raise SynapseError from here. Instead, we call
-        ``self._sso_handler.render_error`` which displays an HTML page for the error.
+        By this time we have already validated the session on the synapse side, and
+        now need to do the provider-specific operations. This includes:
 
-        Most of the OpenID Connect logic happens here:
-
-          - first, we check if there was any error returned by the provider and
-            display it
-          - then we fetch the session cookie, decode and verify it
-          - the ``state`` query parameter should match with the one stored in the
-            session cookie
-          - once we known this session is legit, exchange the code with the
-            provider using the ``token_endpoint`` (see ``_exchange_code``)
+          - exchange the code with the provider using the ``token_endpoint`` (see
+            ``_exchange_code``)
           - once we have the token, use it to either extract the UserInfo from
             the ``id_token`` (``_parse_id_token``), or use the ``access_token``
             to fetch UserInfo from the ``userinfo_endpoint``
@@ -577,86 +701,12 @@ class OidcHandler:
 
         Args:
             request: the incoming request from the browser.
+            session_data: the session data, extracted from our cookie
+            code: The authorization code we got from the callback.
         """
-
-        # The provider might redirect with an error.
-        # In that case, just display it as-is.
-        if b"error" in request.args:
-            # error response from the auth server. see:
-            #  https://tools.ietf.org/html/rfc6749#section-4.1.2.1
-            #  https://openid.net/specs/openid-connect-core-1_0.html#AuthError
-            error = request.args[b"error"][0].decode()
-            description = request.args.get(b"error_description", [b""])[0].decode()
-
-            # Most of the errors returned by the provider could be due by
-            # either the provider misbehaving or Synapse being misconfigured.
-            # The only exception of that is "access_denied", where the user
-            # probably cancelled the login flow. In other cases, log those errors.
-            if error != "access_denied":
-                logger.error("Error from the OIDC provider: %s %s", error, description)
-
-            self._sso_handler.render_error(request, error, description)
-            return
-
-        # otherwise, it is presumably a successful response. see:
-        #   https://tools.ietf.org/html/rfc6749#section-4.1.2
-
-        # Fetch the session cookie
-        session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
-        if session is None:
-            logger.info("No session cookie found")
-            self._sso_handler.render_error(
-                request, "missing_session", "No session cookie found"
-            )
-            return
-
-        # Remove the cookie. There is a good chance that if the callback failed
-        # once, it will fail next time and the code will already be exchanged.
-        # Removing it early avoids spamming the provider with token requests.
-        request.addCookie(
-            SESSION_COOKIE_NAME,
-            b"",
-            path="/_synapse/oidc",
-            expires="Thu, Jan 01 1970 00:00:00 UTC",
-            httpOnly=True,
-            sameSite="lax",
-        )
-
-        # Check for the state query parameter
-        if b"state" not in request.args:
-            logger.info("State parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "State parameter is missing"
-            )
-            return
-
-        state = request.args[b"state"][0].decode()
-
-        # Deserialize the session token and verify it.
-        try:
-            session_data = self._token_generator.verify_oidc_session_token(
-                session, state
-            )
-        except MacaroonDeserializationException as e:
-            logger.exception("Invalid session")
-            self._sso_handler.render_error(request, "invalid_session", str(e))
-            return
-        except MacaroonInvalidSignatureException as e:
-            logger.exception("Could not verify session")
-            self._sso_handler.render_error(request, "mismatching_session", str(e))
-            return
-
         # Exchange the code with the provider
-        if b"code" not in request.args:
-            logger.info("Code parameter is missing")
-            self._sso_handler.render_error(
-                request, "invalid_request", "Code parameter is missing"
-            )
-            return
-
-        logger.debug("Exchanging code")
-        code = request.args[b"code"][0].decode()
         try:
+            logger.debug("Exchanging code")
             token = await self._exchange_code(code)
         except OidcError as e:
             logger.exception("Could not exchange code")

+ 48 - 45
tests/handlers/test_oidc.py

@@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
 
         self.handler = hs.get_oidc_handler()
+        self.provider = self.handler._provider
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
@@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def metadata_edit(self, values):
-        return patch.dict(self.handler._provider_metadata, values)
+        return patch.dict(self.provider._provider_metadata, values)
 
     def assertRenderedError(self, error, error_description=None):
+        self.render_error.assert_called_once()
         args = self.render_error.call_args[0]
         self.assertEqual(args[1], error)
         if error_description is not None:
@@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
-        self.assertEqual(self.handler._callback_url, CALLBACK_URL)
-        self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
-        self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+        self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+        self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+        self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {"discover": True}})
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
-        metadata = self.get_success(self.handler.load_metadata())
+        metadata = self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
 
         self.assertEqual(metadata.issuer, ISSUER)
@@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # subsequent calls should be cached
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
-        jwks = self.get_success(self.handler.load_jwks())
+        jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
         self.assertEqual(jwks, {"keys": []})
 
         # subsequent calls should be cached…
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks())
+        self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_not_called()
 
         # …unless forced
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks(force=True))
+        self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
 
         # Throw if the JWKS uri is missing
         with self.metadata_edit({"jwks_uri": None}):
-            self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+            self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
         # Return empty key set if JWKS are not used
-        self.handler._scopes = []  # not asking the openid scope
+        self.provider._scopes = []  # not asking the openid scope
         self.http_client.get_json.reset_mock()
-        jwks = self.get_success(self.handler.load_jwks(force=True))
+        jwks = self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
-        h = self.handler
+        h = self.provider
 
         # Default test config does not throw
         h._validate_metadata()
@@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
-            self.handler._validate_metadata()
+            self.provider._validate_metadata()
 
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["addCookie"])
         url = self.get_success(
-            self.handler.handle_redirect_request(req, b"http://client/redirect")
+            self.provider.handle_redirect_request(req, b"http://client/redirect")
         )
         url = urlparse(url)
         auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # ensure that we are correctly testing the fallback when "get_extra_attributes"
         # is not implemented.
-        mapping_provider = self.handler._user_mapping_provider
+        mapping_provider = self.provider._user_mapping_provider
         with self.assertRaises(AttributeError):
             _ = mapping_provider.get_extra_attributes
 
@@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         auth_handler.complete_sso_login.assert_called_once_with(
             expected_user_id, request, client_redirect_url, None,
         )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._fetch_userinfo.assert_not_called()
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+        self.provider._fetch_userinfo.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
         with patch.object(
-            self.handler,
+            self.provider,
             "_remote_id_from_userinfo",
             new=Mock(side_effect=MappingException()),
         ):
@@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.handler._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
         auth_handler.complete_sso_login.reset_mock()
-        self.handler._exchange_code.reset_mock()
-        self.handler._parse_id_token.reset_mock()
-        self.handler._fetch_userinfo.reset_mock()
+        self.provider._exchange_code.reset_mock()
+        self.provider._parse_id_token.reset_mock()
+        self.provider._fetch_userinfo.reset_mock()
 
         # With userinfo fetching
-        self.handler._scopes = []  # do not ask the "openid" scope
+        self.provider._scopes = []  # do not ask the "openid" scope
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
             expected_user_id, request, client_redirect_url, None,
         )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_not_called()
-        self.handler._fetch_userinfo.assert_called_once_with(token)
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_not_called()
+        self.provider._fetch_userinfo.assert_called_once_with(token)
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
         from synapse.handlers.oidc_handler import OidcError
 
-        self.handler._exchange_code = simple_async_mock(
+        self.provider._exchange_code = simple_async_mock(
             raises=OidcError("invalid_request")
         )
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
         )
         code = "code"
-        ret = self.get_success(self.handler._exchange_code(code))
+        ret = self.get_success(self.provider._exchange_code(code))
         kwargs = self.http_client.request.call_args[1]
 
         self.assertEqual(ret, token)
@@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         from synapse.handlers.oidc_handler import OidcError
 
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "foo")
         self.assertEqual(exc.value.error_description, "bar")
 
@@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=500, phrase=b"Internal Server Error", body=b"Not JSON",
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
@@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             )
         )
 
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
@@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=200, phrase=b"OK", body=b'{"error": "some_error"}',
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
 
     @override_config(
@@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "foo",
             "phone": "1234567",
         }
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -979,9 +981,10 @@ async def _make_callback_with_userinfo(
     from synapse.handlers.oidc_handler import OidcSessionData
 
     handler = hs.get_oidc_handler()
-    handler._exchange_code = simple_async_mock(return_value={})
-    handler._parse_id_token = simple_async_mock(return_value=userinfo)
-    handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+    provider = handler._provider
+    provider._exchange_code = simple_async_mock(return_value={})
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
 
     state = "state"
     session = handler._token_generator.generate_oidc_session_token(