test_auth.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. from unittest.mock import Mock
  16. import pymacaroons
  17. from twisted.test.proto_helpers import MemoryReactor
  18. from synapse.api.errors import AuthError, ResourceLimitError
  19. from synapse.rest import admin
  20. from synapse.rest.client import login
  21. from synapse.server import HomeServer
  22. from synapse.util import Clock
  23. from tests import unittest
  24. from tests.test_utils import make_awaitable
  25. class AuthTestCase(unittest.HomeserverTestCase):
  26. servlets = [
  27. admin.register_servlets,
  28. login.register_servlets,
  29. ]
  30. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  31. self.auth_handler = hs.get_auth_handler()
  32. self.macaroon_generator = hs.get_macaroon_generator()
  33. # MAU tests
  34. # AuthBlocking reads from the hs' config on initialization. We need to
  35. # modify its config instead of the hs'
  36. self.auth_blocking = hs.get_auth_blocking()
  37. self.auth_blocking._max_mau_value = 50
  38. self.small_number_of_users = 1
  39. self.large_number_of_users = 100
  40. self.user1 = self.register_user("a_user", "pass")
  41. def token_login(self, token: str) -> Optional[str]:
  42. body = {
  43. "type": "m.login.token",
  44. "token": token,
  45. }
  46. channel = self.make_request(
  47. "POST",
  48. "/_matrix/client/v3/login",
  49. body,
  50. )
  51. if channel.code == 200:
  52. return channel.json_body["user_id"]
  53. return None
  54. def test_macaroon_caveats(self) -> None:
  55. token = self.macaroon_generator.generate_guest_access_token("a_user")
  56. macaroon = pymacaroons.Macaroon.deserialize(token)
  57. def verify_gen(caveat: str) -> bool:
  58. return caveat == "gen = 1"
  59. def verify_user(caveat: str) -> bool:
  60. return caveat == "user_id = a_user"
  61. def verify_type(caveat: str) -> bool:
  62. return caveat == "type = access"
  63. def verify_nonce(caveat: str) -> bool:
  64. return caveat.startswith("nonce =")
  65. def verify_guest(caveat: str) -> bool:
  66. return caveat == "guest = true"
  67. v = pymacaroons.Verifier()
  68. v.satisfy_general(verify_gen)
  69. v.satisfy_general(verify_user)
  70. v.satisfy_general(verify_type)
  71. v.satisfy_general(verify_nonce)
  72. v.satisfy_general(verify_guest)
  73. v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
  74. def test_login_token_gives_user_id(self) -> None:
  75. token = self.get_success(
  76. self.auth_handler.create_login_token_for_user_id(
  77. self.user1,
  78. duration_ms=(5 * 1000),
  79. )
  80. )
  81. res = self.get_success(self.auth_handler.consume_login_token(token))
  82. self.assertEqual(self.user1, res.user_id)
  83. self.assertEqual(None, res.auth_provider_id)
  84. def test_login_token_reuse_fails(self) -> None:
  85. token = self.get_success(
  86. self.auth_handler.create_login_token_for_user_id(
  87. self.user1,
  88. duration_ms=(5 * 1000),
  89. )
  90. )
  91. self.get_success(self.auth_handler.consume_login_token(token))
  92. self.get_failure(
  93. self.auth_handler.consume_login_token(token),
  94. AuthError,
  95. )
  96. def test_login_token_expires(self) -> None:
  97. token = self.get_success(
  98. self.auth_handler.create_login_token_for_user_id(
  99. self.user1,
  100. duration_ms=(5 * 1000),
  101. )
  102. )
  103. # when we advance the clock, the token should be rejected
  104. self.reactor.advance(6)
  105. self.get_failure(
  106. self.auth_handler.consume_login_token(token),
  107. AuthError,
  108. )
  109. def test_login_token_gives_auth_provider(self) -> None:
  110. token = self.get_success(
  111. self.auth_handler.create_login_token_for_user_id(
  112. self.user1,
  113. auth_provider_id="my_idp",
  114. auth_provider_session_id="11-22-33-44",
  115. duration_ms=(5 * 1000),
  116. )
  117. )
  118. res = self.get_success(self.auth_handler.consume_login_token(token))
  119. self.assertEqual(self.user1, res.user_id)
  120. self.assertEqual("my_idp", res.auth_provider_id)
  121. self.assertEqual("11-22-33-44", res.auth_provider_session_id)
  122. def test_mau_limits_disabled(self) -> None:
  123. self.auth_blocking._limit_usage_by_mau = False
  124. # Ensure does not throw exception
  125. self.get_success(
  126. self.auth_handler.create_access_token_for_user_id(
  127. self.user1, device_id=None, valid_until_ms=None
  128. )
  129. )
  130. token = self.get_success(
  131. self.auth_handler.create_login_token_for_user_id(self.user1)
  132. )
  133. self.assertIsNotNone(self.token_login(token))
  134. def test_mau_limits_exceeded_large(self) -> None:
  135. self.auth_blocking._limit_usage_by_mau = True
  136. self.hs.get_datastores().main.get_monthly_active_count = Mock(
  137. return_value=make_awaitable(self.large_number_of_users)
  138. )
  139. self.get_failure(
  140. self.auth_handler.create_access_token_for_user_id(
  141. self.user1, device_id=None, valid_until_ms=None
  142. ),
  143. ResourceLimitError,
  144. )
  145. self.hs.get_datastores().main.get_monthly_active_count = Mock(
  146. return_value=make_awaitable(self.large_number_of_users)
  147. )
  148. token = self.get_success(
  149. self.auth_handler.create_login_token_for_user_id(self.user1)
  150. )
  151. self.assertIsNone(self.token_login(token))
  152. def test_mau_limits_parity(self) -> None:
  153. # Ensure we're not at the unix epoch.
  154. self.reactor.advance(1)
  155. self.auth_blocking._limit_usage_by_mau = True
  156. # Set the server to be at the edge of too many users.
  157. self.hs.get_datastores().main.get_monthly_active_count = Mock(
  158. return_value=make_awaitable(self.auth_blocking._max_mau_value)
  159. )
  160. # If not in monthly active cohort
  161. self.get_failure(
  162. self.auth_handler.create_access_token_for_user_id(
  163. self.user1, device_id=None, valid_until_ms=None
  164. ),
  165. ResourceLimitError,
  166. )
  167. token = self.get_success(
  168. self.auth_handler.create_login_token_for_user_id(self.user1)
  169. )
  170. self.assertIsNone(self.token_login(token))
  171. # If in monthly active cohort
  172. self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
  173. return_value=make_awaitable(self.clock.time_msec())
  174. )
  175. self.get_success(
  176. self.auth_handler.create_access_token_for_user_id(
  177. self.user1, device_id=None, valid_until_ms=None
  178. )
  179. )
  180. token = self.get_success(
  181. self.auth_handler.create_login_token_for_user_id(self.user1)
  182. )
  183. self.assertIsNotNone(self.token_login(token))
  184. def test_mau_limits_not_exceeded(self) -> None:
  185. self.auth_blocking._limit_usage_by_mau = True
  186. self.hs.get_datastores().main.get_monthly_active_count = Mock(
  187. return_value=make_awaitable(self.small_number_of_users)
  188. )
  189. # Ensure does not raise exception
  190. self.get_success(
  191. self.auth_handler.create_access_token_for_user_id(
  192. self.user1, device_id=None, valid_until_ms=None
  193. )
  194. )
  195. self.hs.get_datastores().main.get_monthly_active_count = Mock(
  196. return_value=make_awaitable(self.small_number_of_users)
  197. )
  198. token = self.get_success(
  199. self.auth_handler.create_login_token_for_user_id(self.user1)
  200. )
  201. self.assertIsNotNone(self.token_login(token))