test_register.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. # Copyright 2017-2018 New Vector Ltd
  4. # Copyright 2019 The Matrix.org Foundation C.I.C.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import datetime
  18. import json
  19. import os
  20. import pkg_resources
  21. import synapse.rest.admin
  22. from synapse.api.constants import LoginType
  23. from synapse.api.errors import Codes
  24. from synapse.appservice import ApplicationService
  25. from synapse.rest.client.v1 import login
  26. from synapse.rest.client.v2_alpha import account, account_validity, register, sync
  27. from tests import unittest
  28. class RegisterRestServletTestCase(unittest.HomeserverTestCase):
  29. servlets = [register.register_servlets]
  30. def make_homeserver(self, reactor, clock):
  31. self.url = b"/_matrix/client/r0/register"
  32. self.hs = self.setup_test_homeserver()
  33. self.hs.config.enable_registration = True
  34. self.hs.config.registrations_require_3pid = []
  35. self.hs.config.auto_join_rooms = []
  36. self.hs.config.enable_registration_captcha = False
  37. self.hs.config.allow_guest_access = True
  38. return self.hs
  39. def test_POST_appservice_registration_valid(self):
  40. user_id = "@as_user_kermit:test"
  41. as_token = "i_am_an_app_service"
  42. appservice = ApplicationService(
  43. as_token,
  44. self.hs.config.server_name,
  45. id="1234",
  46. namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
  47. )
  48. self.hs.get_datastore().services_cache.append(appservice)
  49. request_data = json.dumps({"username": "as_user_kermit"})
  50. request, channel = self.make_request(
  51. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  52. )
  53. self.render(request)
  54. self.assertEquals(channel.result["code"], b"200", channel.result)
  55. det_data = {"user_id": user_id, "home_server": self.hs.hostname}
  56. self.assertDictContainsSubset(det_data, channel.json_body)
  57. def test_POST_appservice_registration_invalid(self):
  58. self.appservice = None # no application service exists
  59. request_data = json.dumps({"username": "kermit"})
  60. request, channel = self.make_request(
  61. b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
  62. )
  63. self.render(request)
  64. self.assertEquals(channel.result["code"], b"401", channel.result)
  65. def test_POST_bad_password(self):
  66. request_data = json.dumps({"username": "kermit", "password": 666})
  67. request, channel = self.make_request(b"POST", self.url, request_data)
  68. self.render(request)
  69. self.assertEquals(channel.result["code"], b"400", channel.result)
  70. self.assertEquals(channel.json_body["error"], "Invalid password")
  71. def test_POST_bad_username(self):
  72. request_data = json.dumps({"username": 777, "password": "monkey"})
  73. request, channel = self.make_request(b"POST", self.url, request_data)
  74. self.render(request)
  75. self.assertEquals(channel.result["code"], b"400", channel.result)
  76. self.assertEquals(channel.json_body["error"], "Invalid username")
  77. def test_POST_user_valid(self):
  78. user_id = "@kermit:test"
  79. device_id = "frogfone"
  80. params = {
  81. "username": "kermit",
  82. "password": "monkey",
  83. "device_id": device_id,
  84. "auth": {"type": LoginType.DUMMY},
  85. }
  86. request_data = json.dumps(params)
  87. request, channel = self.make_request(b"POST", self.url, request_data)
  88. self.render(request)
  89. det_data = {
  90. "user_id": user_id,
  91. "home_server": self.hs.hostname,
  92. "device_id": device_id,
  93. }
  94. self.assertEquals(channel.result["code"], b"200", channel.result)
  95. self.assertDictContainsSubset(det_data, channel.json_body)
  96. def test_POST_disabled_registration(self):
  97. self.hs.config.enable_registration = False
  98. request_data = json.dumps({"username": "kermit", "password": "monkey"})
  99. self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
  100. request, channel = self.make_request(b"POST", self.url, request_data)
  101. self.render(request)
  102. self.assertEquals(channel.result["code"], b"403", channel.result)
  103. self.assertEquals(channel.json_body["error"], "Registration has been disabled")
  104. def test_POST_guest_registration(self):
  105. self.hs.config.macaroon_secret_key = "test"
  106. self.hs.config.allow_guest_access = True
  107. request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  108. self.render(request)
  109. det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
  110. self.assertEquals(channel.result["code"], b"200", channel.result)
  111. self.assertDictContainsSubset(det_data, channel.json_body)
  112. def test_POST_disabled_guest_registration(self):
  113. self.hs.config.allow_guest_access = False
  114. request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  115. self.render(request)
  116. self.assertEquals(channel.result["code"], b"403", channel.result)
  117. self.assertEquals(channel.json_body["error"], "Guest access is disabled")
  118. def test_POST_ratelimiting_guest(self):
  119. self.hs.config.rc_registration.burst_count = 5
  120. self.hs.config.rc_registration.per_second = 0.17
  121. for i in range(0, 6):
  122. url = self.url + b"?kind=guest"
  123. request, channel = self.make_request(b"POST", url, b"{}")
  124. self.render(request)
  125. if i == 5:
  126. self.assertEquals(channel.result["code"], b"429", channel.result)
  127. retry_after_ms = int(channel.json_body["retry_after_ms"])
  128. else:
  129. self.assertEquals(channel.result["code"], b"200", channel.result)
  130. self.reactor.advance(retry_after_ms / 1000.0)
  131. request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  132. self.render(request)
  133. self.assertEquals(channel.result["code"], b"200", channel.result)
  134. def test_POST_ratelimiting(self):
  135. self.hs.config.rc_registration.burst_count = 5
  136. self.hs.config.rc_registration.per_second = 0.17
  137. for i in range(0, 6):
  138. params = {
  139. "username": "kermit" + str(i),
  140. "password": "monkey",
  141. "device_id": "frogfone",
  142. "auth": {"type": LoginType.DUMMY},
  143. }
  144. request_data = json.dumps(params)
  145. request, channel = self.make_request(b"POST", self.url, request_data)
  146. self.render(request)
  147. if i == 5:
  148. self.assertEquals(channel.result["code"], b"429", channel.result)
  149. retry_after_ms = int(channel.json_body["retry_after_ms"])
  150. else:
  151. self.assertEquals(channel.result["code"], b"200", channel.result)
  152. self.reactor.advance(retry_after_ms / 1000.0)
  153. request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
  154. self.render(request)
  155. self.assertEquals(channel.result["code"], b"200", channel.result)
  156. class AccountValidityTestCase(unittest.HomeserverTestCase):
  157. servlets = [
  158. register.register_servlets,
  159. synapse.rest.admin.register_servlets_for_client_rest_resource,
  160. login.register_servlets,
  161. sync.register_servlets,
  162. account_validity.register_servlets,
  163. ]
  164. def make_homeserver(self, reactor, clock):
  165. config = self.default_config()
  166. # Test for account expiring after a week.
  167. config["enable_registration"] = True
  168. config["account_validity"] = {
  169. "enabled": True,
  170. "period": 604800000, # Time in ms for 1 week
  171. }
  172. self.hs = self.setup_test_homeserver(config=config)
  173. return self.hs
  174. def test_validity_period(self):
  175. self.register_user("kermit", "monkey")
  176. tok = self.login("kermit", "monkey")
  177. # The specific endpoint doesn't matter, all we need is an authenticated
  178. # endpoint.
  179. request, channel = self.make_request(b"GET", "/sync", access_token=tok)
  180. self.render(request)
  181. self.assertEquals(channel.result["code"], b"200", channel.result)
  182. self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
  183. request, channel = self.make_request(b"GET", "/sync", access_token=tok)
  184. self.render(request)
  185. self.assertEquals(channel.result["code"], b"403", channel.result)
  186. self.assertEquals(
  187. channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
  188. )
  189. def test_manual_renewal(self):
  190. user_id = self.register_user("kermit", "monkey")
  191. tok = self.login("kermit", "monkey")
  192. self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
  193. # If we register the admin user at the beginning of the test, it will
  194. # expire at the same time as the normal user and the renewal request
  195. # will be denied.
  196. self.register_user("admin", "adminpassword", admin=True)
  197. admin_tok = self.login("admin", "adminpassword")
  198. url = "/_matrix/client/unstable/admin/account_validity/validity"
  199. params = {"user_id": user_id}
  200. request_data = json.dumps(params)
  201. request, channel = self.make_request(
  202. b"POST", url, request_data, access_token=admin_tok
  203. )
  204. self.render(request)
  205. self.assertEquals(channel.result["code"], b"200", channel.result)
  206. # The specific endpoint doesn't matter, all we need is an authenticated
  207. # endpoint.
  208. request, channel = self.make_request(b"GET", "/sync", access_token=tok)
  209. self.render(request)
  210. self.assertEquals(channel.result["code"], b"200", channel.result)
  211. def test_manual_expire(self):
  212. user_id = self.register_user("kermit", "monkey")
  213. tok = self.login("kermit", "monkey")
  214. self.register_user("admin", "adminpassword", admin=True)
  215. admin_tok = self.login("admin", "adminpassword")
  216. url = "/_matrix/client/unstable/admin/account_validity/validity"
  217. params = {
  218. "user_id": user_id,
  219. "expiration_ts": 0,
  220. "enable_renewal_emails": False,
  221. }
  222. request_data = json.dumps(params)
  223. request, channel = self.make_request(
  224. b"POST", url, request_data, access_token=admin_tok
  225. )
  226. self.render(request)
  227. self.assertEquals(channel.result["code"], b"200", channel.result)
  228. # The specific endpoint doesn't matter, all we need is an authenticated
  229. # endpoint.
  230. request, channel = self.make_request(b"GET", "/sync", access_token=tok)
  231. self.render(request)
  232. self.assertEquals(channel.result["code"], b"403", channel.result)
  233. self.assertEquals(
  234. channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
  235. )
  236. class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
  237. servlets = [
  238. register.register_servlets,
  239. synapse.rest.admin.register_servlets_for_client_rest_resource,
  240. login.register_servlets,
  241. sync.register_servlets,
  242. account_validity.register_servlets,
  243. account.register_servlets,
  244. ]
  245. def make_homeserver(self, reactor, clock):
  246. config = self.default_config()
  247. # Test for account expiring after a week and renewal emails being sent 2
  248. # days before expiry.
  249. config["enable_registration"] = True
  250. config["account_validity"] = {
  251. "enabled": True,
  252. "period": 604800000, # Time in ms for 1 week
  253. "renew_at": 172800000, # Time in ms for 2 days
  254. "renew_by_email_enabled": True,
  255. "renew_email_subject": "Renew your account",
  256. "account_renewed_html_path": "account_renewed.html",
  257. "invalid_token_html_path": "invalid_token.html",
  258. }
  259. # Email config.
  260. self.email_attempts = []
  261. def sendmail(*args, **kwargs):
  262. self.email_attempts.append((args, kwargs))
  263. return
  264. config["email"] = {
  265. "enable_notifs": True,
  266. "template_dir": os.path.abspath(
  267. pkg_resources.resource_filename("synapse", "res/templates")
  268. ),
  269. "expiry_template_html": "notice_expiry.html",
  270. "expiry_template_text": "notice_expiry.txt",
  271. "notif_template_html": "notif_mail.html",
  272. "notif_template_text": "notif_mail.txt",
  273. "smtp_host": "127.0.0.1",
  274. "smtp_port": 20,
  275. "require_transport_security": False,
  276. "smtp_user": None,
  277. "smtp_pass": None,
  278. "notif_from": "test@example.com",
  279. }
  280. config["public_baseurl"] = "aaa"
  281. self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail)
  282. self.store = self.hs.get_datastore()
  283. return self.hs
  284. def test_renewal_email(self):
  285. self.email_attempts = []
  286. (user_id, tok) = self.create_user()
  287. # Move 6 days forward. This should trigger a renewal email to be sent.
  288. self.reactor.advance(datetime.timedelta(days=6).total_seconds())
  289. self.assertEqual(len(self.email_attempts), 1)
  290. # Retrieving the URL from the email is too much pain for now, so we
  291. # retrieve the token from the DB.
  292. renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
  293. url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
  294. request, channel = self.make_request(b"GET", url)
  295. self.render(request)
  296. self.assertEquals(channel.result["code"], b"200", channel.result)
  297. # Check that we're getting HTML back.
  298. content_type = None
  299. for header in channel.result.get("headers", []):
  300. if header[0] == b"Content-Type":
  301. content_type = header[1]
  302. self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
  303. # Check that the HTML we're getting is the one we expect on a successful renewal.
  304. expected_html = self.hs.config.account_validity.account_renewed_html_content
  305. self.assertEqual(
  306. channel.result["body"], expected_html.encode("utf8"), channel.result
  307. )
  308. # Move 3 days forward. If the renewal failed, every authed request with
  309. # our access token should be denied from now, otherwise they should
  310. # succeed.
  311. self.reactor.advance(datetime.timedelta(days=3).total_seconds())
  312. request, channel = self.make_request(b"GET", "/sync", access_token=tok)
  313. self.render(request)
  314. self.assertEquals(channel.result["code"], b"200", channel.result)
  315. def test_renewal_invalid_token(self):
  316. # Hit the renewal endpoint with an invalid token and check that it behaves as
  317. # expected, i.e. that it responds with 404 Not Found and the correct HTML.
  318. url = "/_matrix/client/unstable/account_validity/renew?token=123"
  319. request, channel = self.make_request(b"GET", url)
  320. self.render(request)
  321. self.assertEquals(channel.result["code"], b"404", channel.result)
  322. # Check that we're getting HTML back.
  323. content_type = None
  324. for header in channel.result.get("headers", []):
  325. if header[0] == b"Content-Type":
  326. content_type = header[1]
  327. self.assertEqual(content_type, b"text/html; charset=utf-8", channel.result)
  328. # Check that the HTML we're getting is the one we expect when using an
  329. # invalid/unknown token.
  330. expected_html = self.hs.config.account_validity.invalid_token_html_content
  331. self.assertEqual(
  332. channel.result["body"], expected_html.encode("utf8"), channel.result
  333. )
  334. def test_manual_email_send(self):
  335. self.email_attempts = []
  336. (user_id, tok) = self.create_user()
  337. request, channel = self.make_request(
  338. b"POST",
  339. "/_matrix/client/unstable/account_validity/send_mail",
  340. access_token=tok,
  341. )
  342. self.render(request)
  343. self.assertEquals(channel.result["code"], b"200", channel.result)
  344. self.assertEqual(len(self.email_attempts), 1)
  345. def test_deactivated_user(self):
  346. self.email_attempts = []
  347. (user_id, tok) = self.create_user()
  348. request_data = json.dumps(
  349. {
  350. "auth": {
  351. "type": "m.login.password",
  352. "user": user_id,
  353. "password": "monkey",
  354. },
  355. "erase": False,
  356. }
  357. )
  358. request, channel = self.make_request(
  359. "POST", "account/deactivate", request_data, access_token=tok
  360. )
  361. self.render(request)
  362. self.assertEqual(request.code, 200)
  363. self.reactor.advance(datetime.timedelta(days=8).total_seconds())
  364. self.assertEqual(len(self.email_attempts), 0)
  365. def create_user(self):
  366. user_id = self.register_user("kermit", "monkey")
  367. tok = self.login("kermit", "monkey")
  368. # We need to manually add an email address otherwise the handler will do
  369. # nothing.
  370. now = self.hs.clock.time_msec()
  371. self.get_success(
  372. self.store.user_add_threepid(
  373. user_id=user_id,
  374. medium="email",
  375. address="kermit@example.com",
  376. validated_at=now,
  377. added_at=now,
  378. )
  379. )
  380. return (user_id, tok)
  381. def test_manual_email_send_expired_account(self):
  382. user_id = self.register_user("kermit", "monkey")
  383. tok = self.login("kermit", "monkey")
  384. # We need to manually add an email address otherwise the handler will do
  385. # nothing.
  386. now = self.hs.clock.time_msec()
  387. self.get_success(
  388. self.store.user_add_threepid(
  389. user_id=user_id,
  390. medium="email",
  391. address="kermit@example.com",
  392. validated_at=now,
  393. added_at=now,
  394. )
  395. )
  396. # Make the account expire.
  397. self.reactor.advance(datetime.timedelta(days=8).total_seconds())
  398. # Ignore all emails sent by the automatic background task and only focus on the
  399. # ones sent manually.
  400. self.email_attempts = []
  401. # Test that we're still able to manually trigger a mail to be sent.
  402. request, channel = self.make_request(
  403. b"POST",
  404. "/_matrix/client/unstable/account_validity/send_mail",
  405. access_token=tok,
  406. )
  407. self.render(request)
  408. self.assertEquals(channel.result["code"], b"200", channel.result)
  409. self.assertEqual(len(self.email_attempts), 1)
  410. class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
  411. servlets = [synapse.rest.admin.register_servlets_for_client_rest_resource]
  412. def make_homeserver(self, reactor, clock):
  413. self.validity_period = 10
  414. self.max_delta = self.validity_period * 10.0 / 100.0
  415. config = self.default_config()
  416. config["enable_registration"] = True
  417. config["account_validity"] = {"enabled": False}
  418. self.hs = self.setup_test_homeserver(config=config)
  419. self.hs.config.account_validity.period = self.validity_period
  420. self.store = self.hs.get_datastore()
  421. return self.hs
  422. def test_background_job(self):
  423. """
  424. Tests the same thing as test_background_job, except that it sets the
  425. startup_job_max_delta parameter and checks that the expiration date is within the
  426. allowed range.
  427. """
  428. user_id = self.register_user("kermit_delta", "user")
  429. self.hs.config.account_validity.startup_job_max_delta = self.max_delta
  430. now_ms = self.hs.clock.time_msec()
  431. self.get_success(self.store._set_expiration_date_when_missing())
  432. res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
  433. self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
  434. self.assertLessEqual(res, now_ms + self.validity_period)