test_password_providers.py 39 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009
  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Tests for the password_auth_provider interface"""
  15. from http import HTTPStatus
  16. from typing import Any, Type, Union
  17. from unittest.mock import Mock
  18. import synapse
  19. from synapse.api.constants import LoginType
  20. from synapse.api.errors import Codes
  21. from synapse.handlers.auth import load_legacy_password_auth_providers
  22. from synapse.module_api import ModuleApi
  23. from synapse.rest.client import account, devices, login, logout, register
  24. from synapse.types import JsonDict, UserID
  25. from tests import unittest
  26. from tests.server import FakeChannel
  27. from tests.test_utils import make_awaitable
  28. from tests.unittest import override_config
  29. # Login flows we expect to appear in the list after the normal ones.
  30. ADDITIONAL_LOGIN_FLOWS = [
  31. {"type": "m.login.application_service"},
  32. ]
  33. # a mock instance which the dummy auth providers delegate to, so we can see what's going
  34. # on
  35. mock_password_provider = Mock()
  36. class LegacyPasswordOnlyAuthProvider:
  37. """A legacy password_provider which only implements `check_password`."""
  38. @staticmethod
  39. def parse_config(self):
  40. pass
  41. def __init__(self, config, account_handler):
  42. pass
  43. def check_password(self, *args):
  44. return mock_password_provider.check_password(*args)
  45. class LegacyCustomAuthProvider:
  46. """A legacy password_provider which implements a custom login type."""
  47. @staticmethod
  48. def parse_config(self):
  49. pass
  50. def __init__(self, config, account_handler):
  51. pass
  52. def get_supported_login_types(self):
  53. return {"test.login_type": ["test_field"]}
  54. def check_auth(self, *args):
  55. return mock_password_provider.check_auth(*args)
  56. class CustomAuthProvider:
  57. """A module which registers password_auth_provider callbacks for a custom login type."""
  58. @staticmethod
  59. def parse_config(self):
  60. pass
  61. def __init__(self, config, api: ModuleApi):
  62. api.register_password_auth_provider_callbacks(
  63. auth_checkers={("test.login_type", ("test_field",)): self.check_auth}
  64. )
  65. def check_auth(self, *args):
  66. return mock_password_provider.check_auth(*args)
  67. class LegacyPasswordCustomAuthProvider:
  68. """A password_provider which implements password login via `check_auth`, as well
  69. as a custom type."""
  70. @staticmethod
  71. def parse_config(self):
  72. pass
  73. def __init__(self, config, account_handler):
  74. pass
  75. def get_supported_login_types(self):
  76. return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
  77. def check_auth(self, *args):
  78. return mock_password_provider.check_auth(*args)
  79. class PasswordCustomAuthProvider:
  80. """A module which registers password_auth_provider callbacks for a custom login type.
  81. as well as a password login"""
  82. @staticmethod
  83. def parse_config(self):
  84. pass
  85. def __init__(self, config, api: ModuleApi):
  86. api.register_password_auth_provider_callbacks(
  87. auth_checkers={
  88. ("test.login_type", ("test_field",)): self.check_auth,
  89. ("m.login.password", ("password",)): self.check_auth,
  90. }
  91. )
  92. def check_auth(self, *args):
  93. return mock_password_provider.check_auth(*args)
  94. def check_pass(self, *args):
  95. return mock_password_provider.check_password(*args)
  96. def legacy_providers_config(*providers: Type[Any]) -> dict:
  97. """Returns a config dict that will enable the given legacy password auth providers"""
  98. return {
  99. "password_providers": [
  100. {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
  101. for provider in providers
  102. ]
  103. }
  104. def providers_config(*providers: Type[Any]) -> dict:
  105. """Returns a config dict that will enable the given modules"""
  106. return {
  107. "modules": [
  108. {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
  109. for provider in providers
  110. ]
  111. }
  112. class PasswordAuthProviderTests(unittest.HomeserverTestCase):
  113. servlets = [
  114. synapse.rest.admin.register_servlets,
  115. login.register_servlets,
  116. devices.register_servlets,
  117. logout.register_servlets,
  118. register.register_servlets,
  119. account.register_servlets,
  120. ]
  121. CALLBACK_USERNAME = "get_username_for_registration"
  122. CALLBACK_DISPLAYNAME = "get_displayname_for_registration"
  123. def setUp(self):
  124. # we use a global mock device, so make sure we are starting with a clean slate
  125. mock_password_provider.reset_mock()
  126. super().setUp()
  127. def make_homeserver(self, reactor, clock):
  128. hs = self.setup_test_homeserver()
  129. # Load the modules into the homeserver
  130. module_api = hs.get_module_api()
  131. for module, config in hs.config.modules.loaded_modules:
  132. module(config=config, api=module_api)
  133. load_legacy_password_auth_providers(hs)
  134. return hs
  135. @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
  136. def test_password_only_auth_progiver_login_legacy(self):
  137. self.password_only_auth_provider_login_test_body()
  138. def password_only_auth_provider_login_test_body(self):
  139. # login flows should only have m.login.password
  140. flows = self._get_login_flows()
  141. self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
  142. # check_password must return an awaitable
  143. mock_password_provider.check_password.return_value = make_awaitable(True)
  144. channel = self._send_password_login("u", "p")
  145. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  146. self.assertEqual("@u:test", channel.json_body["user_id"])
  147. mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
  148. mock_password_provider.reset_mock()
  149. # login with mxid should work too
  150. channel = self._send_password_login("@u:bz", "p")
  151. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  152. self.assertEqual("@u:bz", channel.json_body["user_id"])
  153. mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
  154. mock_password_provider.reset_mock()
  155. # try a weird username / pass. Honestly it's unclear what we *expect* to happen
  156. # in these cases, but at least we can guard against the API changing
  157. # unexpectedly
  158. channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
  159. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  160. self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
  161. mock_password_provider.check_password.assert_called_once_with(
  162. "@ USER🙂NAME :test", " pASS😢word "
  163. )
  164. @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
  165. def test_password_only_auth_provider_ui_auth_legacy(self):
  166. self.password_only_auth_provider_ui_auth_test_body()
  167. def password_only_auth_provider_ui_auth_test_body(self):
  168. """UI Auth should delegate correctly to the password provider"""
  169. # create the user, otherwise access doesn't work
  170. module_api = self.hs.get_module_api()
  171. self.get_success(module_api.register_user("u"))
  172. # log in twice, to get two devices
  173. mock_password_provider.check_password.return_value = make_awaitable(True)
  174. tok1 = self.login("u", "p")
  175. self.login("u", "p", device_id="dev2")
  176. mock_password_provider.reset_mock()
  177. # have the auth provider deny the request to start with
  178. mock_password_provider.check_password.return_value = make_awaitable(False)
  179. # make the initial request which returns a 401
  180. session = self._start_delete_device_session(tok1, "dev2")
  181. mock_password_provider.check_password.assert_not_called()
  182. # Make another request providing the UI auth flow.
  183. channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
  184. self.assertEqual(channel.code, 401) # XXX why not a 403?
  185. self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
  186. mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
  187. mock_password_provider.reset_mock()
  188. # Finally, check the request goes through when we allow it
  189. mock_password_provider.check_password.return_value = make_awaitable(True)
  190. channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
  191. self.assertEqual(channel.code, 200)
  192. mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
  193. @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
  194. def test_local_user_fallback_login_legacy(self):
  195. self.local_user_fallback_login_test_body()
  196. def local_user_fallback_login_test_body(self):
  197. """rejected login should fall back to local db"""
  198. self.register_user("localuser", "localpass")
  199. # check_password must return an awaitable
  200. mock_password_provider.check_password.return_value = make_awaitable(False)
  201. channel = self._send_password_login("u", "p")
  202. self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
  203. channel = self._send_password_login("localuser", "localpass")
  204. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  205. self.assertEqual("@localuser:test", channel.json_body["user_id"])
  206. @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
  207. def test_local_user_fallback_ui_auth_legacy(self):
  208. self.local_user_fallback_ui_auth_test_body()
  209. def local_user_fallback_ui_auth_test_body(self):
  210. """rejected login should fall back to local db"""
  211. self.register_user("localuser", "localpass")
  212. # have the auth provider deny the request
  213. mock_password_provider.check_password.return_value = make_awaitable(False)
  214. # log in twice, to get two devices
  215. tok1 = self.login("localuser", "localpass")
  216. self.login("localuser", "localpass", device_id="dev2")
  217. mock_password_provider.check_password.reset_mock()
  218. # first delete should give a 401
  219. session = self._start_delete_device_session(tok1, "dev2")
  220. mock_password_provider.check_password.assert_not_called()
  221. # Wrong password
  222. channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
  223. self.assertEqual(channel.code, 401) # XXX why not a 403?
  224. self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
  225. mock_password_provider.check_password.assert_called_once_with(
  226. "@localuser:test", "xxx"
  227. )
  228. mock_password_provider.reset_mock()
  229. # Right password
  230. channel = self._authed_delete_device(
  231. tok1, "dev2", session, "localuser", "localpass"
  232. )
  233. self.assertEqual(channel.code, 200)
  234. mock_password_provider.check_password.assert_called_once_with(
  235. "@localuser:test", "localpass"
  236. )
  237. @override_config(
  238. {
  239. **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
  240. "password_config": {"localdb_enabled": False},
  241. }
  242. )
  243. def test_no_local_user_fallback_login_legacy(self):
  244. self.no_local_user_fallback_login_test_body()
  245. def no_local_user_fallback_login_test_body(self):
  246. """localdb_enabled can block login with the local password"""
  247. self.register_user("localuser", "localpass")
  248. # check_password must return an awaitable
  249. mock_password_provider.check_password.return_value = make_awaitable(False)
  250. channel = self._send_password_login("localuser", "localpass")
  251. self.assertEqual(channel.code, 403)
  252. self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
  253. mock_password_provider.check_password.assert_called_once_with(
  254. "@localuser:test", "localpass"
  255. )
  256. @override_config(
  257. {
  258. **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
  259. "password_config": {"localdb_enabled": False},
  260. }
  261. )
  262. def test_no_local_user_fallback_ui_auth_legacy(self):
  263. self.no_local_user_fallback_ui_auth_test_body()
  264. def no_local_user_fallback_ui_auth_test_body(self):
  265. """localdb_enabled can block ui auth with the local password"""
  266. self.register_user("localuser", "localpass")
  267. # allow login via the auth provider
  268. mock_password_provider.check_password.return_value = make_awaitable(True)
  269. # log in twice, to get two devices
  270. tok1 = self.login("localuser", "p")
  271. self.login("localuser", "p", device_id="dev2")
  272. mock_password_provider.check_password.reset_mock()
  273. # first delete should give a 401
  274. channel = self._delete_device(tok1, "dev2")
  275. self.assertEqual(channel.code, 401)
  276. # m.login.password UIA is permitted because the auth provider allows it,
  277. # even though the localdb does not.
  278. self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
  279. session = channel.json_body["session"]
  280. mock_password_provider.check_password.assert_not_called()
  281. # now try deleting with the local password
  282. mock_password_provider.check_password.return_value = make_awaitable(False)
  283. channel = self._authed_delete_device(
  284. tok1, "dev2", session, "localuser", "localpass"
  285. )
  286. self.assertEqual(channel.code, 401) # XXX why not a 403?
  287. self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
  288. mock_password_provider.check_password.assert_called_once_with(
  289. "@localuser:test", "localpass"
  290. )
  291. @override_config(
  292. {
  293. **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
  294. "password_config": {"enabled": False},
  295. }
  296. )
  297. def test_password_auth_disabled_legacy(self):
  298. self.password_auth_disabled_test_body()
  299. def password_auth_disabled_test_body(self):
  300. """password auth doesn't work if it's disabled across the board"""
  301. # login flows should be empty
  302. flows = self._get_login_flows()
  303. self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
  304. # login shouldn't work and should be rejected with a 400 ("unknown login type")
  305. channel = self._send_password_login("u", "p")
  306. self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
  307. mock_password_provider.check_password.assert_not_called()
  308. @override_config(legacy_providers_config(LegacyCustomAuthProvider))
  309. def test_custom_auth_provider_login_legacy(self):
  310. self.custom_auth_provider_login_test_body()
  311. @override_config(providers_config(CustomAuthProvider))
  312. def test_custom_auth_provider_login(self):
  313. self.custom_auth_provider_login_test_body()
  314. def custom_auth_provider_login_test_body(self):
  315. # login flows should have the custom flow and m.login.password, since we
  316. # haven't disabled local password lookup.
  317. # (password must come first, because reasons)
  318. flows = self._get_login_flows()
  319. self.assertEqual(
  320. flows,
  321. [{"type": "m.login.password"}, {"type": "test.login_type"}]
  322. + ADDITIONAL_LOGIN_FLOWS,
  323. )
  324. # login with missing param should be rejected
  325. channel = self._send_login("test.login_type", "u")
  326. self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
  327. mock_password_provider.check_auth.assert_not_called()
  328. mock_password_provider.check_auth.return_value = make_awaitable(
  329. ("@user:bz", None)
  330. )
  331. channel = self._send_login("test.login_type", "u", test_field="y")
  332. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  333. self.assertEqual("@user:bz", channel.json_body["user_id"])
  334. mock_password_provider.check_auth.assert_called_once_with(
  335. "u", "test.login_type", {"test_field": "y"}
  336. )
  337. mock_password_provider.reset_mock()
  338. # try a weird username. Again, it's unclear what we *expect* to happen
  339. # in these cases, but at least we can guard against the API changing
  340. # unexpectedly
  341. mock_password_provider.check_auth.return_value = make_awaitable(
  342. ("@ MALFORMED! :bz", None)
  343. )
  344. channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
  345. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  346. self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
  347. mock_password_provider.check_auth.assert_called_once_with(
  348. " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
  349. )
  350. @override_config(legacy_providers_config(LegacyCustomAuthProvider))
  351. def test_custom_auth_provider_ui_auth_legacy(self):
  352. self.custom_auth_provider_ui_auth_test_body()
  353. @override_config(providers_config(CustomAuthProvider))
  354. def test_custom_auth_provider_ui_auth(self):
  355. self.custom_auth_provider_ui_auth_test_body()
  356. def custom_auth_provider_ui_auth_test_body(self):
  357. # register the user and log in twice, to get two devices
  358. self.register_user("localuser", "localpass")
  359. tok1 = self.login("localuser", "localpass")
  360. self.login("localuser", "localpass", device_id="dev2")
  361. # make the initial request which returns a 401
  362. channel = self._delete_device(tok1, "dev2")
  363. self.assertEqual(channel.code, 401)
  364. # Ensure that flows are what is expected.
  365. self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
  366. self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
  367. session = channel.json_body["session"]
  368. # missing param
  369. body = {
  370. "auth": {
  371. "type": "test.login_type",
  372. "identifier": {"type": "m.id.user", "user": "localuser"},
  373. "session": session,
  374. },
  375. }
  376. channel = self._delete_device(tok1, "dev2", body)
  377. self.assertEqual(channel.code, 400)
  378. # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
  379. # use it...
  380. self.assertIn("Missing parameters", channel.json_body["error"])
  381. mock_password_provider.check_auth.assert_not_called()
  382. mock_password_provider.reset_mock()
  383. # right params, but authing as the wrong user
  384. mock_password_provider.check_auth.return_value = make_awaitable(
  385. ("@user:bz", None)
  386. )
  387. body["auth"]["test_field"] = "foo"
  388. channel = self._delete_device(tok1, "dev2", body)
  389. self.assertEqual(channel.code, 403)
  390. self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
  391. mock_password_provider.check_auth.assert_called_once_with(
  392. "localuser", "test.login_type", {"test_field": "foo"}
  393. )
  394. mock_password_provider.reset_mock()
  395. # and finally, succeed
  396. mock_password_provider.check_auth.return_value = make_awaitable(
  397. ("@localuser:test", None)
  398. )
  399. channel = self._delete_device(tok1, "dev2", body)
  400. self.assertEqual(channel.code, 200)
  401. mock_password_provider.check_auth.assert_called_once_with(
  402. "localuser", "test.login_type", {"test_field": "foo"}
  403. )
  404. @override_config(legacy_providers_config(LegacyCustomAuthProvider))
  405. def test_custom_auth_provider_callback_legacy(self):
  406. self.custom_auth_provider_callback_test_body()
  407. @override_config(providers_config(CustomAuthProvider))
  408. def test_custom_auth_provider_callback(self):
  409. self.custom_auth_provider_callback_test_body()
  410. def custom_auth_provider_callback_test_body(self):
  411. callback = Mock(return_value=make_awaitable(None))
  412. mock_password_provider.check_auth.return_value = make_awaitable(
  413. ("@user:bz", callback)
  414. )
  415. channel = self._send_login("test.login_type", "u", test_field="y")
  416. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  417. self.assertEqual("@user:bz", channel.json_body["user_id"])
  418. mock_password_provider.check_auth.assert_called_once_with(
  419. "u", "test.login_type", {"test_field": "y"}
  420. )
  421. # check the args to the callback
  422. callback.assert_called_once()
  423. call_args, call_kwargs = callback.call_args
  424. # should be one positional arg
  425. self.assertEqual(len(call_args), 1)
  426. self.assertEqual(call_args[0]["user_id"], "@user:bz")
  427. for p in ["user_id", "access_token", "device_id", "home_server"]:
  428. self.assertIn(p, call_args[0])
  429. @override_config(
  430. {
  431. **legacy_providers_config(LegacyCustomAuthProvider),
  432. "password_config": {"enabled": False},
  433. }
  434. )
  435. def test_custom_auth_password_disabled_legacy(self):
  436. self.custom_auth_password_disabled_test_body()
  437. @override_config(
  438. {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
  439. )
  440. def test_custom_auth_password_disabled(self):
  441. self.custom_auth_password_disabled_test_body()
  442. def custom_auth_password_disabled_test_body(self):
  443. """Test login with a custom auth provider where password login is disabled"""
  444. self.register_user("localuser", "localpass")
  445. flows = self._get_login_flows()
  446. self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
  447. # login shouldn't work and should be rejected with a 400 ("unknown login type")
  448. channel = self._send_password_login("localuser", "localpass")
  449. self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
  450. mock_password_provider.check_auth.assert_not_called()
  451. @override_config(
  452. {
  453. **legacy_providers_config(LegacyCustomAuthProvider),
  454. "password_config": {"enabled": False, "localdb_enabled": False},
  455. }
  456. )
  457. def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
  458. self.custom_auth_password_disabled_localdb_enabled_test_body()
  459. @override_config(
  460. {
  461. **providers_config(CustomAuthProvider),
  462. "password_config": {"enabled": False, "localdb_enabled": False},
  463. }
  464. )
  465. def test_custom_auth_password_disabled_localdb_enabled(self):
  466. self.custom_auth_password_disabled_localdb_enabled_test_body()
  467. def custom_auth_password_disabled_localdb_enabled_test_body(self):
  468. """Check the localdb_enabled == enabled == False
  469. Regression test for https://github.com/matrix-org/synapse/issues/8914: check
  470. that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
  471. cause an exception.
  472. """
  473. self.register_user("localuser", "localpass")
  474. flows = self._get_login_flows()
  475. self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
  476. # login shouldn't work and should be rejected with a 400 ("unknown login type")
  477. channel = self._send_password_login("localuser", "localpass")
  478. self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
  479. mock_password_provider.check_auth.assert_not_called()
  480. @override_config(
  481. {
  482. **legacy_providers_config(LegacyPasswordCustomAuthProvider),
  483. "password_config": {"enabled": False},
  484. }
  485. )
  486. def test_password_custom_auth_password_disabled_login_legacy(self):
  487. self.password_custom_auth_password_disabled_login_test_body()
  488. @override_config(
  489. {
  490. **providers_config(PasswordCustomAuthProvider),
  491. "password_config": {"enabled": False},
  492. }
  493. )
  494. def test_password_custom_auth_password_disabled_login(self):
  495. self.password_custom_auth_password_disabled_login_test_body()
  496. def password_custom_auth_password_disabled_login_test_body(self):
  497. """log in with a custom auth provider which implements password, but password
  498. login is disabled"""
  499. self.register_user("localuser", "localpass")
  500. flows = self._get_login_flows()
  501. self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
  502. # login shouldn't work and should be rejected with a 400 ("unknown login type")
  503. channel = self._send_password_login("localuser", "localpass")
  504. self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
  505. mock_password_provider.check_auth.assert_not_called()
  506. mock_password_provider.check_password.assert_not_called()
  507. @override_config(
  508. {
  509. **legacy_providers_config(LegacyPasswordCustomAuthProvider),
  510. "password_config": {"enabled": False},
  511. }
  512. )
  513. def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
  514. self.password_custom_auth_password_disabled_ui_auth_test_body()
  515. @override_config(
  516. {
  517. **providers_config(PasswordCustomAuthProvider),
  518. "password_config": {"enabled": False},
  519. }
  520. )
  521. def test_password_custom_auth_password_disabled_ui_auth(self):
  522. self.password_custom_auth_password_disabled_ui_auth_test_body()
  523. def password_custom_auth_password_disabled_ui_auth_test_body(self):
  524. """UI Auth with a custom auth provider which implements password, but password
  525. login is disabled"""
  526. # register the user and log in twice via the test login type to get two devices,
  527. self.register_user("localuser", "localpass")
  528. mock_password_provider.check_auth.return_value = make_awaitable(
  529. ("@localuser:test", None)
  530. )
  531. channel = self._send_login("test.login_type", "localuser", test_field="")
  532. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  533. tok1 = channel.json_body["access_token"]
  534. channel = self._send_login(
  535. "test.login_type", "localuser", test_field="", device_id="dev2"
  536. )
  537. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  538. # make the initial request which returns a 401
  539. channel = self._delete_device(tok1, "dev2")
  540. self.assertEqual(channel.code, 401)
  541. # Ensure that flows are what is expected. In particular, "password" should *not*
  542. # be present.
  543. self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
  544. session = channel.json_body["session"]
  545. mock_password_provider.reset_mock()
  546. # check that auth with password is rejected
  547. body = {
  548. "auth": {
  549. "type": "m.login.password",
  550. "identifier": {"type": "m.id.user", "user": "localuser"},
  551. "password": "localpass",
  552. "session": session,
  553. },
  554. }
  555. channel = self._delete_device(tok1, "dev2", body)
  556. self.assertEqual(channel.code, 400)
  557. self.assertEqual(
  558. "Password login has been disabled.", channel.json_body["error"]
  559. )
  560. mock_password_provider.check_auth.assert_not_called()
  561. mock_password_provider.check_password.assert_not_called()
  562. mock_password_provider.reset_mock()
  563. # successful auth
  564. body["auth"]["type"] = "test.login_type"
  565. body["auth"]["test_field"] = "x"
  566. channel = self._delete_device(tok1, "dev2", body)
  567. self.assertEqual(channel.code, 200)
  568. mock_password_provider.check_auth.assert_called_once_with(
  569. "localuser", "test.login_type", {"test_field": "x"}
  570. )
  571. mock_password_provider.check_password.assert_not_called()
  572. @override_config(
  573. {
  574. **legacy_providers_config(LegacyCustomAuthProvider),
  575. "password_config": {"localdb_enabled": False},
  576. }
  577. )
  578. def test_custom_auth_no_local_user_fallback_legacy(self):
  579. self.custom_auth_no_local_user_fallback_test_body()
  580. @override_config(
  581. {
  582. **providers_config(CustomAuthProvider),
  583. "password_config": {"localdb_enabled": False},
  584. }
  585. )
  586. def test_custom_auth_no_local_user_fallback(self):
  587. self.custom_auth_no_local_user_fallback_test_body()
  588. def custom_auth_no_local_user_fallback_test_body(self):
  589. """Test login with a custom auth provider where the local db is disabled"""
  590. self.register_user("localuser", "localpass")
  591. flows = self._get_login_flows()
  592. self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
  593. # password login shouldn't work and should be rejected with a 400
  594. # ("unknown login type")
  595. channel = self._send_password_login("localuser", "localpass")
  596. self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
  597. def test_on_logged_out(self):
  598. """Tests that the on_logged_out callback is called when the user logs out."""
  599. self.register_user("rin", "password")
  600. tok = self.login("rin", "password")
  601. self.called = False
  602. async def on_logged_out(user_id, device_id, access_token):
  603. self.called = True
  604. on_logged_out = Mock(side_effect=on_logged_out)
  605. self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
  606. on_logged_out
  607. )
  608. channel = self.make_request(
  609. "POST",
  610. "/_matrix/client/v3/logout",
  611. {},
  612. access_token=tok,
  613. )
  614. self.assertEqual(channel.code, 200)
  615. on_logged_out.assert_called_once()
  616. self.assertTrue(self.called)
  617. def test_username(self):
  618. """Tests that the get_username_for_registration callback can define the username
  619. of a user when registering.
  620. """
  621. self._setup_get_name_for_registration(
  622. callback_name=self.CALLBACK_USERNAME,
  623. )
  624. username = "rin"
  625. channel = self.make_request(
  626. "POST",
  627. "/register",
  628. {
  629. "username": username,
  630. "password": "bar",
  631. "auth": {"type": LoginType.DUMMY},
  632. },
  633. )
  634. self.assertEqual(channel.code, 200)
  635. # Our callback takes the username and appends "-foo" to it, check that's what we
  636. # have.
  637. mxid = channel.json_body["user_id"]
  638. self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
  639. def test_username_uia(self):
  640. """Tests that the get_username_for_registration callback is only called at the
  641. end of the UIA flow.
  642. """
  643. m = self._setup_get_name_for_registration(
  644. callback_name=self.CALLBACK_USERNAME,
  645. )
  646. username = "rin"
  647. res = self._do_uia_assert_mock_not_called(username, m)
  648. mxid = res["user_id"]
  649. self.assertEqual(UserID.from_string(mxid).localpart, username + "-foo")
  650. # Check that the callback has been called.
  651. m.assert_called_once()
  652. # Set some email configuration so the test doesn't fail because of its absence.
  653. @override_config({"email": {"notif_from": "noreply@test"}})
  654. def test_3pid_allowed(self):
  655. """Tests that an is_3pid_allowed_callbacks forbidding a 3PID makes Synapse refuse
  656. to bind the new 3PID, and that one allowing a 3PID makes Synapse accept to bind
  657. the 3PID. Also checks that the module is passed a boolean indicating whether the
  658. user to bind this 3PID to is currently registering.
  659. """
  660. self._test_3pid_allowed("rin", False)
  661. self._test_3pid_allowed("kitay", True)
  662. def test_displayname(self):
  663. """Tests that the get_displayname_for_registration callback can define the
  664. display name of a user when registering.
  665. """
  666. self._setup_get_name_for_registration(
  667. callback_name=self.CALLBACK_DISPLAYNAME,
  668. )
  669. username = "rin"
  670. channel = self.make_request(
  671. "POST",
  672. "/register",
  673. {
  674. "username": username,
  675. "password": "bar",
  676. "auth": {"type": LoginType.DUMMY},
  677. },
  678. )
  679. self.assertEqual(channel.code, 200)
  680. # Our callback takes the username and appends "-foo" to it, check that's what we
  681. # have.
  682. user_id = UserID.from_string(channel.json_body["user_id"])
  683. display_name = self.get_success(
  684. self.hs.get_profile_handler().get_displayname(user_id)
  685. )
  686. self.assertEqual(display_name, username + "-foo")
  687. def test_displayname_uia(self):
  688. """Tests that the get_displayname_for_registration callback is only called at the
  689. end of the UIA flow.
  690. """
  691. m = self._setup_get_name_for_registration(
  692. callback_name=self.CALLBACK_DISPLAYNAME,
  693. )
  694. username = "rin"
  695. res = self._do_uia_assert_mock_not_called(username, m)
  696. user_id = UserID.from_string(res["user_id"])
  697. display_name = self.get_success(
  698. self.hs.get_profile_handler().get_displayname(user_id)
  699. )
  700. self.assertEqual(display_name, username + "-foo")
  701. # Check that the callback has been called.
  702. m.assert_called_once()
  703. def _test_3pid_allowed(self, username: str, registration: bool):
  704. """Tests that the "is_3pid_allowed" module callback is called correctly, using
  705. either /register or /account URLs depending on the arguments.
  706. Args:
  707. username: The username to use for the test.
  708. registration: Whether to test with registration URLs.
  709. """
  710. self.hs.get_identity_handler().send_threepid_validation = Mock(
  711. return_value=make_awaitable(0),
  712. )
  713. m = Mock(return_value=make_awaitable(False))
  714. self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
  715. self.register_user(username, "password")
  716. tok = self.login(username, "password")
  717. if registration:
  718. url = "/register/email/requestToken"
  719. else:
  720. url = "/account/3pid/email/requestToken"
  721. channel = self.make_request(
  722. "POST",
  723. url,
  724. {
  725. "client_secret": "foo",
  726. "email": "foo@test.com",
  727. "send_attempt": 0,
  728. },
  729. access_token=tok,
  730. )
  731. self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
  732. self.assertEqual(
  733. channel.json_body["errcode"],
  734. Codes.THREEPID_DENIED,
  735. channel.json_body,
  736. )
  737. m.assert_called_once_with("email", "foo@test.com", registration)
  738. m = Mock(return_value=make_awaitable(True))
  739. self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
  740. channel = self.make_request(
  741. "POST",
  742. url,
  743. {
  744. "client_secret": "foo",
  745. "email": "bar@test.com",
  746. "send_attempt": 0,
  747. },
  748. access_token=tok,
  749. )
  750. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  751. self.assertIn("sid", channel.json_body)
  752. m.assert_called_once_with("email", "bar@test.com", registration)
  753. def _setup_get_name_for_registration(self, callback_name: str) -> Mock:
  754. """Registers either a get_username_for_registration callback or a
  755. get_displayname_for_registration callback that appends "-foo" to the username the
  756. client is trying to register.
  757. """
  758. async def callback(uia_results, params):
  759. self.assertIn(LoginType.DUMMY, uia_results)
  760. username = params["username"]
  761. return username + "-foo"
  762. m = Mock(side_effect=callback)
  763. password_auth_provider = self.hs.get_password_auth_provider()
  764. getattr(password_auth_provider, callback_name + "_callbacks").append(m)
  765. return m
  766. def _do_uia_assert_mock_not_called(self, username: str, m: Mock) -> JsonDict:
  767. # Initiate the UIA flow.
  768. channel = self.make_request(
  769. "POST",
  770. "register",
  771. {"username": username, "type": "m.login.password", "password": "bar"},
  772. )
  773. self.assertEqual(channel.code, 401)
  774. self.assertIn("session", channel.json_body)
  775. # Check that the callback hasn't been called yet.
  776. m.assert_not_called()
  777. # Finish the UIA flow.
  778. session = channel.json_body["session"]
  779. channel = self.make_request(
  780. "POST",
  781. "register",
  782. {"auth": {"session": session, "type": LoginType.DUMMY}},
  783. )
  784. self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
  785. return channel.json_body
  786. def _get_login_flows(self) -> JsonDict:
  787. channel = self.make_request("GET", "/_matrix/client/r0/login")
  788. self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
  789. return channel.json_body["flows"]
  790. def _send_password_login(self, user: str, password: str) -> FakeChannel:
  791. return self._send_login(type="m.login.password", user=user, password=password)
  792. def _send_login(self, type, user, **params) -> FakeChannel:
  793. params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
  794. channel = self.make_request("POST", "/_matrix/client/r0/login", params)
  795. return channel
  796. def _start_delete_device_session(self, access_token, device_id) -> str:
  797. """Make an initial delete device request, and return the UI Auth session ID"""
  798. channel = self._delete_device(access_token, device_id)
  799. self.assertEqual(channel.code, 401)
  800. # Ensure that flows are what is expected.
  801. self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
  802. return channel.json_body["session"]
  803. def _authed_delete_device(
  804. self,
  805. access_token: str,
  806. device_id: str,
  807. session: str,
  808. user_id: str,
  809. password: str,
  810. ) -> FakeChannel:
  811. """Make a delete device request, authenticating with the given uid/password"""
  812. return self._delete_device(
  813. access_token,
  814. device_id,
  815. {
  816. "auth": {
  817. "type": "m.login.password",
  818. "identifier": {"type": "m.id.user", "user": user_id},
  819. "password": password,
  820. "session": session,
  821. },
  822. },
  823. )
  824. def _delete_device(
  825. self,
  826. access_token: str,
  827. device: str,
  828. body: Union[JsonDict, bytes] = b"",
  829. ) -> FakeChannel:
  830. """Delete an individual device."""
  831. channel = self.make_request(
  832. "DELETE", "devices/" + device, body, access_token=access_token
  833. )
  834. return channel