login.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet import defer
  16. from synapse.api.errors import SynapseError, LoginError, Codes
  17. from synapse.types import UserID
  18. from synapse.http.server import finish_request
  19. from synapse.http.servlet import parse_json_object_from_request
  20. from .base import ClientV1RestServlet, client_path_patterns
  21. import simplejson as json
  22. import urllib
  23. import urlparse
  24. import logging
  25. from saml2 import BINDING_HTTP_POST
  26. from saml2 import config
  27. from saml2.client import Saml2Client
  28. import xml.etree.ElementTree as ET
  29. logger = logging.getLogger(__name__)
  30. class LoginRestServlet(ClientV1RestServlet):
  31. PATTERNS = client_path_patterns("/login$")
  32. PASS_TYPE = "m.login.password"
  33. SAML2_TYPE = "m.login.saml2"
  34. CAS_TYPE = "m.login.cas"
  35. TOKEN_TYPE = "m.login.token"
  36. JWT_TYPE = "m.login.jwt"
  37. def __init__(self, hs):
  38. super(LoginRestServlet, self).__init__(hs)
  39. self.idp_redirect_url = hs.config.saml2_idp_redirect_url
  40. self.password_enabled = hs.config.password_enabled
  41. self.saml2_enabled = hs.config.saml2_enabled
  42. self.jwt_enabled = hs.config.jwt_enabled
  43. self.jwt_secret = hs.config.jwt_secret
  44. self.jwt_algorithm = hs.config.jwt_algorithm
  45. self.cas_enabled = hs.config.cas_enabled
  46. self.cas_server_url = hs.config.cas_server_url
  47. self.cas_required_attributes = hs.config.cas_required_attributes
  48. self.servername = hs.config.server_name
  49. self.http_client = hs.get_simple_http_client()
  50. self.auth_handler = self.hs.get_auth_handler()
  51. def on_GET(self, request):
  52. flows = []
  53. if self.jwt_enabled:
  54. flows.append({"type": LoginRestServlet.JWT_TYPE})
  55. if self.saml2_enabled:
  56. flows.append({"type": LoginRestServlet.SAML2_TYPE})
  57. if self.cas_enabled:
  58. flows.append({"type": LoginRestServlet.CAS_TYPE})
  59. # While its valid for us to advertise this login type generally,
  60. # synapse currently only gives out these tokens as part of the
  61. # CAS login flow.
  62. # Generally we don't want to advertise login flows that clients
  63. # don't know how to implement, since they (currently) will always
  64. # fall back to the fallback API if they don't understand one of the
  65. # login flow types returned.
  66. flows.append({"type": LoginRestServlet.TOKEN_TYPE})
  67. if self.password_enabled:
  68. flows.append({"type": LoginRestServlet.PASS_TYPE})
  69. return (200, {"flows": flows})
  70. def on_OPTIONS(self, request):
  71. return (200, {})
  72. @defer.inlineCallbacks
  73. def on_POST(self, request):
  74. login_submission = parse_json_object_from_request(request)
  75. try:
  76. if login_submission["type"] == LoginRestServlet.PASS_TYPE:
  77. if not self.password_enabled:
  78. raise SynapseError(400, "Password login has been disabled.")
  79. result = yield self.do_password_login(login_submission)
  80. defer.returnValue(result)
  81. elif self.saml2_enabled and (login_submission["type"] ==
  82. LoginRestServlet.SAML2_TYPE):
  83. relay_state = ""
  84. if "relay_state" in login_submission:
  85. relay_state = "&RelayState=" + urllib.quote(
  86. login_submission["relay_state"])
  87. result = {
  88. "uri": "%s%s" % (self.idp_redirect_url, relay_state)
  89. }
  90. defer.returnValue((200, result))
  91. elif self.jwt_enabled and (login_submission["type"] ==
  92. LoginRestServlet.JWT_TYPE):
  93. result = yield self.do_jwt_login(login_submission)
  94. defer.returnValue(result)
  95. # TODO Delete this after all CAS clients switch to token login instead
  96. elif self.cas_enabled and (login_submission["type"] ==
  97. LoginRestServlet.CAS_TYPE):
  98. uri = "%s/proxyValidate" % (self.cas_server_url,)
  99. args = {
  100. "ticket": login_submission["ticket"],
  101. "service": login_submission["service"]
  102. }
  103. body = yield self.http_client.get_raw(uri, args)
  104. result = yield self.do_cas_login(body)
  105. defer.returnValue(result)
  106. elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
  107. result = yield self.do_token_login(login_submission)
  108. defer.returnValue(result)
  109. else:
  110. raise SynapseError(400, "Bad login type.")
  111. except KeyError:
  112. raise SynapseError(400, "Missing JSON keys.")
  113. @defer.inlineCallbacks
  114. def do_password_login(self, login_submission):
  115. if 'medium' in login_submission and 'address' in login_submission:
  116. user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
  117. login_submission['medium'], login_submission['address']
  118. )
  119. if not user_id:
  120. raise LoginError(403, "", errcode=Codes.FORBIDDEN)
  121. else:
  122. user_id = login_submission['user']
  123. if not user_id.startswith('@'):
  124. user_id = UserID.create(
  125. user_id, self.hs.hostname
  126. ).to_string()
  127. auth_handler = self.auth_handler
  128. user_id, access_token, refresh_token = yield auth_handler.login_with_password(
  129. user_id=user_id,
  130. password=login_submission["password"])
  131. result = {
  132. "user_id": user_id, # may have changed
  133. "access_token": access_token,
  134. "refresh_token": refresh_token,
  135. "home_server": self.hs.hostname,
  136. }
  137. defer.returnValue((200, result))
  138. @defer.inlineCallbacks
  139. def do_token_login(self, login_submission):
  140. token = login_submission['token']
  141. auth_handler = self.auth_handler
  142. user_id = (
  143. yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
  144. )
  145. user_id, access_token, refresh_token = (
  146. yield auth_handler.get_login_tuple_for_user_id(user_id)
  147. )
  148. result = {
  149. "user_id": user_id, # may have changed
  150. "access_token": access_token,
  151. "refresh_token": refresh_token,
  152. "home_server": self.hs.hostname,
  153. }
  154. defer.returnValue((200, result))
  155. # TODO Delete this after all CAS clients switch to token login instead
  156. @defer.inlineCallbacks
  157. def do_cas_login(self, cas_response_body):
  158. user, attributes = self.parse_cas_response(cas_response_body)
  159. for required_attribute, required_value in self.cas_required_attributes.items():
  160. # If required attribute was not in CAS Response - Forbidden
  161. if required_attribute not in attributes:
  162. raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
  163. # Also need to check value
  164. if required_value is not None:
  165. actual_value = attributes[required_attribute]
  166. # If required attribute value does not match expected - Forbidden
  167. if required_value != actual_value:
  168. raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
  169. user_id = UserID.create(user, self.hs.hostname).to_string()
  170. auth_handler = self.auth_handler
  171. user_exists = yield auth_handler.does_user_exist(user_id)
  172. if user_exists:
  173. user_id, access_token, refresh_token = (
  174. yield auth_handler.get_login_tuple_for_user_id(user_id)
  175. )
  176. result = {
  177. "user_id": user_id, # may have changed
  178. "access_token": access_token,
  179. "refresh_token": refresh_token,
  180. "home_server": self.hs.hostname,
  181. }
  182. else:
  183. user_id, access_token = (
  184. yield self.handlers.registration_handler.register(localpart=user)
  185. )
  186. result = {
  187. "user_id": user_id, # may have changed
  188. "access_token": access_token,
  189. "home_server": self.hs.hostname,
  190. }
  191. defer.returnValue((200, result))
  192. @defer.inlineCallbacks
  193. def do_jwt_login(self, login_submission):
  194. token = login_submission.get("token", None)
  195. if token is None:
  196. raise LoginError(
  197. 401, "Token field for JWT is missing",
  198. errcode=Codes.UNAUTHORIZED
  199. )
  200. import jwt
  201. from jwt.exceptions import InvalidTokenError
  202. try:
  203. payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
  204. except jwt.ExpiredSignatureError:
  205. raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
  206. except InvalidTokenError:
  207. raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
  208. user = payload.get("sub", None)
  209. if user is None:
  210. raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
  211. user_id = UserID.create(user, self.hs.hostname).to_string()
  212. auth_handler = self.auth_handler
  213. user_exists = yield auth_handler.does_user_exist(user_id)
  214. if user_exists:
  215. user_id, access_token, refresh_token = (
  216. yield auth_handler.get_login_tuple_for_user_id(user_id)
  217. )
  218. result = {
  219. "user_id": user_id, # may have changed
  220. "access_token": access_token,
  221. "refresh_token": refresh_token,
  222. "home_server": self.hs.hostname,
  223. }
  224. else:
  225. user_id, access_token = (
  226. yield self.handlers.registration_handler.register(localpart=user)
  227. )
  228. result = {
  229. "user_id": user_id, # may have changed
  230. "access_token": access_token,
  231. "home_server": self.hs.hostname,
  232. }
  233. defer.returnValue((200, result))
  234. # TODO Delete this after all CAS clients switch to token login instead
  235. def parse_cas_response(self, cas_response_body):
  236. root = ET.fromstring(cas_response_body)
  237. if not root.tag.endswith("serviceResponse"):
  238. raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
  239. if not root[0].tag.endswith("authenticationSuccess"):
  240. raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
  241. for child in root[0]:
  242. if child.tag.endswith("user"):
  243. user = child.text
  244. if child.tag.endswith("attributes"):
  245. attributes = {}
  246. for attribute in child:
  247. # ElementTree library expands the namespace in attribute tags
  248. # to the full URL of the namespace.
  249. # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
  250. # We don't care about namespace here and it will always be encased in
  251. # curly braces, so we remove them.
  252. if "}" in attribute.tag:
  253. attributes[attribute.tag.split("}")[1]] = attribute.text
  254. else:
  255. attributes[attribute.tag] = attribute.text
  256. if user is None or attributes is None:
  257. raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
  258. return (user, attributes)
  259. class SAML2RestServlet(ClientV1RestServlet):
  260. PATTERNS = client_path_patterns("/login/saml2", releases=())
  261. def __init__(self, hs):
  262. super(SAML2RestServlet, self).__init__(hs)
  263. self.sp_config = hs.config.saml2_config_path
  264. @defer.inlineCallbacks
  265. def on_POST(self, request):
  266. saml2_auth = None
  267. try:
  268. conf = config.SPConfig()
  269. conf.load_file(self.sp_config)
  270. SP = Saml2Client(conf)
  271. saml2_auth = SP.parse_authn_request_response(
  272. request.args['SAMLResponse'][0], BINDING_HTTP_POST)
  273. except Exception as e: # Not authenticated
  274. logger.exception(e)
  275. if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
  276. username = saml2_auth.name_id.text
  277. handler = self.handlers.registration_handler
  278. (user_id, token) = yield handler.register_saml2(username)
  279. # Forward to the RelayState callback along with ava
  280. if 'RelayState' in request.args:
  281. request.redirect(urllib.unquote(
  282. request.args['RelayState'][0]) +
  283. '?status=authenticated&access_token=' +
  284. token + '&user_id=' + user_id + '&ava=' +
  285. urllib.quote(json.dumps(saml2_auth.ava)))
  286. finish_request(request)
  287. defer.returnValue(None)
  288. defer.returnValue((200, {"status": "authenticated",
  289. "user_id": user_id, "token": token,
  290. "ava": saml2_auth.ava}))
  291. elif 'RelayState' in request.args:
  292. request.redirect(urllib.unquote(
  293. request.args['RelayState'][0]) +
  294. '?status=not_authenticated')
  295. finish_request(request)
  296. defer.returnValue(None)
  297. defer.returnValue((200, {"status": "not_authenticated"}))
  298. # TODO Delete this after all CAS clients switch to token login instead
  299. class CasRestServlet(ClientV1RestServlet):
  300. PATTERNS = client_path_patterns("/login/cas", releases=())
  301. def __init__(self, hs):
  302. super(CasRestServlet, self).__init__(hs)
  303. self.cas_server_url = hs.config.cas_server_url
  304. def on_GET(self, request):
  305. return (200, {"serverUrl": self.cas_server_url})
  306. class CasRedirectServlet(ClientV1RestServlet):
  307. PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
  308. def __init__(self, hs):
  309. super(CasRedirectServlet, self).__init__(hs)
  310. self.cas_server_url = hs.config.cas_server_url
  311. self.cas_service_url = hs.config.cas_service_url
  312. def on_GET(self, request):
  313. args = request.args
  314. if "redirectUrl" not in args:
  315. return (400, "Redirect URL not specified for CAS auth")
  316. client_redirect_url_param = urllib.urlencode({
  317. "redirectUrl": args["redirectUrl"][0]
  318. })
  319. hs_redirect_url = self.cas_service_url + "/_matrix/client/api/v1/login/cas/ticket"
  320. service_param = urllib.urlencode({
  321. "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
  322. })
  323. request.redirect("%s?%s" % (self.cas_server_url, service_param))
  324. finish_request(request)
  325. class CasTicketServlet(ClientV1RestServlet):
  326. PATTERNS = client_path_patterns("/login/cas/ticket", releases=())
  327. def __init__(self, hs):
  328. super(CasTicketServlet, self).__init__(hs)
  329. self.cas_server_url = hs.config.cas_server_url
  330. self.cas_service_url = hs.config.cas_service_url
  331. self.cas_required_attributes = hs.config.cas_required_attributes
  332. @defer.inlineCallbacks
  333. def on_GET(self, request):
  334. client_redirect_url = request.args["redirectUrl"][0]
  335. http_client = self.hs.get_simple_http_client()
  336. uri = self.cas_server_url + "/proxyValidate"
  337. args = {
  338. "ticket": request.args["ticket"],
  339. "service": self.cas_service_url
  340. }
  341. body = yield http_client.get_raw(uri, args)
  342. result = yield self.handle_cas_response(request, body, client_redirect_url)
  343. defer.returnValue(result)
  344. @defer.inlineCallbacks
  345. def handle_cas_response(self, request, cas_response_body, client_redirect_url):
  346. user, attributes = self.parse_cas_response(cas_response_body)
  347. for required_attribute, required_value in self.cas_required_attributes.items():
  348. # If required attribute was not in CAS Response - Forbidden
  349. if required_attribute not in attributes:
  350. raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
  351. # Also need to check value
  352. if required_value is not None:
  353. actual_value = attributes[required_attribute]
  354. # If required attribute value does not match expected - Forbidden
  355. if required_value != actual_value:
  356. raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
  357. user_id = UserID.create(user, self.hs.hostname).to_string()
  358. auth_handler = self.auth_handler
  359. user_exists = yield auth_handler.does_user_exist(user_id)
  360. if not user_exists:
  361. user_id, _ = (
  362. yield self.handlers.registration_handler.register(localpart=user)
  363. )
  364. login_token = auth_handler.generate_short_term_login_token(user_id)
  365. redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
  366. login_token)
  367. request.redirect(redirect_url)
  368. finish_request(request)
  369. def add_login_token_to_redirect_url(self, url, token):
  370. url_parts = list(urlparse.urlparse(url))
  371. query = dict(urlparse.parse_qsl(url_parts[4]))
  372. query.update({"loginToken": token})
  373. url_parts[4] = urllib.urlencode(query)
  374. return urlparse.urlunparse(url_parts)
  375. def parse_cas_response(self, cas_response_body):
  376. root = ET.fromstring(cas_response_body)
  377. if not root.tag.endswith("serviceResponse"):
  378. raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
  379. if not root[0].tag.endswith("authenticationSuccess"):
  380. raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
  381. for child in root[0]:
  382. if child.tag.endswith("user"):
  383. user = child.text
  384. if child.tag.endswith("attributes"):
  385. attributes = {}
  386. for attribute in child:
  387. # ElementTree library expands the namespace in attribute tags
  388. # to the full URL of the namespace.
  389. # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
  390. # We don't care about namespace here and it will always be encased in
  391. # curly braces, so we remove them.
  392. if "}" in attribute.tag:
  393. attributes[attribute.tag.split("}")[1]] = attribute.text
  394. else:
  395. attributes[attribute.tag] = attribute.text
  396. if user is None or attributes is None:
  397. raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
  398. return (user, attributes)
  399. def register_servlets(hs, http_server):
  400. LoginRestServlet(hs).register(http_server)
  401. if hs.config.saml2_enabled:
  402. SAML2RestServlet(hs).register(http_server)
  403. if hs.config.cas_enabled:
  404. CasRedirectServlet(hs).register(http_server)
  405. CasTicketServlet(hs).register(http_server)
  406. CasRestServlet(hs).register(http_server)
  407. # TODO PasswordResetRestServlet(hs).register(http_server)