test_cas.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict
  15. from unittest.mock import AsyncMock, Mock
  16. from twisted.test.proto_helpers import MemoryReactor
  17. from synapse.handlers.cas import CasResponse
  18. from synapse.server import HomeServer
  19. from synapse.util import Clock
  20. from tests.unittest import HomeserverTestCase, override_config
  21. # These are a few constants that are used as config parameters in the tests.
  22. BASE_URL = "https://synapse/"
  23. SERVER_URL = "https://issuer/"
  24. class CasHandlerTestCase(HomeserverTestCase):
  25. def default_config(self) -> Dict[str, Any]:
  26. config = super().default_config()
  27. config["public_baseurl"] = BASE_URL
  28. cas_config = {
  29. "enabled": True,
  30. "server_url": SERVER_URL,
  31. "service_url": BASE_URL,
  32. }
  33. # Update this config with what's in the default config so that
  34. # override_config works as expected.
  35. cas_config.update(config.get("cas_config", {}))
  36. config["cas_config"] = cas_config
  37. return config
  38. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  39. hs = self.setup_test_homeserver()
  40. self.handler = hs.get_cas_handler()
  41. # Reduce the number of attempts when generating MXIDs.
  42. sso_handler = hs.get_sso_handler()
  43. sso_handler._MAP_USERNAME_RETRIES = 3
  44. return hs
  45. def test_map_cas_user_to_user(self) -> None:
  46. """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
  47. # stub out the auth handler
  48. auth_handler = self.hs.get_auth_handler()
  49. auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
  50. cas_response = CasResponse("test_user", {})
  51. request = _mock_request()
  52. self.get_success(
  53. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  54. )
  55. # check that the auth handler got called as expected
  56. auth_handler.complete_sso_login.assert_called_once_with(
  57. "@test_user:test",
  58. "cas",
  59. request,
  60. "redirect_uri",
  61. None,
  62. new_user=True,
  63. auth_provider_session_id=None,
  64. )
  65. def test_map_cas_user_to_existing_user(self) -> None:
  66. """Existing users can log in with CAS account."""
  67. store = self.hs.get_datastores().main
  68. self.get_success(
  69. store.register_user(user_id="@test_user:test", password_hash=None)
  70. )
  71. # stub out the auth handler
  72. auth_handler = self.hs.get_auth_handler()
  73. auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
  74. # Map a user via SSO.
  75. cas_response = CasResponse("test_user", {})
  76. request = _mock_request()
  77. self.get_success(
  78. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  79. )
  80. # check that the auth handler got called as expected
  81. auth_handler.complete_sso_login.assert_called_once_with(
  82. "@test_user:test",
  83. "cas",
  84. request,
  85. "redirect_uri",
  86. None,
  87. new_user=False,
  88. auth_provider_session_id=None,
  89. )
  90. # Subsequent calls should map to the same mxid.
  91. auth_handler.complete_sso_login.reset_mock()
  92. self.get_success(
  93. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  94. )
  95. auth_handler.complete_sso_login.assert_called_once_with(
  96. "@test_user:test",
  97. "cas",
  98. request,
  99. "redirect_uri",
  100. None,
  101. new_user=False,
  102. auth_provider_session_id=None,
  103. )
  104. def test_map_cas_user_to_invalid_localpart(self) -> None:
  105. """CAS automaps invalid characters to base-64 encoding."""
  106. # stub out the auth handler
  107. auth_handler = self.hs.get_auth_handler()
  108. auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
  109. cas_response = CasResponse("föö", {})
  110. request = _mock_request()
  111. self.get_success(
  112. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  113. )
  114. # check that the auth handler got called as expected
  115. auth_handler.complete_sso_login.assert_called_once_with(
  116. "@f=c3=b6=c3=b6:test",
  117. "cas",
  118. request,
  119. "redirect_uri",
  120. None,
  121. new_user=True,
  122. auth_provider_session_id=None,
  123. )
  124. @override_config(
  125. {
  126. "cas_config": {
  127. "required_attributes": {"userGroup": "staff", "department": None}
  128. }
  129. }
  130. )
  131. def test_required_attributes(self) -> None:
  132. """The required attributes must be met from the CAS response."""
  133. # stub out the auth handler
  134. auth_handler = self.hs.get_auth_handler()
  135. auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
  136. # The response doesn't have the proper userGroup or department.
  137. cas_response = CasResponse("test_user", {})
  138. request = _mock_request()
  139. self.get_success(
  140. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  141. )
  142. auth_handler.complete_sso_login.assert_not_called()
  143. # The response doesn't have any department.
  144. cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
  145. request.reset_mock()
  146. self.get_success(
  147. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  148. )
  149. auth_handler.complete_sso_login.assert_not_called()
  150. # Add the proper attributes and it should succeed.
  151. cas_response = CasResponse(
  152. "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
  153. )
  154. request.reset_mock()
  155. self.get_success(
  156. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  157. )
  158. # check that the auth handler got called as expected
  159. auth_handler.complete_sso_login.assert_called_once_with(
  160. "@test_user:test",
  161. "cas",
  162. request,
  163. "redirect_uri",
  164. None,
  165. new_user=True,
  166. auth_provider_session_id=None,
  167. )
  168. @override_config({"cas_config": {"enable_registration": False}})
  169. def test_map_cas_user_does_not_register_new_user(self) -> None:
  170. """Ensures new users are not registered if the enabled registration flag is disabled."""
  171. # stub out the auth handler
  172. auth_handler = self.hs.get_auth_handler()
  173. auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
  174. cas_response = CasResponse("test_user", {})
  175. request = _mock_request()
  176. self.get_success(
  177. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  178. )
  179. # check that the auth handler was not called as expected
  180. auth_handler.complete_sso_login.assert_not_called()
  181. def _mock_request() -> Mock:
  182. """Returns a mock which will stand in as a SynapseRequest"""
  183. mock = Mock(
  184. spec=[
  185. "finish",
  186. "getClientAddress",
  187. "getHeader",
  188. "setHeader",
  189. "setResponseCode",
  190. "write",
  191. ]
  192. )
  193. # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
  194. mock._disconnected = False
  195. return mock