|
@@ -12,14 +12,28 @@
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
+import binascii
|
|
|
import inspect
|
|
|
+import json
|
|
|
import logging
|
|
|
-from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
|
|
|
+from typing import (
|
|
|
+ TYPE_CHECKING,
|
|
|
+ Any,
|
|
|
+ Dict,
|
|
|
+ Generic,
|
|
|
+ List,
|
|
|
+ Optional,
|
|
|
+ Type,
|
|
|
+ TypeVar,
|
|
|
+ Union,
|
|
|
+)
|
|
|
from urllib.parse import urlencode, urlparse
|
|
|
|
|
|
import attr
|
|
|
+import unpaddedbase64
|
|
|
from authlib.common.security import generate_token
|
|
|
-from authlib.jose import JsonWebToken, jwt
|
|
|
+from authlib.jose import JsonWebToken, JWTClaims
|
|
|
+from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
|
|
|
from authlib.oauth2.auth import ClientAuth
|
|
|
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
|
|
|
from authlib.oidc.core import CodeIDToken, UserInfo
|
|
@@ -35,9 +49,12 @@ from typing_extensions import TypedDict
|
|
|
from twisted.web.client import readBody
|
|
|
from twisted.web.http_headers import Headers
|
|
|
|
|
|
+from synapse.api.errors import SynapseError
|
|
|
from synapse.config import ConfigError
|
|
|
from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
|
|
|
from synapse.handlers.sso import MappingException, UserAttributes
|
|
|
+from synapse.http.server import finish_request
|
|
|
+from synapse.http.servlet import parse_string
|
|
|
from synapse.http.site import SynapseRequest
|
|
|
from synapse.logging.context import make_deferred_yieldable
|
|
|
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
|
|
@@ -88,6 +105,8 @@ class Token(TypedDict):
|
|
|
#: there is no real point of doing this in our case.
|
|
|
JWK = Dict[str, str]
|
|
|
|
|
|
+C = TypeVar("C")
|
|
|
+
|
|
|
|
|
|
#: A JWK Set, as per RFC7517 sec 5.
|
|
|
class JWKS(TypedDict):
|
|
@@ -247,6 +266,80 @@ class OidcHandler:
|
|
|
|
|
|
await oidc_provider.handle_oidc_callback(request, session_data, code)
|
|
|
|
|
|
+ async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
|
|
|
+ """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
|
|
|
+
|
|
|
+ This extracts the logout_token from the request and tries to figure out
|
|
|
+ which OpenID Provider it is comming from. This works by matching the iss claim
|
|
|
+ with the issuer and the aud claim with the client_id.
|
|
|
+
|
|
|
+ Since at this point we don't know who signed the JWT, we can't just
|
|
|
+ decode it using authlib since it will always verifies the signature. We
|
|
|
+ have to decode it manually without validating the signature. The actual JWT
|
|
|
+ verification is done in the `OidcProvider.handler_backchannel_logout` method,
|
|
|
+ once we figured out which provider sent the request.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ request: the incoming request from the browser.
|
|
|
+ """
|
|
|
+ logout_token = parse_string(request, "logout_token")
|
|
|
+ if logout_token is None:
|
|
|
+ raise SynapseError(400, "Missing logout_token in request")
|
|
|
+
|
|
|
+ # A JWT looks like this:
|
|
|
+ # header.payload.signature
|
|
|
+ # where all parts are encoded with urlsafe base64.
|
|
|
+ # The aud and iss claims we care about are in the payload part, which
|
|
|
+ # is a JSON object.
|
|
|
+ try:
|
|
|
+ # By destructuring the list after splitting, we ensure that we have
|
|
|
+ # exactly 3 segments
|
|
|
+ _, payload, _ = logout_token.split(".")
|
|
|
+ except ValueError:
|
|
|
+ raise SynapseError(400, "Invalid logout_token in request")
|
|
|
+
|
|
|
+ try:
|
|
|
+ payload_bytes = unpaddedbase64.decode_base64(payload)
|
|
|
+ claims = json_decoder.decode(payload_bytes.decode("utf-8"))
|
|
|
+ except (json.JSONDecodeError, binascii.Error, UnicodeError):
|
|
|
+ raise SynapseError(400, "Invalid logout_token payload in request")
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Let's extract the iss and aud claims
|
|
|
+ iss = claims["iss"]
|
|
|
+ aud = claims["aud"]
|
|
|
+ # The aud claim can be either a string or a list of string. Here we
|
|
|
+ # normalize it as a list of strings.
|
|
|
+ if isinstance(aud, str):
|
|
|
+ aud = [aud]
|
|
|
+
|
|
|
+ # Check that we have the right types for the aud and the iss claims
|
|
|
+ if not isinstance(iss, str) or not isinstance(aud, list):
|
|
|
+ raise TypeError()
|
|
|
+ for a in aud:
|
|
|
+ if not isinstance(a, str):
|
|
|
+ raise TypeError()
|
|
|
+
|
|
|
+ # At this point we properly checked both claims types
|
|
|
+ issuer: str = iss
|
|
|
+ audience: List[str] = aud
|
|
|
+ except (TypeError, KeyError):
|
|
|
+ raise SynapseError(400, "Invalid issuer/audience in logout_token")
|
|
|
+
|
|
|
+ # Now that we know the audience and the issuer, we can figure out from
|
|
|
+ # what provider it is coming from
|
|
|
+ oidc_provider: Optional[OidcProvider] = None
|
|
|
+ for provider in self._providers.values():
|
|
|
+ if provider.issuer == issuer and provider.client_id in audience:
|
|
|
+ oidc_provider = provider
|
|
|
+ break
|
|
|
+
|
|
|
+ if oidc_provider is None:
|
|
|
+ raise SynapseError(400, "Could not find the OP that issued this event")
|
|
|
+
|
|
|
+ # Ask the provider to handle the logout request.
|
|
|
+ await oidc_provider.handle_backchannel_logout(request, logout_token)
|
|
|
+
|
|
|
|
|
|
class OidcError(Exception):
|
|
|
"""Used to catch errors when calling the token_endpoint"""
|
|
@@ -342,6 +435,7 @@ class OidcProvider:
|
|
|
self.idp_brand = provider.idp_brand
|
|
|
|
|
|
self._sso_handler = hs.get_sso_handler()
|
|
|
+ self._device_handler = hs.get_device_handler()
|
|
|
|
|
|
self._sso_handler.register_identity_provider(self)
|
|
|
|
|
@@ -400,6 +494,41 @@ class OidcProvider:
|
|
|
# If we're not using userinfo, we need a valid jwks to validate the ID token
|
|
|
m.validate_jwks_uri()
|
|
|
|
|
|
+ if self._config.backchannel_logout_enabled:
|
|
|
+ if not m.get("backchannel_logout_supported", False):
|
|
|
+ logger.warning(
|
|
|
+ "OIDC Back-Channel Logout is enabled for issuer %r"
|
|
|
+ "but it does not advertise support for it",
|
|
|
+ self.issuer,
|
|
|
+ )
|
|
|
+
|
|
|
+ elif not m.get("backchannel_logout_session_supported", False):
|
|
|
+ logger.warning(
|
|
|
+ "OIDC Back-Channel Logout is enabled and supported "
|
|
|
+ "by issuer %r but it might not send a session ID with "
|
|
|
+ "logout tokens, which is required for the logouts to work",
|
|
|
+ self.issuer,
|
|
|
+ )
|
|
|
+
|
|
|
+ if not self._config.backchannel_logout_ignore_sub:
|
|
|
+ # If OIDC backchannel logouts are enabled, the provider mapping provider
|
|
|
+ # should use the `sub` claim. We verify that by mapping a dumb user and
|
|
|
+ # see if we get back the sub claim
|
|
|
+ user = UserInfo({"sub": "thisisasubject"})
|
|
|
+ try:
|
|
|
+ subject = self._user_mapping_provider.get_remote_user_id(user)
|
|
|
+ if subject != user["sub"]:
|
|
|
+ raise ValueError("Unexpected subject")
|
|
|
+ except Exception:
|
|
|
+ logger.warning(
|
|
|
+ f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
|
|
|
+ "but it looks like the configured `user_mapping_provider` "
|
|
|
+ "does not use the `sub` claim as subject. If it is the case, "
|
|
|
+ "and you want Synapse to ignore the `sub` claim in OIDC "
|
|
|
+ "Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
|
|
|
+ "to `true` in the issuer config."
|
|
|
+ )
|
|
|
+
|
|
|
@property
|
|
|
def _uses_userinfo(self) -> bool:
|
|
|
"""Returns True if the ``userinfo_endpoint`` should be used.
|
|
@@ -415,6 +544,16 @@ class OidcProvider:
|
|
|
or self._user_profile_method == "userinfo_endpoint"
|
|
|
)
|
|
|
|
|
|
+ @property
|
|
|
+ def issuer(self) -> str:
|
|
|
+ """The issuer identifying this provider."""
|
|
|
+ return self._config.issuer
|
|
|
+
|
|
|
+ @property
|
|
|
+ def client_id(self) -> str:
|
|
|
+ """The client_id used when interacting with this provider."""
|
|
|
+ return self._config.client_id
|
|
|
+
|
|
|
async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
|
|
|
"""Return the provider metadata.
|
|
|
|
|
@@ -662,6 +801,59 @@ class OidcProvider:
|
|
|
|
|
|
return UserInfo(resp)
|
|
|
|
|
|
+ async def _verify_jwt(
|
|
|
+ self,
|
|
|
+ alg_values: List[str],
|
|
|
+ token: str,
|
|
|
+ claims_cls: Type[C],
|
|
|
+ claims_options: Optional[dict] = None,
|
|
|
+ claims_params: Optional[dict] = None,
|
|
|
+ ) -> C:
|
|
|
+ """Decode and validate a JWT, re-fetching the JWKS as needed.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ alg_values: list of `alg` values allowed when verifying the JWT.
|
|
|
+ token: the JWT.
|
|
|
+ claims_cls: the JWTClaims class to use to validate the claims.
|
|
|
+ claims_options: dict of options passed to the `claims_cls` constructor.
|
|
|
+ claims_params: dict of params passed to the `claims_cls` constructor.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The decoded claims in the JWT.
|
|
|
+ """
|
|
|
+ jwt = JsonWebToken(alg_values)
|
|
|
+
|
|
|
+ logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
|
|
|
+
|
|
|
+ # Try to decode the keys in cache first, then retry by forcing the keys
|
|
|
+ # to be reloaded
|
|
|
+ jwk_set = await self.load_jwks()
|
|
|
+ try:
|
|
|
+ claims = jwt.decode(
|
|
|
+ token,
|
|
|
+ key=jwk_set,
|
|
|
+ claims_cls=claims_cls,
|
|
|
+ claims_options=claims_options,
|
|
|
+ claims_params=claims_params,
|
|
|
+ )
|
|
|
+ except ValueError:
|
|
|
+ logger.info("Reloading JWKS after decode error")
|
|
|
+ jwk_set = await self.load_jwks(force=True) # try reloading the jwks
|
|
|
+ claims = jwt.decode(
|
|
|
+ token,
|
|
|
+ key=jwk_set,
|
|
|
+ claims_cls=claims_cls,
|
|
|
+ claims_options=claims_options,
|
|
|
+ claims_params=claims_params,
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
|
|
|
+
|
|
|
+ claims.validate(
|
|
|
+ now=self._clock.time(), leeway=120
|
|
|
+ ) # allows 2 min of clock skew
|
|
|
+ return claims
|
|
|
+
|
|
|
async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
|
|
|
"""Return an instance of UserInfo from token's ``id_token``.
|
|
|
|
|
@@ -675,13 +867,13 @@ class OidcProvider:
|
|
|
The decoded claims in the ID token.
|
|
|
"""
|
|
|
id_token = token.get("id_token")
|
|
|
- logger.debug("Attempting to decode JWT id_token %r", id_token)
|
|
|
|
|
|
# That has been theoritically been checked by the caller, so even though
|
|
|
# assertion are not enabled in production, it is mainly here to appease mypy
|
|
|
assert id_token is not None
|
|
|
|
|
|
metadata = await self.load_metadata()
|
|
|
+
|
|
|
claims_params = {
|
|
|
"nonce": nonce,
|
|
|
"client_id": self._client_auth.client_id,
|
|
@@ -691,38 +883,17 @@ class OidcProvider:
|
|
|
# in the `id_token` that we can check against.
|
|
|
claims_params["access_token"] = token["access_token"]
|
|
|
|
|
|
- alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
|
|
- jwt = JsonWebToken(alg_values)
|
|
|
-
|
|
|
- claim_options = {"iss": {"values": [metadata["issuer"]]}}
|
|
|
+ claims_options = {"iss": {"values": [metadata["issuer"]]}}
|
|
|
|
|
|
- # Try to decode the keys in cache first, then retry by forcing the keys
|
|
|
- # to be reloaded
|
|
|
- jwk_set = await self.load_jwks()
|
|
|
- try:
|
|
|
- claims = jwt.decode(
|
|
|
- id_token,
|
|
|
- key=jwk_set,
|
|
|
- claims_cls=CodeIDToken,
|
|
|
- claims_options=claim_options,
|
|
|
- claims_params=claims_params,
|
|
|
- )
|
|
|
- except ValueError:
|
|
|
- logger.info("Reloading JWKS after decode error")
|
|
|
- jwk_set = await self.load_jwks(force=True) # try reloading the jwks
|
|
|
- claims = jwt.decode(
|
|
|
- id_token,
|
|
|
- key=jwk_set,
|
|
|
- claims_cls=CodeIDToken,
|
|
|
- claims_options=claim_options,
|
|
|
- claims_params=claims_params,
|
|
|
- )
|
|
|
-
|
|
|
- logger.debug("Decoded id_token JWT %r; validating", claims)
|
|
|
+ alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
|
|
|
|
|
- claims.validate(
|
|
|
- now=self._clock.time(), leeway=120
|
|
|
- ) # allows 2 min of clock skew
|
|
|
+ claims = await self._verify_jwt(
|
|
|
+ alg_values=alg_values,
|
|
|
+ token=id_token,
|
|
|
+ claims_cls=CodeIDToken,
|
|
|
+ claims_options=claims_options,
|
|
|
+ claims_params=claims_params,
|
|
|
+ )
|
|
|
|
|
|
return claims
|
|
|
|
|
@@ -1043,6 +1214,146 @@ class OidcProvider:
|
|
|
# to be strings.
|
|
|
return str(remote_user_id)
|
|
|
|
|
|
+ async def handle_backchannel_logout(
|
|
|
+ self, request: SynapseRequest, logout_token: str
|
|
|
+ ) -> None:
|
|
|
+ """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
|
|
|
+
|
|
|
+ The OIDC Provider posts a logout token to this endpoint when a user
|
|
|
+ session ends. That token is a JWT signed with the same keys as
|
|
|
+ ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
|
|
|
+ validate the JWT and figure out what session to end.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ request: The request to respond to
|
|
|
+ logout_token: The logout token (a JWT) extracted from the request body
|
|
|
+ """
|
|
|
+ # Back-Channel Logout can be disabled in the config, hence this check.
|
|
|
+ # This is not that important for now since Synapse is registered
|
|
|
+ # manually to the OP, so not specifying the backchannel-logout URI is
|
|
|
+ # as effective than disabling it here. It might make more sense if we
|
|
|
+ # support dynamic registration in Synapse at some point.
|
|
|
+ if not self._config.backchannel_logout_enabled:
|
|
|
+ logger.warning(
|
|
|
+ f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
|
|
|
+ )
|
|
|
+
|
|
|
+ # TODO: this responds with a 400 status code, which is what the OIDC
|
|
|
+ # Back-Channel Logout spec expects, but spec also suggests answering with
|
|
|
+ # a JSON object, with the `error` and `error_description` fields set, which
|
|
|
+ # we are not doing here.
|
|
|
+ # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
|
|
|
+ raise SynapseError(
|
|
|
+ 400, "OpenID Connect Back-Channel Logout is disabled for this provider"
|
|
|
+ )
|
|
|
+
|
|
|
+ metadata = await self.load_metadata()
|
|
|
+
|
|
|
+ # As per OIDC Back-Channel Logout 1.0 sec. 2.4:
|
|
|
+ # A Logout Token MUST be signed and MAY also be encrypted. The same
|
|
|
+ # keys are used to sign and encrypt Logout Tokens as are used for ID
|
|
|
+ # Tokens. If the Logout Token is encrypted, it SHOULD replicate the
|
|
|
+ # iss (issuer) claim in the JWT Header Parameters, as specified in
|
|
|
+ # Section 5.3 of [JWT].
|
|
|
+ alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
|
|
|
+
|
|
|
+ # As per sec. 2.6:
|
|
|
+ # 3. Validate the iss, aud, and iat Claims in the same way they are
|
|
|
+ # validated in ID Tokens.
|
|
|
+ # Which means the audience should contain Synapse's client_id and the
|
|
|
+ # issuer should be the IdP issuer
|
|
|
+ claims_options = {
|
|
|
+ "iss": {"values": [metadata["issuer"]]},
|
|
|
+ "aud": {"values": [self.client_id]},
|
|
|
+ }
|
|
|
+
|
|
|
+ try:
|
|
|
+ claims = await self._verify_jwt(
|
|
|
+ alg_values=alg_values,
|
|
|
+ token=logout_token,
|
|
|
+ claims_cls=LogoutToken,
|
|
|
+ claims_options=claims_options,
|
|
|
+ )
|
|
|
+ except JoseError:
|
|
|
+ logger.exception("Invalid logout_token")
|
|
|
+ raise SynapseError(400, "Invalid logout_token")
|
|
|
+
|
|
|
+ # As per sec. 2.6:
|
|
|
+ # 4. Verify that the Logout Token contains a sub Claim, a sid Claim,
|
|
|
+ # or both.
|
|
|
+ # 5. Verify that the Logout Token contains an events Claim whose
|
|
|
+ # value is JSON object containing the member name
|
|
|
+ # http://schemas.openid.net/event/backchannel-logout.
|
|
|
+ # 6. Verify that the Logout Token does not contain a nonce Claim.
|
|
|
+ # This is all verified by the LogoutToken claims class, so at this
|
|
|
+ # point the `sid` claim exists and is a string.
|
|
|
+ sid: str = claims.get("sid")
|
|
|
+
|
|
|
+ # If the `sub` claim was included in the logout token, we check that it matches
|
|
|
+ # that it matches the right user. We can have cases where the `sub` claim is not
|
|
|
+ # the ID saved in database, so we let admins disable this check in config.
|
|
|
+ sub: Optional[str] = claims.get("sub")
|
|
|
+ expected_user_id: Optional[str] = None
|
|
|
+ if sub is not None and not self._config.backchannel_logout_ignore_sub:
|
|
|
+ expected_user_id = await self._store.get_user_by_external_id(
|
|
|
+ self.idp_id, sub
|
|
|
+ )
|
|
|
+
|
|
|
+ # Invalidate any running user-mapping sessions, in-flight login tokens and
|
|
|
+ # active devices
|
|
|
+ await self._sso_handler.revoke_sessions_for_provider_session_id(
|
|
|
+ auth_provider_id=self.idp_id,
|
|
|
+ auth_provider_session_id=sid,
|
|
|
+ expected_user_id=expected_user_id,
|
|
|
+ )
|
|
|
+
|
|
|
+ request.setResponseCode(200)
|
|
|
+ request.setHeader(b"Cache-Control", b"no-cache, no-store")
|
|
|
+ request.setHeader(b"Pragma", b"no-cache")
|
|
|
+ finish_request(request)
|
|
|
+
|
|
|
+
|
|
|
+class LogoutToken(JWTClaims):
|
|
|
+ """
|
|
|
+ Holds and verify claims of a logout token, as per
|
|
|
+ https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
|
|
|
+ """
|
|
|
+
|
|
|
+ REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
|
|
|
+
|
|
|
+ def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
|
|
|
+ """Validate everything in claims payload."""
|
|
|
+ super().validate(now, leeway)
|
|
|
+ self.validate_sid()
|
|
|
+ self.validate_events()
|
|
|
+ self.validate_nonce()
|
|
|
+
|
|
|
+ def validate_sid(self) -> None:
|
|
|
+ """Ensure the sid claim is present"""
|
|
|
+ sid = self.get("sid")
|
|
|
+ if not sid:
|
|
|
+ raise MissingClaimError("sid")
|
|
|
+
|
|
|
+ if not isinstance(sid, str):
|
|
|
+ raise InvalidClaimError("sid")
|
|
|
+
|
|
|
+ def validate_nonce(self) -> None:
|
|
|
+ """Ensure the nonce claim is absent"""
|
|
|
+ if "nonce" in self:
|
|
|
+ raise InvalidClaimError("nonce")
|
|
|
+
|
|
|
+ def validate_events(self) -> None:
|
|
|
+ """Ensure the events claim is present and with the right value"""
|
|
|
+ events = self.get("events")
|
|
|
+ if not events:
|
|
|
+ raise MissingClaimError("events")
|
|
|
+
|
|
|
+ if not isinstance(events, dict):
|
|
|
+ raise InvalidClaimError("events")
|
|
|
+
|
|
|
+ if "http://schemas.openid.net/event/backchannel-logout" not in events:
|
|
|
+ raise InvalidClaimError("events")
|
|
|
+
|
|
|
|
|
|
# number of seconds a newly-generated client secret should be valid for
|
|
|
CLIENT_SECRET_VALIDITY_SECONDS = 3600
|
|
@@ -1112,6 +1423,7 @@ class JwtClientSecret:
|
|
|
logger.info(
|
|
|
"Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
|
|
|
)
|
|
|
+ jwt = JsonWebToken(header["alg"])
|
|
|
self._cached_secret = jwt.encode(header, payload, self._key.key)
|
|
|
self._cached_secret_replacement_time = (
|
|
|
expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
|
|
@@ -1126,9 +1438,6 @@ class UserAttributeDict(TypedDict):
|
|
|
emails: List[str]
|
|
|
|
|
|
|
|
|
-C = TypeVar("C")
|
|
|
-
|
|
|
-
|
|
|
class OidcMappingProvider(Generic[C]):
|
|
|
"""A mapping provider maps a UserInfo object to user attributes.
|
|
|
|