|
@@ -13,17 +13,19 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
import logging
|
|
|
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
|
|
|
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
|
|
|
|
|
|
import attr
|
|
|
+from typing_extensions import NoReturn
|
|
|
|
|
|
from twisted.web.http import Request
|
|
|
|
|
|
-from synapse.api.errors import RedirectException
|
|
|
+from synapse.api.errors import RedirectException, SynapseError
|
|
|
from synapse.http.server import respond_with_html
|
|
|
from synapse.http.site import SynapseRequest
|
|
|
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
|
|
|
from synapse.util.async_helpers import Linearizer
|
|
|
+from synapse.util.stringutils import random_string
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from synapse.server import HomeServer
|
|
@@ -40,16 +42,52 @@ class MappingException(Exception):
|
|
|
|
|
|
@attr.s
|
|
|
class UserAttributes:
|
|
|
- localpart = attr.ib(type=str)
|
|
|
+ # the localpart of the mxid that the mapper has assigned to the user.
|
|
|
+ # if `None`, the mapper has not picked a userid, and the user should be prompted to
|
|
|
+ # enter one.
|
|
|
+ localpart = attr.ib(type=Optional[str])
|
|
|
display_name = attr.ib(type=Optional[str], default=None)
|
|
|
emails = attr.ib(type=List[str], default=attr.Factory(list))
|
|
|
|
|
|
|
|
|
+@attr.s(slots=True)
|
|
|
+class UsernameMappingSession:
|
|
|
+ """Data we track about SSO sessions"""
|
|
|
+
|
|
|
+ # A unique identifier for this SSO provider, e.g. "oidc" or "saml".
|
|
|
+ auth_provider_id = attr.ib(type=str)
|
|
|
+
|
|
|
+ # user ID on the IdP server
|
|
|
+ remote_user_id = attr.ib(type=str)
|
|
|
+
|
|
|
+ # attributes returned by the ID mapper
|
|
|
+ display_name = attr.ib(type=Optional[str])
|
|
|
+ emails = attr.ib(type=List[str])
|
|
|
+
|
|
|
+ # An optional dictionary of extra attributes to be provided to the client in the
|
|
|
+ # login response.
|
|
|
+ extra_login_attributes = attr.ib(type=Optional[JsonDict])
|
|
|
+
|
|
|
+ # where to redirect the client back to
|
|
|
+ client_redirect_url = attr.ib(type=str)
|
|
|
+
|
|
|
+ # expiry time for the session, in milliseconds
|
|
|
+ expiry_time_ms = attr.ib(type=int)
|
|
|
+
|
|
|
+
|
|
|
+# the HTTP cookie used to track the mapping session id
|
|
|
+USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
|
|
|
+
|
|
|
+
|
|
|
class SsoHandler:
|
|
|
# The number of attempts to ask the mapping provider for when generating an MXID.
|
|
|
_MAP_USERNAME_RETRIES = 1000
|
|
|
|
|
|
+ # the time a UsernameMappingSession remains valid for
|
|
|
+ _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000
|
|
|
+
|
|
|
def __init__(self, hs: "HomeServer"):
|
|
|
+ self._clock = hs.get_clock()
|
|
|
self._store = hs.get_datastore()
|
|
|
self._server_name = hs.hostname
|
|
|
self._registration_handler = hs.get_registration_handler()
|
|
@@ -59,6 +97,9 @@ class SsoHandler:
|
|
|
# a lock on the mappings
|
|
|
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
|
|
|
|
|
|
+ # a map from session id to session data
|
|
|
+ self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession]
|
|
|
+
|
|
|
def render_error(
|
|
|
self, request, error: str, error_description: Optional[str] = None
|
|
|
) -> None:
|
|
@@ -206,6 +247,18 @@ class SsoHandler:
|
|
|
# Otherwise, generate a new user.
|
|
|
if not user_id:
|
|
|
attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
|
|
|
+
|
|
|
+ if attributes.localpart is None:
|
|
|
+ # the mapper doesn't return a username. bail out with a redirect to
|
|
|
+ # the username picker.
|
|
|
+ await self._redirect_to_username_picker(
|
|
|
+ auth_provider_id,
|
|
|
+ remote_user_id,
|
|
|
+ attributes,
|
|
|
+ client_redirect_url,
|
|
|
+ extra_login_attributes,
|
|
|
+ )
|
|
|
+
|
|
|
user_id = await self._register_mapped_user(
|
|
|
attributes,
|
|
|
auth_provider_id,
|
|
@@ -243,10 +296,8 @@ class SsoHandler:
|
|
|
)
|
|
|
|
|
|
if not attributes.localpart:
|
|
|
- raise MappingException(
|
|
|
- "Error parsing SSO response: SSO mapping provider plugin "
|
|
|
- "did not return a localpart value"
|
|
|
- )
|
|
|
+ # the mapper has not picked a localpart
|
|
|
+ return attributes
|
|
|
|
|
|
# Check if this mxid already exists
|
|
|
user_id = UserID(attributes.localpart, self._server_name).to_string()
|
|
@@ -261,6 +312,59 @@ class SsoHandler:
|
|
|
)
|
|
|
return attributes
|
|
|
|
|
|
+ async def _redirect_to_username_picker(
|
|
|
+ self,
|
|
|
+ auth_provider_id: str,
|
|
|
+ remote_user_id: str,
|
|
|
+ attributes: UserAttributes,
|
|
|
+ client_redirect_url: str,
|
|
|
+ extra_login_attributes: Optional[JsonDict],
|
|
|
+ ) -> NoReturn:
|
|
|
+ """Creates a UsernameMappingSession and redirects the browser
|
|
|
+
|
|
|
+ Called if the user mapping provider doesn't return a localpart for a new user.
|
|
|
+ Raises a RedirectException which redirects the browser to the username picker.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
|
|
|
+ "oidc" or "saml".
|
|
|
+
|
|
|
+ remote_user_id: The unique identifier from the SSO provider.
|
|
|
+
|
|
|
+ attributes: the user attributes returned by the user mapping provider.
|
|
|
+
|
|
|
+ client_redirect_url: The redirect URL passed in by the client, which we
|
|
|
+ will eventually redirect back to.
|
|
|
+
|
|
|
+ extra_login_attributes: An optional dictionary of extra
|
|
|
+ attributes to be provided to the client in the login response.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ RedirectException
|
|
|
+ """
|
|
|
+ session_id = random_string(16)
|
|
|
+ now = self._clock.time_msec()
|
|
|
+ session = UsernameMappingSession(
|
|
|
+ auth_provider_id=auth_provider_id,
|
|
|
+ remote_user_id=remote_user_id,
|
|
|
+ display_name=attributes.display_name,
|
|
|
+ emails=attributes.emails,
|
|
|
+ client_redirect_url=client_redirect_url,
|
|
|
+ expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS,
|
|
|
+ extra_login_attributes=extra_login_attributes,
|
|
|
+ )
|
|
|
+
|
|
|
+ self._username_mapping_sessions[session_id] = session
|
|
|
+ logger.info("Recorded registration session id %s", session_id)
|
|
|
+
|
|
|
+ # Set the cookie and redirect to the username picker
|
|
|
+ e = RedirectException(b"/_synapse/client/pick_username")
|
|
|
+ e.cookies.append(
|
|
|
+ b"%s=%s; path=/"
|
|
|
+ % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
|
|
|
+ )
|
|
|
+ raise e
|
|
|
+
|
|
|
async def _register_mapped_user(
|
|
|
self,
|
|
|
attributes: UserAttributes,
|
|
@@ -269,9 +373,38 @@ class SsoHandler:
|
|
|
user_agent: str,
|
|
|
ip_address: str,
|
|
|
) -> str:
|
|
|
+ """Register a new SSO user.
|
|
|
+
|
|
|
+ This is called once we have successfully mapped the remote user id onto a local
|
|
|
+ user id, one way or another.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ attributes: user attributes returned by the user mapping provider,
|
|
|
+ including a non-empty localpart.
|
|
|
+
|
|
|
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
|
|
|
+ "oidc" or "saml".
|
|
|
+
|
|
|
+ remote_user_id: The unique identifier from the SSO provider.
|
|
|
+
|
|
|
+ user_agent: The user-agent in the HTTP request (used for potential
|
|
|
+ shadow-banning.)
|
|
|
+
|
|
|
+ ip_address: The IP address of the requester (used for potential
|
|
|
+ shadow-banning.)
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ a MappingException if the localpart is invalid.
|
|
|
+
|
|
|
+ a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart
|
|
|
+ is already taken.
|
|
|
+ """
|
|
|
+
|
|
|
# Since the localpart is provided via a potentially untrusted module,
|
|
|
# ensure the MXID is valid before registering.
|
|
|
- if contains_invalid_mxid_characters(attributes.localpart):
|
|
|
+ if not attributes.localpart or contains_invalid_mxid_characters(
|
|
|
+ attributes.localpart
|
|
|
+ ):
|
|
|
raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
|
|
|
|
|
|
logger.debug("Mapped SSO user to local part %s", attributes.localpart)
|
|
@@ -326,3 +459,108 @@ class SsoHandler:
|
|
|
await self._auth_handler.complete_sso_ui_auth(
|
|
|
user_id, ui_auth_session_id, request
|
|
|
)
|
|
|
+
|
|
|
+ async def check_username_availability(
|
|
|
+ self, localpart: str, session_id: str,
|
|
|
+ ) -> bool:
|
|
|
+ """Handle an "is username available" callback check
|
|
|
+
|
|
|
+ Args:
|
|
|
+ localpart: desired localpart
|
|
|
+ session_id: the session id for the username picker
|
|
|
+ Returns:
|
|
|
+ True if the username is available
|
|
|
+ Raises:
|
|
|
+ SynapseError if the localpart is invalid or the session is unknown
|
|
|
+ """
|
|
|
+
|
|
|
+ # make sure that there is a valid mapping session, to stop people dictionary-
|
|
|
+ # scanning for accounts
|
|
|
+
|
|
|
+ self._expire_old_sessions()
|
|
|
+ session = self._username_mapping_sessions.get(session_id)
|
|
|
+ if not session:
|
|
|
+ logger.info("Couldn't find session id %s", session_id)
|
|
|
+ raise SynapseError(400, "unknown session")
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ "[session %s] Checking for availability of username %s",
|
|
|
+ session_id,
|
|
|
+ localpart,
|
|
|
+ )
|
|
|
+
|
|
|
+ if contains_invalid_mxid_characters(localpart):
|
|
|
+ raise SynapseError(400, "localpart is invalid: %s" % (localpart,))
|
|
|
+ user_id = UserID(localpart, self._server_name).to_string()
|
|
|
+ user_infos = await self._store.get_users_by_id_case_insensitive(user_id)
|
|
|
+
|
|
|
+ logger.info("[session %s] users: %s", session_id, user_infos)
|
|
|
+ return not user_infos
|
|
|
+
|
|
|
+ async def handle_submit_username_request(
|
|
|
+ self, request: SynapseRequest, localpart: str, session_id: str
|
|
|
+ ) -> None:
|
|
|
+ """Handle a request to the username-picker 'submit' endpoint
|
|
|
+
|
|
|
+ Will serve an HTTP response to the request.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ request: HTTP request
|
|
|
+ localpart: localpart requested by the user
|
|
|
+ session_id: ID of the username mapping session, extracted from a cookie
|
|
|
+ """
|
|
|
+ self._expire_old_sessions()
|
|
|
+ session = self._username_mapping_sessions.get(session_id)
|
|
|
+ if not session:
|
|
|
+ logger.info("Couldn't find session id %s", session_id)
|
|
|
+ raise SynapseError(400, "unknown session")
|
|
|
+
|
|
|
+ logger.info("[session %s] Registering localpart %s", session_id, localpart)
|
|
|
+
|
|
|
+ attributes = UserAttributes(
|
|
|
+ localpart=localpart,
|
|
|
+ display_name=session.display_name,
|
|
|
+ emails=session.emails,
|
|
|
+ )
|
|
|
+
|
|
|
+ # the following will raise a 400 error if the username has been taken in the
|
|
|
+ # meantime.
|
|
|
+ user_id = await self._register_mapped_user(
|
|
|
+ attributes,
|
|
|
+ session.auth_provider_id,
|
|
|
+ session.remote_user_id,
|
|
|
+ request.get_user_agent(""),
|
|
|
+ request.getClientIP(),
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info("[session %s] Registered userid %s", session_id, user_id)
|
|
|
+
|
|
|
+ # delete the mapping session and the cookie
|
|
|
+ del self._username_mapping_sessions[session_id]
|
|
|
+
|
|
|
+ # delete the cookie
|
|
|
+ request.addCookie(
|
|
|
+ USERNAME_MAPPING_SESSION_COOKIE_NAME,
|
|
|
+ b"",
|
|
|
+ expires=b"Thu, 01 Jan 1970 00:00:00 GMT",
|
|
|
+ path=b"/",
|
|
|
+ )
|
|
|
+
|
|
|
+ await self._auth_handler.complete_sso_login(
|
|
|
+ user_id,
|
|
|
+ request,
|
|
|
+ session.client_redirect_url,
|
|
|
+ session.extra_login_attributes,
|
|
|
+ )
|
|
|
+
|
|
|
+ def _expire_old_sessions(self):
|
|
|
+ to_expire = []
|
|
|
+ now = int(self._clock.time_msec())
|
|
|
+
|
|
|
+ for session_id, session in self._username_mapping_sessions.items():
|
|
|
+ if session.expiry_time_ms <= now:
|
|
|
+ to_expire.append(session_id)
|
|
|
+
|
|
|
+ for session_id in to_expire:
|
|
|
+ logger.info("Expiring mapping session %s", session_id)
|
|
|
+ del self._username_mapping_sessions[session_id]
|