123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- # -*- coding: utf-8 -*-
- # Copyright 2014-2016 OpenMarket Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from twisted.internet import defer
- from synapse.api.errors import SynapseError, LoginError, Codes
- from synapse.types import UserID
- from synapse.http.server import finish_request
- from synapse.http.servlet import parse_json_object_from_request
- from .base import ClientV1RestServlet, client_path_patterns
- import simplejson as json
- import urllib
- import urlparse
- import logging
- from saml2 import BINDING_HTTP_POST
- from saml2 import config
- from saml2.client import Saml2Client
- import xml.etree.ElementTree as ET
- logger = logging.getLogger(__name__)
- class LoginRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login$")
- PASS_TYPE = "m.login.password"
- SAML2_TYPE = "m.login.saml2"
- CAS_TYPE = "m.login.cas"
- TOKEN_TYPE = "m.login.token"
- JWT_TYPE = "m.login.jwt"
- def __init__(self, hs):
- super(LoginRestServlet, self).__init__(hs)
- self.idp_redirect_url = hs.config.saml2_idp_redirect_url
- self.password_enabled = hs.config.password_enabled
- self.saml2_enabled = hs.config.saml2_enabled
- self.jwt_enabled = hs.config.jwt_enabled
- self.jwt_secret = hs.config.jwt_secret
- self.jwt_algorithm = hs.config.jwt_algorithm
- self.cas_enabled = hs.config.cas_enabled
- self.cas_server_url = hs.config.cas_server_url
- self.cas_required_attributes = hs.config.cas_required_attributes
- self.servername = hs.config.server_name
- self.http_client = hs.get_simple_http_client()
- self.auth_handler = self.hs.get_auth_handler()
- def on_GET(self, request):
- flows = []
- if self.jwt_enabled:
- flows.append({"type": LoginRestServlet.JWT_TYPE})
- if self.saml2_enabled:
- flows.append({"type": LoginRestServlet.SAML2_TYPE})
- if self.cas_enabled:
- flows.append({"type": LoginRestServlet.CAS_TYPE})
- # While its valid for us to advertise this login type generally,
- # synapse currently only gives out these tokens as part of the
- # CAS login flow.
- # Generally we don't want to advertise login flows that clients
- # don't know how to implement, since they (currently) will always
- # fall back to the fallback API if they don't understand one of the
- # login flow types returned.
- flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- if self.password_enabled:
- flows.append({"type": LoginRestServlet.PASS_TYPE})
- return (200, {"flows": flows})
- def on_OPTIONS(self, request):
- return (200, {})
- @defer.inlineCallbacks
- def on_POST(self, request):
- login_submission = parse_json_object_from_request(request)
- try:
- if login_submission["type"] == LoginRestServlet.PASS_TYPE:
- if not self.password_enabled:
- raise SynapseError(400, "Password login has been disabled.")
- result = yield self.do_password_login(login_submission)
- defer.returnValue(result)
- elif self.saml2_enabled and (login_submission["type"] ==
- LoginRestServlet.SAML2_TYPE):
- relay_state = ""
- if "relay_state" in login_submission:
- relay_state = "&RelayState=" + urllib.quote(
- login_submission["relay_state"])
- result = {
- "uri": "%s%s" % (self.idp_redirect_url, relay_state)
- }
- defer.returnValue((200, result))
- elif self.jwt_enabled and (login_submission["type"] ==
- LoginRestServlet.JWT_TYPE):
- result = yield self.do_jwt_login(login_submission)
- defer.returnValue(result)
- # TODO Delete this after all CAS clients switch to token login instead
- elif self.cas_enabled and (login_submission["type"] ==
- LoginRestServlet.CAS_TYPE):
- uri = "%s/proxyValidate" % (self.cas_server_url,)
- args = {
- "ticket": login_submission["ticket"],
- "service": login_submission["service"]
- }
- body = yield self.http_client.get_raw(uri, args)
- result = yield self.do_cas_login(body)
- defer.returnValue(result)
- elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- result = yield self.do_token_login(login_submission)
- defer.returnValue(result)
- else:
- raise SynapseError(400, "Bad login type.")
- except KeyError:
- raise SynapseError(400, "Missing JSON keys.")
- @defer.inlineCallbacks
- def do_password_login(self, login_submission):
- if 'medium' in login_submission and 'address' in login_submission:
- user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- login_submission['medium'], login_submission['address']
- )
- if not user_id:
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
- else:
- user_id = login_submission['user']
- if not user_id.startswith('@'):
- user_id = UserID.create(
- user_id, self.hs.hostname
- ).to_string()
- auth_handler = self.auth_handler
- user_id, access_token, refresh_token = yield auth_handler.login_with_password(
- user_id=user_id,
- password=login_submission["password"])
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "refresh_token": refresh_token,
- "home_server": self.hs.hostname,
- }
- defer.returnValue((200, result))
- @defer.inlineCallbacks
- def do_token_login(self, login_submission):
- token = login_submission['token']
- auth_handler = self.auth_handler
- user_id = (
- yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
- )
- user_id, access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(user_id)
- )
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "refresh_token": refresh_token,
- "home_server": self.hs.hostname,
- }
- defer.returnValue((200, result))
- # TODO Delete this after all CAS clients switch to token login instead
- @defer.inlineCallbacks
- def do_cas_login(self, cas_response_body):
- user, attributes = self.parse_cas_response(cas_response_body)
- for required_attribute, required_value in self.cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- user_id = UserID.create(user, self.hs.hostname).to_string()
- auth_handler = self.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if user_exists:
- user_id, access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(user_id)
- )
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "refresh_token": refresh_token,
- "home_server": self.hs.hostname,
- }
- else:
- user_id, access_token = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }
- defer.returnValue((200, result))
- @defer.inlineCallbacks
- def do_jwt_login(self, login_submission):
- token = login_submission.get("token", None)
- if token is None:
- raise LoginError(
- 401, "Token field for JWT is missing",
- errcode=Codes.UNAUTHORIZED
- )
- import jwt
- from jwt.exceptions import InvalidTokenError
- try:
- payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
- except jwt.ExpiredSignatureError:
- raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
- except InvalidTokenError:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
- user = payload.get("sub", None)
- if user is None:
- raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
- user_id = UserID.create(user, self.hs.hostname).to_string()
- auth_handler = self.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if user_exists:
- user_id, access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(user_id)
- )
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "refresh_token": refresh_token,
- "home_server": self.hs.hostname,
- }
- else:
- user_id, access_token = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }
- defer.returnValue((200, result))
- # TODO Delete this after all CAS clients switch to token login instead
- def parse_cas_response(self, cas_response_body):
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not root[0].tag.endswith("authenticationSuccess"):
- raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- attributes = {}
- for attribute in child:
- # ElementTree library expands the namespace in attribute tags
- # to the full URL of the namespace.
- # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
- # We don't care about namespace here and it will always be encased in
- # curly braces, so we remove them.
- if "}" in attribute.tag:
- attributes[attribute.tag.split("}")[1]] = attribute.text
- else:
- attributes[attribute.tag] = attribute.text
- if user is None or attributes is None:
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- return (user, attributes)
- class SAML2RestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/saml2", releases=())
- def __init__(self, hs):
- super(SAML2RestServlet, self).__init__(hs)
- self.sp_config = hs.config.saml2_config_path
- @defer.inlineCallbacks
- def on_POST(self, request):
- saml2_auth = None
- try:
- conf = config.SPConfig()
- conf.load_file(self.sp_config)
- SP = Saml2Client(conf)
- saml2_auth = SP.parse_authn_request_response(
- request.args['SAMLResponse'][0], BINDING_HTTP_POST)
- except Exception as e: # Not authenticated
- logger.exception(e)
- if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
- username = saml2_auth.name_id.text
- handler = self.handlers.registration_handler
- (user_id, token) = yield handler.register_saml2(username)
- # Forward to the RelayState callback along with ava
- if 'RelayState' in request.args:
- request.redirect(urllib.unquote(
- request.args['RelayState'][0]) +
- '?status=authenticated&access_token=' +
- token + '&user_id=' + user_id + '&ava=' +
- urllib.quote(json.dumps(saml2_auth.ava)))
- finish_request(request)
- defer.returnValue(None)
- defer.returnValue((200, {"status": "authenticated",
- "user_id": user_id, "token": token,
- "ava": saml2_auth.ava}))
- elif 'RelayState' in request.args:
- request.redirect(urllib.unquote(
- request.args['RelayState'][0]) +
- '?status=not_authenticated')
- finish_request(request)
- defer.returnValue(None)
- defer.returnValue((200, {"status": "not_authenticated"}))
- # TODO Delete this after all CAS clients switch to token login instead
- class CasRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/cas", releases=())
- def __init__(self, hs):
- super(CasRestServlet, self).__init__(hs)
- self.cas_server_url = hs.config.cas_server_url
- def on_GET(self, request):
- return (200, {"serverUrl": self.cas_server_url})
- class CasRedirectServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
- def __init__(self, hs):
- super(CasRedirectServlet, self).__init__(hs)
- self.cas_server_url = hs.config.cas_server_url
- self.cas_service_url = hs.config.cas_service_url
- def on_GET(self, request):
- args = request.args
- if "redirectUrl" not in args:
- return (400, "Redirect URL not specified for CAS auth")
- client_redirect_url_param = urllib.urlencode({
- "redirectUrl": args["redirectUrl"][0]
- })
- hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
- service_param = urllib.urlencode({
- "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
- })
- request.redirect("%s?%s" % (self.cas_server_url, service_param))
- finish_request(request)
- class CasTicketServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/cas/ticket", releases=())
- def __init__(self, hs):
- super(CasTicketServlet, self).__init__(hs)
- self.cas_server_url = hs.config.cas_server_url
- self.cas_service_url = hs.config.cas_service_url
- self.cas_required_attributes = hs.config.cas_required_attributes
- @defer.inlineCallbacks
- def on_GET(self, request):
- client_redirect_url = request.args["redirectUrl"][0]
- http_client = self.hs.get_simple_http_client()
- uri = self.cas_server_url + "/proxyValidate"
- args = {
- "ticket": request.args["ticket"],
- "service": self.cas_service_url
- }
- body = yield http_client.get_raw(uri, args)
- result = yield self.handle_cas_response(request, body, client_redirect_url)
- defer.returnValue(result)
- @defer.inlineCallbacks
- def handle_cas_response(self, request, cas_response_body, client_redirect_url):
- user, attributes = self.parse_cas_response(cas_response_body)
- for required_attribute, required_value in self.cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- user_id = UserID.create(user, self.hs.hostname).to_string()
- auth_handler = self.auth_handler
- user_exists = yield auth_handler.does_user_exist(user_id)
- if not user_exists:
- user_id, _ = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
- login_token = auth_handler.generate_short_term_login_token(user_id)
- redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
- login_token)
- request.redirect(redirect_url)
- finish_request(request)
- def add_login_token_to_redirect_url(self, url, token):
- url_parts = list(urlparse.urlparse(url))
- query = dict(urlparse.parse_qsl(url_parts[4]))
- query.update({"loginToken": token})
- url_parts[4] = urllib.urlencode(query)
- return urlparse.urlunparse(url_parts)
- def parse_cas_response(self, cas_response_body):
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not root[0].tag.endswith("authenticationSuccess"):
- raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- attributes = {}
- for attribute in child:
- # ElementTree library expands the namespace in attribute tags
- # to the full URL of the namespace.
- # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
- # We don't care about namespace here and it will always be encased in
- # curly braces, so we remove them.
- if "}" in attribute.tag:
- attributes[attribute.tag.split("}")[1]] = attribute.text
- else:
- attributes[attribute.tag] = attribute.text
- if user is None or attributes is None:
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- return (user, attributes)
- def register_servlets(hs, http_server):
- LoginRestServlet(hs).register(http_server)
- if hs.config.saml2_enabled:
- SAML2RestServlet(hs).register(http_server)
- if hs.config.cas_enabled:
- CasRedirectServlet(hs).register(http_server)
- CasTicketServlet(hs).register(http_server)
- CasRestServlet(hs).register(http_server)
- # TODO PasswordResetRestServlet(hs).register(http_server)
|