test_cas.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from unittest.mock import Mock
  15. from synapse.handlers.cas import CasResponse
  16. from tests.test_utils import simple_async_mock
  17. from tests.unittest import HomeserverTestCase, override_config
  18. # These are a few constants that are used as config parameters in the tests.
  19. BASE_URL = "https://synapse/"
  20. SERVER_URL = "https://issuer/"
  21. class CasHandlerTestCase(HomeserverTestCase):
  22. def default_config(self):
  23. config = super().default_config()
  24. config["public_baseurl"] = BASE_URL
  25. cas_config = {
  26. "enabled": True,
  27. "server_url": SERVER_URL,
  28. "service_url": BASE_URL,
  29. }
  30. # Update this config with what's in the default config so that
  31. # override_config works as expected.
  32. cas_config.update(config.get("cas_config", {}))
  33. config["cas_config"] = cas_config
  34. return config
  35. def make_homeserver(self, reactor, clock):
  36. hs = self.setup_test_homeserver()
  37. self.handler = hs.get_cas_handler()
  38. # Reduce the number of attempts when generating MXIDs.
  39. sso_handler = hs.get_sso_handler()
  40. sso_handler._MAP_USERNAME_RETRIES = 3
  41. return hs
  42. def test_map_cas_user_to_user(self):
  43. """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
  44. # stub out the auth handler
  45. auth_handler = self.hs.get_auth_handler()
  46. auth_handler.complete_sso_login = simple_async_mock()
  47. cas_response = CasResponse("test_user", {})
  48. request = _mock_request()
  49. self.get_success(
  50. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  51. )
  52. # check that the auth handler got called as expected
  53. auth_handler.complete_sso_login.assert_called_once_with(
  54. "@test_user:test",
  55. "cas",
  56. request,
  57. "redirect_uri",
  58. None,
  59. new_user=True,
  60. auth_provider_session_id=None,
  61. )
  62. def test_map_cas_user_to_existing_user(self):
  63. """Existing users can log in with CAS account."""
  64. store = self.hs.get_datastores().main
  65. self.get_success(
  66. store.register_user(user_id="@test_user:test", password_hash=None)
  67. )
  68. # stub out the auth handler
  69. auth_handler = self.hs.get_auth_handler()
  70. auth_handler.complete_sso_login = simple_async_mock()
  71. # Map a user via SSO.
  72. cas_response = CasResponse("test_user", {})
  73. request = _mock_request()
  74. self.get_success(
  75. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  76. )
  77. # check that the auth handler got called as expected
  78. auth_handler.complete_sso_login.assert_called_once_with(
  79. "@test_user:test",
  80. "cas",
  81. request,
  82. "redirect_uri",
  83. None,
  84. new_user=False,
  85. auth_provider_session_id=None,
  86. )
  87. # Subsequent calls should map to the same mxid.
  88. auth_handler.complete_sso_login.reset_mock()
  89. self.get_success(
  90. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  91. )
  92. auth_handler.complete_sso_login.assert_called_once_with(
  93. "@test_user:test",
  94. "cas",
  95. request,
  96. "redirect_uri",
  97. None,
  98. new_user=False,
  99. auth_provider_session_id=None,
  100. )
  101. def test_map_cas_user_to_invalid_localpart(self):
  102. """CAS automaps invalid characters to base-64 encoding."""
  103. # stub out the auth handler
  104. auth_handler = self.hs.get_auth_handler()
  105. auth_handler.complete_sso_login = simple_async_mock()
  106. cas_response = CasResponse("föö", {})
  107. request = _mock_request()
  108. self.get_success(
  109. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  110. )
  111. # check that the auth handler got called as expected
  112. auth_handler.complete_sso_login.assert_called_once_with(
  113. "@f=c3=b6=c3=b6:test",
  114. "cas",
  115. request,
  116. "redirect_uri",
  117. None,
  118. new_user=True,
  119. auth_provider_session_id=None,
  120. )
  121. @override_config(
  122. {
  123. "cas_config": {
  124. "required_attributes": {"userGroup": "staff", "department": None}
  125. }
  126. }
  127. )
  128. def test_required_attributes(self):
  129. """The required attributes must be met from the CAS response."""
  130. # stub out the auth handler
  131. auth_handler = self.hs.get_auth_handler()
  132. auth_handler.complete_sso_login = simple_async_mock()
  133. # The response doesn't have the proper userGroup or department.
  134. cas_response = CasResponse("test_user", {})
  135. request = _mock_request()
  136. self.get_success(
  137. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  138. )
  139. auth_handler.complete_sso_login.assert_not_called()
  140. # The response doesn't have any department.
  141. cas_response = CasResponse("test_user", {"userGroup": "staff"})
  142. request.reset_mock()
  143. self.get_success(
  144. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  145. )
  146. auth_handler.complete_sso_login.assert_not_called()
  147. # Add the proper attributes and it should succeed.
  148. cas_response = CasResponse(
  149. "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
  150. )
  151. request.reset_mock()
  152. self.get_success(
  153. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  154. )
  155. # check that the auth handler got called as expected
  156. auth_handler.complete_sso_login.assert_called_once_with(
  157. "@test_user:test",
  158. "cas",
  159. request,
  160. "redirect_uri",
  161. None,
  162. new_user=True,
  163. auth_provider_session_id=None,
  164. )
  165. def _mock_request():
  166. """Returns a mock which will stand in as a SynapseRequest"""
  167. return Mock(spec=["getClientIP", "getHeader", "_disconnected"])