login.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014, 2015 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet import defer
  16. from synapse.api.errors import SynapseError, LoginError, Codes
  17. from synapse.http.client import SimpleHttpClient
  18. from synapse.types import UserID
  19. from base import ClientV1RestServlet, client_path_pattern
  20. import simplejson as json
  21. import urllib
  22. import logging
  23. from saml2 import BINDING_HTTP_POST
  24. from saml2 import config
  25. from saml2.client import Saml2Client
  26. import xml.etree.ElementTree as ET
  27. logger = logging.getLogger(__name__)
  28. class LoginRestServlet(ClientV1RestServlet):
  29. PATTERN = client_path_pattern("/login$")
  30. PASS_TYPE = "m.login.password"
  31. SAML2_TYPE = "m.login.saml2"
  32. CAS_TYPE = "m.login.cas"
  33. def __init__(self, hs):
  34. super(LoginRestServlet, self).__init__(hs)
  35. self.idp_redirect_url = hs.config.saml2_idp_redirect_url
  36. self.password_enabled = hs.config.password_enabled
  37. self.saml2_enabled = hs.config.saml2_enabled
  38. self.cas_enabled = hs.config.cas_enabled
  39. self.cas_server_url = hs.config.cas_server_url
  40. self.cas_required_attributes = hs.config.cas_required_attributes
  41. self.servername = hs.config.server_name
  42. def on_GET(self, request):
  43. flows = []
  44. if self.saml2_enabled:
  45. flows.append({"type": LoginRestServlet.SAML2_TYPE})
  46. if self.cas_enabled:
  47. flows.append({"type": LoginRestServlet.CAS_TYPE})
  48. if self.password_enabled:
  49. flows.append({"type": LoginRestServlet.PASS_TYPE})
  50. return (200, {"flows": flows})
  51. def on_OPTIONS(self, request):
  52. return (200, {})
  53. @defer.inlineCallbacks
  54. def on_POST(self, request):
  55. login_submission = _parse_json(request)
  56. try:
  57. if login_submission["type"] == LoginRestServlet.PASS_TYPE:
  58. if not self.password_enabled:
  59. raise SynapseError(400, "Password login has been disabled.")
  60. result = yield self.do_password_login(login_submission)
  61. defer.returnValue(result)
  62. elif self.saml2_enabled and (login_submission["type"] ==
  63. LoginRestServlet.SAML2_TYPE):
  64. relay_state = ""
  65. if "relay_state" in login_submission:
  66. relay_state = "&RelayState="+urllib.quote(
  67. login_submission["relay_state"])
  68. result = {
  69. "uri": "%s%s" % (self.idp_redirect_url, relay_state)
  70. }
  71. defer.returnValue((200, result))
  72. elif self.cas_enabled and (login_submission["type"] ==
  73. LoginRestServlet.CAS_TYPE):
  74. # TODO: get this from the homeserver rather than creating a new one for
  75. # each request
  76. http_client = SimpleHttpClient(self.hs)
  77. uri = "%s/proxyValidate" % (self.cas_server_url,)
  78. args = {
  79. "ticket": login_submission["ticket"],
  80. "service": login_submission["service"]
  81. }
  82. body = yield http_client.get_raw(uri, args)
  83. result = yield self.do_cas_login(body)
  84. defer.returnValue(result)
  85. else:
  86. raise SynapseError(400, "Bad login type.")
  87. except KeyError:
  88. raise SynapseError(400, "Missing JSON keys.")
  89. @defer.inlineCallbacks
  90. def do_password_login(self, login_submission):
  91. if 'medium' in login_submission and 'address' in login_submission:
  92. user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
  93. login_submission['medium'], login_submission['address']
  94. )
  95. else:
  96. user_id = login_submission['user']
  97. if not user_id.startswith('@'):
  98. user_id = UserID.create(
  99. user_id, self.hs.hostname
  100. ).to_string()
  101. auth_handler = self.handlers.auth_handler
  102. user_id, access_token, refresh_token = yield auth_handler.login_with_password(
  103. user_id=user_id,
  104. password=login_submission["password"])
  105. result = {
  106. "user_id": user_id, # may have changed
  107. "access_token": access_token,
  108. "refresh_token": refresh_token,
  109. "home_server": self.hs.hostname,
  110. }
  111. defer.returnValue((200, result))
  112. @defer.inlineCallbacks
  113. def do_cas_login(self, cas_response_body):
  114. user, attributes = self.parse_cas_response(cas_response_body)
  115. for required_attribute, required_value in self.cas_required_attributes.items():
  116. # If required attribute was not in CAS Response - Forbidden
  117. if required_attribute not in attributes:
  118. raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
  119. # Also need to check value
  120. if required_value is not None:
  121. actual_value = attributes[required_attribute]
  122. # If required attribute value does not match expected - Forbidden
  123. if required_value != actual_value:
  124. raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
  125. user_id = UserID.create(user, self.hs.hostname).to_string()
  126. auth_handler = self.handlers.auth_handler
  127. user_exists = yield auth_handler.does_user_exist(user_id)
  128. if user_exists:
  129. user_id, access_token, refresh_token = (
  130. yield auth_handler.login_with_cas_user_id(user_id)
  131. )
  132. result = {
  133. "user_id": user_id, # may have changed
  134. "access_token": access_token,
  135. "refresh_token": refresh_token,
  136. "home_server": self.hs.hostname,
  137. }
  138. else:
  139. user_id, access_token = (
  140. yield self.handlers.registration_handler.register(localpart=user)
  141. )
  142. result = {
  143. "user_id": user_id, # may have changed
  144. "access_token": access_token,
  145. "home_server": self.hs.hostname,
  146. }
  147. defer.returnValue((200, result))
  148. def parse_cas_response(self, cas_response_body):
  149. root = ET.fromstring(cas_response_body)
  150. if not root.tag.endswith("serviceResponse"):
  151. raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
  152. if not root[0].tag.endswith("authenticationSuccess"):
  153. raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
  154. for child in root[0]:
  155. if child.tag.endswith("user"):
  156. user = child.text
  157. if child.tag.endswith("attributes"):
  158. attributes = {}
  159. for attribute in child:
  160. # ElementTree library expands the namespace in attribute tags
  161. # to the full URL of the namespace.
  162. # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
  163. # We don't care about namespace here and it will always be encased in
  164. # curly braces, so we remove them.
  165. if "}" in attribute.tag:
  166. attributes[attribute.tag.split("}")[1]] = attribute.text
  167. else:
  168. attributes[attribute.tag] = attribute.text
  169. if user is None or attributes is None:
  170. raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
  171. return (user, attributes)
  172. class LoginFallbackRestServlet(ClientV1RestServlet):
  173. PATTERN = client_path_pattern("/login/fallback$")
  174. def on_GET(self, request):
  175. # TODO(kegan): This should be returning some HTML which is capable of
  176. # hitting LoginRestServlet
  177. return (200, {})
  178. class PasswordResetRestServlet(ClientV1RestServlet):
  179. PATTERN = client_path_pattern("/login/reset")
  180. @defer.inlineCallbacks
  181. def on_POST(self, request):
  182. reset_info = _parse_json(request)
  183. try:
  184. email = reset_info["email"]
  185. user_id = reset_info["user_id"]
  186. handler = self.handlers.login_handler
  187. yield handler.reset_password(user_id, email)
  188. # purposefully give no feedback to avoid people hammering different
  189. # combinations.
  190. defer.returnValue((200, {}))
  191. except KeyError:
  192. raise SynapseError(
  193. 400,
  194. "Missing keys. Requires 'email' and 'user_id'."
  195. )
  196. class SAML2RestServlet(ClientV1RestServlet):
  197. PATTERN = client_path_pattern("/login/saml2")
  198. def __init__(self, hs):
  199. super(SAML2RestServlet, self).__init__(hs)
  200. self.sp_config = hs.config.saml2_config_path
  201. @defer.inlineCallbacks
  202. def on_POST(self, request):
  203. saml2_auth = None
  204. try:
  205. conf = config.SPConfig()
  206. conf.load_file(self.sp_config)
  207. SP = Saml2Client(conf)
  208. saml2_auth = SP.parse_authn_request_response(
  209. request.args['SAMLResponse'][0], BINDING_HTTP_POST)
  210. except Exception, e: # Not authenticated
  211. logger.exception(e)
  212. if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
  213. username = saml2_auth.name_id.text
  214. handler = self.handlers.registration_handler
  215. (user_id, token) = yield handler.register_saml2(username)
  216. # Forward to the RelayState callback along with ava
  217. if 'RelayState' in request.args:
  218. request.redirect(urllib.unquote(
  219. request.args['RelayState'][0]) +
  220. '?status=authenticated&access_token=' +
  221. token + '&user_id=' + user_id + '&ava=' +
  222. urllib.quote(json.dumps(saml2_auth.ava)))
  223. request.finish()
  224. defer.returnValue(None)
  225. defer.returnValue((200, {"status": "authenticated",
  226. "user_id": user_id, "token": token,
  227. "ava": saml2_auth.ava}))
  228. elif 'RelayState' in request.args:
  229. request.redirect(urllib.unquote(
  230. request.args['RelayState'][0]) +
  231. '?status=not_authenticated')
  232. request.finish()
  233. defer.returnValue(None)
  234. defer.returnValue((200, {"status": "not_authenticated"}))
  235. class CasRestServlet(ClientV1RestServlet):
  236. PATTERN = client_path_pattern("/login/cas")
  237. def __init__(self, hs):
  238. super(CasRestServlet, self).__init__(hs)
  239. self.cas_server_url = hs.config.cas_server_url
  240. def on_GET(self, request):
  241. return (200, {"serverUrl": self.cas_server_url})
  242. def _parse_json(request):
  243. try:
  244. content = json.loads(request.content.read())
  245. if type(content) != dict:
  246. raise SynapseError(400, "Content must be a JSON object.")
  247. return content
  248. except ValueError:
  249. raise SynapseError(400, "Content not JSON.")
  250. def register_servlets(hs, http_server):
  251. LoginRestServlet(hs).register(http_server)
  252. if hs.config.saml2_enabled:
  253. SAML2RestServlet(hs).register(http_server)
  254. if hs.config.cas_enabled:
  255. CasRestServlet(hs).register(http_server)
  256. # TODO PasswordResetRestServlet(hs).register(http_server)