login.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
  17. from synapse.api.errors import Codes, LoginError, SynapseError
  18. from synapse.api.ratelimiting import Ratelimiter
  19. from synapse.appservice import ApplicationService
  20. from synapse.handlers.sso import SsoIdentityProvider
  21. from synapse.http import get_request_uri
  22. from synapse.http.server import HttpServer, finish_request
  23. from synapse.http.servlet import (
  24. RestServlet,
  25. parse_json_object_from_request,
  26. parse_string,
  27. )
  28. from synapse.http.site import SynapseRequest
  29. from synapse.rest.client.v2_alpha._base import client_patterns
  30. from synapse.rest.well_known import WellKnownBuilder
  31. from synapse.types import JsonDict, UserID
  32. if TYPE_CHECKING:
  33. from synapse.server import HomeServer
  34. logger = logging.getLogger(__name__)
  35. class LoginRestServlet(RestServlet):
  36. PATTERNS = client_patterns("/login$", v1=True)
  37. CAS_TYPE = "m.login.cas"
  38. SSO_TYPE = "m.login.sso"
  39. TOKEN_TYPE = "m.login.token"
  40. JWT_TYPE = "org.matrix.login.jwt"
  41. JWT_TYPE_DEPRECATED = "m.login.jwt"
  42. APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
  43. def __init__(self, hs: "HomeServer"):
  44. super().__init__()
  45. self.hs = hs
  46. # JWT configuration variables.
  47. self.jwt_enabled = hs.config.jwt_enabled
  48. self.jwt_secret = hs.config.jwt_secret
  49. self.jwt_algorithm = hs.config.jwt_algorithm
  50. self.jwt_issuer = hs.config.jwt_issuer
  51. self.jwt_audiences = hs.config.jwt_audiences
  52. # SSO configuration.
  53. self.saml2_enabled = hs.config.saml2_enabled
  54. self.cas_enabled = hs.config.cas_enabled
  55. self.oidc_enabled = hs.config.oidc_enabled
  56. self._msc2858_enabled = hs.config.experimental.msc2858_enabled
  57. self.auth = hs.get_auth()
  58. self.auth_handler = self.hs.get_auth_handler()
  59. self.registration_handler = hs.get_registration_handler()
  60. self._sso_handler = hs.get_sso_handler()
  61. self._well_known_builder = WellKnownBuilder(hs)
  62. self._address_ratelimiter = Ratelimiter(
  63. clock=hs.get_clock(),
  64. rate_hz=self.hs.config.rc_login_address.per_second,
  65. burst_count=self.hs.config.rc_login_address.burst_count,
  66. )
  67. self._account_ratelimiter = Ratelimiter(
  68. clock=hs.get_clock(),
  69. rate_hz=self.hs.config.rc_login_account.per_second,
  70. burst_count=self.hs.config.rc_login_account.burst_count,
  71. )
  72. def on_GET(self, request: SynapseRequest):
  73. flows = []
  74. if self.jwt_enabled:
  75. flows.append({"type": LoginRestServlet.JWT_TYPE})
  76. flows.append({"type": LoginRestServlet.JWT_TYPE_DEPRECATED})
  77. if self.cas_enabled:
  78. # we advertise CAS for backwards compat, though MSC1721 renamed it
  79. # to SSO.
  80. flows.append({"type": LoginRestServlet.CAS_TYPE})
  81. if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
  82. sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
  83. if self._msc2858_enabled:
  84. sso_flow["org.matrix.msc2858.identity_providers"] = [
  85. _get_auth_flow_dict_for_idp(idp)
  86. for idp in self._sso_handler.get_identity_providers().values()
  87. ]
  88. flows.append(sso_flow)
  89. # While it's valid for us to advertise this login type generally,
  90. # synapse currently only gives out these tokens as part of the
  91. # SSO login flow.
  92. # Generally we don't want to advertise login flows that clients
  93. # don't know how to implement, since they (currently) will always
  94. # fall back to the fallback API if they don't understand one of the
  95. # login flow types returned.
  96. flows.append({"type": LoginRestServlet.TOKEN_TYPE})
  97. flows.extend(
  98. ({"type": t} for t in self.auth_handler.get_supported_login_types())
  99. )
  100. flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
  101. return 200, {"flows": flows}
  102. async def on_POST(self, request: SynapseRequest):
  103. login_submission = parse_json_object_from_request(request)
  104. try:
  105. if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
  106. appservice = self.auth.get_appservice_by_req(request)
  107. if appservice.is_rate_limited():
  108. self._address_ratelimiter.ratelimit(request.getClientIP())
  109. result = await self._do_appservice_login(login_submission, appservice)
  110. elif self.jwt_enabled and (
  111. login_submission["type"] == LoginRestServlet.JWT_TYPE
  112. or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
  113. ):
  114. self._address_ratelimiter.ratelimit(request.getClientIP())
  115. result = await self._do_jwt_login(login_submission)
  116. elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
  117. self._address_ratelimiter.ratelimit(request.getClientIP())
  118. result = await self._do_token_login(login_submission)
  119. else:
  120. self._address_ratelimiter.ratelimit(request.getClientIP())
  121. result = await self._do_other_login(login_submission)
  122. except KeyError:
  123. raise SynapseError(400, "Missing JSON keys.")
  124. well_known_data = self._well_known_builder.get_well_known()
  125. if well_known_data:
  126. result["well_known"] = well_known_data
  127. return 200, result
  128. async def _do_appservice_login(
  129. self, login_submission: JsonDict, appservice: ApplicationService
  130. ):
  131. identifier = login_submission.get("identifier")
  132. logger.info("Got appservice login request with identifier: %r", identifier)
  133. if not isinstance(identifier, dict):
  134. raise SynapseError(
  135. 400, "Invalid identifier in login submission", Codes.INVALID_PARAM
  136. )
  137. # this login flow only supports identifiers of type "m.id.user".
  138. if identifier.get("type") != "m.id.user":
  139. raise SynapseError(
  140. 400, "Unknown login identifier type", Codes.INVALID_PARAM
  141. )
  142. user = identifier.get("user")
  143. if not isinstance(user, str):
  144. raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
  145. if user.startswith("@"):
  146. qualified_user_id = user
  147. else:
  148. qualified_user_id = UserID(user, self.hs.hostname).to_string()
  149. if not appservice.is_interested_in_user(qualified_user_id):
  150. raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
  151. return await self._complete_login(
  152. qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
  153. )
  154. async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
  155. """Handle non-token/saml/jwt logins
  156. Args:
  157. login_submission:
  158. Returns:
  159. HTTP response
  160. """
  161. # Log the request we got, but only certain fields to minimise the chance of
  162. # logging someone's password (even if they accidentally put it in the wrong
  163. # field)
  164. logger.info(
  165. "Got login request with identifier: %r, medium: %r, address: %r, user: %r",
  166. login_submission.get("identifier"),
  167. login_submission.get("medium"),
  168. login_submission.get("address"),
  169. login_submission.get("user"),
  170. )
  171. canonical_user_id, callback = await self.auth_handler.validate_login(
  172. login_submission, ratelimit=True
  173. )
  174. result = await self._complete_login(
  175. canonical_user_id, login_submission, callback
  176. )
  177. return result
  178. async def _complete_login(
  179. self,
  180. user_id: str,
  181. login_submission: JsonDict,
  182. callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
  183. create_non_existent_users: bool = False,
  184. ratelimit: bool = True,
  185. auth_provider_id: Optional[str] = None,
  186. ) -> Dict[str, str]:
  187. """Called when we've successfully authed the user and now need to
  188. actually login them in (e.g. create devices). This gets called on
  189. all successful logins.
  190. Applies the ratelimiting for successful login attempts against an
  191. account.
  192. Args:
  193. user_id: ID of the user to register.
  194. login_submission: Dictionary of login information.
  195. callback: Callback function to run after login.
  196. create_non_existent_users: Whether to create the user if they don't
  197. exist. Defaults to False.
  198. ratelimit: Whether to ratelimit the login request.
  199. auth_provider_id: The SSO IdP the user used, if any (just used for the
  200. prometheus metrics).
  201. Returns:
  202. result: Dictionary of account information after successful login.
  203. """
  204. # Before we actually log them in we check if they've already logged in
  205. # too often. This happens here rather than before as we don't
  206. # necessarily know the user before now.
  207. if ratelimit:
  208. self._account_ratelimiter.ratelimit(user_id.lower())
  209. if create_non_existent_users:
  210. canonical_uid = await self.auth_handler.check_user_exists(user_id)
  211. if not canonical_uid:
  212. canonical_uid = await self.registration_handler.register_user(
  213. localpart=UserID.from_string(user_id).localpart
  214. )
  215. user_id = canonical_uid
  216. device_id = login_submission.get("device_id")
  217. initial_display_name = login_submission.get("initial_device_display_name")
  218. device_id, access_token = await self.registration_handler.register_device(
  219. user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
  220. )
  221. result = {
  222. "user_id": user_id,
  223. "access_token": access_token,
  224. "home_server": self.hs.hostname,
  225. "device_id": device_id,
  226. }
  227. if callback is not None:
  228. await callback(result)
  229. return result
  230. async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
  231. """
  232. Handle the final stage of SSO login.
  233. Args:
  234. login_submission: The JSON request body.
  235. Returns:
  236. The body of the JSON response.
  237. """
  238. token = login_submission["token"]
  239. auth_handler = self.auth_handler
  240. res = await auth_handler.validate_short_term_login_token(token)
  241. return await self._complete_login(
  242. res.user_id,
  243. login_submission,
  244. self.auth_handler._sso_login_callback,
  245. auth_provider_id=res.auth_provider_id,
  246. )
  247. async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
  248. token = login_submission.get("token", None)
  249. if token is None:
  250. raise LoginError(
  251. 403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
  252. )
  253. import jwt
  254. try:
  255. payload = jwt.decode(
  256. token,
  257. self.jwt_secret,
  258. algorithms=[self.jwt_algorithm],
  259. issuer=self.jwt_issuer,
  260. audience=self.jwt_audiences,
  261. )
  262. except jwt.PyJWTError as e:
  263. # A JWT error occurred, return some info back to the client.
  264. raise LoginError(
  265. 403,
  266. "JWT validation failed: %s" % (str(e),),
  267. errcode=Codes.FORBIDDEN,
  268. )
  269. user = payload.get("sub", None)
  270. if user is None:
  271. raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
  272. user_id = UserID(user, self.hs.hostname).to_string()
  273. result = await self._complete_login(
  274. user_id, login_submission, create_non_existent_users=True
  275. )
  276. return result
  277. def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
  278. """Return an entry for the login flow dict
  279. Returns an entry suitable for inclusion in "identity_providers" in the
  280. response to GET /_matrix/client/r0/login
  281. """
  282. e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
  283. if idp.idp_icon:
  284. e["icon"] = idp.idp_icon
  285. if idp.idp_brand:
  286. e["brand"] = idp.idp_brand
  287. return e
  288. class SsoRedirectServlet(RestServlet):
  289. PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
  290. def __init__(self, hs: "HomeServer"):
  291. # make sure that the relevant handlers are instantiated, so that they
  292. # register themselves with the main SSOHandler.
  293. if hs.config.cas_enabled:
  294. hs.get_cas_handler()
  295. if hs.config.saml2_enabled:
  296. hs.get_saml_handler()
  297. if hs.config.oidc_enabled:
  298. hs.get_oidc_handler()
  299. self._sso_handler = hs.get_sso_handler()
  300. self._msc2858_enabled = hs.config.experimental.msc2858_enabled
  301. self._public_baseurl = hs.config.public_baseurl
  302. def register(self, http_server: HttpServer) -> None:
  303. super().register(http_server)
  304. if self._msc2858_enabled:
  305. # expose additional endpoint for MSC2858 support
  306. http_server.register_paths(
  307. "GET",
  308. client_patterns(
  309. "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
  310. releases=(),
  311. unstable=True,
  312. ),
  313. self.on_GET,
  314. self.__class__.__name__,
  315. )
  316. async def on_GET(
  317. self, request: SynapseRequest, idp_id: Optional[str] = None
  318. ) -> None:
  319. if not self._public_baseurl:
  320. raise SynapseError(400, "SSO requires a valid public_baseurl")
  321. # if this isn't the expected hostname, redirect to the right one, so that we
  322. # get our cookies back.
  323. requested_uri = get_request_uri(request)
  324. baseurl_bytes = self._public_baseurl.encode("utf-8")
  325. if not requested_uri.startswith(baseurl_bytes):
  326. # swap out the incorrect base URL for the right one.
  327. #
  328. # The idea here is to redirect from
  329. # https://foo.bar/whatever/_matrix/...
  330. # to
  331. # https://public.baseurl/_matrix/...
  332. #
  333. i = requested_uri.index(b"/_matrix")
  334. new_uri = baseurl_bytes[:-1] + requested_uri[i:]
  335. logger.info(
  336. "Requested URI %s is not canonical: redirecting to %s",
  337. requested_uri.decode("utf-8", errors="replace"),
  338. new_uri.decode("utf-8", errors="replace"),
  339. )
  340. request.redirect(new_uri)
  341. finish_request(request)
  342. return
  343. client_redirect_url = parse_string(
  344. request, "redirectUrl", required=True, encoding=None
  345. )
  346. sso_url = await self._sso_handler.handle_redirect_request(
  347. request,
  348. client_redirect_url,
  349. idp_id,
  350. )
  351. logger.info("Redirecting to %s", sso_url)
  352. request.redirect(sso_url)
  353. finish_request(request)
  354. class CasTicketServlet(RestServlet):
  355. PATTERNS = client_patterns("/login/cas/ticket", v1=True)
  356. def __init__(self, hs):
  357. super().__init__()
  358. self._cas_handler = hs.get_cas_handler()
  359. async def on_GET(self, request: SynapseRequest) -> None:
  360. client_redirect_url = parse_string(request, "redirectUrl")
  361. ticket = parse_string(request, "ticket", required=True)
  362. # Maybe get a session ID (if this ticket is from user interactive
  363. # authentication).
  364. session = parse_string(request, "session")
  365. # Either client_redirect_url or session must be provided.
  366. if not client_redirect_url and not session:
  367. message = "Missing string query parameter redirectUrl or session"
  368. raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
  369. await self._cas_handler.handle_ticket(
  370. request, ticket, client_redirect_url, session
  371. )
  372. def register_servlets(hs, http_server):
  373. LoginRestServlet(hs).register(http_server)
  374. SsoRedirectServlet(hs).register(http_server)
  375. if hs.config.cas_enabled:
  376. CasTicketServlet(hs).register(http_server)