|
@@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
|
|
|
|
|
|
self.handler = hs.get_oidc_handler()
|
|
|
+ self.provider = self.handler._provider
|
|
|
sso_handler = hs.get_sso_handler()
|
|
|
# Mock the render error method.
|
|
|
self.render_error = Mock(return_value=None)
|
|
@@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
return hs
|
|
|
|
|
|
def metadata_edit(self, values):
|
|
|
- return patch.dict(self.handler._provider_metadata, values)
|
|
|
+ return patch.dict(self.provider._provider_metadata, values)
|
|
|
|
|
|
def assertRenderedError(self, error, error_description=None):
|
|
|
+ self.render_error.assert_called_once()
|
|
|
args = self.render_error.call_args[0]
|
|
|
self.assertEqual(args[1], error)
|
|
|
if error_description is not None:
|
|
@@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
|
|
|
def test_config(self):
|
|
|
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
|
|
- self.assertEqual(self.handler._callback_url, CALLBACK_URL)
|
|
|
- self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
|
|
|
- self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
|
|
|
+ self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
|
|
+ self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
|
|
+ self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
|
|
|
|
|
@override_config({"oidc_config": {"discover": True}})
|
|
|
def test_discovery(self):
|
|
|
"""The handler should discover the endpoints from OIDC discovery document."""
|
|
|
# This would throw if some metadata were invalid
|
|
|
- metadata = self.get_success(self.handler.load_metadata())
|
|
|
+ metadata = self.get_success(self.provider.load_metadata())
|
|
|
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
|
|
|
|
|
|
self.assertEqual(metadata.issuer, ISSUER)
|
|
@@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
|
|
|
# subsequent calls should be cached
|
|
|
self.http_client.reset_mock()
|
|
|
- self.get_success(self.handler.load_metadata())
|
|
|
+ self.get_success(self.provider.load_metadata())
|
|
|
self.http_client.get_json.assert_not_called()
|
|
|
|
|
|
@override_config({"oidc_config": COMMON_CONFIG})
|
|
|
def test_no_discovery(self):
|
|
|
"""When discovery is disabled, it should not try to load from discovery document."""
|
|
|
- self.get_success(self.handler.load_metadata())
|
|
|
+ self.get_success(self.provider.load_metadata())
|
|
|
self.http_client.get_json.assert_not_called()
|
|
|
|
|
|
@override_config({"oidc_config": COMMON_CONFIG})
|
|
|
def test_load_jwks(self):
|
|
|
"""JWKS loading is done once (then cached) if used."""
|
|
|
- jwks = self.get_success(self.handler.load_jwks())
|
|
|
+ jwks = self.get_success(self.provider.load_jwks())
|
|
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
|
|
self.assertEqual(jwks, {"keys": []})
|
|
|
|
|
|
# subsequent calls should be cached…
|
|
|
self.http_client.reset_mock()
|
|
|
- self.get_success(self.handler.load_jwks())
|
|
|
+ self.get_success(self.provider.load_jwks())
|
|
|
self.http_client.get_json.assert_not_called()
|
|
|
|
|
|
# …unless forced
|
|
|
self.http_client.reset_mock()
|
|
|
- self.get_success(self.handler.load_jwks(force=True))
|
|
|
+ self.get_success(self.provider.load_jwks(force=True))
|
|
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
|
|
|
|
|
# Throw if the JWKS uri is missing
|
|
|
with self.metadata_edit({"jwks_uri": None}):
|
|
|
- self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
|
|
|
+ self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
|
|
|
|
|
# Return empty key set if JWKS are not used
|
|
|
- self.handler._scopes = [] # not asking the openid scope
|
|
|
+ self.provider._scopes = [] # not asking the openid scope
|
|
|
self.http_client.get_json.reset_mock()
|
|
|
- jwks = self.get_success(self.handler.load_jwks(force=True))
|
|
|
+ jwks = self.get_success(self.provider.load_jwks(force=True))
|
|
|
self.http_client.get_json.assert_not_called()
|
|
|
self.assertEqual(jwks, {"keys": []})
|
|
|
|
|
|
@override_config({"oidc_config": COMMON_CONFIG})
|
|
|
def test_validate_config(self):
|
|
|
"""Provider metadatas are extensively validated."""
|
|
|
- h = self.handler
|
|
|
+ h = self.provider
|
|
|
|
|
|
# Default test config does not throw
|
|
|
h._validate_metadata()
|
|
@@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
"""Provider metadata validation can be disabled by config."""
|
|
|
with self.metadata_edit({"issuer": "http://insecure"}):
|
|
|
# This should not throw
|
|
|
- self.handler._validate_metadata()
|
|
|
+ self.provider._validate_metadata()
|
|
|
|
|
|
def test_redirect_request(self):
|
|
|
"""The redirect request has the right arguments & generates a valid session cookie."""
|
|
|
req = Mock(spec=["addCookie"])
|
|
|
url = self.get_success(
|
|
|
- self.handler.handle_redirect_request(req, b"http://client/redirect")
|
|
|
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
|
|
|
)
|
|
|
url = urlparse(url)
|
|
|
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
|
|
@@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
|
|
|
# ensure that we are correctly testing the fallback when "get_extra_attributes"
|
|
|
# is not implemented.
|
|
|
- mapping_provider = self.handler._user_mapping_provider
|
|
|
+ mapping_provider = self.provider._user_mapping_provider
|
|
|
with self.assertRaises(AttributeError):
|
|
|
_ = mapping_provider.get_extra_attributes
|
|
|
|
|
@@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
"username": username,
|
|
|
}
|
|
|
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
|
|
- self.handler._exchange_code = simple_async_mock(return_value=token)
|
|
|
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
|
- self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
|
|
+ self.provider._exchange_code = simple_async_mock(return_value=token)
|
|
|
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
|
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
|
|
auth_handler = self.hs.get_auth_handler()
|
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
|
|
|
@@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
auth_handler.complete_sso_login.assert_called_once_with(
|
|
|
expected_user_id, request, client_redirect_url, None,
|
|
|
)
|
|
|
- self.handler._exchange_code.assert_called_once_with(code)
|
|
|
- self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
|
|
- self.handler._fetch_userinfo.assert_not_called()
|
|
|
+ self.provider._exchange_code.assert_called_once_with(code)
|
|
|
+ self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
|
|
+ self.provider._fetch_userinfo.assert_not_called()
|
|
|
self.render_error.assert_not_called()
|
|
|
|
|
|
# Handle mapping errors
|
|
|
with patch.object(
|
|
|
- self.handler,
|
|
|
+ self.provider,
|
|
|
"_remote_id_from_userinfo",
|
|
|
new=Mock(side_effect=MappingException()),
|
|
|
):
|
|
@@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
self.assertRenderedError("mapping_error")
|
|
|
|
|
|
# Handle ID token errors
|
|
|
- self.handler._parse_id_token = simple_async_mock(raises=Exception())
|
|
|
+ self.provider._parse_id_token = simple_async_mock(raises=Exception())
|
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
|
self.assertRenderedError("invalid_token")
|
|
|
|
|
|
auth_handler.complete_sso_login.reset_mock()
|
|
|
- self.handler._exchange_code.reset_mock()
|
|
|
- self.handler._parse_id_token.reset_mock()
|
|
|
- self.handler._fetch_userinfo.reset_mock()
|
|
|
+ self.provider._exchange_code.reset_mock()
|
|
|
+ self.provider._parse_id_token.reset_mock()
|
|
|
+ self.provider._fetch_userinfo.reset_mock()
|
|
|
|
|
|
# With userinfo fetching
|
|
|
- self.handler._scopes = [] # do not ask the "openid" scope
|
|
|
+ self.provider._scopes = [] # do not ask the "openid" scope
|
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
|
|
|
|
auth_handler.complete_sso_login.assert_called_once_with(
|
|
|
expected_user_id, request, client_redirect_url, None,
|
|
|
)
|
|
|
- self.handler._exchange_code.assert_called_once_with(code)
|
|
|
- self.handler._parse_id_token.assert_not_called()
|
|
|
- self.handler._fetch_userinfo.assert_called_once_with(token)
|
|
|
+ self.provider._exchange_code.assert_called_once_with(code)
|
|
|
+ self.provider._parse_id_token.assert_not_called()
|
|
|
+ self.provider._fetch_userinfo.assert_called_once_with(token)
|
|
|
self.render_error.assert_not_called()
|
|
|
|
|
|
# Handle userinfo fetching error
|
|
|
- self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
|
|
|
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
|
self.assertRenderedError("fetch_error")
|
|
|
|
|
|
# Handle code exchange failure
|
|
|
from synapse.handlers.oidc_handler import OidcError
|
|
|
|
|
|
- self.handler._exchange_code = simple_async_mock(
|
|
|
+ self.provider._exchange_code = simple_async_mock(
|
|
|
raises=OidcError("invalid_request")
|
|
|
)
|
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
@@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
|
|
|
)
|
|
|
code = "code"
|
|
|
- ret = self.get_success(self.handler._exchange_code(code))
|
|
|
+ ret = self.get_success(self.provider._exchange_code(code))
|
|
|
kwargs = self.http_client.request.call_args[1]
|
|
|
|
|
|
self.assertEqual(ret, token)
|
|
@@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
)
|
|
|
from synapse.handlers.oidc_handler import OidcError
|
|
|
|
|
|
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
|
|
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
|
|
self.assertEqual(exc.value.error, "foo")
|
|
|
self.assertEqual(exc.value.error_description, "bar")
|
|
|
|
|
@@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
|
|
|
)
|
|
|
)
|
|
|
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
|
|
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
|
|
self.assertEqual(exc.value.error, "server_error")
|
|
|
|
|
|
# Internal server error with JSON body
|
|
@@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
)
|
|
|
)
|
|
|
|
|
|
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
|
|
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
|
|
self.assertEqual(exc.value.error, "internal_server_error")
|
|
|
|
|
|
# 4xx error without "error" field
|
|
|
self.http_client.request = simple_async_mock(
|
|
|
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
|
|
|
)
|
|
|
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
|
|
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
|
|
self.assertEqual(exc.value.error, "server_error")
|
|
|
|
|
|
# 2xx error with "error" field
|
|
@@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
|
|
|
)
|
|
|
)
|
|
|
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
|
|
|
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
|
|
|
self.assertEqual(exc.value.error, "some_error")
|
|
|
|
|
|
@override_config(
|
|
@@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
|
"username": "foo",
|
|
|
"phone": "1234567",
|
|
|
}
|
|
|
- self.handler._exchange_code = simple_async_mock(return_value=token)
|
|
|
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
|
+ self.provider._exchange_code = simple_async_mock(return_value=token)
|
|
|
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
|
auth_handler = self.hs.get_auth_handler()
|
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
|
|
|
@@ -979,9 +981,10 @@ async def _make_callback_with_userinfo(
|
|
|
from synapse.handlers.oidc_handler import OidcSessionData
|
|
|
|
|
|
handler = hs.get_oidc_handler()
|
|
|
- handler._exchange_code = simple_async_mock(return_value={})
|
|
|
- handler._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
|
- handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
|
|
+ provider = handler._provider
|
|
|
+ provider._exchange_code = simple_async_mock(return_value={})
|
|
|
+ provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
|
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
|
|
|
|
|
state = "state"
|
|
|
session = handler._token_generator.generate_oidc_session_token(
|