Browse Source

Allow additional SSO properties to be passed to the client (#8413)

Patrick Cloke 3 years ago
parent
commit
8b40843392

+ 1 - 0
changelog.d/8413.feature

@@ -0,0 +1 @@
+Support passing additional single sign-on parameters to the client.

+ 8 - 0
docs/sample_config.yaml

@@ -1748,6 +1748,14 @@ oidc_config:
       #
       #display_name_template: "{{ user.given_name }} {{ user.last_name }}"
 
+      # Jinja2 templates for extra attributes to send back to the client during
+      # login.
+      #
+      # Note that these are non-standard and clients will ignore them without modifications.
+      #
+      #extra_attributes:
+        #birthdate: "{{ user.birthdate }}"
+
 
 
 # Enable CAS for registration and login.

+ 13 - 1
docs/sso_mapping_providers.md

@@ -57,7 +57,7 @@ A custom mapping provider must specify the following methods:
     - This method must return a string, which is the unique identifier for the
       user. Commonly the ``sub`` claim of the response.
 * `map_user_attributes(self, userinfo, token)`
-    - This method should be async.
+    - This method must be async.
     - Arguments:
       - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
                      information from.
@@ -66,6 +66,18 @@ A custom mapping provider must specify the following methods:
     - Returns a dictionary with two keys:
       - localpart: A required string, used to generate the Matrix ID.
       - displayname: An optional string, the display name for the user.
+* `get_extra_attributes(self, userinfo, token)`
+    - This method must be async.
+    - Arguments:
+      - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
+                     information from.
+      - `token` - A dictionary which includes information necessary to make
+                  further requests to the OpenID provider.
+    - Returns a dictionary that is suitable to be serialized to JSON. This
+      will be returned as part of the response during a successful login.
+
+      Note that care should be taken to not overwrite any of the parameters
+      usually returned as part of the [login response](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login).
 
 ### Default OpenID Mapping Provider
 

+ 16 - 0
docs/workers.md

@@ -243,6 +243,22 @@ for the room are in flight:
 
     ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$
 
+Additionally, the following endpoints should be included if Synapse is configured
+to use SSO (you only need to include the ones for whichever SSO provider you're
+using):
+
+    # OpenID Connect requests.
+    ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
+    ^/_synapse/oidc/callback$
+
+    # SAML requests.
+    ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$
+    ^/_matrix/saml2/authn_response$
+
+    # CAS requests.
+    ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$
+    ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$
+
 Note that a HTTP listener with `client` and `federation` resources must be
 configured in the `worker_listeners` option in the worker config.
 

+ 8 - 0
synapse/config/oidc_config.py

@@ -204,6 +204,14 @@ class OIDCConfig(Config):
               # If unset, no displayname will be set.
               #
               #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}"
+
+              # Jinja2 templates for extra attributes to send back to the client during
+              # login.
+              #
+              # Note that these are non-standard and clients will ignore them without modifications.
+              #
+              #extra_attributes:
+                #birthdate: "{{{{ user.birthdate }}}}"
         """.format(
             mapping_provider=DEFAULT_USER_MAPPING_PROVIDER
         )

+ 59 - 1
synapse/handlers/auth.py

@@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
     }
 
 
+@attr.s(slots=True)
+class SsoLoginExtraAttributes:
+    """Data we track about SAML2 sessions"""
+
+    # time the session was created, in milliseconds
+    creation_time = attr.ib(type=int)
+    extra_attributes = attr.ib(type=JsonDict)
+
+
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
@@ -239,6 +248,10 @@ class AuthHandler(BaseHandler):
         # cast to tuple for use with str.startswith
         self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
 
+        # A mapping of user ID to extra attributes to include in the login
+        # response.
+        self._extra_attributes = {}  # type: Dict[str, SsoLoginExtraAttributes]
+
     async def validate_user_via_ui_auth(
         self,
         requester: Requester,
@@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler):
         registered_user_id: str,
         request: SynapseRequest,
         client_redirect_url: str,
+        extra_attributes: Optional[JsonDict] = None,
     ):
         """Having figured out a mxid for this user, complete the HTTP request
 
@@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler):
             request: The request to complete.
             client_redirect_url: The URL to which to redirect the user at the end of the
                 process.
+            extra_attributes: Extra attributes which will be passed to the client
+                during successful login. Must be JSON serializable.
         """
         # If the account has been deactivated, do not proceed with the login
         # flow.
@@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler):
             respond_with_html(request, 403, self._sso_account_deactivated_template)
             return
 
-        self._complete_sso_login(registered_user_id, request, client_redirect_url)
+        self._complete_sso_login(
+            registered_user_id, request, client_redirect_url, extra_attributes
+        )
 
     def _complete_sso_login(
         self,
         registered_user_id: str,
         request: SynapseRequest,
         client_redirect_url: str,
+        extra_attributes: Optional[JsonDict] = None,
     ):
         """
         The synchronous portion of complete_sso_login.
 
         This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
         """
+        # Store any extra attributes which will be passed in the login response.
+        # Note that this is per-user so it may overwrite a previous value, this
+        # is considered OK since the newest SSO attributes should be most valid.
+        if extra_attributes:
+            self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
+                self._clock.time_msec(), extra_attributes,
+            )
+
         # Create a login token
         login_token = self.macaroon_gen.generate_short_term_login_token(
             registered_user_id
@@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler):
         )
         respond_with_html(request, 200, html)
 
+    async def _sso_login_callback(self, login_result: JsonDict) -> None:
+        """
+        A login callback which might add additional attributes to the login response.
+
+        Args:
+            login_result: The data to be sent to the client. Includes the user
+                ID and access token.
+        """
+        # Expire attributes before processing. Note that there shouldn't be any
+        # valid logins that still have extra attributes.
+        self._expire_sso_extra_attributes()
+
+        extra_attributes = self._extra_attributes.get(login_result["user_id"])
+        if extra_attributes:
+            login_result.update(extra_attributes.extra_attributes)
+
+    def _expire_sso_extra_attributes(self) -> None:
+        """
+        Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid.
+        """
+        # TODO This should match the amount of time the macaroon is valid for.
+        LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000
+        expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME
+        to_expire = set()
+        for user_id, data in self._extra_attributes.items():
+            if data.creation_time < expire_before:
+                to_expire.add(user_id)
+        for user_id in to_expire:
+            logger.debug("Expiring extra attributes for user %s", user_id)
+            del self._extra_attributes[user_id]
+
     @staticmethod
     def add_query_param_to_url(url: str, param_name: str, param: Any):
         url_parts = list(urllib.parse.urlparse(url))

+ 53 - 3
synapse/handlers/oidc_handler.py

@@ -37,7 +37,7 @@ from synapse.config import ConfigError
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -707,6 +707,15 @@ class OidcHandler:
             self._render_error(request, "mapping_error", str(e))
             return
 
+        # Mapping providers might not have get_extra_attributes: only call this
+        # method if it exists.
+        extra_attributes = None
+        get_extra_attributes = getattr(
+            self._user_mapping_provider, "get_extra_attributes", None
+        )
+        if get_extra_attributes:
+            extra_attributes = await get_extra_attributes(userinfo, token)
+
         # and finally complete the login
         if ui_auth_session_id:
             await self._auth_handler.complete_sso_ui_auth(
@@ -714,7 +723,7 @@ class OidcHandler:
             )
         else:
             await self._auth_handler.complete_sso_login(
-                user_id, request, client_redirect_url
+                user_id, request, client_redirect_url, extra_attributes
             )
 
     def _generate_oidc_session_token(
@@ -984,7 +993,7 @@ class OidcMappingProvider(Generic[C]):
     async def map_user_attributes(
         self, userinfo: UserInfo, token: Token
     ) -> UserAttribute:
-        """Map a ``UserInfo`` objects into user attributes.
+        """Map a `UserInfo` object into user attributes.
 
         Args:
             userinfo: An object representing the user given by the OIDC provider
@@ -995,6 +1004,18 @@ class OidcMappingProvider(Generic[C]):
         """
         raise NotImplementedError()
 
+    async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+        """Map a `UserInfo` object into additional attributes passed to the client during login.
+
+        Args:
+            userinfo: An object representing the user given by the OIDC provider
+            token: A dict with the tokens returned by the provider
+
+        Returns:
+            A dict containing additional attributes. Must be JSON serializable.
+        """
+        return {}
+
 
 # Used to clear out "None" values in templates
 def jinja_finalize(thing):
@@ -1009,6 +1030,7 @@ class JinjaOidcMappingConfig:
     subject_claim = attr.ib()  # type: str
     localpart_template = attr.ib()  # type: Template
     display_name_template = attr.ib()  # type: Optional[Template]
+    extra_attributes = attr.ib()  # type: Dict[str, Template]
 
 
 class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1047,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
                     % (e,)
                 )
 
+        extra_attributes = {}  # type Dict[str, Template]
+        if "extra_attributes" in config:
+            extra_attributes_config = config.get("extra_attributes") or {}
+            if not isinstance(extra_attributes_config, dict):
+                raise ConfigError(
+                    "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
+                )
+
+            for key, value in extra_attributes_config.items():
+                try:
+                    extra_attributes[key] = env.from_string(value)
+                except Exception as e:
+                    raise ConfigError(
+                        "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
+                        % (key, e)
+                    )
+
         return JinjaOidcMappingConfig(
             subject_claim=subject_claim,
             localpart_template=localpart_template,
             display_name_template=display_name_template,
+            extra_attributes=extra_attributes,
         )
 
     def get_remote_user_id(self, userinfo: UserInfo) -> str:
@@ -1071,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
                 display_name = None
 
         return UserAttribute(localpart=localpart, display_name=display_name)
+
+    async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
+        extras = {}  # type: Dict[str, str]
+        for key, template in self._config.extra_attributes.items():
+            try:
+                extras[key] = template.render(user=userinfo).strip()
+            except Exception as e:
+                # Log an error and skip this value (don't break login for this).
+                logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
+        return extras

+ 15 - 7
synapse/rest/client/v1/login.py

@@ -284,9 +284,7 @@ class LoginRestServlet(RestServlet):
         self,
         user_id: str,
         login_submission: JsonDict,
-        callback: Optional[
-            Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
-        ] = None,
+        callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
     ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
@@ -299,12 +297,12 @@ class LoginRestServlet(RestServlet):
         Args:
             user_id: ID of the user to register.
             login_submission: Dictionary of login information.
-            callback: Callback function to run after registration.
+            callback: Callback function to run after login.
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
 
         Returns:
-            result: Dictionary of account information after successful registration.
+            result: Dictionary of account information after successful login.
         """
 
         # Before we actually log them in we check if they've already logged in
@@ -339,14 +337,24 @@ class LoginRestServlet(RestServlet):
         return result
 
     async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
+        """
+        Handle the final stage of SSO login.
+
+        Args:
+             login_submission: The JSON request body.
+
+        Returns:
+            The body of the JSON response.
+        """
         token = login_submission["token"]
         auth_handler = self.auth_handler
         user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
             token
         )
 
-        result = await self._complete_login(user_id, login_submission)
-        return result
+        return await self._complete_login(
+            user_id, login_submission, self.auth_handler._sso_login_callback
+        )
 
     async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
         token = login_submission.get("token", None)

+ 105 - 55
tests/handlers/test_oidc.py

@@ -21,7 +21,6 @@ from mock import Mock, patch
 import attr
 import pymacaroons
 
-from twisted.internet import defer
 from twisted.python.failure import Failure
 from twisted.web._newclient import ResponseDone
 
@@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider):
     async def map_user_attributes(self, userinfo, token):
         return {"localpart": userinfo["username"], "display_name": None}
 
+    # Do not include get_extra_attributes to test backwards compatibility paths.
+
+
+class TestMappingProviderExtra(TestMappingProvider):
+    async def get_extra_attributes(self, userinfo, token):
+        return {"phone": userinfo["phone"]}
+
 
 def simple_async_mock(return_value=None, raises=None):
     # AsyncMock is not available in python3.5, this mimics part of its behaviour
@@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         config = self.default_config()
         config["public_baseurl"] = BASE_URL
-        oidc_config = config.get("oidc_config", {})
+        oidc_config = {}
         oidc_config["enabled"] = True
         oidc_config["client_id"] = CLIENT_ID
         oidc_config["client_secret"] = CLIENT_SECRET
@@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         oidc_config["user_mapping_provider"] = {
             "module": __name__ + ".TestMappingProvider",
         }
+
+        # Update this config with what's in the default config so that
+        # override_config works as expected.
+        oidc_config.update(config.get("oidc_config", {}))
         config["oidc_config"] = oidc_config
 
         hs = self.setup_test_homeserver(
@@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {"discover": True}})
-    @defer.inlineCallbacks
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
-        metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+        metadata = self.get_success(self.handler.load_metadata())
         self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
 
         self.assertEqual(metadata.issuer, ISSUER)
@@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # subsequent calls should be cached
         self.http_client.reset_mock()
-        yield defer.ensureDeferred(self.handler.load_metadata())
+        self.get_success(self.handler.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
-    @defer.inlineCallbacks
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
-        yield defer.ensureDeferred(self.handler.load_metadata())
+        self.get_success(self.handler.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
-    @defer.inlineCallbacks
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
-        jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+        jwks = self.get_success(self.handler.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
         self.assertEqual(jwks, {"keys": []})
 
         # subsequent calls should be cached…
         self.http_client.reset_mock()
-        yield defer.ensureDeferred(self.handler.load_jwks())
+        self.get_success(self.handler.load_jwks())
         self.http_client.get_json.assert_not_called()
 
         # …unless forced
         self.http_client.reset_mock()
-        yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+        self.get_success(self.handler.load_jwks(force=True))
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
 
         # Throw if the JWKS uri is missing
         with self.metadata_edit({"jwks_uri": None}):
-            with self.assertRaises(RuntimeError):
-                yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+            self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
 
         # Return empty key set if JWKS are not used
         self.handler._scopes = []  # not asking the openid scope
         self.http_client.get_json.reset_mock()
-        jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+        jwks = self.get_success(self.handler.load_jwks(force=True))
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
@@ -299,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
             # This should not throw
             self.handler._validate_metadata()
 
-    @defer.inlineCallbacks
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["addCookie"])
-        url = yield defer.ensureDeferred(
+        url = self.get_success(
             self.handler.handle_redirect_request(req, b"http://client/redirect")
         )
         url = urlparse(url)
@@ -343,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(params["nonce"], [nonce])
         self.assertEqual(redirect, "http://client/redirect")
 
-    @defer.inlineCallbacks
     def test_callback_error(self):
         """Errors from the provider returned in the callback are displayed."""
         self.handler._render_error = Mock()
         request = Mock(args={})
         request.args[b"error"] = [b"invalid_client"]
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_client", "")
 
         request.args[b"error_description"] = [b"some description"]
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_client", "some description")
 
-    @defer.inlineCallbacks
     def test_callback(self):
         """Code callback works and display errors if something went wrong.
 
@@ -377,7 +380,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "foo",
             "preferred_username": "bar",
         }
-        user_id = UserID("foo", "domain.org")
+        user_id = "@foo:domain.org"
         self.handler._render_error = Mock(return_value=None)
         self.handler._exchange_code = simple_async_mock(return_value=token)
         self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
@@ -394,13 +397,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
         client_redirect_url = "http://client/redirect"
         user_agent = "Browser"
         ip_address = "10.0.0.1"
-        session = self.handler._generate_oidc_session_token(
+        request.getCookie.return_value = self.handler._generate_oidc_session_token(
             state=state,
             nonce=nonce,
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=None,
         )
-        request.getCookie.return_value = session
 
         request.args = {}
         request.args[b"code"] = [code.encode("utf-8")]
@@ -410,10 +412,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
         request.getClientIP.return_value = ip_address
 
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url,
+            user_id, request, client_redirect_url, {},
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -427,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.handler._map_userinfo_to_user = simple_async_mock(
             raises=MappingException()
         )
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error")
         self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
 
         # Handle ID token errors
         self.handler._parse_id_token = simple_async_mock(raises=Exception())
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
         self.handler._auth_handler.complete_sso_login.reset_mock()
@@ -444,10 +446,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # With userinfo fetching
         self.handler._scopes = []  # do not ask the "openid" scope
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url,
+            user_id, request, client_redirect_url, {},
         )
         self.handler._exchange_code.assert_called_once_with(code)
         self.handler._parse_id_token.assert_not_called()
@@ -459,17 +461,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # Handle userinfo fetching error
         self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
         self.handler._exchange_code = simple_async_mock(
             raises=OidcError("invalid_request")
         )
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
-    @defer.inlineCallbacks
     def test_callback_session(self):
         """The callback verifies the session presence and validity"""
         self.handler._render_error = Mock(return_value=None)
@@ -478,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # Missing cookie
         request.args = {}
         request.getCookie.return_value = None
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("missing_session", "No session cookie found")
 
         # Missing session parameter
         request.args = {}
         request.getCookie.return_value = "session"
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request", "State parameter is missing")
 
         # Invalid cookie
         request.args = {}
         request.args[b"state"] = [b"state"]
         request.getCookie.return_value = "session"
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_session")
 
         # Mismatching session
@@ -504,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
         request.getCookie.return_value = session
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mismatching_session")
 
         # Valid session
         request.args = {}
         request.args[b"state"] = [b"state"]
         request.getCookie.return_value = session
-        yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
     @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
-    @defer.inlineCallbacks
     def test_exchange_code(self):
         """Code exchange behaves correctly and handles various error scenarios."""
         token = {"type": "bearer"}
@@ -524,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
         )
         code = "code"
-        ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+        ret = self.get_success(self.handler._exchange_code(code))
         kwargs = self.http_client.request.call_args[1]
 
         self.assertEqual(ret, token)
@@ -546,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 body=b'{"error": "foo", "error_description": "bar"}',
             )
         )
-        with self.assertRaises(OidcError) as exc:
-            yield defer.ensureDeferred(self.handler._exchange_code(code))
-        self.assertEqual(exc.exception.error, "foo")
-        self.assertEqual(exc.exception.error_description, "bar")
+        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        self.assertEqual(exc.value.error, "foo")
+        self.assertEqual(exc.value.error_description, "bar")
 
         # Internal server error with no JSON body
         self.http_client.request = simple_async_mock(
@@ -557,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=500, phrase=b"Internal Server Error", body=b"Not JSON",
             )
         )
-        with self.assertRaises(OidcError) as exc:
-            yield defer.ensureDeferred(self.handler._exchange_code(code))
-        self.assertEqual(exc.exception.error, "server_error")
+        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
         self.http_client.request = simple_async_mock(
@@ -569,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 body=b'{"error": "internal_server_error"}',
             )
         )
-        with self.assertRaises(OidcError) as exc:
-            yield defer.ensureDeferred(self.handler._exchange_code(code))
-        self.assertEqual(exc.exception.error, "internal_server_error")
+
+        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
         )
-        with self.assertRaises(OidcError) as exc:
-            yield defer.ensureDeferred(self.handler._exchange_code(code))
-        self.assertEqual(exc.exception.error, "server_error")
+        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
         self.http_client.request = simple_async_mock(
@@ -587,9 +584,62 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=200, phrase=b"OK", body=b'{"error": "some_error"}',
             )
         )
-        with self.assertRaises(OidcError) as exc:
-            yield defer.ensureDeferred(self.handler._exchange_code(code))
-        self.assertEqual(exc.exception.error, "some_error")
+        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        self.assertEqual(exc.value.error, "some_error")
+
+    @override_config(
+        {
+            "oidc_config": {
+                "user_mapping_provider": {
+                    "module": __name__ + ".TestMappingProviderExtra"
+                }
+            }
+        }
+    )
+    def test_extra_attributes(self):
+        """
+        Login while using a mapping provider that implements get_extra_attributes.
+        """
+        token = {
+            "type": "bearer",
+            "id_token": "id_token",
+            "access_token": "access_token",
+        }
+        userinfo = {
+            "sub": "foo",
+            "phone": "1234567",
+        }
+        user_id = "@foo:domain.org"
+        self.handler._exchange_code = simple_async_mock(return_value=token)
+        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+        self.handler._auth_handler.complete_sso_login = simple_async_mock()
+        request = Mock(
+            spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+        )
+
+        state = "state"
+        client_redirect_url = "http://client/redirect"
+        request.getCookie.return_value = self.handler._generate_oidc_session_token(
+            state=state,
+            nonce="nonce",
+            client_redirect_url=client_redirect_url,
+            ui_auth_session_id=None,
+        )
+
+        request.args = {}
+        request.args[b"code"] = [b"code"]
+        request.args[b"state"] = [state.encode("utf-8")]
+
+        request.requestHeaders = Mock(spec=["getRawHeaders"])
+        request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
+        request.getClientIP.return_value = "10.0.0.1"
+
+        self.get_success(self.handler.handle_oidc_callback(request))
+
+        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+            user_id, request, client_redirect_url, {"phone": "1234567"},
+        )
 
     def test_map_userinfo_to_user(self):
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""