|
@@ -21,12 +21,13 @@ import attr
|
|
|
from typing_extensions import NoReturn, Protocol
|
|
|
|
|
|
from twisted.web.http import Request
|
|
|
+from twisted.web.iweb import IRequest
|
|
|
|
|
|
from synapse.api.constants import LoginType
|
|
|
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
|
|
|
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
|
|
|
from synapse.http import get_request_user_agent
|
|
|
-from synapse.http.server import respond_with_html
|
|
|
+from synapse.http.server import respond_with_html, respond_with_redirect
|
|
|
from synapse.http.site import SynapseRequest
|
|
|
from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
|
|
|
from synapse.util.async_helpers import Linearizer
|
|
@@ -141,6 +142,9 @@ class UsernameMappingSession:
|
|
|
# expiry time for the session, in milliseconds
|
|
|
expiry_time_ms = attr.ib(type=int)
|
|
|
|
|
|
+ # choices made by the user
|
|
|
+ chosen_localpart = attr.ib(type=Optional[str], default=None)
|
|
|
+
|
|
|
|
|
|
# the HTTP cookie used to track the mapping session id
|
|
|
USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session"
|
|
@@ -387,6 +391,8 @@ class SsoHandler:
|
|
|
to an additional page. (e.g. to prompt for more information)
|
|
|
|
|
|
"""
|
|
|
+ new_user = False
|
|
|
+
|
|
|
# grab a lock while we try to find a mapping for this user. This seems...
|
|
|
# optimistic, especially for implementations that end up redirecting to
|
|
|
# interstitial pages.
|
|
@@ -427,9 +433,14 @@ class SsoHandler:
|
|
|
get_request_user_agent(request),
|
|
|
request.getClientIP(),
|
|
|
)
|
|
|
+ new_user = True
|
|
|
|
|
|
await self._auth_handler.complete_sso_login(
|
|
|
- user_id, request, client_redirect_url, extra_login_attributes
|
|
|
+ user_id,
|
|
|
+ request,
|
|
|
+ client_redirect_url,
|
|
|
+ extra_login_attributes,
|
|
|
+ new_user=new_user,
|
|
|
)
|
|
|
|
|
|
async def _call_attribute_mapper(
|
|
@@ -519,7 +530,7 @@ class SsoHandler:
|
|
|
logger.info("Recorded registration session id %s", session_id)
|
|
|
|
|
|
# Set the cookie and redirect to the username picker
|
|
|
- e = RedirectException(b"/_synapse/client/pick_username")
|
|
|
+ e = RedirectException(b"/_synapse/client/pick_username/account_details")
|
|
|
e.cookies.append(
|
|
|
b"%s=%s; path=/"
|
|
|
% (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii"))
|
|
@@ -647,6 +658,25 @@ class SsoHandler:
|
|
|
)
|
|
|
respond_with_html(request, 200, html)
|
|
|
|
|
|
+ def get_mapping_session(self, session_id: str) -> UsernameMappingSession:
|
|
|
+ """Look up the given username mapping session
|
|
|
+
|
|
|
+ If it is not found, raises a SynapseError with an http code of 400
|
|
|
+
|
|
|
+ Args:
|
|
|
+ session_id: session to look up
|
|
|
+ Returns:
|
|
|
+ active mapping session
|
|
|
+ Raises:
|
|
|
+ SynapseError if the session is not found/has expired
|
|
|
+ """
|
|
|
+ self._expire_old_sessions()
|
|
|
+ session = self._username_mapping_sessions.get(session_id)
|
|
|
+ if session:
|
|
|
+ return session
|
|
|
+ logger.info("Couldn't find session id %s", session_id)
|
|
|
+ raise SynapseError(400, "unknown session")
|
|
|
+
|
|
|
async def check_username_availability(
|
|
|
self, localpart: str, session_id: str,
|
|
|
) -> bool:
|
|
@@ -663,12 +693,7 @@ class SsoHandler:
|
|
|
|
|
|
# make sure that there is a valid mapping session, to stop people dictionary-
|
|
|
# scanning for accounts
|
|
|
-
|
|
|
- self._expire_old_sessions()
|
|
|
- session = self._username_mapping_sessions.get(session_id)
|
|
|
- if not session:
|
|
|
- logger.info("Couldn't find session id %s", session_id)
|
|
|
- raise SynapseError(400, "unknown session")
|
|
|
+ self.get_mapping_session(session_id)
|
|
|
|
|
|
logger.info(
|
|
|
"[session %s] Checking for availability of username %s",
|
|
@@ -696,16 +721,33 @@ class SsoHandler:
|
|
|
localpart: localpart requested by the user
|
|
|
session_id: ID of the username mapping session, extracted from a cookie
|
|
|
"""
|
|
|
- self._expire_old_sessions()
|
|
|
- session = self._username_mapping_sessions.get(session_id)
|
|
|
- if not session:
|
|
|
- logger.info("Couldn't find session id %s", session_id)
|
|
|
- raise SynapseError(400, "unknown session")
|
|
|
+ session = self.get_mapping_session(session_id)
|
|
|
+
|
|
|
+ # update the session with the user's choices
|
|
|
+ session.chosen_localpart = localpart
|
|
|
+
|
|
|
+ # we're done; now we can register the user
|
|
|
+ respond_with_redirect(request, b"/_synapse/client/sso_register")
|
|
|
|
|
|
- logger.info("[session %s] Registering localpart %s", session_id, localpart)
|
|
|
+ async def register_sso_user(self, request: Request, session_id: str) -> None:
|
|
|
+ """Called once we have all the info we need to register a new user.
|
|
|
+
|
|
|
+ Does so and serves an HTTP response
|
|
|
+
|
|
|
+ Args:
|
|
|
+ request: HTTP request
|
|
|
+ session_id: ID of the username mapping session, extracted from a cookie
|
|
|
+ """
|
|
|
+ session = self.get_mapping_session(session_id)
|
|
|
+
|
|
|
+ logger.info(
|
|
|
+ "[session %s] Registering localpart %s",
|
|
|
+ session_id,
|
|
|
+ session.chosen_localpart,
|
|
|
+ )
|
|
|
|
|
|
attributes = UserAttributes(
|
|
|
- localpart=localpart,
|
|
|
+ localpart=session.chosen_localpart,
|
|
|
display_name=session.display_name,
|
|
|
emails=session.emails,
|
|
|
)
|
|
@@ -720,7 +762,12 @@ class SsoHandler:
|
|
|
request.getClientIP(),
|
|
|
)
|
|
|
|
|
|
- logger.info("[session %s] Registered userid %s", session_id, user_id)
|
|
|
+ logger.info(
|
|
|
+ "[session %s] Registered userid %s with attributes %s",
|
|
|
+ session_id,
|
|
|
+ user_id,
|
|
|
+ attributes,
|
|
|
+ )
|
|
|
|
|
|
# delete the mapping session and the cookie
|
|
|
del self._username_mapping_sessions[session_id]
|
|
@@ -738,6 +785,7 @@ class SsoHandler:
|
|
|
request,
|
|
|
session.client_redirect_url,
|
|
|
session.extra_login_attributes,
|
|
|
+ new_user=True,
|
|
|
)
|
|
|
|
|
|
def _expire_old_sessions(self):
|
|
@@ -751,3 +799,14 @@ class SsoHandler:
|
|
|
for session_id in to_expire:
|
|
|
logger.info("Expiring mapping session %s", session_id)
|
|
|
del self._username_mapping_sessions[session_id]
|
|
|
+
|
|
|
+
|
|
|
+def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
|
|
|
+ """Extract the session ID from the cookie
|
|
|
+
|
|
|
+ Raises a SynapseError if the cookie isn't found
|
|
|
+ """
|
|
|
+ session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
|
|
|
+ if not session_id:
|
|
|
+ raise SynapseError(code=400, msg="missing session_id")
|
|
|
+ return session_id.decode("ascii", errors="replace")
|