|
@@ -23,8 +23,12 @@ from twisted.web.client import PartialDownloadError
|
|
|
|
|
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
|
|
from synapse.http.server import finish_request
|
|
|
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
|
|
-from synapse.types import UserID
|
|
|
+from synapse.http.servlet import (
|
|
|
+ RestServlet,
|
|
|
+ parse_json_object_from_request,
|
|
|
+ parse_string,
|
|
|
+)
|
|
|
+from synapse.types import UserID, map_username_to_mxid_localpart
|
|
|
from synapse.util.msisdn import phone_number_to_msisdn
|
|
|
|
|
|
from .base import ClientV1RestServlet, client_path_patterns
|
|
@@ -358,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet):
|
|
|
self.cas_server_url = hs.config.cas_server_url
|
|
|
self.cas_service_url = hs.config.cas_service_url
|
|
|
self.cas_required_attributes = hs.config.cas_required_attributes
|
|
|
- self.auth_handler = hs.get_auth_handler()
|
|
|
- self.handlers = hs.get_handlers()
|
|
|
- self.macaroon_gen = hs.get_macaroon_generator()
|
|
|
+ self._sso_auth_handler = SSOAuthHandler(hs)
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
def on_GET(self, request):
|
|
|
- client_redirect_url = request.args[b"redirectUrl"][0]
|
|
|
+ client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
|
|
http_client = self.hs.get_simple_http_client()
|
|
|
uri = self.cas_server_url + "/proxyValidate"
|
|
|
args = {
|
|
|
- "ticket": request.args[b"ticket"][0].decode('ascii'),
|
|
|
+ "ticket": parse_string(request, "ticket", required=True),
|
|
|
"service": self.cas_service_url
|
|
|
}
|
|
|
try:
|
|
@@ -380,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
|
|
|
result = yield self.handle_cas_response(request, body, client_redirect_url)
|
|
|
defer.returnValue(result)
|
|
|
|
|
|
- @defer.inlineCallbacks
|
|
|
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
|
|
user, attributes = self.parse_cas_response(cas_response_body)
|
|
|
|
|
@@ -396,28 +397,9 @@ class CasTicketServlet(ClientV1RestServlet):
|
|
|
if required_value != actual_value:
|
|
|
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
|
|
|
|
- user_id = UserID(user, self.hs.hostname).to_string()
|
|
|
- auth_handler = self.auth_handler
|
|
|
- registered_user_id = yield auth_handler.check_user_exists(user_id)
|
|
|
- if not registered_user_id:
|
|
|
- registered_user_id, _ = (
|
|
|
- yield self.handlers.registration_handler.register(localpart=user)
|
|
|
- )
|
|
|
-
|
|
|
- login_token = self.macaroon_gen.generate_short_term_login_token(
|
|
|
- registered_user_id
|
|
|
+ return self._sso_auth_handler.on_successful_auth(
|
|
|
+ user, request, client_redirect_url,
|
|
|
)
|
|
|
- redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
|
|
|
- login_token)
|
|
|
- request.redirect(redirect_url)
|
|
|
- finish_request(request)
|
|
|
-
|
|
|
- def add_login_token_to_redirect_url(self, url, token):
|
|
|
- url_parts = list(urllib.parse.urlparse(url))
|
|
|
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
|
|
- query.update({"loginToken": token})
|
|
|
- url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
|
|
|
- return urllib.parse.urlunparse(url_parts)
|
|
|
|
|
|
def parse_cas_response(self, cas_response_body):
|
|
|
user = None
|
|
@@ -452,6 +434,71 @@ class CasTicketServlet(ClientV1RestServlet):
|
|
|
return user, attributes
|
|
|
|
|
|
|
|
|
+class SSOAuthHandler(object):
|
|
|
+ """
|
|
|
+ Utility class for Resources and Servlets which handle the response from a SSO
|
|
|
+ service
|
|
|
+
|
|
|
+ Args:
|
|
|
+ hs (synapse.server.HomeServer)
|
|
|
+ """
|
|
|
+ def __init__(self, hs):
|
|
|
+ self._hostname = hs.hostname
|
|
|
+ self._auth_handler = hs.get_auth_handler()
|
|
|
+ self._registration_handler = hs.get_handlers().registration_handler
|
|
|
+ self._macaroon_gen = hs.get_macaroon_generator()
|
|
|
+
|
|
|
+ @defer.inlineCallbacks
|
|
|
+ def on_successful_auth(
|
|
|
+ self, username, request, client_redirect_url,
|
|
|
+ ):
|
|
|
+ """Called once the user has successfully authenticated with the SSO.
|
|
|
+
|
|
|
+ Registers the user if necessary, and then returns a redirect (with
|
|
|
+ a login token) to the client.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ username (unicode|bytes): the remote user id. We'll map this onto
|
|
|
+ something sane for a MXID localpath.
|
|
|
+
|
|
|
+ request (SynapseRequest): the incoming request from the browser. We'll
|
|
|
+ respond to it with a redirect.
|
|
|
+
|
|
|
+ client_redirect_url (unicode): the redirect_url the client gave us when
|
|
|
+ it first started the process.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Deferred[none]: Completes once we have handled the request.
|
|
|
+ """
|
|
|
+ localpart = map_username_to_mxid_localpart(username)
|
|
|
+ user_id = UserID(localpart, self._hostname).to_string()
|
|
|
+ registered_user_id = yield self._auth_handler.check_user_exists(user_id)
|
|
|
+ if not registered_user_id:
|
|
|
+ registered_user_id, _ = (
|
|
|
+ yield self._registration_handler.register(
|
|
|
+ localpart=localpart,
|
|
|
+ generate_token=False,
|
|
|
+ )
|
|
|
+ )
|
|
|
+
|
|
|
+ login_token = self._macaroon_gen.generate_short_term_login_token(
|
|
|
+ registered_user_id
|
|
|
+ )
|
|
|
+ redirect_url = self._add_login_token_to_redirect_url(
|
|
|
+ client_redirect_url, login_token
|
|
|
+ )
|
|
|
+ request.redirect(redirect_url)
|
|
|
+ finish_request(request)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _add_login_token_to_redirect_url(url, token):
|
|
|
+ url_parts = list(urllib.parse.urlparse(url))
|
|
|
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
|
|
|
+ query.update({"loginToken": token})
|
|
|
+ url_parts[4] = urllib.parse.urlencode(query)
|
|
|
+ return urllib.parse.urlunparse(url_parts)
|
|
|
+
|
|
|
+
|
|
|
def register_servlets(hs, http_server):
|
|
|
LoginRestServlet(hs).register(http_server)
|
|
|
if hs.config.cas_enabled:
|