|
@@ -13,14 +13,18 @@
|
|
# limitations under the License.
|
|
# limitations under the License.
|
|
import json
|
|
import json
|
|
import os
|
|
import os
|
|
|
|
+from typing import Any, Dict
|
|
from unittest.mock import ANY, Mock, patch
|
|
from unittest.mock import ANY, Mock, patch
|
|
from urllib.parse import parse_qs, urlparse
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
|
|
import pymacaroons
|
|
import pymacaroons
|
|
|
|
|
|
|
|
+from twisted.test.proto_helpers import MemoryReactor
|
|
|
|
+
|
|
from synapse.handlers.sso import MappingException
|
|
from synapse.handlers.sso import MappingException
|
|
from synapse.server import HomeServer
|
|
from synapse.server import HomeServer
|
|
-from synapse.types import UserID
|
|
|
|
|
|
+from synapse.types import JsonDict, UserID
|
|
|
|
+from synapse.util import Clock
|
|
from synapse.util.macaroons import get_value_from_macaroon
|
|
from synapse.util.macaroons import get_value_from_macaroon
|
|
|
|
|
|
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
|
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
|
|
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-async def get_json(url):
|
|
|
|
|
|
+async def get_json(url: str) -> JsonDict:
|
|
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
|
# Mock get_json calls to handle jwks & oidc discovery endpoints
|
|
if url == WELL_KNOWN:
|
|
if url == WELL_KNOWN:
|
|
# Minimal discovery document, as defined in OpenID.Discovery
|
|
# Minimal discovery document, as defined in OpenID.Discovery
|
|
@@ -116,6 +120,8 @@ async def get_json(url):
|
|
elif url == JWKS_URI:
|
|
elif url == JWKS_URI:
|
|
return {"keys": []}
|
|
return {"keys": []}
|
|
|
|
|
|
|
|
+ return {}
|
|
|
|
+
|
|
|
|
|
|
def _key_file_path() -> str:
|
|
def _key_file_path() -> str:
|
|
"""path to a file containing the private half of a test key"""
|
|
"""path to a file containing the private half of a test key"""
|
|
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
if not HAS_OIDC:
|
|
if not HAS_OIDC:
|
|
skip = "requires OIDC"
|
|
skip = "requires OIDC"
|
|
|
|
|
|
- def default_config(self):
|
|
|
|
|
|
+ def default_config(self) -> Dict[str, Any]:
|
|
config = super().default_config()
|
|
config = super().default_config()
|
|
config["public_baseurl"] = BASE_URL
|
|
config["public_baseurl"] = BASE_URL
|
|
return config
|
|
return config
|
|
|
|
|
|
- def make_homeserver(self, reactor, clock):
|
|
|
|
|
|
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
|
self.http_client = Mock(spec=["get_json"])
|
|
self.http_client = Mock(spec=["get_json"])
|
|
self.http_client.get_json.side_effect = get_json
|
|
self.http_client.get_json.side_effect = get_json
|
|
self.http_client.user_agent = b"Synapse Test"
|
|
self.http_client.user_agent = b"Synapse Test"
|
|
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
sso_handler = hs.get_sso_handler()
|
|
sso_handler = hs.get_sso_handler()
|
|
# Mock the render error method.
|
|
# Mock the render error method.
|
|
self.render_error = Mock(return_value=None)
|
|
self.render_error = Mock(return_value=None)
|
|
- sso_handler.render_error = self.render_error
|
|
|
|
|
|
+ sso_handler.render_error = self.render_error # type: ignore[assignment]
|
|
|
|
|
|
# Reduce the number of attempts when generating MXIDs.
|
|
# Reduce the number of attempts when generating MXIDs.
|
|
sso_handler._MAP_USERNAME_RETRIES = 3
|
|
sso_handler._MAP_USERNAME_RETRIES = 3
|
|
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
return args
|
|
return args
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
- def test_config(self):
|
|
|
|
|
|
+ def test_config(self) -> None:
|
|
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
|
"""Basic config correctly sets up the callback URL and client auth correctly."""
|
|
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
|
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
|
|
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
|
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
|
|
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
|
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
|
|
|
|
|
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
|
|
- def test_discovery(self):
|
|
|
|
|
|
+ def test_discovery(self) -> None:
|
|
"""The handler should discover the endpoints from OIDC discovery document."""
|
|
"""The handler should discover the endpoints from OIDC discovery document."""
|
|
# This would throw if some metadata were invalid
|
|
# This would throw if some metadata were invalid
|
|
metadata = self.get_success(self.provider.load_metadata())
|
|
metadata = self.get_success(self.provider.load_metadata())
|
|
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
self.http_client.get_json.assert_not_called()
|
|
self.http_client.get_json.assert_not_called()
|
|
|
|
|
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
|
- def test_no_discovery(self):
|
|
|
|
|
|
+ def test_no_discovery(self) -> None:
|
|
"""When discovery is disabled, it should not try to load from discovery document."""
|
|
"""When discovery is disabled, it should not try to load from discovery document."""
|
|
self.get_success(self.provider.load_metadata())
|
|
self.get_success(self.provider.load_metadata())
|
|
self.http_client.get_json.assert_not_called()
|
|
self.http_client.get_json.assert_not_called()
|
|
|
|
|
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
|
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
|
|
- def test_load_jwks(self):
|
|
|
|
|
|
+ def test_load_jwks(self) -> None:
|
|
"""JWKS loading is done once (then cached) if used."""
|
|
"""JWKS loading is done once (then cached) if used."""
|
|
jwks = self.get_success(self.provider.load_jwks())
|
|
jwks = self.get_success(self.provider.load_jwks())
|
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
|
self.http_client.get_json.assert_called_once_with(JWKS_URI)
|
|
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
|
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
- def test_validate_config(self):
|
|
|
|
|
|
+ def test_validate_config(self) -> None:
|
|
"""Provider metadatas are extensively validated."""
|
|
"""Provider metadatas are extensively validated."""
|
|
h = self.provider
|
|
h = self.provider
|
|
|
|
|
|
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
force_load_metadata()
|
|
force_load_metadata()
|
|
|
|
|
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
|
|
- def test_skip_verification(self):
|
|
|
|
|
|
+ def test_skip_verification(self) -> None:
|
|
"""Provider metadata validation can be disabled by config."""
|
|
"""Provider metadata validation can be disabled by config."""
|
|
with self.metadata_edit({"issuer": "http://insecure"}):
|
|
with self.metadata_edit({"issuer": "http://insecure"}):
|
|
# This should not throw
|
|
# This should not throw
|
|
get_awaitable_result(self.provider.load_metadata())
|
|
get_awaitable_result(self.provider.load_metadata())
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
- def test_redirect_request(self):
|
|
|
|
|
|
+ def test_redirect_request(self) -> None:
|
|
"""The redirect request has the right arguments & generates a valid session cookie."""
|
|
"""The redirect request has the right arguments & generates a valid session cookie."""
|
|
req = Mock(spec=["cookies"])
|
|
req = Mock(spec=["cookies"])
|
|
req.cookies = []
|
|
req.cookies = []
|
|
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
self.assertEqual(redirect, "http://client/redirect")
|
|
self.assertEqual(redirect, "http://client/redirect")
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
- def test_callback_error(self):
|
|
|
|
|
|
+ def test_callback_error(self) -> None:
|
|
"""Errors from the provider returned in the callback are displayed."""
|
|
"""Errors from the provider returned in the callback are displayed."""
|
|
request = Mock(args={})
|
|
request = Mock(args={})
|
|
request.args[b"error"] = [b"invalid_client"]
|
|
request.args[b"error"] = [b"invalid_client"]
|
|
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
self.assertRenderedError("invalid_client", "some description")
|
|
self.assertRenderedError("invalid_client", "some description")
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
- def test_callback(self):
|
|
|
|
|
|
+ def test_callback(self) -> None:
|
|
"""Code callback works and display errors if something went wrong.
|
|
"""Code callback works and display errors if something went wrong.
|
|
|
|
|
|
A lot of scenarios are tested here:
|
|
A lot of scenarios are tested here:
|
|
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
"username": username,
|
|
"username": username,
|
|
}
|
|
}
|
|
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
|
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
|
|
- self.provider._exchange_code = simple_async_mock(return_value=token)
|
|
|
|
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
|
|
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
|
|
|
|
|
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
|
|
|
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
|
|
|
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
|
|
|
|
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
self.assertRenderedError("mapping_error")
|
|
self.assertRenderedError("mapping_error")
|
|
|
|
|
|
# Handle ID token errors
|
|
# Handle ID token errors
|
|
- self.provider._parse_id_token = simple_async_mock(raises=Exception())
|
|
|
|
|
|
+ self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
self.assertRenderedError("invalid_token")
|
|
self.assertRenderedError("invalid_token")
|
|
|
|
|
|
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
"type": "bearer",
|
|
"type": "bearer",
|
|
"access_token": "access_token",
|
|
"access_token": "access_token",
|
|
}
|
|
}
|
|
- self.provider._exchange_code = simple_async_mock(return_value=token)
|
|
|
|
|
|
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
|
|
|
|
auth_handler.complete_sso_login.assert_called_once_with(
|
|
auth_handler.complete_sso_login.assert_called_once_with(
|
|
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
id_token = {
|
|
id_token = {
|
|
"sid": "abcdefgh",
|
|
"sid": "abcdefgh",
|
|
}
|
|
}
|
|
- self.provider._parse_id_token = simple_async_mock(return_value=id_token)
|
|
|
|
- self.provider._exchange_code = simple_async_mock(return_value=token)
|
|
|
|
|
|
+ self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
|
|
|
|
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
|
auth_handler.complete_sso_login.reset_mock()
|
|
auth_handler.complete_sso_login.reset_mock()
|
|
self.provider._fetch_userinfo.reset_mock()
|
|
self.provider._fetch_userinfo.reset_mock()
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
self.render_error.assert_not_called()
|
|
self.render_error.assert_not_called()
|
|
|
|
|
|
# Handle userinfo fetching error
|
|
# Handle userinfo fetching error
|
|
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
|
|
|
|
|
|
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
self.assertRenderedError("fetch_error")
|
|
self.assertRenderedError("fetch_error")
|
|
|
|
|
|
# Handle code exchange failure
|
|
# Handle code exchange failure
|
|
from synapse.handlers.oidc import OidcError
|
|
from synapse.handlers.oidc import OidcError
|
|
|
|
|
|
- self.provider._exchange_code = simple_async_mock(
|
|
|
|
|
|
+ self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
|
|
raises=OidcError("invalid_request")
|
|
raises=OidcError("invalid_request")
|
|
)
|
|
)
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
self.get_success(self.handler.handle_oidc_callback(request))
|
|
self.assertRenderedError("invalid_request")
|
|
self.assertRenderedError("invalid_request")
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
- def test_callback_session(self):
|
|
|
|
|
|
+ def test_callback_session(self) -> None:
|
|
"""The callback verifies the session presence and validity"""
|
|
"""The callback verifies the session presence and validity"""
|
|
request = Mock(spec=["args", "getCookie", "cookies"])
|
|
request = Mock(spec=["args", "getCookie", "cookies"])
|
|
|
|
|
|
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
@override_config(
|
|
@override_config(
|
|
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
|
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
|
|
)
|
|
)
|
|
- def test_exchange_code(self):
|
|
|
|
|
|
+ def test_exchange_code(self) -> None:
|
|
"""Code exchange behaves correctly and handles various error scenarios."""
|
|
"""Code exchange behaves correctly and handles various error scenarios."""
|
|
token = {"type": "bearer"}
|
|
token = {"type": "bearer"}
|
|
token_json = json.dumps(token).encode("utf-8")
|
|
token_json = json.dumps(token).encode("utf-8")
|
|
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
)
|
|
)
|
|
- def test_exchange_code_jwt_key(self):
|
|
|
|
|
|
+ def test_exchange_code_jwt_key(self) -> None:
|
|
"""Test that code exchange works with a JWK client secret."""
|
|
"""Test that code exchange works with a JWK client secret."""
|
|
from authlib.jose import jwt
|
|
from authlib.jose import jwt
|
|
|
|
|
|
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
)
|
|
)
|
|
- def test_exchange_code_no_auth(self):
|
|
|
|
|
|
+ def test_exchange_code_no_auth(self) -> None:
|
|
"""Test that code exchange works with no client secret."""
|
|
"""Test that code exchange works with no client secret."""
|
|
token = {"type": "bearer"}
|
|
token = {"type": "bearer"}
|
|
self.http_client.request = simple_async_mock(
|
|
self.http_client.request = simple_async_mock(
|
|
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
)
|
|
)
|
|
- def test_extra_attributes(self):
|
|
|
|
|
|
+ def test_extra_attributes(self) -> None:
|
|
"""
|
|
"""
|
|
Login while using a mapping provider that implements get_extra_attributes.
|
|
Login while using a mapping provider that implements get_extra_attributes.
|
|
"""
|
|
"""
|
|
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
"username": "foo",
|
|
"username": "foo",
|
|
"phone": "1234567",
|
|
"phone": "1234567",
|
|
}
|
|
}
|
|
- self.provider._exchange_code = simple_async_mock(return_value=token)
|
|
|
|
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
|
|
|
|
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
|
|
|
|
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
|
|
|
|
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
)
|
|
)
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
- def test_map_userinfo_to_user(self):
|
|
|
|
|
|
+ def test_map_userinfo_to_user(self) -> None:
|
|
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
|
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
|
|
|
|
- userinfo = {
|
|
|
|
|
|
+ userinfo: dict = {
|
|
"sub": "test_user",
|
|
"sub": "test_user",
|
|
"username": "test_user",
|
|
"username": "test_user",
|
|
}
|
|
}
|
|
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
)
|
|
)
|
|
|
|
|
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
|
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
|
|
- def test_map_userinfo_to_existing_user(self):
|
|
|
|
|
|
+ def test_map_userinfo_to_existing_user(self) -> None:
|
|
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
|
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
|
|
store = self.hs.get_datastores().main
|
|
store = self.hs.get_datastores().main
|
|
user = UserID.from_string("@test_user:test")
|
|
user = UserID.from_string("@test_user:test")
|
|
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
)
|
|
)
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
- def test_map_userinfo_to_invalid_localpart(self):
|
|
|
|
|
|
+ def test_map_userinfo_to_invalid_localpart(self) -> None:
|
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
|
"""If the mapping provider generates an invalid localpart it should be rejected."""
|
|
self.get_success(
|
|
self.get_success(
|
|
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
|
|
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
|
|
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
)
|
|
)
|
|
- def test_map_userinfo_to_user_retries(self):
|
|
|
|
|
|
+ def test_map_userinfo_to_user_retries(self) -> None:
|
|
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
|
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
)
|
|
)
|
|
|
|
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
@override_config({"oidc_config": DEFAULT_CONFIG})
|
|
- def test_empty_localpart(self):
|
|
|
|
|
|
+ def test_empty_localpart(self) -> None:
|
|
"""Attempts to map onto an empty localpart should be rejected."""
|
|
"""Attempts to map onto an empty localpart should be rejected."""
|
|
userinfo = {
|
|
userinfo = {
|
|
"sub": "tester",
|
|
"sub": "tester",
|
|
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
)
|
|
)
|
|
- def test_null_localpart(self):
|
|
|
|
|
|
+ def test_null_localpart(self) -> None:
|
|
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
|
|
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
|
|
userinfo = {
|
|
userinfo = {
|
|
"sub": "tester",
|
|
"sub": "tester",
|
|
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
)
|
|
)
|
|
- def test_attribute_requirements(self):
|
|
|
|
|
|
+ def test_attribute_requirements(self) -> None:
|
|
"""The required attributes must be met from the OIDC userinfo response."""
|
|
"""The required attributes must be met from the OIDC userinfo response."""
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
)
|
|
)
|
|
- def test_attribute_requirements_contains(self):
|
|
|
|
|
|
+ def test_attribute_requirements_contains(self) -> None:
|
|
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
|
|
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
}
|
|
}
|
|
}
|
|
}
|
|
)
|
|
)
|
|
- def test_attribute_requirements_mismatch(self):
|
|
|
|
|
|
+ def test_attribute_requirements_mismatch(self) -> None:
|
|
"""
|
|
"""
|
|
Test that auth fails if attributes exist but don't match,
|
|
Test that auth fails if attributes exist but don't match,
|
|
or are non-string values.
|
|
or are non-string values.
|
|
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler = self.hs.get_auth_handler()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
auth_handler.complete_sso_login = simple_async_mock()
|
|
# userinfo with "test": "not_foobar" attribute should fail
|
|
# userinfo with "test": "not_foobar" attribute should fail
|
|
- userinfo = {
|
|
|
|
|
|
+ userinfo: dict = {
|
|
"sub": "tester",
|
|
"sub": "tester",
|
|
"username": "tester",
|
|
"username": "tester",
|
|
"test": "not_foobar",
|
|
"test": "not_foobar",
|
|
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
|
|
|
|
|
|
handler = hs.get_oidc_handler()
|
|
handler = hs.get_oidc_handler()
|
|
provider = handler._providers["oidc"]
|
|
provider = handler._providers["oidc"]
|
|
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
|
|
|
|
- provider._parse_id_token = simple_async_mock(return_value=userinfo)
|
|
|
|
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
|
|
|
|
|
+ provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
|
|
|
|
+ provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
|
|
|
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
|
|
|
|
|
|
state = "state"
|
|
state = "state"
|
|
session = handler._token_generator.generate_oidc_session_token(
|
|
session = handler._token_generator.generate_oidc_session_token(
|