test_register.py 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2017-2018 New Vector Ltd
  3. # Copyright 2019 The Matrix.org Foundation C.I.C.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import datetime
  17. import os
  18. from typing import Any, Dict, List, Tuple
  19. import pkg_resources
  20. from twisted.test.proto_helpers import MemoryReactor
  21. import synapse.rest.admin
  22. from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
  23. from synapse.api.errors import Codes
  24. from synapse.appservice import ApplicationService
  25. from synapse.rest.client import account, account_validity, login, logout, register, sync
  26. from synapse.server import HomeServer
  27. from synapse.storage._base import db_to_json
  28. from synapse.types import JsonDict
  29. from synapse.util import Clock
  30. from tests import unittest
  31. from tests.unittest import override_config
  32. class RegisterRestServletTestCase(unittest.HomeserverTestCase):
  33. servlets = [
  34. login.register_servlets,
  35. register.register_servlets,
  36. synapse.rest.admin.register_servlets,
  37. ]
  38. url = b"/_matrix/client/r0/register"
  39. def default_config(self) -> Dict[str, Any]:
  40. config = super().default_config()
  41. config["allow_guest_access"] = True
  42. return config
  43. def test_POST_appservice_registration_valid(self) -> None:
  44. user_id = "@as_user_kermit:test"
  45. as_token = "i_am_an_app_service"
  46. appservice = ApplicationService(
  47. as_token,
  48. id="1234",
  49. namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
  50. sender="@as:test",
  51. )
  52. self.hs.get_datastores().main.services_cache.append(appservice)
  53. request_data = {
  54. "username": "as_user_kermit",
  55. "type": APP_SERVICE_REGISTRATION_TYPE,
  56. }
  57. channel = self.make_request(
  58. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  59. )
  60. self.assertEqual(channel.code, 200, msg=channel.result)
  61. det_data = {"user_id": user_id, "home_server": self.hs.hostname}
  62. self.assertDictContainsSubset(det_data, channel.json_body)
  63. def test_POST_appservice_registration_no_type(self) -> None:
  64. as_token = "i_am_an_app_service"
  65. appservice = ApplicationService(
  66. as_token,
  67. id="1234",
  68. namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
  69. sender="@as:test",
  70. )
  71. self.hs.get_datastores().main.services_cache.append(appservice)
  72. request_data = {"username": "as_user_kermit"}
  73. channel = self.make_request(
  74. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  75. )
  76. self.assertEqual(channel.code, 400, msg=channel.result)
  77. def test_POST_appservice_registration_invalid(self) -> None:
  78. self.appservice = None # no application service exists
  79. request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
  80. channel = self.make_request(
  81. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  82. )
  83. self.assertEqual(channel.code, 401, msg=channel.result)
  84. def test_POST_bad_password(self) -> None:
  85. request_data = {"username": "kermit", "password": 666}
  86. channel = self.make_request(b"POST", self.url, request_data)
  87. self.assertEqual(channel.code, 400, msg=channel.result)
  88. self.assertEqual(channel.json_body["error"], "Invalid password")
  89. def test_POST_bad_username(self) -> None:
  90. request_data = {"username": 777, "password": "monkey"}
  91. channel = self.make_request(b"POST", self.url, request_data)
  92. self.assertEqual(channel.code, 400, msg=channel.result)
  93. self.assertEqual(channel.json_body["error"], "Invalid username")
  94. def test_POST_user_valid(self) -> None:
  95. user_id = "@kermit:test"
  96. device_id = "frogfone"
  97. request_data = {
  98. "username": "kermit",
  99. "password": "monkey",
  100. "device_id": device_id,
  101. "auth": {"type": LoginType.DUMMY},
  102. }
  103. channel = self.make_request(b"POST", self.url, request_data)
  104. det_data = {
  105. "user_id": user_id,
  106. "home_server": self.hs.hostname,
  107. "device_id": device_id,
  108. }
  109. self.assertEqual(channel.code, 200, msg=channel.result)
  110. self.assertDictContainsSubset(det_data, channel.json_body)
  111. @override_config({"enable_registration": False})
  112. def test_POST_disabled_registration(self) -> None:
  113. request_data = {"username": "kermit", "password": "monkey"}
  114. self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
  115. channel = self.make_request(b"POST", self.url, request_data)
  116. self.assertEqual(channel.code, 403, msg=channel.result)
  117. self.assertEqual(channel.json_body["error"], "Registration has been disabled")
  118. self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
  119. def test_POST_guest_registration(self) -> None:
  120. self.hs.config.key.macaroon_secret_key = "test"
  121. self.hs.config.registration.allow_guest_access = True
  122. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  123. det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
  124. self.assertEqual(channel.code, 200, msg=channel.result)
  125. self.assertDictContainsSubset(det_data, channel.json_body)
  126. def test_POST_disabled_guest_registration(self) -> None:
  127. self.hs.config.registration.allow_guest_access = False
  128. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  129. self.assertEqual(channel.code, 403, msg=channel.result)
  130. self.assertEqual(channel.json_body["error"], "Guest access is disabled")
  131. @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
  132. def test_POST_ratelimiting_guest(self) -> None:
  133. for i in range(0, 6):
  134. url = self.url + b"?kind=guest"
  135. channel = self.make_request(b"POST", url, b"{}")
  136. if i == 5:
  137. self.assertEqual(channel.code, 429, msg=channel.result)
  138. retry_after_ms = int(channel.json_body["retry_after_ms"])
  139. else:
  140. self.assertEqual(channel.code, 200, msg=channel.result)
  141. self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
  142. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  143. self.assertEqual(channel.code, 200, msg=channel.result)
  144. @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
  145. def test_POST_ratelimiting(self) -> None:
  146. for i in range(0, 6):
  147. request_data = {
  148. "username": "kermit" + str(i),
  149. "password": "monkey",
  150. "device_id": "frogfone",
  151. "auth": {"type": LoginType.DUMMY},
  152. }
  153. channel = self.make_request(b"POST", self.url, request_data)
  154. if i == 5:
  155. self.assertEqual(channel.code, 429, msg=channel.result)
  156. retry_after_ms = int(channel.json_body["retry_after_ms"])
  157. else:
  158. self.assertEqual(channel.code, 200, msg=channel.result)
  159. self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
  160. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  161. self.assertEqual(channel.code, 200, msg=channel.result)
  162. @override_config({"registration_requires_token": True})
  163. def test_POST_registration_requires_token(self) -> None:
  164. username = "kermit"
  165. device_id = "frogfone"
  166. token = "abcd"
  167. store = self.hs.get_datastores().main
  168. self.get_success(
  169. store.db_pool.simple_insert(
  170. "registration_tokens",
  171. {
  172. "token": token,
  173. "uses_allowed": None,
  174. "pending": 0,
  175. "completed": 0,
  176. "expiry_time": None,
  177. },
  178. )
  179. )
  180. params: JsonDict = {
  181. "username": username,
  182. "password": "monkey",
  183. "device_id": device_id,
  184. }
  185. # Request without auth to get flows and session
  186. channel = self.make_request(b"POST", self.url, params)
  187. self.assertEqual(channel.code, 401, msg=channel.result)
  188. flows = channel.json_body["flows"]
  189. # Synapse adds a dummy stage to differentiate flows where otherwise one
  190. # flow would be a subset of another flow.
  191. self.assertCountEqual(
  192. [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
  193. (f["stages"] for f in flows),
  194. )
  195. session = channel.json_body["session"]
  196. # Do the registration token stage and check it has completed
  197. params["auth"] = {
  198. "type": LoginType.REGISTRATION_TOKEN,
  199. "token": token,
  200. "session": session,
  201. }
  202. channel = self.make_request(b"POST", self.url, params)
  203. self.assertEqual(channel.code, 401, msg=channel.result)
  204. completed = channel.json_body["completed"]
  205. self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
  206. # Do the m.login.dummy stage and check registration was successful
  207. params["auth"] = {
  208. "type": LoginType.DUMMY,
  209. "session": session,
  210. }
  211. channel = self.make_request(b"POST", self.url, params)
  212. det_data = {
  213. "user_id": f"@{username}:{self.hs.hostname}",
  214. "home_server": self.hs.hostname,
  215. "device_id": device_id,
  216. }
  217. self.assertEqual(channel.code, 200, msg=channel.result)
  218. self.assertDictContainsSubset(det_data, channel.json_body)
  219. # Check the `completed` counter has been incremented and pending is 0
  220. res = self.get_success(
  221. store.db_pool.simple_select_one(
  222. "registration_tokens",
  223. keyvalues={"token": token},
  224. retcols=["pending", "completed"],
  225. )
  226. )
  227. self.assertEqual(res["completed"], 1)
  228. self.assertEqual(res["pending"], 0)
  229. @override_config({"registration_requires_token": True})
  230. def test_POST_registration_token_invalid(self) -> None:
  231. params: JsonDict = {
  232. "username": "kermit",
  233. "password": "monkey",
  234. }
  235. # Request without auth to get session
  236. channel = self.make_request(b"POST", self.url, params)
  237. session = channel.json_body["session"]
  238. # Test with token param missing (invalid)
  239. params["auth"] = {
  240. "type": LoginType.REGISTRATION_TOKEN,
  241. "session": session,
  242. }
  243. channel = self.make_request(b"POST", self.url, params)
  244. self.assertEqual(channel.code, 401, msg=channel.result)
  245. self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
  246. self.assertEqual(channel.json_body["completed"], [])
  247. # Test with non-string (invalid)
  248. params["auth"]["token"] = 1234
  249. channel = self.make_request(b"POST", self.url, params)
  250. self.assertEqual(channel.code, 401, msg=channel.result)
  251. self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
  252. self.assertEqual(channel.json_body["completed"], [])
  253. # Test with unknown token (invalid)
  254. params["auth"]["token"] = "1234"
  255. channel = self.make_request(b"POST", self.url, params)
  256. self.assertEqual(channel.code, 401, msg=channel.result)
  257. self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  258. self.assertEqual(channel.json_body["completed"], [])
  259. @override_config({"registration_requires_token": True})
  260. def test_POST_registration_token_limit_uses(self) -> None:
  261. token = "abcd"
  262. store = self.hs.get_datastores().main
  263. # Create token that can be used once
  264. self.get_success(
  265. store.db_pool.simple_insert(
  266. "registration_tokens",
  267. {
  268. "token": token,
  269. "uses_allowed": 1,
  270. "pending": 0,
  271. "completed": 0,
  272. "expiry_time": None,
  273. },
  274. )
  275. )
  276. params1: JsonDict = {"username": "bert", "password": "monkey"}
  277. params2: JsonDict = {"username": "ernie", "password": "monkey"}
  278. # Do 2 requests without auth to get two session IDs
  279. channel1 = self.make_request(b"POST", self.url, params1)
  280. session1 = channel1.json_body["session"]
  281. channel2 = self.make_request(b"POST", self.url, params2)
  282. session2 = channel2.json_body["session"]
  283. # Use token with session1 and check `pending` is 1
  284. params1["auth"] = {
  285. "type": LoginType.REGISTRATION_TOKEN,
  286. "token": token,
  287. "session": session1,
  288. }
  289. self.make_request(b"POST", self.url, params1)
  290. # Repeat request to make sure pending isn't increased again
  291. self.make_request(b"POST", self.url, params1)
  292. pending = self.get_success(
  293. store.db_pool.simple_select_one_onecol(
  294. "registration_tokens",
  295. keyvalues={"token": token},
  296. retcol="pending",
  297. )
  298. )
  299. self.assertEqual(pending, 1)
  300. # Check auth fails when using token with session2
  301. params2["auth"] = {
  302. "type": LoginType.REGISTRATION_TOKEN,
  303. "token": token,
  304. "session": session2,
  305. }
  306. channel = self.make_request(b"POST", self.url, params2)
  307. self.assertEqual(channel.code, 401, msg=channel.result)
  308. self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  309. self.assertEqual(channel.json_body["completed"], [])
  310. # Complete registration with session1
  311. params1["auth"]["type"] = LoginType.DUMMY
  312. self.make_request(b"POST", self.url, params1)
  313. # Check pending=0 and completed=1
  314. res = self.get_success(
  315. store.db_pool.simple_select_one(
  316. "registration_tokens",
  317. keyvalues={"token": token},
  318. retcols=["pending", "completed"],
  319. )
  320. )
  321. self.assertEqual(res["pending"], 0)
  322. self.assertEqual(res["completed"], 1)
  323. # Check auth still fails when using token with session2
  324. channel = self.make_request(b"POST", self.url, params2)
  325. self.assertEqual(channel.code, 401, msg=channel.result)
  326. self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  327. self.assertEqual(channel.json_body["completed"], [])
  328. @override_config({"registration_requires_token": True})
  329. def test_POST_registration_token_expiry(self) -> None:
  330. token = "abcd"
  331. now = self.hs.get_clock().time_msec()
  332. store = self.hs.get_datastores().main
  333. # Create token that expired yesterday
  334. self.get_success(
  335. store.db_pool.simple_insert(
  336. "registration_tokens",
  337. {
  338. "token": token,
  339. "uses_allowed": None,
  340. "pending": 0,
  341. "completed": 0,
  342. "expiry_time": now - 24 * 60 * 60 * 1000,
  343. },
  344. )
  345. )
  346. params: JsonDict = {"username": "kermit", "password": "monkey"}
  347. # Request without auth to get session
  348. channel = self.make_request(b"POST", self.url, params)
  349. session = channel.json_body["session"]
  350. # Check authentication fails with expired token
  351. params["auth"] = {
  352. "type": LoginType.REGISTRATION_TOKEN,
  353. "token": token,
  354. "session": session,
  355. }
  356. channel = self.make_request(b"POST", self.url, params)
  357. self.assertEqual(channel.code, 401, msg=channel.result)
  358. self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  359. self.assertEqual(channel.json_body["completed"], [])
  360. # Update token so it expires tomorrow
  361. self.get_success(
  362. store.db_pool.simple_update_one(
  363. "registration_tokens",
  364. keyvalues={"token": token},
  365. updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
  366. )
  367. )
  368. # Check authentication succeeds
  369. channel = self.make_request(b"POST", self.url, params)
  370. completed = channel.json_body["completed"]
  371. self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
  372. @override_config({"registration_requires_token": True})
  373. def test_POST_registration_token_session_expiry(self) -> None:
  374. """Test `pending` is decremented when an uncompleted session expires."""
  375. token = "abcd"
  376. store = self.hs.get_datastores().main
  377. self.get_success(
  378. store.db_pool.simple_insert(
  379. "registration_tokens",
  380. {
  381. "token": token,
  382. "uses_allowed": None,
  383. "pending": 0,
  384. "completed": 0,
  385. "expiry_time": None,
  386. },
  387. )
  388. )
  389. # Do 2 requests without auth to get two session IDs
  390. params1: JsonDict = {"username": "bert", "password": "monkey"}
  391. params2: JsonDict = {"username": "ernie", "password": "monkey"}
  392. channel1 = self.make_request(b"POST", self.url, params1)
  393. session1 = channel1.json_body["session"]
  394. channel2 = self.make_request(b"POST", self.url, params2)
  395. session2 = channel2.json_body["session"]
  396. # Use token with both sessions
  397. params1["auth"] = {
  398. "type": LoginType.REGISTRATION_TOKEN,
  399. "token": token,
  400. "session": session1,
  401. }
  402. self.make_request(b"POST", self.url, params1)
  403. params2["auth"] = {
  404. "type": LoginType.REGISTRATION_TOKEN,
  405. "token": token,
  406. "session": session2,
  407. }
  408. self.make_request(b"POST", self.url, params2)
  409. # Complete registration with session1
  410. params1["auth"]["type"] = LoginType.DUMMY
  411. self.make_request(b"POST", self.url, params1)
  412. # Check `result` of registration token stage for session1 is `True`
  413. result1 = self.get_success(
  414. store.db_pool.simple_select_one_onecol(
  415. "ui_auth_sessions_credentials",
  416. keyvalues={
  417. "session_id": session1,
  418. "stage_type": LoginType.REGISTRATION_TOKEN,
  419. },
  420. retcol="result",
  421. )
  422. )
  423. self.assertTrue(db_to_json(result1))
  424. # Check `result` for session2 is the token used
  425. result2 = self.get_success(
  426. store.db_pool.simple_select_one_onecol(
  427. "ui_auth_sessions_credentials",
  428. keyvalues={
  429. "session_id": session2,
  430. "stage_type": LoginType.REGISTRATION_TOKEN,
  431. },
  432. retcol="result",
  433. )
  434. )
  435. self.assertEqual(db_to_json(result2), token)
  436. # Delete both sessions (mimics expiry)
  437. self.get_success(
  438. store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
  439. )
  440. # Check pending is now 0
  441. pending = self.get_success(
  442. store.db_pool.simple_select_one_onecol(
  443. "registration_tokens",
  444. keyvalues={"token": token},
  445. retcol="pending",
  446. )
  447. )
  448. self.assertEqual(pending, 0)
  449. @override_config({"registration_requires_token": True})
  450. def test_POST_registration_token_session_expiry_deleted_token(self) -> None:
  451. """Test session expiry doesn't break when the token is deleted.
  452. 1. Start but don't complete UIA with a registration token
  453. 2. Delete the token from the database
  454. 3. Expire the session
  455. """
  456. token = "abcd"
  457. store = self.hs.get_datastores().main
  458. self.get_success(
  459. store.db_pool.simple_insert(
  460. "registration_tokens",
  461. {
  462. "token": token,
  463. "uses_allowed": None,
  464. "pending": 0,
  465. "completed": 0,
  466. "expiry_time": None,
  467. },
  468. )
  469. )
  470. # Do request without auth to get a session ID
  471. params: JsonDict = {"username": "kermit", "password": "monkey"}
  472. channel = self.make_request(b"POST", self.url, params)
  473. session = channel.json_body["session"]
  474. # Use token
  475. params["auth"] = {
  476. "type": LoginType.REGISTRATION_TOKEN,
  477. "token": token,
  478. "session": session,
  479. }
  480. self.make_request(b"POST", self.url, params)
  481. # Delete token
  482. self.get_success(
  483. store.db_pool.simple_delete_one(
  484. "registration_tokens",
  485. keyvalues={"token": token},
  486. )
  487. )
  488. # Delete session (mimics expiry)
  489. self.get_success(
  490. store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
  491. )
  492. def test_advertised_flows(self) -> None:
  493. channel = self.make_request(b"POST", self.url, b"{}")
  494. self.assertEqual(channel.code, 401, msg=channel.result)
  495. flows = channel.json_body["flows"]
  496. # with the stock config, we only expect the dummy flow
  497. self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
  498. @unittest.override_config(
  499. {
  500. "public_baseurl": "https://test_server",
  501. "enable_registration_captcha": True,
  502. "user_consent": {
  503. "version": "1",
  504. "template_dir": "/",
  505. "require_at_registration": True,
  506. },
  507. "account_threepid_delegates": {
  508. "email": "https://id_server",
  509. "msisdn": "https://id_server",
  510. },
  511. }
  512. )
  513. def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
  514. channel = self.make_request(b"POST", self.url, b"{}")
  515. self.assertEqual(channel.code, 401, msg=channel.result)
  516. flows = channel.json_body["flows"]
  517. self.assertCountEqual(
  518. [
  519. ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
  520. ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
  521. ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
  522. [
  523. "m.login.recaptcha",
  524. "m.login.terms",
  525. "m.login.msisdn",
  526. "m.login.email.identity",
  527. ],
  528. ],
  529. (f["stages"] for f in flows),
  530. )
  531. @unittest.override_config(
  532. {
  533. "public_baseurl": "https://test_server",
  534. "registrations_require_3pid": ["email"],
  535. "disable_msisdn_registration": True,
  536. "email": {
  537. "smtp_host": "mail_server",
  538. "smtp_port": 2525,
  539. "notif_from": "sender@host",
  540. },
  541. }
  542. )
  543. def test_advertised_flows_no_msisdn_email_required(self) -> None:
  544. channel = self.make_request(b"POST", self.url, b"{}")
  545. self.assertEqual(channel.code, 401, msg=channel.result)
  546. flows = channel.json_body["flows"]
  547. # with the stock config, we expect all four combinations of 3pid
  548. self.assertCountEqual(
  549. [["m.login.email.identity"]], (f["stages"] for f in flows)
  550. )
  551. @unittest.override_config(
  552. {
  553. "request_token_inhibit_3pid_errors": True,
  554. "public_baseurl": "https://test_server",
  555. "email": {
  556. "smtp_host": "mail_server",
  557. "smtp_port": 2525,
  558. "notif_from": "sender@host",
  559. },
  560. }
  561. )
  562. def test_request_token_existing_email_inhibit_error(self) -> None:
  563. """Test that requesting a token via this endpoint doesn't leak existing
  564. associations if configured that way.
  565. """
  566. user_id = self.register_user("kermit", "monkey")
  567. self.login("kermit", "monkey")
  568. email = "test@example.com"
  569. # Add a threepid
  570. self.get_success(
  571. self.hs.get_datastores().main.user_add_threepid(
  572. user_id=user_id,
  573. medium="email",
  574. address=email,
  575. validated_at=0,
  576. added_at=0,
  577. )
  578. )
  579. channel = self.make_request(
  580. "POST",
  581. b"register/email/requestToken",
  582. {"client_secret": "foobar", "email": email, "send_attempt": 1},
  583. )
  584. self.assertEqual(200, channel.code, channel.result)
  585. self.assertIsNotNone(channel.json_body.get("sid"))
  586. @unittest.override_config(
  587. {
  588. "public_baseurl": "https://test_server",
  589. "email": {
  590. "smtp_host": "mail_server",
  591. "smtp_port": 2525,
  592. "notif_from": "sender@host",
  593. },
  594. }
  595. )
  596. def test_reject_invalid_email(self) -> None:
  597. """Check that bad emails are rejected"""
  598. # Test for email with multiple @
  599. channel = self.make_request(
  600. "POST",
  601. b"register/email/requestToken",
  602. {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1},
  603. )
  604. self.assertEqual(400, channel.code, channel.result)
  605. # Check error to ensure that we're not erroring due to a bug in the test.
  606. self.assertEqual(
  607. channel.json_body,
  608. {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
  609. )
  610. # Test for email with no @
  611. channel = self.make_request(
  612. "POST",
  613. b"register/email/requestToken",
  614. {"client_secret": "foobar", "email": "email", "send_attempt": 1},
  615. )
  616. self.assertEqual(400, channel.code, channel.result)
  617. self.assertEqual(
  618. channel.json_body,
  619. {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
  620. )
  621. # Test for super long email
  622. email = "a@" + "a" * 1000
  623. channel = self.make_request(
  624. "POST",
  625. b"register/email/requestToken",
  626. {"client_secret": "foobar", "email": email, "send_attempt": 1},
  627. )
  628. self.assertEqual(400, channel.code, channel.result)
  629. self.assertEqual(
  630. channel.json_body,
  631. {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
  632. )
  633. @override_config(
  634. {
  635. "inhibit_user_in_use_error": True,
  636. }
  637. )
  638. def test_inhibit_user_in_use_error(self) -> None:
  639. """Tests that the 'inhibit_user_in_use_error' configuration flag behaves
  640. correctly.
  641. """
  642. username = "arthur"
  643. # Manually register the user, so we know the test isn't passing because of a lack
  644. # of clashing.
  645. reg_handler = self.hs.get_registration_handler()
  646. self.get_success(reg_handler.register_user(username))
  647. # Check that /available correctly ignores the username provided despite the
  648. # username being already registered.
  649. channel = self.make_request("GET", "register/available?username=" + username)
  650. self.assertEqual(200, channel.code, channel.result)
  651. # Test that when starting a UIA registration flow the request doesn't fail because
  652. # of a conflicting username
  653. channel = self.make_request(
  654. "POST",
  655. "register",
  656. {"username": username, "type": "m.login.password", "password": "foo"},
  657. )
  658. self.assertEqual(channel.code, 401)
  659. self.assertIn("session", channel.json_body)
  660. # Test that finishing the registration fails because of a conflicting username.
  661. session = channel.json_body["session"]
  662. channel = self.make_request(
  663. "POST",
  664. "register",
  665. {"auth": {"session": session, "type": LoginType.DUMMY}},
  666. )
  667. self.assertEqual(channel.code, 400, channel.json_body)
  668. self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE)
  669. class AccountValidityTestCase(unittest.HomeserverTestCase):
  670. servlets = [
  671. register.register_servlets,
  672. synapse.rest.admin.register_servlets_for_client_rest_resource,
  673. login.register_servlets,
  674. sync.register_servlets,
  675. logout.register_servlets,
  676. account_validity.register_servlets,
  677. ]
  678. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  679. config = self.default_config()
  680. # Test for account expiring after a week.
  681. config["enable_registration"] = True
  682. config["account_validity"] = {
  683. "enabled": True,
  684. "period": 604800000, # Time in ms for 1 week
  685. }
  686. self.hs = self.setup_test_homeserver(config=config)
  687. return self.hs
  688. def test_validity_period(self) -> None:
  689. self.register_user("kermit", "monkey")
  690. tok = self.login("kermit", "monkey")
  691. # The specific endpoint doesn't matter, all we need is an authenticated
  692. # endpoint.
  693. channel = self.make_request(b"GET", "/sync", access_token=tok)
  694. self.assertEqual(channel.code, 200, msg=channel.result)
  695. self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
  696. channel = self.make_request(b"GET", "/sync", access_token=tok)
  697. self.assertEqual(channel.code, 403, msg=channel.result)
  698. self.assertEqual(
  699. channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
  700. )
  701. def test_manual_renewal(self) -> None:
  702. user_id = self.register_user("kermit", "monkey")
  703. tok = self.login("kermit", "monkey")
  704. self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
  705. # If we register the admin user at the beginning of the test, it will
  706. # expire at the same time as the normal user and the renewal request
  707. # will be denied.
  708. self.register_user("admin", "adminpassword", admin=True)
  709. admin_tok = self.login("admin", "adminpassword")
  710. url = "/_synapse/admin/v1/account_validity/validity"
  711. request_data = {"user_id": user_id}
  712. channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
  713. self.assertEqual(channel.code, 200, msg=channel.result)
  714. # The specific endpoint doesn't matter, all we need is an authenticated
  715. # endpoint.
  716. channel = self.make_request(b"GET", "/sync", access_token=tok)
  717. self.assertEqual(channel.code, 200, msg=channel.result)
  718. def test_manual_expire(self) -> None:
  719. user_id = self.register_user("kermit", "monkey")
  720. tok = self.login("kermit", "monkey")
  721. self.register_user("admin", "adminpassword", admin=True)
  722. admin_tok = self.login("admin", "adminpassword")
  723. url = "/_synapse/admin/v1/account_validity/validity"
  724. request_data = {
  725. "user_id": user_id,
  726. "expiration_ts": 0,
  727. "enable_renewal_emails": False,
  728. }
  729. channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
  730. self.assertEqual(channel.code, 200, msg=channel.result)
  731. # The specific endpoint doesn't matter, all we need is an authenticated
  732. # endpoint.
  733. channel = self.make_request(b"GET", "/sync", access_token=tok)
  734. self.assertEqual(channel.code, 403, msg=channel.result)
  735. self.assertEqual(
  736. channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
  737. )
  738. def test_logging_out_expired_user(self) -> None:
  739. user_id = self.register_user("kermit", "monkey")
  740. tok = self.login("kermit", "monkey")
  741. self.register_user("admin", "adminpassword", admin=True)
  742. admin_tok = self.login("admin", "adminpassword")
  743. url = "/_synapse/admin/v1/account_validity/validity"
  744. request_data = {
  745. "user_id": user_id,
  746. "expiration_ts": 0,
  747. "enable_renewal_emails": False,
  748. }
  749. channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
  750. self.assertEqual(channel.code, 200, msg=channel.result)
  751. # Try to log the user out
  752. channel = self.make_request(b"POST", "/logout", access_token=tok)
  753. self.assertEqual(channel.code, 200, msg=channel.result)
  754. # Log the user in again (allowed for expired accounts)
  755. tok = self.login("kermit", "monkey")
  756. # Try to log out all of the user's sessions
  757. channel = self.make_request(b"POST", "/logout/all", access_token=tok)
  758. self.assertEqual(channel.code, 200, msg=channel.result)
  759. class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
  760. servlets = [
  761. register.register_servlets,
  762. synapse.rest.admin.register_servlets_for_client_rest_resource,
  763. login.register_servlets,
  764. sync.register_servlets,
  765. account_validity.register_servlets,
  766. account.register_servlets,
  767. ]
  768. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  769. config = self.default_config()
  770. # Test for account expiring after a week and renewal emails being sent 2
  771. # days before expiry.
  772. config["enable_registration"] = True
  773. config["account_validity"] = {
  774. "enabled": True,
  775. "period": 604800000, # Time in ms for 1 week
  776. "renew_at": 172800000, # Time in ms for 2 days
  777. "renew_by_email_enabled": True,
  778. "renew_email_subject": "Renew your account",
  779. "account_renewed_html_path": "account_renewed.html",
  780. "invalid_token_html_path": "invalid_token.html",
  781. }
  782. # Email config.
  783. config["email"] = {
  784. "enable_notifs": True,
  785. "template_dir": os.path.abspath(
  786. pkg_resources.resource_filename("synapse", "res/templates")
  787. ),
  788. "expiry_template_html": "notice_expiry.html",
  789. "expiry_template_text": "notice_expiry.txt",
  790. "notif_template_html": "notif_mail.html",
  791. "notif_template_text": "notif_mail.txt",
  792. "smtp_host": "127.0.0.1",
  793. "smtp_port": 20,
  794. "require_transport_security": False,
  795. "smtp_user": None,
  796. "smtp_pass": None,
  797. "notif_from": "test@example.com",
  798. }
  799. self.hs = self.setup_test_homeserver(config=config)
  800. async def sendmail(*args: Any, **kwargs: Any) -> None:
  801. self.email_attempts.append((args, kwargs))
  802. self.email_attempts: List[Tuple[Any, Any]] = []
  803. self.hs.get_send_email_handler()._sendmail = sendmail
  804. self.store = self.hs.get_datastores().main
  805. return self.hs
  806. def test_renewal_email(self) -> None:
  807. self.email_attempts = []
  808. (user_id, tok) = self.create_user()
  809. # Move 5 days forward. This should trigger a renewal email to be sent.
  810. self.reactor.advance(datetime.timedelta(days=5).total_seconds())
  811. self.assertEqual(len(self.email_attempts), 1)
  812. # Retrieving the URL from the email is too much pain for now, so we
  813. # retrieve the token from the DB.
  814. renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
  815. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
  816. channel = self.make_request(b"GET", url)
  817. self.assertEqual(channel.code, 200, msg=channel.result)
  818. # Check that we're getting HTML back.
  819. content_type = channel.headers.getRawHeaders(b"Content-Type")
  820. self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
  821. # Check that the HTML we're getting is the one we expect on a successful renewal.
  822. expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
  823. expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render(
  824. expiration_ts=expiration_ts
  825. )
  826. self.assertEqual(
  827. channel.result["body"], expected_html.encode("utf8"), channel.result
  828. )
  829. # Move 1 day forward. Try to renew with the same token again.
  830. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
  831. channel = self.make_request(b"GET", url)
  832. self.assertEqual(channel.code, 200, msg=channel.result)
  833. # Check that we're getting HTML back.
  834. content_type = channel.headers.getRawHeaders(b"Content-Type")
  835. self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
  836. # Check that the HTML we're getting is the one we expect when reusing a
  837. # token. The account expiration date should not have changed.
  838. expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render(
  839. expiration_ts=expiration_ts
  840. )
  841. self.assertEqual(
  842. channel.result["body"], expected_html.encode("utf8"), channel.result
  843. )
  844. # Move 3 days forward. If the renewal failed, every authed request with
  845. # our access token should be denied from now, otherwise they should
  846. # succeed.
  847. self.reactor.advance(datetime.timedelta(days=3).total_seconds())
  848. channel = self.make_request(b"GET", "/sync", access_token=tok)
  849. self.assertEqual(channel.code, 200, msg=channel.result)
  850. def test_renewal_invalid_token(self) -> None:
  851. # Hit the renewal endpoint with an invalid token and check that it behaves as
  852. # expected, i.e. that it responds with 404 Not Found and the correct HTML.
  853. url = "/_matrix/client/unstable/account_validity/renew?token=123"
  854. channel = self.make_request(b"GET", url)
  855. self.assertEqual(channel.code, 404, msg=channel.result)
  856. # Check that we're getting HTML back.
  857. content_type = channel.headers.getRawHeaders(b"Content-Type")
  858. self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
  859. # Check that the HTML we're getting is the one we expect when using an
  860. # invalid/unknown token.
  861. expected_html = (
  862. self.hs.config.account_validity.account_validity_invalid_token_template.render()
  863. )
  864. self.assertEqual(
  865. channel.result["body"], expected_html.encode("utf8"), channel.result
  866. )
  867. def test_manual_email_send(self) -> None:
  868. self.email_attempts = []
  869. (user_id, tok) = self.create_user()
  870. channel = self.make_request(
  871. b"POST",
  872. "/_matrix/client/unstable/account_validity/send_mail",
  873. access_token=tok,
  874. )
  875. self.assertEqual(channel.code, 200, msg=channel.result)
  876. self.assertEqual(len(self.email_attempts), 1)
  877. def test_deactivated_user(self) -> None:
  878. self.email_attempts = []
  879. (user_id, tok) = self.create_user()
  880. request_data = {
  881. "auth": {
  882. "type": "m.login.password",
  883. "user": user_id,
  884. "password": "monkey",
  885. },
  886. "erase": False,
  887. }
  888. channel = self.make_request(
  889. "POST", "account/deactivate", request_data, access_token=tok
  890. )
  891. self.assertEqual(channel.code, 200)
  892. self.reactor.advance(datetime.timedelta(days=8).total_seconds())
  893. self.assertEqual(len(self.email_attempts), 0)
  894. def create_user(self) -> Tuple[str, str]:
  895. user_id = self.register_user("kermit", "monkey")
  896. tok = self.login("kermit", "monkey")
  897. # We need to manually add an email address otherwise the handler will do
  898. # nothing.
  899. now = self.hs.get_clock().time_msec()
  900. self.get_success(
  901. self.store.user_add_threepid(
  902. user_id=user_id,
  903. medium="email",
  904. address="kermit@example.com",
  905. validated_at=now,
  906. added_at=now,
  907. )
  908. )
  909. return user_id, tok
  910. def test_manual_email_send_expired_account(self) -> None:
  911. user_id = self.register_user("kermit", "monkey")
  912. tok = self.login("kermit", "monkey")
  913. # We need to manually add an email address otherwise the handler will do
  914. # nothing.
  915. now = self.hs.get_clock().time_msec()
  916. self.get_success(
  917. self.store.user_add_threepid(
  918. user_id=user_id,
  919. medium="email",
  920. address="kermit@example.com",
  921. validated_at=now,
  922. added_at=now,
  923. )
  924. )
  925. # Make the account expire.
  926. self.reactor.advance(datetime.timedelta(days=8).total_seconds())
  927. # Ignore all emails sent by the automatic background task and only focus on the
  928. # ones sent manually.
  929. self.email_attempts = []
  930. # Test that we're still able to manually trigger a mail to be sent.
  931. channel = self.make_request(
  932. b"POST",
  933. "/_matrix/client/unstable/account_validity/send_mail",
  934. access_token=tok,
  935. )
  936. self.assertEqual(channel.code, 200, msg=channel.result)
  937. self.assertEqual(len(self.email_attempts), 1)
  938. class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
  939. servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
  940. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  941. self.validity_period = 10
  942. self.max_delta = self.validity_period * 10.0 / 100.0
  943. config = self.default_config()
  944. config["enable_registration"] = True
  945. config["account_validity"] = {"enabled": False}
  946. self.hs = self.setup_test_homeserver(config=config)
  947. # We need to set these directly, instead of in the homeserver config dict above.
  948. # This is due to account validity-related config options not being read by
  949. # Synapse when account_validity.enabled is False.
  950. self.hs.get_datastores().main._account_validity_period = self.validity_period
  951. self.hs.get_datastores().main._account_validity_startup_job_max_delta = (
  952. self.max_delta
  953. )
  954. self.store = self.hs.get_datastores().main
  955. return self.hs
  956. def test_background_job(self) -> None:
  957. """
  958. Tests the same thing as test_background_job, except that it sets the
  959. startup_job_max_delta parameter and checks that the expiration date is within the
  960. allowed range.
  961. """
  962. user_id = self.register_user("kermit_delta", "user")
  963. self.hs.config.account_validity.startup_job_max_delta = self.max_delta
  964. now_ms = self.hs.get_clock().time_msec()
  965. self.get_success(self.store._set_expiration_date_when_missing())
  966. res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
  967. self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
  968. self.assertLessEqual(res, now_ms + self.validity_period)
  969. class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
  970. servlets = [register.register_servlets]
  971. url = "/_matrix/client/v1/register/m.login.registration_token/validity"
  972. def default_config(self) -> Dict[str, Any]:
  973. config = super().default_config()
  974. config["registration_requires_token"] = True
  975. return config
  976. def test_GET_token_valid(self) -> None:
  977. token = "abcd"
  978. store = self.hs.get_datastores().main
  979. self.get_success(
  980. store.db_pool.simple_insert(
  981. "registration_tokens",
  982. {
  983. "token": token,
  984. "uses_allowed": None,
  985. "pending": 0,
  986. "completed": 0,
  987. "expiry_time": None,
  988. },
  989. )
  990. )
  991. channel = self.make_request(
  992. b"GET",
  993. f"{self.url}?token={token}",
  994. )
  995. self.assertEqual(channel.code, 200, msg=channel.result)
  996. self.assertEqual(channel.json_body["valid"], True)
  997. def test_GET_token_invalid(self) -> None:
  998. token = "1234"
  999. channel = self.make_request(
  1000. b"GET",
  1001. f"{self.url}?token={token}",
  1002. )
  1003. self.assertEqual(channel.code, 200, msg=channel.result)
  1004. self.assertEqual(channel.json_body["valid"], False)
  1005. @override_config(
  1006. {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
  1007. )
  1008. def test_GET_ratelimiting(self) -> None:
  1009. token = "1234"
  1010. for i in range(0, 6):
  1011. channel = self.make_request(
  1012. b"GET",
  1013. f"{self.url}?token={token}",
  1014. )
  1015. if i == 5:
  1016. self.assertEqual(channel.code, 429, msg=channel.result)
  1017. retry_after_ms = int(channel.json_body["retry_after_ms"])
  1018. else:
  1019. self.assertEqual(channel.code, 200, msg=channel.result)
  1020. self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
  1021. channel = self.make_request(
  1022. b"GET",
  1023. f"{self.url}?token={token}",
  1024. )
  1025. self.assertEqual(channel.code, 200, msg=channel.result)