oidc_handler.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2020 Quentin Gliech
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import inspect
  16. import logging
  17. from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
  18. from urllib.parse import urlencode
  19. import attr
  20. import pymacaroons
  21. from authlib.common.security import generate_token
  22. from authlib.jose import JsonWebToken
  23. from authlib.oauth2.auth import ClientAuth
  24. from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
  25. from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
  26. from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
  27. from jinja2 import Environment, Template
  28. from pymacaroons.exceptions import (
  29. MacaroonDeserializationException,
  30. MacaroonInvalidSignatureException,
  31. )
  32. from typing_extensions import TypedDict
  33. from twisted.web.client import readBody
  34. from synapse.config import ConfigError
  35. from synapse.config.oidc_config import OidcProviderConfig
  36. from synapse.handlers.sso import MappingException, UserAttributes
  37. from synapse.http.site import SynapseRequest
  38. from synapse.logging.context import make_deferred_yieldable
  39. from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
  40. from synapse.util import json_decoder
  41. if TYPE_CHECKING:
  42. from synapse.server import HomeServer
  43. logger = logging.getLogger(__name__)
  44. SESSION_COOKIE_NAME = b"oidc_session"
  45. #: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
  46. #: OpenID.Core sec 3.1.3.3.
  47. Token = TypedDict(
  48. "Token",
  49. {
  50. "access_token": str,
  51. "token_type": str,
  52. "id_token": Optional[str],
  53. "refresh_token": Optional[str],
  54. "expires_in": int,
  55. "scope": Optional[str],
  56. },
  57. )
  58. #: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
  59. #: there is no real point of doing this in our case.
  60. JWK = Dict[str, str]
  61. #: A JWK Set, as per RFC7517 sec 5.
  62. JWKS = TypedDict("JWKS", {"keys": List[JWK]})
  63. class OidcHandler:
  64. """Handles requests related to the OpenID Connect login flow.
  65. """
  66. def __init__(self, hs: "HomeServer"):
  67. self._sso_handler = hs.get_sso_handler()
  68. provider_conf = hs.config.oidc.oidc_provider
  69. # we should not have been instantiated if there is no configured provider.
  70. assert provider_conf is not None
  71. self._token_generator = OidcSessionTokenGenerator(hs)
  72. self._provider = OidcProvider(hs, self._token_generator, provider_conf)
  73. async def load_metadata(self) -> None:
  74. """Validate the config and load the metadata from the remote endpoint.
  75. Called at startup to ensure we have everything we need.
  76. """
  77. await self._provider.load_metadata()
  78. await self._provider.load_jwks()
  79. async def handle_oidc_callback(self, request: SynapseRequest) -> None:
  80. """Handle an incoming request to /_synapse/oidc/callback
  81. Since we might want to display OIDC-related errors in a user-friendly
  82. way, we don't raise SynapseError from here. Instead, we call
  83. ``self._sso_handler.render_error`` which displays an HTML page for the error.
  84. Most of the OpenID Connect logic happens here:
  85. - first, we check if there was any error returned by the provider and
  86. display it
  87. - then we fetch the session cookie, decode and verify it
  88. - the ``state`` query parameter should match with the one stored in the
  89. session cookie
  90. Once we know the session is legit, we then delegate to the OIDC Provider
  91. implementation, which will exchange the code with the provider and complete the
  92. login/authentication.
  93. Args:
  94. request: the incoming request from the browser.
  95. """
  96. # The provider might redirect with an error.
  97. # In that case, just display it as-is.
  98. if b"error" in request.args:
  99. # error response from the auth server. see:
  100. # https://tools.ietf.org/html/rfc6749#section-4.1.2.1
  101. # https://openid.net/specs/openid-connect-core-1_0.html#AuthError
  102. error = request.args[b"error"][0].decode()
  103. description = request.args.get(b"error_description", [b""])[0].decode()
  104. # Most of the errors returned by the provider could be due by
  105. # either the provider misbehaving or Synapse being misconfigured.
  106. # The only exception of that is "access_denied", where the user
  107. # probably cancelled the login flow. In other cases, log those errors.
  108. if error != "access_denied":
  109. logger.error("Error from the OIDC provider: %s %s", error, description)
  110. self._sso_handler.render_error(request, error, description)
  111. return
  112. # otherwise, it is presumably a successful response. see:
  113. # https://tools.ietf.org/html/rfc6749#section-4.1.2
  114. # Fetch the session cookie
  115. session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
  116. if session is None:
  117. logger.info("No session cookie found")
  118. self._sso_handler.render_error(
  119. request, "missing_session", "No session cookie found"
  120. )
  121. return
  122. # Remove the cookie. There is a good chance that if the callback failed
  123. # once, it will fail next time and the code will already be exchanged.
  124. # Removing it early avoids spamming the provider with token requests.
  125. request.addCookie(
  126. SESSION_COOKIE_NAME,
  127. b"",
  128. path="/_synapse/oidc",
  129. expires="Thu, Jan 01 1970 00:00:00 UTC",
  130. httpOnly=True,
  131. sameSite="lax",
  132. )
  133. # Check for the state query parameter
  134. if b"state" not in request.args:
  135. logger.info("State parameter is missing")
  136. self._sso_handler.render_error(
  137. request, "invalid_request", "State parameter is missing"
  138. )
  139. return
  140. state = request.args[b"state"][0].decode()
  141. # Deserialize the session token and verify it.
  142. try:
  143. session_data = self._token_generator.verify_oidc_session_token(
  144. session, state
  145. )
  146. except MacaroonDeserializationException as e:
  147. logger.exception("Invalid session")
  148. self._sso_handler.render_error(request, "invalid_session", str(e))
  149. return
  150. except MacaroonInvalidSignatureException as e:
  151. logger.exception("Could not verify session")
  152. self._sso_handler.render_error(request, "mismatching_session", str(e))
  153. return
  154. if b"code" not in request.args:
  155. logger.info("Code parameter is missing")
  156. self._sso_handler.render_error(
  157. request, "invalid_request", "Code parameter is missing"
  158. )
  159. return
  160. code = request.args[b"code"][0].decode()
  161. await self._provider.handle_oidc_callback(request, session_data, code)
  162. class OidcError(Exception):
  163. """Used to catch errors when calling the token_endpoint
  164. """
  165. def __init__(self, error, error_description=None):
  166. self.error = error
  167. self.error_description = error_description
  168. def __str__(self):
  169. if self.error_description:
  170. return "{}: {}".format(self.error, self.error_description)
  171. return self.error
  172. class OidcProvider:
  173. """Wraps the config for a single OIDC IdentityProvider
  174. Provides methods for handling redirect requests and callbacks via that particular
  175. IdP.
  176. """
  177. def __init__(
  178. self,
  179. hs: "HomeServer",
  180. token_generator: "OidcSessionTokenGenerator",
  181. provider: OidcProviderConfig,
  182. ):
  183. self._store = hs.get_datastore()
  184. self._token_generator = token_generator
  185. self._callback_url = hs.config.oidc_callback_url # type: str
  186. self._scopes = provider.scopes
  187. self._user_profile_method = provider.user_profile_method
  188. self._client_auth = ClientAuth(
  189. provider.client_id, provider.client_secret, provider.client_auth_method,
  190. ) # type: ClientAuth
  191. self._client_auth_method = provider.client_auth_method
  192. self._provider_metadata = OpenIDProviderMetadata(
  193. issuer=provider.issuer,
  194. authorization_endpoint=provider.authorization_endpoint,
  195. token_endpoint=provider.token_endpoint,
  196. userinfo_endpoint=provider.userinfo_endpoint,
  197. jwks_uri=provider.jwks_uri,
  198. ) # type: OpenIDProviderMetadata
  199. self._provider_needs_discovery = provider.discover
  200. self._user_mapping_provider = provider.user_mapping_provider_class(
  201. provider.user_mapping_provider_config
  202. )
  203. self._skip_verification = provider.skip_verification
  204. self._allow_existing_users = provider.allow_existing_users
  205. self._http_client = hs.get_proxied_http_client()
  206. self._server_name = hs.config.server_name # type: str
  207. # identifier for the external_ids table
  208. self.idp_id = "oidc"
  209. # user-facing name of this auth provider
  210. self.idp_name = "OIDC"
  211. self._sso_handler = hs.get_sso_handler()
  212. self._sso_handler.register_identity_provider(self)
  213. def _validate_metadata(self):
  214. """Verifies the provider metadata.
  215. This checks the validity of the currently loaded provider. Not
  216. everything is checked, only:
  217. - ``issuer``
  218. - ``authorization_endpoint``
  219. - ``token_endpoint``
  220. - ``response_types_supported`` (checks if "code" is in it)
  221. - ``jwks_uri``
  222. Raises:
  223. ValueError: if something in the provider is not valid
  224. """
  225. # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
  226. if self._skip_verification is True:
  227. return
  228. m = self._provider_metadata
  229. m.validate_issuer()
  230. m.validate_authorization_endpoint()
  231. m.validate_token_endpoint()
  232. if m.get("token_endpoint_auth_methods_supported") is not None:
  233. m.validate_token_endpoint_auth_methods_supported()
  234. if (
  235. self._client_auth_method
  236. not in m["token_endpoint_auth_methods_supported"]
  237. ):
  238. raise ValueError(
  239. '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
  240. auth_method=self._client_auth_method,
  241. supported=m["token_endpoint_auth_methods_supported"],
  242. )
  243. )
  244. if m.get("response_types_supported") is not None:
  245. m.validate_response_types_supported()
  246. if "code" not in m["response_types_supported"]:
  247. raise ValueError(
  248. '"code" not in "response_types_supported" (%r)'
  249. % (m["response_types_supported"],)
  250. )
  251. # Ensure there's a userinfo endpoint to fetch from if it is required.
  252. if self._uses_userinfo:
  253. if m.get("userinfo_endpoint") is None:
  254. raise ValueError(
  255. 'provider has no "userinfo_endpoint", even though it is required'
  256. )
  257. else:
  258. # If we're not using userinfo, we need a valid jwks to validate the ID token
  259. if m.get("jwks") is None:
  260. if m.get("jwks_uri") is not None:
  261. m.validate_jwks_uri()
  262. else:
  263. raise ValueError('"jwks_uri" must be set')
  264. @property
  265. def _uses_userinfo(self) -> bool:
  266. """Returns True if the ``userinfo_endpoint`` should be used.
  267. This is based on the requested scopes: if the scopes include
  268. ``openid``, the provider should give use an ID token containing the
  269. user information. If not, we should fetch them using the
  270. ``access_token`` with the ``userinfo_endpoint``.
  271. """
  272. return (
  273. "openid" not in self._scopes
  274. or self._user_profile_method == "userinfo_endpoint"
  275. )
  276. async def load_metadata(self) -> OpenIDProviderMetadata:
  277. """Load and validate the provider metadata.
  278. The values metadatas are discovered if ``oidc_config.discovery`` is
  279. ``True`` and then cached.
  280. Raises:
  281. ValueError: if something in the provider is not valid
  282. Returns:
  283. The provider's metadata.
  284. """
  285. # If we are using the OpenID Discovery documents, it needs to be loaded once
  286. # FIXME: should there be a lock here?
  287. if self._provider_needs_discovery:
  288. url = get_well_known_url(self._provider_metadata["issuer"], external=True)
  289. metadata_response = await self._http_client.get_json(url)
  290. # TODO: maybe update the other way around to let user override some values?
  291. self._provider_metadata.update(metadata_response)
  292. self._provider_needs_discovery = False
  293. self._validate_metadata()
  294. return self._provider_metadata
  295. async def load_jwks(self, force: bool = False) -> JWKS:
  296. """Load the JSON Web Key Set used to sign ID tokens.
  297. If we're not using the ``userinfo_endpoint``, user infos are extracted
  298. from the ID token, which is a JWT signed by keys given by the provider.
  299. The keys are then cached.
  300. Args:
  301. force: Force reloading the keys.
  302. Returns:
  303. The key set
  304. Looks like this::
  305. {
  306. 'keys': [
  307. {
  308. 'kid': 'abcdef',
  309. 'kty': 'RSA',
  310. 'alg': 'RS256',
  311. 'use': 'sig',
  312. 'e': 'XXXX',
  313. 'n': 'XXXX',
  314. }
  315. ]
  316. }
  317. """
  318. if self._uses_userinfo:
  319. # We're not using jwt signing, return an empty jwk set
  320. return {"keys": []}
  321. # First check if the JWKS are loaded in the provider metadata.
  322. # It can happen either if the provider gives its JWKS in the discovery
  323. # document directly or if it was already loaded once.
  324. metadata = await self.load_metadata()
  325. jwk_set = metadata.get("jwks")
  326. if jwk_set is not None and not force:
  327. return jwk_set
  328. # Loading the JWKS using the `jwks_uri` metadata
  329. uri = metadata.get("jwks_uri")
  330. if not uri:
  331. raise RuntimeError('Missing "jwks_uri" in metadata')
  332. jwk_set = await self._http_client.get_json(uri)
  333. # Caching the JWKS in the provider's metadata
  334. self._provider_metadata["jwks"] = jwk_set
  335. return jwk_set
  336. async def _exchange_code(self, code: str) -> Token:
  337. """Exchange an authorization code for a token.
  338. This calls the ``token_endpoint`` with the authorization code we
  339. received in the callback to exchange it for a token. The call uses the
  340. ``ClientAuth`` to authenticate with the client with its ID and secret.
  341. See:
  342. https://tools.ietf.org/html/rfc6749#section-3.2
  343. https://openid.net/specs/openid-connect-core-1_0.html#TokenEndpoint
  344. Args:
  345. code: The authorization code we got from the callback.
  346. Returns:
  347. A dict containing various tokens.
  348. May look like this::
  349. {
  350. 'token_type': 'bearer',
  351. 'access_token': 'abcdef',
  352. 'expires_in': 3599,
  353. 'id_token': 'ghijkl',
  354. 'refresh_token': 'mnopqr',
  355. }
  356. Raises:
  357. OidcError: when the ``token_endpoint`` returned an error.
  358. """
  359. metadata = await self.load_metadata()
  360. token_endpoint = metadata.get("token_endpoint")
  361. headers = {
  362. "Content-Type": "application/x-www-form-urlencoded",
  363. "User-Agent": self._http_client.user_agent,
  364. "Accept": "application/json",
  365. }
  366. args = {
  367. "grant_type": "authorization_code",
  368. "code": code,
  369. "redirect_uri": self._callback_url,
  370. }
  371. body = urlencode(args, True)
  372. # Fill the body/headers with credentials
  373. uri, headers, body = self._client_auth.prepare(
  374. method="POST", uri=token_endpoint, headers=headers, body=body
  375. )
  376. headers = {k: [v] for (k, v) in headers.items()}
  377. # Do the actual request
  378. # We're not using the SimpleHttpClient util methods as we don't want to
  379. # check the HTTP status code and we do the body encoding ourself.
  380. response = await self._http_client.request(
  381. method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
  382. )
  383. # This is used in multiple error messages below
  384. status = "{code} {phrase}".format(
  385. code=response.code, phrase=response.phrase.decode("utf-8")
  386. )
  387. resp_body = await make_deferred_yieldable(readBody(response))
  388. if response.code >= 500:
  389. # In case of a server error, we should first try to decode the body
  390. # and check for an error field. If not, we respond with a generic
  391. # error message.
  392. try:
  393. resp = json_decoder.decode(resp_body.decode("utf-8"))
  394. error = resp["error"]
  395. description = resp.get("error_description", error)
  396. except (ValueError, KeyError):
  397. # Catch ValueError for the JSON decoding and KeyError for the "error" field
  398. error = "server_error"
  399. description = (
  400. (
  401. 'Authorization server responded with a "{status}" error '
  402. "while exchanging the authorization code."
  403. ).format(status=status),
  404. )
  405. raise OidcError(error, description)
  406. # Since it is a not a 5xx code, body should be a valid JSON. It will
  407. # raise if not.
  408. resp = json_decoder.decode(resp_body.decode("utf-8"))
  409. if "error" in resp:
  410. error = resp["error"]
  411. # In case the authorization server responded with an error field,
  412. # it should be a 4xx code. If not, warn about it but don't do
  413. # anything special and report the original error message.
  414. if response.code < 400:
  415. logger.debug(
  416. "Invalid response from the authorization server: "
  417. 'responded with a "{status}" '
  418. "but body has an error field: {error!r}".format(
  419. status=status, error=resp["error"]
  420. )
  421. )
  422. description = resp.get("error_description", error)
  423. raise OidcError(error, description)
  424. # Now, this should not be an error. According to RFC6749 sec 5.1, it
  425. # should be a 200 code. We're a bit more flexible than that, and will
  426. # only throw on a 4xx code.
  427. if response.code >= 400:
  428. description = (
  429. 'Authorization server responded with a "{status}" error '
  430. 'but did not include an "error" field in its response.'.format(
  431. status=status
  432. )
  433. )
  434. logger.warning(description)
  435. # Body was still valid JSON. Might be useful to log it for debugging.
  436. logger.warning("Code exchange response: {resp!r}".format(resp=resp))
  437. raise OidcError("server_error", description)
  438. return resp
  439. async def _fetch_userinfo(self, token: Token) -> UserInfo:
  440. """Fetch user information from the ``userinfo_endpoint``.
  441. Args:
  442. token: the token given by the ``token_endpoint``.
  443. Must include an ``access_token`` field.
  444. Returns:
  445. UserInfo: an object representing the user.
  446. """
  447. metadata = await self.load_metadata()
  448. resp = await self._http_client.get_json(
  449. metadata["userinfo_endpoint"],
  450. headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
  451. )
  452. return UserInfo(resp)
  453. async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
  454. """Return an instance of UserInfo from token's ``id_token``.
  455. Args:
  456. token: the token given by the ``token_endpoint``.
  457. Must include an ``id_token`` field.
  458. nonce: the nonce value originally sent in the initial authorization
  459. request. This value should match the one inside the token.
  460. Returns:
  461. An object representing the user.
  462. """
  463. metadata = await self.load_metadata()
  464. claims_params = {
  465. "nonce": nonce,
  466. "client_id": self._client_auth.client_id,
  467. }
  468. if "access_token" in token:
  469. # If we got an `access_token`, there should be an `at_hash` claim
  470. # in the `id_token` that we can check against.
  471. claims_params["access_token"] = token["access_token"]
  472. claims_cls = CodeIDToken
  473. else:
  474. claims_cls = ImplicitIDToken
  475. alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
  476. jwt = JsonWebToken(alg_values)
  477. claim_options = {"iss": {"values": [metadata["issuer"]]}}
  478. # Try to decode the keys in cache first, then retry by forcing the keys
  479. # to be reloaded
  480. jwk_set = await self.load_jwks()
  481. try:
  482. claims = jwt.decode(
  483. token["id_token"],
  484. key=jwk_set,
  485. claims_cls=claims_cls,
  486. claims_options=claim_options,
  487. claims_params=claims_params,
  488. )
  489. except ValueError:
  490. logger.info("Reloading JWKS after decode error")
  491. jwk_set = await self.load_jwks(force=True) # try reloading the jwks
  492. claims = jwt.decode(
  493. token["id_token"],
  494. key=jwk_set,
  495. claims_cls=claims_cls,
  496. claims_options=claim_options,
  497. claims_params=claims_params,
  498. )
  499. claims.validate(leeway=120) # allows 2 min of clock skew
  500. return UserInfo(claims)
  501. async def handle_redirect_request(
  502. self,
  503. request: SynapseRequest,
  504. client_redirect_url: Optional[bytes],
  505. ui_auth_session_id: Optional[str] = None,
  506. ) -> str:
  507. """Handle an incoming request to /login/sso/redirect
  508. It returns a redirect to the authorization endpoint with a few
  509. parameters:
  510. - ``client_id``: the client ID set in ``oidc_config.client_id``
  511. - ``response_type``: ``code``
  512. - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
  513. - ``scope``: the list of scopes set in ``oidc_config.scopes``
  514. - ``state``: a random string
  515. - ``nonce``: a random string
  516. In addition generating a redirect URL, we are setting a cookie with
  517. a signed macaroon token containing the state, the nonce and the
  518. client_redirect_url params. Those are then checked when the client
  519. comes back from the provider.
  520. Args:
  521. request: the incoming request from the browser.
  522. We'll respond to it with a redirect and a cookie.
  523. client_redirect_url: the URL that we should redirect the client to
  524. when everything is done (or None for UI Auth)
  525. ui_auth_session_id: The session ID of the ongoing UI Auth (or
  526. None if this is a login).
  527. Returns:
  528. The redirect URL to the authorization endpoint.
  529. """
  530. state = generate_token()
  531. nonce = generate_token()
  532. if not client_redirect_url:
  533. client_redirect_url = b""
  534. cookie = self._token_generator.generate_oidc_session_token(
  535. state=state,
  536. session_data=OidcSessionData(
  537. nonce=nonce,
  538. client_redirect_url=client_redirect_url.decode(),
  539. ui_auth_session_id=ui_auth_session_id,
  540. ),
  541. )
  542. request.addCookie(
  543. SESSION_COOKIE_NAME,
  544. cookie,
  545. path="/_synapse/oidc",
  546. max_age="3600",
  547. httpOnly=True,
  548. sameSite="lax",
  549. )
  550. metadata = await self.load_metadata()
  551. authorization_endpoint = metadata.get("authorization_endpoint")
  552. return prepare_grant_uri(
  553. authorization_endpoint,
  554. client_id=self._client_auth.client_id,
  555. response_type="code",
  556. redirect_uri=self._callback_url,
  557. scope=self._scopes,
  558. state=state,
  559. nonce=nonce,
  560. )
  561. async def handle_oidc_callback(
  562. self, request: SynapseRequest, session_data: "OidcSessionData", code: str
  563. ) -> None:
  564. """Handle an incoming request to /_synapse/oidc/callback
  565. By this time we have already validated the session on the synapse side, and
  566. now need to do the provider-specific operations. This includes:
  567. - exchange the code with the provider using the ``token_endpoint`` (see
  568. ``_exchange_code``)
  569. - once we have the token, use it to either extract the UserInfo from
  570. the ``id_token`` (``_parse_id_token``), or use the ``access_token``
  571. to fetch UserInfo from the ``userinfo_endpoint``
  572. (``_fetch_userinfo``)
  573. - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
  574. finish the login
  575. Args:
  576. request: the incoming request from the browser.
  577. session_data: the session data, extracted from our cookie
  578. code: The authorization code we got from the callback.
  579. """
  580. # Exchange the code with the provider
  581. try:
  582. logger.debug("Exchanging code")
  583. token = await self._exchange_code(code)
  584. except OidcError as e:
  585. logger.exception("Could not exchange code")
  586. self._sso_handler.render_error(request, e.error, e.error_description)
  587. return
  588. logger.debug("Successfully obtained OAuth2 access token")
  589. # Now that we have a token, get the userinfo, either by decoding the
  590. # `id_token` or by fetching the `userinfo_endpoint`.
  591. if self._uses_userinfo:
  592. logger.debug("Fetching userinfo")
  593. try:
  594. userinfo = await self._fetch_userinfo(token)
  595. except Exception as e:
  596. logger.exception("Could not fetch userinfo")
  597. self._sso_handler.render_error(request, "fetch_error", str(e))
  598. return
  599. else:
  600. logger.debug("Extracting userinfo from id_token")
  601. try:
  602. userinfo = await self._parse_id_token(token, nonce=session_data.nonce)
  603. except Exception as e:
  604. logger.exception("Invalid id_token")
  605. self._sso_handler.render_error(request, "invalid_token", str(e))
  606. return
  607. # first check if we're doing a UIA
  608. if session_data.ui_auth_session_id:
  609. try:
  610. remote_user_id = self._remote_id_from_userinfo(userinfo)
  611. except Exception as e:
  612. logger.exception("Could not extract remote user id")
  613. self._sso_handler.render_error(request, "mapping_error", str(e))
  614. return
  615. return await self._sso_handler.complete_sso_ui_auth_request(
  616. self.idp_id, remote_user_id, session_data.ui_auth_session_id, request
  617. )
  618. # otherwise, it's a login
  619. # Call the mapper to register/login the user
  620. try:
  621. await self._complete_oidc_login(
  622. userinfo, token, request, session_data.client_redirect_url
  623. )
  624. except MappingException as e:
  625. logger.exception("Could not map user")
  626. self._sso_handler.render_error(request, "mapping_error", str(e))
  627. async def _complete_oidc_login(
  628. self,
  629. userinfo: UserInfo,
  630. token: Token,
  631. request: SynapseRequest,
  632. client_redirect_url: str,
  633. ) -> None:
  634. """Given a UserInfo response, complete the login flow
  635. UserInfo should have a claim that uniquely identifies users. This claim
  636. is usually `sub`, but can be configured with `oidc_config.subject_claim`.
  637. It is then used as an `external_id`.
  638. If we don't find the user that way, we should register the user,
  639. mapping the localpart and the display name from the UserInfo.
  640. If a user already exists with the mxid we've mapped and allow_existing_users
  641. is disabled, raise an exception.
  642. Otherwise, render a redirect back to the client_redirect_url with a loginToken.
  643. Args:
  644. userinfo: an object representing the user
  645. token: a dict with the tokens obtained from the provider
  646. request: The request to respond to
  647. client_redirect_url: The redirect URL passed in by the client.
  648. Raises:
  649. MappingException: if there was an error while mapping some properties
  650. """
  651. try:
  652. remote_user_id = self._remote_id_from_userinfo(userinfo)
  653. except Exception as e:
  654. raise MappingException(
  655. "Failed to extract subject from OIDC response: %s" % (e,)
  656. )
  657. # Older mapping providers don't accept the `failures` argument, so we
  658. # try and detect support.
  659. mapper_signature = inspect.signature(
  660. self._user_mapping_provider.map_user_attributes
  661. )
  662. supports_failures = "failures" in mapper_signature.parameters
  663. async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
  664. """
  665. Call the mapping provider to map the OIDC userinfo and token to user attributes.
  666. This is backwards compatibility for abstraction for the SSO handler.
  667. """
  668. if supports_failures:
  669. attributes = await self._user_mapping_provider.map_user_attributes(
  670. userinfo, token, failures
  671. )
  672. else:
  673. # If the mapping provider does not support processing failures,
  674. # do not continually generate the same Matrix ID since it will
  675. # continue to already be in use. Note that the error raised is
  676. # arbitrary and will get turned into a MappingException.
  677. if failures:
  678. raise MappingException(
  679. "Mapping provider does not support de-duplicating Matrix IDs"
  680. )
  681. attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
  682. userinfo, token
  683. )
  684. return UserAttributes(**attributes)
  685. async def grandfather_existing_users() -> Optional[str]:
  686. if self._allow_existing_users:
  687. # If allowing existing users we want to generate a single localpart
  688. # and attempt to match it.
  689. attributes = await oidc_response_to_user_attributes(failures=0)
  690. user_id = UserID(attributes.localpart, self._server_name).to_string()
  691. users = await self._store.get_users_by_id_case_insensitive(user_id)
  692. if users:
  693. # If an existing matrix ID is returned, then use it.
  694. if len(users) == 1:
  695. previously_registered_user_id = next(iter(users))
  696. elif user_id in users:
  697. previously_registered_user_id = user_id
  698. else:
  699. # Do not attempt to continue generating Matrix IDs.
  700. raise MappingException(
  701. "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
  702. user_id, users
  703. )
  704. )
  705. return previously_registered_user_id
  706. return None
  707. # Mapping providers might not have get_extra_attributes: only call this
  708. # method if it exists.
  709. extra_attributes = None
  710. get_extra_attributes = getattr(
  711. self._user_mapping_provider, "get_extra_attributes", None
  712. )
  713. if get_extra_attributes:
  714. extra_attributes = await get_extra_attributes(userinfo, token)
  715. await self._sso_handler.complete_sso_login_request(
  716. self.idp_id,
  717. remote_user_id,
  718. request,
  719. client_redirect_url,
  720. oidc_response_to_user_attributes,
  721. grandfather_existing_users,
  722. extra_attributes,
  723. )
  724. def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
  725. """Extract the unique remote id from an OIDC UserInfo block
  726. Args:
  727. userinfo: An object representing the user given by the OIDC provider
  728. Returns:
  729. remote user id
  730. """
  731. remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
  732. # Some OIDC providers use integer IDs, but Synapse expects external IDs
  733. # to be strings.
  734. return str(remote_user_id)
  735. class OidcSessionTokenGenerator:
  736. """Methods for generating and checking OIDC Session cookies."""
  737. def __init__(self, hs: "HomeServer"):
  738. self._clock = hs.get_clock()
  739. self._server_name = hs.hostname
  740. self._macaroon_secret_key = hs.config.key.macaroon_secret_key
  741. def generate_oidc_session_token(
  742. self,
  743. state: str,
  744. session_data: "OidcSessionData",
  745. duration_in_ms: int = (60 * 60 * 1000),
  746. ) -> str:
  747. """Generates a signed token storing data about an OIDC session.
  748. When Synapse initiates an authorization flow, it creates a random state
  749. and a random nonce. Those parameters are given to the provider and
  750. should be verified when the client comes back from the provider.
  751. It is also used to store the client_redirect_url, which is used to
  752. complete the SSO login flow.
  753. Args:
  754. state: The ``state`` parameter passed to the OIDC provider.
  755. session_data: data to include in the session token.
  756. duration_in_ms: An optional duration for the token in milliseconds.
  757. Defaults to an hour.
  758. Returns:
  759. A signed macaroon token with the session information.
  760. """
  761. macaroon = pymacaroons.Macaroon(
  762. location=self._server_name, identifier="key", key=self._macaroon_secret_key,
  763. )
  764. macaroon.add_first_party_caveat("gen = 1")
  765. macaroon.add_first_party_caveat("type = session")
  766. macaroon.add_first_party_caveat("state = %s" % (state,))
  767. macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
  768. macaroon.add_first_party_caveat(
  769. "client_redirect_url = %s" % (session_data.client_redirect_url,)
  770. )
  771. if session_data.ui_auth_session_id:
  772. macaroon.add_first_party_caveat(
  773. "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
  774. )
  775. now = self._clock.time_msec()
  776. expiry = now + duration_in_ms
  777. macaroon.add_first_party_caveat("time < %d" % (expiry,))
  778. return macaroon.serialize()
  779. def verify_oidc_session_token(
  780. self, session: bytes, state: str
  781. ) -> "OidcSessionData":
  782. """Verifies and extract an OIDC session token.
  783. This verifies that a given session token was issued by this homeserver
  784. and extract the nonce and client_redirect_url caveats.
  785. Args:
  786. session: The session token to verify
  787. state: The state the OIDC provider gave back
  788. Returns:
  789. The data extracted from the session cookie
  790. """
  791. macaroon = pymacaroons.Macaroon.deserialize(session)
  792. v = pymacaroons.Verifier()
  793. v.satisfy_exact("gen = 1")
  794. v.satisfy_exact("type = session")
  795. v.satisfy_exact("state = %s" % (state,))
  796. v.satisfy_general(lambda c: c.startswith("nonce = "))
  797. v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
  798. # Sometimes there's a UI auth session ID, it seems to be OK to attempt
  799. # to always satisfy this.
  800. v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
  801. v.satisfy_general(self._verify_expiry)
  802. v.verify(macaroon, self._macaroon_secret_key)
  803. # Extract the `nonce`, `client_redirect_url`, and maybe the
  804. # `ui_auth_session_id` from the token.
  805. nonce = self._get_value_from_macaroon(macaroon, "nonce")
  806. client_redirect_url = self._get_value_from_macaroon(
  807. macaroon, "client_redirect_url"
  808. )
  809. try:
  810. ui_auth_session_id = self._get_value_from_macaroon(
  811. macaroon, "ui_auth_session_id"
  812. ) # type: Optional[str]
  813. except ValueError:
  814. ui_auth_session_id = None
  815. return OidcSessionData(
  816. nonce=nonce,
  817. client_redirect_url=client_redirect_url,
  818. ui_auth_session_id=ui_auth_session_id,
  819. )
  820. def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
  821. """Extracts a caveat value from a macaroon token.
  822. Args:
  823. macaroon: the token
  824. key: the key of the caveat to extract
  825. Returns:
  826. The extracted value
  827. Raises:
  828. Exception: if the caveat was not in the macaroon
  829. """
  830. prefix = key + " = "
  831. for caveat in macaroon.caveats:
  832. if caveat.caveat_id.startswith(prefix):
  833. return caveat.caveat_id[len(prefix) :]
  834. raise ValueError("No %s caveat in macaroon" % (key,))
  835. def _verify_expiry(self, caveat: str) -> bool:
  836. prefix = "time < "
  837. if not caveat.startswith(prefix):
  838. return False
  839. expiry = int(caveat[len(prefix) :])
  840. now = self._clock.time_msec()
  841. return now < expiry
  842. @attr.s(frozen=True, slots=True)
  843. class OidcSessionData:
  844. """The attributes which are stored in a OIDC session cookie"""
  845. # The `nonce` parameter passed to the OIDC provider.
  846. nonce = attr.ib(type=str)
  847. # The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
  848. client_redirect_url = attr.ib(type=str)
  849. # The session ID of the ongoing UI Auth (None if this is a login)
  850. ui_auth_session_id = attr.ib(type=Optional[str], default=None)
  851. UserAttributeDict = TypedDict(
  852. "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]}
  853. )
  854. C = TypeVar("C")
  855. class OidcMappingProvider(Generic[C]):
  856. """A mapping provider maps a UserInfo object to user attributes.
  857. It should provide the API described by this class.
  858. """
  859. def __init__(self, config: C):
  860. """
  861. Args:
  862. config: A custom config object from this module, parsed by ``parse_config()``
  863. """
  864. @staticmethod
  865. def parse_config(config: dict) -> C:
  866. """Parse the dict provided by the homeserver's config
  867. Args:
  868. config: A dictionary containing configuration options for this provider
  869. Returns:
  870. A custom config object for this module
  871. """
  872. raise NotImplementedError()
  873. def get_remote_user_id(self, userinfo: UserInfo) -> str:
  874. """Get a unique user ID for this user.
  875. Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object.
  876. Args:
  877. userinfo: An object representing the user given by the OIDC provider
  878. Returns:
  879. A unique user ID
  880. """
  881. raise NotImplementedError()
  882. async def map_user_attributes(
  883. self, userinfo: UserInfo, token: Token, failures: int
  884. ) -> UserAttributeDict:
  885. """Map a `UserInfo` object into user attributes.
  886. Args:
  887. userinfo: An object representing the user given by the OIDC provider
  888. token: A dict with the tokens returned by the provider
  889. failures: How many times a call to this function with this
  890. UserInfo has resulted in a failure.
  891. Returns:
  892. A dict containing the ``localpart`` and (optionally) the ``display_name``
  893. """
  894. raise NotImplementedError()
  895. async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
  896. """Map a `UserInfo` object into additional attributes passed to the client during login.
  897. Args:
  898. userinfo: An object representing the user given by the OIDC provider
  899. token: A dict with the tokens returned by the provider
  900. Returns:
  901. A dict containing additional attributes. Must be JSON serializable.
  902. """
  903. return {}
  904. # Used to clear out "None" values in templates
  905. def jinja_finalize(thing):
  906. return thing if thing is not None else ""
  907. env = Environment(finalize=jinja_finalize)
  908. @attr.s
  909. class JinjaOidcMappingConfig:
  910. subject_claim = attr.ib(type=str)
  911. localpart_template = attr.ib(type=Optional[Template])
  912. display_name_template = attr.ib(type=Optional[Template])
  913. extra_attributes = attr.ib(type=Dict[str, Template])
  914. class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
  915. """An implementation of a mapping provider based on Jinja templates.
  916. This is the default mapping provider.
  917. """
  918. def __init__(self, config: JinjaOidcMappingConfig):
  919. self._config = config
  920. @staticmethod
  921. def parse_config(config: dict) -> JinjaOidcMappingConfig:
  922. subject_claim = config.get("subject_claim", "sub")
  923. localpart_template = None # type: Optional[Template]
  924. if "localpart_template" in config:
  925. try:
  926. localpart_template = env.from_string(config["localpart_template"])
  927. except Exception as e:
  928. raise ConfigError(
  929. "invalid jinja template", path=["localpart_template"]
  930. ) from e
  931. display_name_template = None # type: Optional[Template]
  932. if "display_name_template" in config:
  933. try:
  934. display_name_template = env.from_string(config["display_name_template"])
  935. except Exception as e:
  936. raise ConfigError(
  937. "invalid jinja template", path=["display_name_template"]
  938. ) from e
  939. extra_attributes = {} # type Dict[str, Template]
  940. if "extra_attributes" in config:
  941. extra_attributes_config = config.get("extra_attributes") or {}
  942. if not isinstance(extra_attributes_config, dict):
  943. raise ConfigError("must be a dict", path=["extra_attributes"])
  944. for key, value in extra_attributes_config.items():
  945. try:
  946. extra_attributes[key] = env.from_string(value)
  947. except Exception as e:
  948. raise ConfigError(
  949. "invalid jinja template", path=["extra_attributes", key]
  950. ) from e
  951. return JinjaOidcMappingConfig(
  952. subject_claim=subject_claim,
  953. localpart_template=localpart_template,
  954. display_name_template=display_name_template,
  955. extra_attributes=extra_attributes,
  956. )
  957. def get_remote_user_id(self, userinfo: UserInfo) -> str:
  958. return userinfo[self._config.subject_claim]
  959. async def map_user_attributes(
  960. self, userinfo: UserInfo, token: Token, failures: int
  961. ) -> UserAttributeDict:
  962. localpart = None
  963. if self._config.localpart_template:
  964. localpart = self._config.localpart_template.render(user=userinfo).strip()
  965. # Ensure only valid characters are included in the MXID.
  966. localpart = map_username_to_mxid_localpart(localpart)
  967. # Append suffix integer if last call to this function failed to produce
  968. # a usable mxid.
  969. localpart += str(failures) if failures else ""
  970. display_name = None # type: Optional[str]
  971. if self._config.display_name_template is not None:
  972. display_name = self._config.display_name_template.render(
  973. user=userinfo
  974. ).strip()
  975. if display_name == "":
  976. display_name = None
  977. return UserAttributeDict(localpart=localpart, display_name=display_name)
  978. async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
  979. extras = {} # type: Dict[str, str]
  980. for key, template in self._config.extra_attributes.items():
  981. try:
  982. extras[key] = template.render(user=userinfo).strip()
  983. except Exception as e:
  984. # Log an error and skip this value (don't break login for this).
  985. logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e))
  986. return extras