Browse Source

Support trying multiple localparts for OpenID Connect. (#8801)

Abstracts the SAML and OpenID Connect code which attempts to regenerate
the localpart of a matrix ID if it is already in use.
Patrick Cloke 3 years ago
parent
commit
4fd222ad70

+ 1 - 0
changelog.d/8801.feature

@@ -0,0 +1 @@
+Add support for re-trying generation of a localpart for OpenID Connect mapping providers.

+ 10 - 1
docs/sso_mapping_providers.md

@@ -63,13 +63,22 @@ A custom mapping provider must specify the following methods:
                      information from.
     - This method must return a string, which is the unique identifier for the
       user. Commonly the ``sub`` claim of the response.
-* `map_user_attributes(self, userinfo, token)`
+* `map_user_attributes(self, userinfo, token, failures)`
     - This method must be async.
     - Arguments:
       - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
                      information from.
       - `token` - A dictionary which includes information necessary to make
                   further requests to the OpenID provider.
+      - `failures` - An `int` that represents the amount of times the returned
+                     mxid localpart mapping has failed.  This should be used
+                     to create a deduplicated mxid localpart which should be
+                     returned instead. For example, if this method returns
+                     `john.doe` as the value of `localpart` in the returned
+                     dict, and that is already taken on the homeserver, this
+                     method will be called again with the same parameters but
+                     with failures=1. The method should then return a different
+                     `localpart` value, such as `john.doe1`.
     - Returns a dictionary with two keys:
       - localpart: A required string, used to generate the Matrix ID.
       - displayname: An optional string, the display name for the user.

+ 50 - 70
synapse/handlers/oidc_handler.py

@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import inspect
 import logging
 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
 from urllib.parse import urlencode
@@ -35,15 +36,10 @@ from twisted.web.client import readBody
 
 from synapse.config import ConfigError
 from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
-from synapse.types import (
-    JsonDict,
-    UserID,
-    contains_invalid_mxid_characters,
-    map_username_to_mxid_localpart,
-)
+from synapse.types import JsonDict, map_username_to_mxid_localpart
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -869,73 +865,51 @@ class OidcHandler(BaseHandler):
         # to be strings.
         remote_user_id = str(remote_user_id)
 
-        # first of all, check if we already have a mapping for this user
-        previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
-            self._auth_provider_id, remote_user_id,
+        # Older mapping providers don't accept the `failures` argument, so we
+        # try and detect support.
+        mapper_signature = inspect.signature(
+            self._user_mapping_provider.map_user_attributes
         )
-        if previously_registered_user_id:
-            return previously_registered_user_id
+        supports_failures = "failures" in mapper_signature.parameters
 
-        # Otherwise, generate a new user.
-        try:
-            attributes = await self._user_mapping_provider.map_user_attributes(
-                userinfo, token
-            )
-        except Exception as e:
-            raise MappingException(
-                "Could not extract user attributes from OIDC response: " + str(e)
-            )
+        async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Call the mapping provider to map the OIDC userinfo and token to user attributes.
 
-        logger.debug(
-            "Retrieved user attributes from user mapping provider: %r", attributes
-        )
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            if supports_failures:
+                attributes = await self._user_mapping_provider.map_user_attributes(
+                    userinfo, token, failures
+                )
+            else:
+                # If the mapping provider does not support processing failures,
+                # do not continually generate the same Matrix ID since it will
+                # continue to already be in use. Note that the error raised is
+                # arbitrary and will get turned into a MappingException.
+                if failures:
+                    raise RuntimeError(
+                        "Mapping provider does not support de-duplicating Matrix IDs"
+                    )
 
-        localpart = attributes["localpart"]
-        if not localpart:
-            raise MappingException(
-                "Error parsing OIDC response: OIDC mapping provider plugin "
-                "did not return a localpart value"
-            )
+                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
+                    userinfo, token
+                )
 
-        user_id = UserID(localpart, self.server_name).to_string()
-        users = await self.store.get_users_by_id_case_insensitive(user_id)
-        if users:
-            if self._allow_existing_users:
-                if len(users) == 1:
-                    registered_user_id = next(iter(users))
-                elif user_id in users:
-                    registered_user_id = user_id
-                else:
-                    raise MappingException(
-                        "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
-                            user_id, list(users.keys())
-                        )
-                    )
-            else:
-                # This mxid is taken
-                raise MappingException("mxid '{}' is already taken".format(user_id))
-        else:
-            # Since the localpart is provided via a potentially untrusted module,
-            # ensure the MXID is valid before registering.
-            if contains_invalid_mxid_characters(localpart):
-                raise MappingException("localpart is invalid: %s" % (localpart,))
-
-            # It's the first time this user is logging in and the mapped mxid was
-            # not taken, register the user
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=attributes["display_name"],
-                user_agent_ips=[(user_agent, ip_address)],
-            )
+            return UserAttributes(**attributes)
 
-        await self.store.record_user_external_id(
-            self._auth_provider_id, remote_user_id, registered_user_id,
+        return await self._sso_handler.get_mxid_from_sso(
+            self._auth_provider_id,
+            remote_user_id,
+            user_agent,
+            ip_address,
+            oidc_response_to_user_attributes,
+            self._allow_existing_users,
         )
-        return registered_user_id
 
 
-UserAttribute = TypedDict(
-    "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+    "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
 )
 C = TypeVar("C")
 
@@ -978,13 +952,15 @@ class OidcMappingProvider(Generic[C]):
         raise NotImplementedError()
 
     async def map_user_attributes(
-        self, userinfo: UserInfo, token: Token
-    ) -> UserAttribute:
+        self, userinfo: UserInfo, token: Token, failures: int
+    ) -> UserAttributeDict:
         """Map a `UserInfo` object into user attributes.
 
         Args:
             userinfo: An object representing the user given by the OIDC provider
             token: A dict with the tokens returned by the provider
+            failures: How many times a call to this function with this
+                UserInfo has resulted in a failure.
 
         Returns:
             A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1084,13 +1060,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         return userinfo[self._config.subject_claim]
 
     async def map_user_attributes(
-        self, userinfo: UserInfo, token: Token
-    ) -> UserAttribute:
+        self, userinfo: UserInfo, token: Token, failures: int
+    ) -> UserAttributeDict:
         localpart = self._config.localpart_template.render(user=userinfo).strip()
 
         # Ensure only valid characters are included in the MXID.
         localpart = map_username_to_mxid_localpart(localpart)
 
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid.
+        localpart += str(failures) if failures else ""
+
         display_name = None  # type: Optional[str]
         if self._config.display_name_template is not None:
             display_name = self._config.display_name_template.render(
@@ -1100,7 +1080,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             if display_name == "":
                 display_name = None
 
-        return UserAttribute(localpart=localpart, display_name=display_name)
+        return UserAttributeDict(localpart=localpart, display_name=display_name)
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
         extras = {}  # type: Dict[str, str]

+ 28 - 63
synapse/handlers/saml_handler.py

@@ -25,13 +25,12 @@ from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
 from synapse.handlers._base import BaseHandler
-from synapse.handlers.sso import MappingException
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
 from synapse.types import (
     UserID,
-    contains_invalid_mxid_characters,
     map_username_to_mxid_localpart,
     mxid_localpart_allowed_characters,
 )
@@ -250,14 +249,26 @@ class SamlHandler(BaseHandler):
                 "Failed to extract remote user id from SAML response"
             )
 
-        with (await self._mapping_lock.queue(self._auth_provider_id)):
-            # first of all, check if we already have a mapping for this user
-            previously_registered_user_id = await self._sso_handler.get_sso_user_by_remote_user_id(
-                self._auth_provider_id, remote_user_id,
+        async def saml_response_to_remapped_user_attributes(
+            failures: int,
+        ) -> UserAttributes:
+            """
+            Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            # Call the mapping provider.
+            result = self._user_mapping_provider.saml_response_to_user_attributes(
+                saml2_auth, failures, client_redirect_url
+            )
+            # Remap some of the results.
+            return UserAttributes(
+                localpart=result.get("mxid_localpart"),
+                display_name=result.get("displayname"),
+                emails=result.get("emails"),
             )
-            if previously_registered_user_id:
-                return previously_registered_user_id
 
+        with (await self._mapping_lock.queue(self._auth_provider_id)):
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
             if (
@@ -284,59 +295,13 @@ class SamlHandler(BaseHandler):
                     )
                     return registered_user_id
 
-            # Map saml response to user attributes using the configured mapping provider
-            for i in range(1000):
-                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
-                    saml2_auth, i, client_redirect_url=client_redirect_url,
-                )
-
-                logger.debug(
-                    "Retrieved SAML attributes from user mapping provider: %s "
-                    "(attempt %d)",
-                    attribute_dict,
-                    i,
-                )
-
-                localpart = attribute_dict.get("mxid_localpart")
-                if not localpart:
-                    raise MappingException(
-                        "Error parsing SAML2 response: SAML mapping provider plugin "
-                        "did not return a mxid_localpart value"
-                    )
-
-                displayname = attribute_dict.get("displayname")
-                emails = attribute_dict.get("emails", [])
-
-                # Check if this mxid already exists
-                if not await self.store.get_users_by_id_case_insensitive(
-                    UserID(localpart, self.server_name).to_string()
-                ):
-                    # This mxid is free
-                    break
-            else:
-                # Unable to generate a username in 1000 iterations
-                # Break and return error to the user
-                raise MappingException(
-                    "Unable to generate a Matrix ID from the SAML response"
-                )
-
-            # Since the localpart is provided via a potentially untrusted module,
-            # ensure the MXID is valid before registering.
-            if contains_invalid_mxid_characters(localpart):
-                raise MappingException("localpart is invalid: %s" % (localpart,))
-
-            logger.debug("Mapped SAML user to local part %s", localpart)
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=displayname,
-                bind_emails=emails,
-                user_agent_ips=[(user_agent, ip_address)],
-            )
-
-            await self.store.record_user_external_id(
-                self._auth_provider_id, remote_user_id, registered_user_id
+            return await self._sso_handler.get_mxid_from_sso(
+                self._auth_provider_id,
+                remote_user_id,
+                user_agent,
+                ip_address,
+                saml_response_to_remapped_user_attributes,
             )
-            return registered_user_id
 
     def expire_sessions(self):
         expire_before = self.clock.time_msec() - self._saml2_session_lifetime
@@ -451,11 +416,11 @@ class DefaultSamlMappingProvider:
             )
 
         # Use the configured mapper for this mxid_source
-        base_mxid_localpart = self._mxid_mapper(mxid_source)
+        localpart = self._mxid_mapper(mxid_source)
 
         # Append suffix integer if last call to this function failed to produce
-        # a usable mxid
-        localpart = base_mxid_localpart + (str(failures) if failures else "")
+        # a usable mxid.
+        localpart += str(failures) if failures else ""
 
         # Retrieve the display name from the saml response
         # If displayname is None, the mxid_localpart will be used instead

+ 154 - 1
synapse/handlers/sso.py

@@ -13,10 +13,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+
+import attr
 
 from synapse.handlers._base import BaseHandler
 from synapse.http.server import respond_with_html
+from synapse.types import UserID, contains_invalid_mxid_characters
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -29,9 +32,20 @@ class MappingException(Exception):
     """
 
 
+@attr.s
+class UserAttributes:
+    localpart = attr.ib(type=str)
+    display_name = attr.ib(type=Optional[str], default=None)
+    emails = attr.ib(type=List[str], default=attr.Factory(list))
+
+
 class SsoHandler(BaseHandler):
+    # The number of attempts to ask the mapping provider for when generating an MXID.
+    _MAP_USERNAME_RETRIES = 1000
+
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
+        self._registration_handler = hs.get_registration_handler()
         self._error_template = hs.config.sso_error_template
 
     def render_error(
@@ -94,3 +108,142 @@ class SsoHandler(BaseHandler):
 
         # No match.
         return None
+
+    async def get_mxid_from_sso(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        user_agent: str,
+        ip_address: str,
+        sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+        allow_existing_users: bool = False,
+    ) -> str:
+        """
+        Given an SSO ID, retrieve the user ID for it and possibly register the user.
+
+        This first checks if the SSO ID has previously been linked to a matrix ID,
+        if it has that matrix ID is returned regardless of the current mapping
+        logic.
+
+        The mapping function is called (potentially multiple times) to generate
+        a localpart for the user.
+
+        If an unused localpart is generated, the user is registered from the
+        given user-agent and IP address and the SSO ID is linked to this matrix
+        ID for subsequent calls.
+
+        If allow_existing_users is true the mapping function is only called once
+        and results in:
+
+            1. The use of a previously registered matrix ID. In this case, the
+               SSO ID is linked to the matrix ID. (Note it is possible that
+               other SSO IDs are linked to the same matrix ID.)
+            2. An unused localpart, in which case the user is registered (as
+               discussed above).
+            3. An error if the generated localpart matches multiple pre-existing
+               matrix IDs. Generally this should not happen.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The unique identifier from the SSO provider.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+            sso_to_matrix_id_mapper: A callable to generate the user attributes.
+                The only parameter is an integer which represents the amount of
+                times the returned mxid localpart mapping has failed.
+            allow_existing_users: True if the localpart returned from the
+                mapping provider can be linked to an existing matrix ID.
+
+        Returns:
+             The user ID associated with the SSO response.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: some mapping providers may raise this if they need
+                to redirect to an interstitial page.
+
+        """
+        # first of all, check if we already have a mapping for this user
+        previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+            auth_provider_id, remote_user_id,
+        )
+        if previously_registered_user_id:
+            return previously_registered_user_id
+
+        # Otherwise, generate a new user.
+        for i in range(self._MAP_USERNAME_RETRIES):
+            try:
+                attributes = await sso_to_matrix_id_mapper(i)
+            except Exception as e:
+                raise MappingException(
+                    "Could not extract user attributes from SSO response: " + str(e)
+                )
+
+            logger.debug(
+                "Retrieved user attributes from user mapping provider: %r (attempt %d)",
+                attributes,
+                i,
+            )
+
+            if not attributes.localpart:
+                raise MappingException(
+                    "Error parsing SSO response: SSO mapping provider plugin "
+                    "did not return a localpart value"
+                )
+
+            # Check if this mxid already exists
+            user_id = UserID(attributes.localpart, self.server_name).to_string()
+            users = await self.store.get_users_by_id_case_insensitive(user_id)
+            # Note, if allow_existing_users is true then the loop is guaranteed
+            # to end on the first iteration: either by matching an existing user,
+            # raising an error, or registering a new user. See the docstring for
+            # more in-depth an explanation.
+            if users and allow_existing_users:
+                # If an existing matrix ID is returned, then use it.
+                if len(users) == 1:
+                    previously_registered_user_id = next(iter(users))
+                elif user_id in users:
+                    previously_registered_user_id = user_id
+                else:
+                    # Do not attempt to continue generating Matrix IDs.
+                    raise MappingException(
+                        "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+                            user_id, users
+                        )
+                    )
+
+                # Future logins should also match this user ID.
+                await self.store.record_user_external_id(
+                    auth_provider_id, remote_user_id, previously_registered_user_id
+                )
+
+                return previously_registered_user_id
+
+            elif not users:
+                # This mxid is free
+                break
+        else:
+            # Unable to generate a username in 1000 iterations
+            # Break and return error to the user
+            raise MappingException(
+                "Unable to generate a Matrix ID from the SSO response"
+            )
+
+        # Since the localpart is provided via a potentially untrusted module,
+        # ensure the MXID is valid before registering.
+        if contains_invalid_mxid_characters(attributes.localpart):
+            raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
+
+        logger.debug("Mapped SSO user to local part %s", attributes.localpart)
+        registered_user_id = await self._registration_handler.register_user(
+            localpart=attributes.localpart,
+            default_display_name=attributes.display_name,
+            bind_emails=attributes.emails,
+            user_agent_ips=[(user_agent, ip_address)],
+        )
+
+        await self.store.record_user_external_id(
+            auth_provider_id, remote_user_id, registered_user_id
+        )
+        return registered_user_id

+ 87 - 1
tests/handlers/test_oidc.py

@@ -89,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider):
         return {"phone": userinfo["phone"]}
 
 
+class TestMappingProviderFailures(TestMappingProvider):
+    async def map_user_attributes(self, userinfo, token, failures):
+        return {
+            "localpart": userinfo["username"] + (str(failures) if failures else ""),
+            "display_name": None,
+        }
+
+
 def simple_async_mock(return_value=None, raises=None):
     # AsyncMock is not available in python3.5, this mimics part of its behaviour
     async def cb(*args, **kwargs):
@@ -152,6 +160,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.render_error = Mock(return_value=None)
         self.handler._sso_handler.render_error = self.render_error
 
+        # Reduce the number of attempts when generating MXIDs.
+        self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
+
         return hs
 
     def metadata_edit(self, values):
@@ -693,7 +704,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
             ),
             MappingException,
         )
-        self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+        self.assertEqual(
+            str(e.value),
+            "Could not extract user attributes from SSO response: Mapping provider does not support de-duplicating Matrix IDs",
+        )
 
     @override_config({"oidc_config": {"allow_existing_users": True}})
     def test_map_userinfo_to_existing_user(self):
@@ -703,6 +717,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
+
+        # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
@@ -715,6 +731,23 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         self.assertEqual(mxid, "@test_user:test")
 
+        # Note that a second SSO user can be mapped to the same Matrix ID. (This
+        # requires a unique sub, but something that maps to the same matrix ID,
+        # in this case we'll just use the same username. A more realistic example
+        # would be subs which are email addresses, and mapping from the localpart
+        # of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.)
+        userinfo = {
+            "sub": "test1",
+            "username": "test_user",
+        }
+        token = {}
+        mxid = self.get_success(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            )
+        )
+        self.assertEqual(mxid, "@test_user:test")
+
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
         self.get_success(
@@ -762,6 +795,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "föö",
         }
         token = {}
+
         e = self.get_failure(
             self.handler._map_userinfo_to_user(
                 userinfo, token, "user-agent", "10.10.10.10"
@@ -769,3 +803,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
             MappingException,
         )
         self.assertEqual(str(e.value), "localpart is invalid: föö")
+
+    @override_config(
+        {
+            "oidc_config": {
+                "user_mapping_provider": {
+                    "module": __name__ + ".TestMappingProviderFailures"
+                }
+            }
+        }
+    )
+    def test_map_userinfo_to_user_retries(self):
+        """The mapping provider can retry generating an MXID if the MXID is already in use."""
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.register_user(user_id="@test_user:test", password_hash=None)
+        )
+        userinfo = {
+            "sub": "test",
+            "username": "test_user",
+        }
+        token = {}
+        mxid = self.get_success(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            )
+        )
+        # test_user is already taken, so test_user1 gets registered instead.
+        self.assertEqual(mxid, "@test_user1:test")
+
+        # Register all of the potential users for a particular username.
+        self.get_success(
+            store.register_user(user_id="@tester:test", password_hash=None)
+        )
+        for i in range(1, 3):
+            self.get_success(
+                store.register_user(user_id="@tester%d:test" % i, password_hash=None)
+            )
+
+        # Now attempt to map to a username, this will fail since all potential usernames are taken.
+        userinfo = {
+            "sub": "tester",
+            "username": "tester",
+        }
+        e = self.get_failure(
+            self.handler._map_userinfo_to_user(
+                userinfo, token, "user-agent", "10.10.10.10"
+            ),
+            MappingException,
+        )
+        self.assertEqual(
+            str(e.value), "Unable to generate a Matrix ID from the SSO response"
+        )