test_register.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2017-2018 New Vector Ltd
  3. # Copyright 2019 The Matrix.org Foundation C.I.C.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import datetime
  17. import json
  18. import os
  19. import pkg_resources
  20. import synapse.rest.admin
  21. from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
  22. from synapse.api.errors import Codes
  23. from synapse.appservice import ApplicationService
  24. from synapse.rest.client import account, account_validity, login, logout, register, sync
  25. from synapse.storage._base import db_to_json
  26. from tests import unittest
  27. from tests.unittest import override_config
  28. class RegisterRestServletTestCase(unittest.HomeserverTestCase):
  29. servlets = [
  30. login.register_servlets,
  31. register.register_servlets,
  32. synapse.rest.admin.register_servlets,
  33. ]
  34. url = b"/_matrix/client/r0/register"
  35. def default_config(self):
  36. config = super().default_config()
  37. config["allow_guest_access"] = True
  38. return config
  39. def test_POST_appservice_registration_valid(self):
  40. user_id = "@as_user_kermit:test"
  41. as_token = "i_am_an_app_service"
  42. appservice = ApplicationService(
  43. as_token,
  44. self.hs.config.server.server_name,
  45. id="1234",
  46. namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
  47. sender="@as:test",
  48. )
  49. self.hs.get_datastore().services_cache.append(appservice)
  50. request_data = json.dumps(
  51. {"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
  52. )
  53. channel = self.make_request(
  54. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  55. )
  56. self.assertEquals(channel.result["code"], b"200", channel.result)
  57. det_data = {"user_id": user_id, "home_server": self.hs.hostname}
  58. self.assertDictContainsSubset(det_data, channel.json_body)
  59. def test_POST_appservice_registration_no_type(self):
  60. as_token = "i_am_an_app_service"
  61. appservice = ApplicationService(
  62. as_token,
  63. self.hs.config.server.server_name,
  64. id="1234",
  65. namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
  66. sender="@as:test",
  67. )
  68. self.hs.get_datastore().services_cache.append(appservice)
  69. request_data = json.dumps({"username": "as_user_kermit"})
  70. channel = self.make_request(
  71. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  72. )
  73. self.assertEquals(channel.result["code"], b"400", channel.result)
  74. def test_POST_appservice_registration_invalid(self):
  75. self.appservice = None # no application service exists
  76. request_data = json.dumps(
  77. {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
  78. )
  79. channel = self.make_request(
  80. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  81. )
  82. self.assertEquals(channel.result["code"], b"401", channel.result)
  83. def test_POST_bad_password(self):
  84. request_data = json.dumps({"username": "kermit", "password": 666})
  85. channel = self.make_request(b"POST", self.url, request_data)
  86. self.assertEquals(channel.result["code"], b"400", channel.result)
  87. self.assertEquals(channel.json_body["error"], "Invalid password")
  88. def test_POST_bad_username(self):
  89. request_data = json.dumps({"username": 777, "password": "monkey"})
  90. channel = self.make_request(b"POST", self.url, request_data)
  91. self.assertEquals(channel.result["code"], b"400", channel.result)
  92. self.assertEquals(channel.json_body["error"], "Invalid username")
  93. def test_POST_user_valid(self):
  94. user_id = "@kermit:test"
  95. device_id = "frogfone"
  96. params = {
  97. "username": "kermit",
  98. "password": "monkey",
  99. "device_id": device_id,
  100. "auth": {"type": LoginType.DUMMY},
  101. }
  102. request_data = json.dumps(params)
  103. channel = self.make_request(b"POST", self.url, request_data)
  104. det_data = {
  105. "user_id": user_id,
  106. "home_server": self.hs.hostname,
  107. "device_id": device_id,
  108. }
  109. self.assertEquals(channel.result["code"], b"200", channel.result)
  110. self.assertDictContainsSubset(det_data, channel.json_body)
  111. @override_config({"enable_registration": False})
  112. def test_POST_disabled_registration(self):
  113. request_data = json.dumps({"username": "kermit", "password": "monkey"})
  114. self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
  115. channel = self.make_request(b"POST", self.url, request_data)
  116. self.assertEquals(channel.result["code"], b"403", channel.result)
  117. self.assertEquals(channel.json_body["error"], "Registration has been disabled")
  118. self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN")
  119. def test_POST_guest_registration(self):
  120. self.hs.config.key.macaroon_secret_key = "test"
  121. self.hs.config.registration.allow_guest_access = True
  122. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  123. det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
  124. self.assertEquals(channel.result["code"], b"200", channel.result)
  125. self.assertDictContainsSubset(det_data, channel.json_body)
  126. def test_POST_disabled_guest_registration(self):
  127. self.hs.config.registration.allow_guest_access = False
  128. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  129. self.assertEquals(channel.result["code"], b"403", channel.result)
  130. self.assertEquals(channel.json_body["error"], "Guest access is disabled")
  131. @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
  132. def test_POST_ratelimiting_guest(self):
  133. for i in range(0, 6):
  134. url = self.url + b"?kind=guest"
  135. channel = self.make_request(b"POST", url, b"{}")
  136. if i == 5:
  137. self.assertEquals(channel.result["code"], b"429", channel.result)
  138. retry_after_ms = int(channel.json_body["retry_after_ms"])
  139. else:
  140. self.assertEquals(channel.result["code"], b"200", channel.result)
  141. self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
  142. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  143. self.assertEquals(channel.result["code"], b"200", channel.result)
  144. @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
  145. def test_POST_ratelimiting(self):
  146. for i in range(0, 6):
  147. params = {
  148. "username": "kermit" + str(i),
  149. "password": "monkey",
  150. "device_id": "frogfone",
  151. "auth": {"type": LoginType.DUMMY},
  152. }
  153. request_data = json.dumps(params)
  154. channel = self.make_request(b"POST", self.url, request_data)
  155. if i == 5:
  156. self.assertEquals(channel.result["code"], b"429", channel.result)
  157. retry_after_ms = int(channel.json_body["retry_after_ms"])
  158. else:
  159. self.assertEquals(channel.result["code"], b"200", channel.result)
  160. self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
  161. channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  162. self.assertEquals(channel.result["code"], b"200", channel.result)
  163. @override_config({"registration_requires_token": True})
  164. def test_POST_registration_requires_token(self):
  165. username = "kermit"
  166. device_id = "frogfone"
  167. token = "abcd"
  168. store = self.hs.get_datastore()
  169. self.get_success(
  170. store.db_pool.simple_insert(
  171. "registration_tokens",
  172. {
  173. "token": token,
  174. "uses_allowed": None,
  175. "pending": 0,
  176. "completed": 0,
  177. "expiry_time": None,
  178. },
  179. )
  180. )
  181. params = {
  182. "username": username,
  183. "password": "monkey",
  184. "device_id": device_id,
  185. }
  186. # Request without auth to get flows and session
  187. channel = self.make_request(b"POST", self.url, json.dumps(params))
  188. self.assertEquals(channel.result["code"], b"401", channel.result)
  189. flows = channel.json_body["flows"]
  190. # Synapse adds a dummy stage to differentiate flows where otherwise one
  191. # flow would be a subset of another flow.
  192. self.assertCountEqual(
  193. [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
  194. (f["stages"] for f in flows),
  195. )
  196. session = channel.json_body["session"]
  197. # Do the registration token stage and check it has completed
  198. params["auth"] = {
  199. "type": LoginType.REGISTRATION_TOKEN,
  200. "token": token,
  201. "session": session,
  202. }
  203. request_data = json.dumps(params)
  204. channel = self.make_request(b"POST", self.url, request_data)
  205. self.assertEquals(channel.result["code"], b"401", channel.result)
  206. completed = channel.json_body["completed"]
  207. self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
  208. # Do the m.login.dummy stage and check registration was successful
  209. params["auth"] = {
  210. "type": LoginType.DUMMY,
  211. "session": session,
  212. }
  213. request_data = json.dumps(params)
  214. channel = self.make_request(b"POST", self.url, request_data)
  215. det_data = {
  216. "user_id": f"@{username}:{self.hs.hostname}",
  217. "home_server": self.hs.hostname,
  218. "device_id": device_id,
  219. }
  220. self.assertEquals(channel.result["code"], b"200", channel.result)
  221. self.assertDictContainsSubset(det_data, channel.json_body)
  222. # Check the `completed` counter has been incremented and pending is 0
  223. res = self.get_success(
  224. store.db_pool.simple_select_one(
  225. "registration_tokens",
  226. keyvalues={"token": token},
  227. retcols=["pending", "completed"],
  228. )
  229. )
  230. self.assertEquals(res["completed"], 1)
  231. self.assertEquals(res["pending"], 0)
  232. @override_config({"registration_requires_token": True})
  233. def test_POST_registration_token_invalid(self):
  234. params = {
  235. "username": "kermit",
  236. "password": "monkey",
  237. }
  238. # Request without auth to get session
  239. channel = self.make_request(b"POST", self.url, json.dumps(params))
  240. session = channel.json_body["session"]
  241. # Test with token param missing (invalid)
  242. params["auth"] = {
  243. "type": LoginType.REGISTRATION_TOKEN,
  244. "session": session,
  245. }
  246. channel = self.make_request(b"POST", self.url, json.dumps(params))
  247. self.assertEquals(channel.result["code"], b"401", channel.result)
  248. self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
  249. self.assertEquals(channel.json_body["completed"], [])
  250. # Test with non-string (invalid)
  251. params["auth"]["token"] = 1234
  252. channel = self.make_request(b"POST", self.url, json.dumps(params))
  253. self.assertEquals(channel.result["code"], b"401", channel.result)
  254. self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
  255. self.assertEquals(channel.json_body["completed"], [])
  256. # Test with unknown token (invalid)
  257. params["auth"]["token"] = "1234"
  258. channel = self.make_request(b"POST", self.url, json.dumps(params))
  259. self.assertEquals(channel.result["code"], b"401", channel.result)
  260. self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  261. self.assertEquals(channel.json_body["completed"], [])
  262. @override_config({"registration_requires_token": True})
  263. def test_POST_registration_token_limit_uses(self):
  264. token = "abcd"
  265. store = self.hs.get_datastore()
  266. # Create token that can be used once
  267. self.get_success(
  268. store.db_pool.simple_insert(
  269. "registration_tokens",
  270. {
  271. "token": token,
  272. "uses_allowed": 1,
  273. "pending": 0,
  274. "completed": 0,
  275. "expiry_time": None,
  276. },
  277. )
  278. )
  279. params1 = {"username": "bert", "password": "monkey"}
  280. params2 = {"username": "ernie", "password": "monkey"}
  281. # Do 2 requests without auth to get two session IDs
  282. channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
  283. session1 = channel1.json_body["session"]
  284. channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
  285. session2 = channel2.json_body["session"]
  286. # Use token with session1 and check `pending` is 1
  287. params1["auth"] = {
  288. "type": LoginType.REGISTRATION_TOKEN,
  289. "token": token,
  290. "session": session1,
  291. }
  292. self.make_request(b"POST", self.url, json.dumps(params1))
  293. # Repeat request to make sure pending isn't increased again
  294. self.make_request(b"POST", self.url, json.dumps(params1))
  295. pending = self.get_success(
  296. store.db_pool.simple_select_one_onecol(
  297. "registration_tokens",
  298. keyvalues={"token": token},
  299. retcol="pending",
  300. )
  301. )
  302. self.assertEquals(pending, 1)
  303. # Check auth fails when using token with session2
  304. params2["auth"] = {
  305. "type": LoginType.REGISTRATION_TOKEN,
  306. "token": token,
  307. "session": session2,
  308. }
  309. channel = self.make_request(b"POST", self.url, json.dumps(params2))
  310. self.assertEquals(channel.result["code"], b"401", channel.result)
  311. self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  312. self.assertEquals(channel.json_body["completed"], [])
  313. # Complete registration with session1
  314. params1["auth"]["type"] = LoginType.DUMMY
  315. self.make_request(b"POST", self.url, json.dumps(params1))
  316. # Check pending=0 and completed=1
  317. res = self.get_success(
  318. store.db_pool.simple_select_one(
  319. "registration_tokens",
  320. keyvalues={"token": token},
  321. retcols=["pending", "completed"],
  322. )
  323. )
  324. self.assertEquals(res["pending"], 0)
  325. self.assertEquals(res["completed"], 1)
  326. # Check auth still fails when using token with session2
  327. channel = self.make_request(b"POST", self.url, json.dumps(params2))
  328. self.assertEquals(channel.result["code"], b"401", channel.result)
  329. self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  330. self.assertEquals(channel.json_body["completed"], [])
  331. @override_config({"registration_requires_token": True})
  332. def test_POST_registration_token_expiry(self):
  333. token = "abcd"
  334. now = self.hs.get_clock().time_msec()
  335. store = self.hs.get_datastore()
  336. # Create token that expired yesterday
  337. self.get_success(
  338. store.db_pool.simple_insert(
  339. "registration_tokens",
  340. {
  341. "token": token,
  342. "uses_allowed": None,
  343. "pending": 0,
  344. "completed": 0,
  345. "expiry_time": now - 24 * 60 * 60 * 1000,
  346. },
  347. )
  348. )
  349. params = {"username": "kermit", "password": "monkey"}
  350. # Request without auth to get session
  351. channel = self.make_request(b"POST", self.url, json.dumps(params))
  352. session = channel.json_body["session"]
  353. # Check authentication fails with expired token
  354. params["auth"] = {
  355. "type": LoginType.REGISTRATION_TOKEN,
  356. "token": token,
  357. "session": session,
  358. }
  359. channel = self.make_request(b"POST", self.url, json.dumps(params))
  360. self.assertEquals(channel.result["code"], b"401", channel.result)
  361. self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
  362. self.assertEquals(channel.json_body["completed"], [])
  363. # Update token so it expires tomorrow
  364. self.get_success(
  365. store.db_pool.simple_update_one(
  366. "registration_tokens",
  367. keyvalues={"token": token},
  368. updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
  369. )
  370. )
  371. # Check authentication succeeds
  372. channel = self.make_request(b"POST", self.url, json.dumps(params))
  373. completed = channel.json_body["completed"]
  374. self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
  375. @override_config({"registration_requires_token": True})
  376. def test_POST_registration_token_session_expiry(self):
  377. """Test `pending` is decremented when an uncompleted session expires."""
  378. token = "abcd"
  379. store = self.hs.get_datastore()
  380. self.get_success(
  381. store.db_pool.simple_insert(
  382. "registration_tokens",
  383. {
  384. "token": token,
  385. "uses_allowed": None,
  386. "pending": 0,
  387. "completed": 0,
  388. "expiry_time": None,
  389. },
  390. )
  391. )
  392. # Do 2 requests without auth to get two session IDs
  393. params1 = {"username": "bert", "password": "monkey"}
  394. params2 = {"username": "ernie", "password": "monkey"}
  395. channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
  396. session1 = channel1.json_body["session"]
  397. channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
  398. session2 = channel2.json_body["session"]
  399. # Use token with both sessions
  400. params1["auth"] = {
  401. "type": LoginType.REGISTRATION_TOKEN,
  402. "token": token,
  403. "session": session1,
  404. }
  405. self.make_request(b"POST", self.url, json.dumps(params1))
  406. params2["auth"] = {
  407. "type": LoginType.REGISTRATION_TOKEN,
  408. "token": token,
  409. "session": session2,
  410. }
  411. self.make_request(b"POST", self.url, json.dumps(params2))
  412. # Complete registration with session1
  413. params1["auth"]["type"] = LoginType.DUMMY
  414. self.make_request(b"POST", self.url, json.dumps(params1))
  415. # Check `result` of registration token stage for session1 is `True`
  416. result1 = self.get_success(
  417. store.db_pool.simple_select_one_onecol(
  418. "ui_auth_sessions_credentials",
  419. keyvalues={
  420. "session_id": session1,
  421. "stage_type": LoginType.REGISTRATION_TOKEN,
  422. },
  423. retcol="result",
  424. )
  425. )
  426. self.assertTrue(db_to_json(result1))
  427. # Check `result` for session2 is the token used
  428. result2 = self.get_success(
  429. store.db_pool.simple_select_one_onecol(
  430. "ui_auth_sessions_credentials",
  431. keyvalues={
  432. "session_id": session2,
  433. "stage_type": LoginType.REGISTRATION_TOKEN,
  434. },
  435. retcol="result",
  436. )
  437. )
  438. self.assertEquals(db_to_json(result2), token)
  439. # Delete both sessions (mimics expiry)
  440. self.get_success(
  441. store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
  442. )
  443. # Check pending is now 0
  444. pending = self.get_success(
  445. store.db_pool.simple_select_one_onecol(
  446. "registration_tokens",
  447. keyvalues={"token": token},
  448. retcol="pending",
  449. )
  450. )
  451. self.assertEquals(pending, 0)
  452. @override_config({"registration_requires_token": True})
  453. def test_POST_registration_token_session_expiry_deleted_token(self):
  454. """Test session expiry doesn't break when the token is deleted.
  455. 1. Start but don't complete UIA with a registration token
  456. 2. Delete the token from the database
  457. 3. Expire the session
  458. """
  459. token = "abcd"
  460. store = self.hs.get_datastore()
  461. self.get_success(
  462. store.db_pool.simple_insert(
  463. "registration_tokens",
  464. {
  465. "token": token,
  466. "uses_allowed": None,
  467. "pending": 0,
  468. "completed": 0,
  469. "expiry_time": None,
  470. },
  471. )
  472. )
  473. # Do request without auth to get a session ID
  474. params = {"username": "kermit", "password": "monkey"}
  475. channel = self.make_request(b"POST", self.url, json.dumps(params))
  476. session = channel.json_body["session"]
  477. # Use token
  478. params["auth"] = {
  479. "type": LoginType.REGISTRATION_TOKEN,
  480. "token": token,
  481. "session": session,
  482. }
  483. self.make_request(b"POST", self.url, json.dumps(params))
  484. # Delete token
  485. self.get_success(
  486. store.db_pool.simple_delete_one(
  487. "registration_tokens",
  488. keyvalues={"token": token},
  489. )
  490. )
  491. # Delete session (mimics expiry)
  492. self.get_success(
  493. store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
  494. )
  495. def test_advertised_flows(self):
  496. channel = self.make_request(b"POST", self.url, b"{}")
  497. self.assertEquals(channel.result["code"], b"401", channel.result)
  498. flows = channel.json_body["flows"]
  499. # with the stock config, we only expect the dummy flow
  500. self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
  501. @unittest.override_config(
  502. {
  503. "public_baseurl": "https://test_server",
  504. "enable_registration_captcha": True,
  505. "user_consent": {
  506. "version": "1",
  507. "template_dir": "/",
  508. "require_at_registration": True,
  509. },
  510. "account_threepid_delegates": {
  511. "email": "https://id_server",
  512. "msisdn": "https://id_server",
  513. },
  514. }
  515. )
  516. def test_advertised_flows_captcha_and_terms_and_3pids(self):
  517. channel = self.make_request(b"POST", self.url, b"{}")
  518. self.assertEquals(channel.result["code"], b"401", channel.result)
  519. flows = channel.json_body["flows"]
  520. self.assertCountEqual(
  521. [
  522. ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
  523. ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
  524. ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
  525. [
  526. "m.login.recaptcha",
  527. "m.login.terms",
  528. "m.login.msisdn",
  529. "m.login.email.identity",
  530. ],
  531. ],
  532. (f["stages"] for f in flows),
  533. )
  534. @unittest.override_config(
  535. {
  536. "public_baseurl": "https://test_server",
  537. "registrations_require_3pid": ["email"],
  538. "disable_msisdn_registration": True,
  539. "email": {
  540. "smtp_host": "mail_server",
  541. "smtp_port": 2525,
  542. "notif_from": "sender@host",
  543. },
  544. }
  545. )
  546. def test_advertised_flows_no_msisdn_email_required(self):
  547. channel = self.make_request(b"POST", self.url, b"{}")
  548. self.assertEquals(channel.result["code"], b"401", channel.result)
  549. flows = channel.json_body["flows"]
  550. # with the stock config, we expect all four combinations of 3pid
  551. self.assertCountEqual(
  552. [["m.login.email.identity"]], (f["stages"] for f in flows)
  553. )
  554. @unittest.override_config(
  555. {
  556. "request_token_inhibit_3pid_errors": True,
  557. "public_baseurl": "https://test_server",
  558. "email": {
  559. "smtp_host": "mail_server",
  560. "smtp_port": 2525,
  561. "notif_from": "sender@host",
  562. },
  563. }
  564. )
  565. def test_request_token_existing_email_inhibit_error(self):
  566. """Test that requesting a token via this endpoint doesn't leak existing
  567. associations if configured that way.
  568. """
  569. user_id = self.register_user("kermit", "monkey")
  570. self.login("kermit", "monkey")
  571. email = "test@example.com"
  572. # Add a threepid
  573. self.get_success(
  574. self.hs.get_datastore().user_add_threepid(
  575. user_id=user_id,
  576. medium="email",
  577. address=email,
  578. validated_at=0,
  579. added_at=0,
  580. )
  581. )
  582. channel = self.make_request(
  583. "POST",
  584. b"register/email/requestToken",
  585. {"client_secret": "foobar", "email": email, "send_attempt": 1},
  586. )
  587. self.assertEquals(200, channel.code, channel.result)
  588. self.assertIsNotNone(channel.json_body.get("sid"))
  589. @unittest.override_config(
  590. {
  591. "public_baseurl": "https://test_server",
  592. "email": {
  593. "smtp_host": "mail_server",
  594. "smtp_port": 2525,
  595. "notif_from": "sender@host",
  596. },
  597. }
  598. )
  599. def test_reject_invalid_email(self):
  600. """Check that bad emails are rejected"""
  601. # Test for email with multiple @
  602. channel = self.make_request(
  603. "POST",
  604. b"register/email/requestToken",
  605. {"client_secret": "foobar", "email": "email@@email", "send_attempt": 1},
  606. )
  607. self.assertEquals(400, channel.code, channel.result)
  608. # Check error to ensure that we're not erroring due to a bug in the test.
  609. self.assertEquals(
  610. channel.json_body,
  611. {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
  612. )
  613. # Test for email with no @
  614. channel = self.make_request(
  615. "POST",
  616. b"register/email/requestToken",
  617. {"client_secret": "foobar", "email": "email", "send_attempt": 1},
  618. )
  619. self.assertEquals(400, channel.code, channel.result)
  620. self.assertEquals(
  621. channel.json_body,
  622. {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
  623. )
  624. # Test for super long email
  625. email = "a@" + "a" * 1000
  626. channel = self.make_request(
  627. "POST",
  628. b"register/email/requestToken",
  629. {"client_secret": "foobar", "email": email, "send_attempt": 1},
  630. )
  631. self.assertEquals(400, channel.code, channel.result)
  632. self.assertEquals(
  633. channel.json_body,
  634. {"errcode": "M_UNKNOWN", "error": "Unable to parse email address"},
  635. )
  636. class AccountValidityTestCase(unittest.HomeserverTestCase):
  637. servlets = [
  638. register.register_servlets,
  639. synapse.rest.admin.register_servlets_for_client_rest_resource,
  640. login.register_servlets,
  641. sync.register_servlets,
  642. logout.register_servlets,
  643. account_validity.register_servlets,
  644. ]
  645. def make_homeserver(self, reactor, clock):
  646. config = self.default_config()
  647. # Test for account expiring after a week.
  648. config["enable_registration"] = True
  649. config["account_validity"] = {
  650. "enabled": True,
  651. "period": 604800000, # Time in ms for 1 week
  652. }
  653. self.hs = self.setup_test_homeserver(config=config)
  654. return self.hs
  655. def test_validity_period(self):
  656. self.register_user("kermit", "monkey")
  657. tok = self.login("kermit", "monkey")
  658. # The specific endpoint doesn't matter, all we need is an authenticated
  659. # endpoint.
  660. channel = self.make_request(b"GET", "/sync", access_token=tok)
  661. self.assertEquals(channel.result["code"], b"200", channel.result)
  662. self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
  663. channel = self.make_request(b"GET", "/sync", access_token=tok)
  664. self.assertEquals(channel.result["code"], b"403", channel.result)
  665. self.assertEquals(
  666. channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
  667. )
  668. def test_manual_renewal(self):
  669. user_id = self.register_user("kermit", "monkey")
  670. tok = self.login("kermit", "monkey")
  671. self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
  672. # If we register the admin user at the beginning of the test, it will
  673. # expire at the same time as the normal user and the renewal request
  674. # will be denied.
  675. self.register_user("admin", "adminpassword", admin=True)
  676. admin_tok = self.login("admin", "adminpassword")
  677. url = "/_synapse/admin/v1/account_validity/validity"
  678. params = {"user_id": user_id}
  679. request_data = json.dumps(params)
  680. channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
  681. self.assertEquals(channel.result["code"], b"200", channel.result)
  682. # The specific endpoint doesn't matter, all we need is an authenticated
  683. # endpoint.
  684. channel = self.make_request(b"GET", "/sync", access_token=tok)
  685. self.assertEquals(channel.result["code"], b"200", channel.result)
  686. def test_manual_expire(self):
  687. user_id = self.register_user("kermit", "monkey")
  688. tok = self.login("kermit", "monkey")
  689. self.register_user("admin", "adminpassword", admin=True)
  690. admin_tok = self.login("admin", "adminpassword")
  691. url = "/_synapse/admin/v1/account_validity/validity"
  692. params = {
  693. "user_id": user_id,
  694. "expiration_ts": 0,
  695. "enable_renewal_emails": False,
  696. }
  697. request_data = json.dumps(params)
  698. channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
  699. self.assertEquals(channel.result["code"], b"200", channel.result)
  700. # The specific endpoint doesn't matter, all we need is an authenticated
  701. # endpoint.
  702. channel = self.make_request(b"GET", "/sync", access_token=tok)
  703. self.assertEquals(channel.result["code"], b"403", channel.result)
  704. self.assertEquals(
  705. channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
  706. )
  707. def test_logging_out_expired_user(self):
  708. user_id = self.register_user("kermit", "monkey")
  709. tok = self.login("kermit", "monkey")
  710. self.register_user("admin", "adminpassword", admin=True)
  711. admin_tok = self.login("admin", "adminpassword")
  712. url = "/_synapse/admin/v1/account_validity/validity"
  713. params = {
  714. "user_id": user_id,
  715. "expiration_ts": 0,
  716. "enable_renewal_emails": False,
  717. }
  718. request_data = json.dumps(params)
  719. channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
  720. self.assertEquals(channel.result["code"], b"200", channel.result)
  721. # Try to log the user out
  722. channel = self.make_request(b"POST", "/logout", access_token=tok)
  723. self.assertEquals(channel.result["code"], b"200", channel.result)
  724. # Log the user in again (allowed for expired accounts)
  725. tok = self.login("kermit", "monkey")
  726. # Try to log out all of the user's sessions
  727. channel = self.make_request(b"POST", "/logout/all", access_token=tok)
  728. self.assertEquals(channel.result["code"], b"200", channel.result)
  729. class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
  730. servlets = [
  731. register.register_servlets,
  732. synapse.rest.admin.register_servlets_for_client_rest_resource,
  733. login.register_servlets,
  734. sync.register_servlets,
  735. account_validity.register_servlets,
  736. account.register_servlets,
  737. ]
  738. def make_homeserver(self, reactor, clock):
  739. config = self.default_config()
  740. # Test for account expiring after a week and renewal emails being sent 2
  741. # days before expiry.
  742. config["enable_registration"] = True
  743. config["account_validity"] = {
  744. "enabled": True,
  745. "period": 604800000, # Time in ms for 1 week
  746. "renew_at": 172800000, # Time in ms for 2 days
  747. "renew_by_email_enabled": True,
  748. "renew_email_subject": "Renew your account",
  749. "account_renewed_html_path": "account_renewed.html",
  750. "invalid_token_html_path": "invalid_token.html",
  751. }
  752. # Email config.
  753. config["email"] = {
  754. "enable_notifs": True,
  755. "template_dir": os.path.abspath(
  756. pkg_resources.resource_filename("synapse", "res/templates")
  757. ),
  758. "expiry_template_html": "notice_expiry.html",
  759. "expiry_template_text": "notice_expiry.txt",
  760. "notif_template_html": "notif_mail.html",
  761. "notif_template_text": "notif_mail.txt",
  762. "smtp_host": "127.0.0.1",
  763. "smtp_port": 20,
  764. "require_transport_security": False,
  765. "smtp_user": None,
  766. "smtp_pass": None,
  767. "notif_from": "test@example.com",
  768. }
  769. self.hs = self.setup_test_homeserver(config=config)
  770. async def sendmail(*args, **kwargs):
  771. self.email_attempts.append((args, kwargs))
  772. self.email_attempts = []
  773. self.hs.get_send_email_handler()._sendmail = sendmail
  774. self.store = self.hs.get_datastore()
  775. return self.hs
  776. def test_renewal_email(self):
  777. self.email_attempts = []
  778. (user_id, tok) = self.create_user()
  779. # Move 5 days forward. This should trigger a renewal email to be sent.
  780. self.reactor.advance(datetime.timedelta(days=5).total_seconds())
  781. self.assertEqual(len(self.email_attempts), 1)
  782. # Retrieving the URL from the email is too much pain for now, so we
  783. # retrieve the token from the DB.
  784. renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
  785. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
  786. channel = self.make_request(b"GET", url)
  787. self.assertEquals(channel.result["code"], b"200", channel.result)
  788. # Check that we're getting HTML back.
  789. content_type = channel.headers.getRawHeaders(b"Content-Type")
  790. self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
  791. # Check that the HTML we're getting is the one we expect on a successful renewal.
  792. expiration_ts = self.get_success(self.store.get_expiration_ts_for_user(user_id))
  793. expected_html = self.hs.config.account_validity.account_validity_account_renewed_template.render(
  794. expiration_ts=expiration_ts
  795. )
  796. self.assertEqual(
  797. channel.result["body"], expected_html.encode("utf8"), channel.result
  798. )
  799. # Move 1 day forward. Try to renew with the same token again.
  800. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
  801. channel = self.make_request(b"GET", url)
  802. self.assertEquals(channel.result["code"], b"200", channel.result)
  803. # Check that we're getting HTML back.
  804. content_type = channel.headers.getRawHeaders(b"Content-Type")
  805. self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
  806. # Check that the HTML we're getting is the one we expect when reusing a
  807. # token. The account expiration date should not have changed.
  808. expected_html = self.hs.config.account_validity.account_validity_account_previously_renewed_template.render(
  809. expiration_ts=expiration_ts
  810. )
  811. self.assertEqual(
  812. channel.result["body"], expected_html.encode("utf8"), channel.result
  813. )
  814. # Move 3 days forward. If the renewal failed, every authed request with
  815. # our access token should be denied from now, otherwise they should
  816. # succeed.
  817. self.reactor.advance(datetime.timedelta(days=3).total_seconds())
  818. channel = self.make_request(b"GET", "/sync", access_token=tok)
  819. self.assertEquals(channel.result["code"], b"200", channel.result)
  820. def test_renewal_invalid_token(self):
  821. # Hit the renewal endpoint with an invalid token and check that it behaves as
  822. # expected, i.e. that it responds with 404 Not Found and the correct HTML.
  823. url = "/_matrix/client/unstable/account_validity/renew?token=123"
  824. channel = self.make_request(b"GET", url)
  825. self.assertEquals(channel.result["code"], b"404", channel.result)
  826. # Check that we're getting HTML back.
  827. content_type = channel.headers.getRawHeaders(b"Content-Type")
  828. self.assertEqual(content_type, [b"text/html; charset=utf-8"], channel.result)
  829. # Check that the HTML we're getting is the one we expect when using an
  830. # invalid/unknown token.
  831. expected_html = (
  832. self.hs.config.account_validity.account_validity_invalid_token_template.render()
  833. )
  834. self.assertEqual(
  835. channel.result["body"], expected_html.encode("utf8"), channel.result
  836. )
  837. def test_manual_email_send(self):
  838. self.email_attempts = []
  839. (user_id, tok) = self.create_user()
  840. channel = self.make_request(
  841. b"POST",
  842. "/_matrix/client/unstable/account_validity/send_mail",
  843. access_token=tok,
  844. )
  845. self.assertEquals(channel.result["code"], b"200", channel.result)
  846. self.assertEqual(len(self.email_attempts), 1)
  847. def test_deactivated_user(self):
  848. self.email_attempts = []
  849. (user_id, tok) = self.create_user()
  850. request_data = json.dumps(
  851. {
  852. "auth": {
  853. "type": "m.login.password",
  854. "user": user_id,
  855. "password": "monkey",
  856. },
  857. "erase": False,
  858. }
  859. )
  860. channel = self.make_request(
  861. "POST", "account/deactivate", request_data, access_token=tok
  862. )
  863. self.assertEqual(channel.code, 200)
  864. self.reactor.advance(datetime.timedelta(days=8).total_seconds())
  865. self.assertEqual(len(self.email_attempts), 0)
  866. def create_user(self):
  867. user_id = self.register_user("kermit", "monkey")
  868. tok = self.login("kermit", "monkey")
  869. # We need to manually add an email address otherwise the handler will do
  870. # nothing.
  871. now = self.hs.get_clock().time_msec()
  872. self.get_success(
  873. self.store.user_add_threepid(
  874. user_id=user_id,
  875. medium="email",
  876. address="kermit@example.com",
  877. validated_at=now,
  878. added_at=now,
  879. )
  880. )
  881. return user_id, tok
  882. def test_manual_email_send_expired_account(self):
  883. user_id = self.register_user("kermit", "monkey")
  884. tok = self.login("kermit", "monkey")
  885. # We need to manually add an email address otherwise the handler will do
  886. # nothing.
  887. now = self.hs.get_clock().time_msec()
  888. self.get_success(
  889. self.store.user_add_threepid(
  890. user_id=user_id,
  891. medium="email",
  892. address="kermit@example.com",
  893. validated_at=now,
  894. added_at=now,
  895. )
  896. )
  897. # Make the account expire.
  898. self.reactor.advance(datetime.timedelta(days=8).total_seconds())
  899. # Ignore all emails sent by the automatic background task and only focus on the
  900. # ones sent manually.
  901. self.email_attempts = []
  902. # Test that we're still able to manually trigger a mail to be sent.
  903. channel = self.make_request(
  904. b"POST",
  905. "/_matrix/client/unstable/account_validity/send_mail",
  906. access_token=tok,
  907. )
  908. self.assertEquals(channel.result["code"], b"200", channel.result)
  909. self.assertEqual(len(self.email_attempts), 1)
  910. class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
  911. servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
  912. def make_homeserver(self, reactor, clock):
  913. self.validity_period = 10
  914. self.max_delta = self.validity_period * 10.0 / 100.0
  915. config = self.default_config()
  916. config["enable_registration"] = True
  917. config["account_validity"] = {"enabled": False}
  918. self.hs = self.setup_test_homeserver(config=config)
  919. # We need to set these directly, instead of in the homeserver config dict above.
  920. # This is due to account validity-related config options not being read by
  921. # Synapse when account_validity.enabled is False.
  922. self.hs.get_datastore()._account_validity_period = self.validity_period
  923. self.hs.get_datastore()._account_validity_startup_job_max_delta = self.max_delta
  924. self.store = self.hs.get_datastore()
  925. return self.hs
  926. def test_background_job(self):
  927. """
  928. Tests the same thing as test_background_job, except that it sets the
  929. startup_job_max_delta parameter and checks that the expiration date is within the
  930. allowed range.
  931. """
  932. user_id = self.register_user("kermit_delta", "user")
  933. self.hs.config.account_validity.startup_job_max_delta = self.max_delta
  934. now_ms = self.hs.get_clock().time_msec()
  935. self.get_success(self.store._set_expiration_date_when_missing())
  936. res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
  937. self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
  938. self.assertLessEqual(res, now_ms + self.validity_period)
  939. class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
  940. servlets = [register.register_servlets]
  941. url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
  942. def default_config(self):
  943. config = super().default_config()
  944. config["registration_requires_token"] = True
  945. return config
  946. def test_GET_token_valid(self):
  947. token = "abcd"
  948. store = self.hs.get_datastore()
  949. self.get_success(
  950. store.db_pool.simple_insert(
  951. "registration_tokens",
  952. {
  953. "token": token,
  954. "uses_allowed": None,
  955. "pending": 0,
  956. "completed": 0,
  957. "expiry_time": None,
  958. },
  959. )
  960. )
  961. channel = self.make_request(
  962. b"GET",
  963. f"{self.url}?token={token}",
  964. )
  965. self.assertEquals(channel.result["code"], b"200", channel.result)
  966. self.assertEquals(channel.json_body["valid"], True)
  967. def test_GET_token_invalid(self):
  968. token = "1234"
  969. channel = self.make_request(
  970. b"GET",
  971. f"{self.url}?token={token}",
  972. )
  973. self.assertEquals(channel.result["code"], b"200", channel.result)
  974. self.assertEquals(channel.json_body["valid"], False)
  975. @override_config(
  976. {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
  977. )
  978. def test_GET_ratelimiting(self):
  979. token = "1234"
  980. for i in range(0, 6):
  981. channel = self.make_request(
  982. b"GET",
  983. f"{self.url}?token={token}",
  984. )
  985. if i == 5:
  986. self.assertEquals(channel.result["code"], b"429", channel.result)
  987. retry_after_ms = int(channel.json_body["retry_after_ms"])
  988. else:
  989. self.assertEquals(channel.result["code"], b"200", channel.result)
  990. self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
  991. channel = self.make_request(
  992. b"GET",
  993. f"{self.url}?token={token}",
  994. )
  995. self.assertEquals(channel.result["code"], b"200", channel.result)