123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723 |
- # Copyright 2021 Callum Brown
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import random
- import string
- from typing import Optional
- from twisted.test.proto_helpers import MemoryReactor
- import synapse.rest.admin
- from synapse.api.errors import Codes
- from synapse.rest.client import login
- from synapse.server import HomeServer
- from synapse.util import Clock
- from tests import unittest
- class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- ]
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
- self.other_user = self.register_user("user", "pass")
- self.other_user_tok = self.login("user", "pass")
- self.url = "/_synapse/admin/v1/registration_tokens"
- def _new_token(
- self,
- token: Optional[str] = None,
- uses_allowed: Optional[int] = None,
- pending: int = 0,
- completed: int = 0,
- expiry_time: Optional[int] = None,
- ) -> str:
- """Helper function to create a token."""
- if token is None:
- token = "".join(random.choices(string.ascii_letters, k=8))
- self.get_success(
- self.store.db_pool.simple_insert(
- "registration_tokens",
- {
- "token": token,
- "uses_allowed": uses_allowed,
- "pending": pending,
- "completed": completed,
- "expiry_time": expiry_time,
- },
- )
- )
- return token
- # CREATION
- def test_create_no_auth(self) -> None:
- """Try to create a token without authentication."""
- channel = self.make_request("POST", self.url + "/new", {})
- self.assertEqual(401, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_create_requester_not_admin(self) -> None:
- """Try to create a token while not an admin."""
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {},
- access_token=self.other_user_tok,
- )
- self.assertEqual(403, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_create_using_defaults(self) -> None:
- """Create a token using all the defaults."""
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(len(channel.json_body["token"]), 16)
- self.assertIsNone(channel.json_body["uses_allowed"])
- self.assertIsNone(channel.json_body["expiry_time"])
- self.assertEqual(channel.json_body["pending"], 0)
- self.assertEqual(channel.json_body["completed"], 0)
- def test_create_specifying_fields(self) -> None:
- """Create a token specifying the value of all fields."""
- # As many of the allowed characters as possible with length <= 64
- token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-"
- data = {
- "token": token,
- "uses_allowed": 1,
- "expiry_time": self.clock.time_msec() + 1000000,
- }
- channel = self.make_request(
- "POST",
- self.url + "/new",
- data,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["token"], token)
- self.assertEqual(channel.json_body["uses_allowed"], 1)
- self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
- self.assertEqual(channel.json_body["pending"], 0)
- self.assertEqual(channel.json_body["completed"], 0)
- def test_create_with_null_value(self) -> None:
- """Create a token specifying unlimited uses and no expiry."""
- data = {
- "uses_allowed": None,
- "expiry_time": None,
- }
- channel = self.make_request(
- "POST",
- self.url + "/new",
- data,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(len(channel.json_body["token"]), 16)
- self.assertIsNone(channel.json_body["uses_allowed"])
- self.assertIsNone(channel.json_body["expiry_time"])
- self.assertEqual(channel.json_body["pending"], 0)
- self.assertEqual(channel.json_body["completed"], 0)
- def test_create_token_too_long(self) -> None:
- """Check token longer than 64 chars is invalid."""
- data = {"token": "a" * 65}
- channel = self.make_request(
- "POST",
- self.url + "/new",
- data,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_token_invalid_chars(self) -> None:
- """Check you can't create token with invalid characters."""
- data = {
- "token": "abc/def",
- }
- channel = self.make_request(
- "POST",
- self.url + "/new",
- data,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_token_already_exists(self) -> None:
- """Check you can't create token that already exists."""
- data = {
- "token": "abcd",
- }
- channel1 = self.make_request(
- "POST",
- self.url + "/new",
- data,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel1.code, msg=channel1.json_body)
- channel2 = self.make_request(
- "POST",
- self.url + "/new",
- data,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel2.code, msg=channel2.json_body)
- self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_unable_to_generate_token(self) -> None:
- """Check right error is raised when server can't generate unique token."""
- # Create all possible single character tokens
- tokens = []
- for c in string.ascii_letters + string.digits + "._~-":
- tokens.append((c, None, 0, 0, None))
- self.get_success(
- self.store.db_pool.simple_insert_many(
- "registration_tokens",
- keys=("token", "uses_allowed", "pending", "completed", "expiry_time"),
- values=tokens,
- desc="create_all_registration_tokens",
- )
- )
- # Check creating a single character token fails with a 500 status code
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"length": 1},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(500, channel.code, msg=channel.json_body)
- def test_create_uses_allowed(self) -> None:
- """Check you can only create a token with good values for uses_allowed."""
- # Should work with 0 (token is invalid from the start)
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"uses_allowed": 0},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["uses_allowed"], 0)
- # Should fail with negative integer
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"uses_allowed": -5},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(
- 400,
- channel.code,
- msg=channel.json_body,
- )
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- # Should fail with float
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"uses_allowed": 1.5},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_expiry_time(self) -> None:
- """Check you can't create a token with an invalid expiry_time."""
- # Should fail with a time in the past
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"expiry_time": self.clock.time_msec() - 10000},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- # Should fail with float
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"expiry_time": self.clock.time_msec() + 1000000.5},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_create_length(self) -> None:
- """Check you can only generate a token with a valid length."""
- # Should work with 64
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"length": 64},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(len(channel.json_body["token"]), 64)
- # Should fail with 0
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"length": 0},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- # Should fail with a negative integer
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"length": -5},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- # Should fail with a float
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"length": 8.5},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- # Should fail with 65
- channel = self.make_request(
- "POST",
- self.url + "/new",
- {"length": 65},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- # UPDATING
- def test_update_no_auth(self) -> None:
- """Try to update a token without authentication."""
- channel = self.make_request(
- "PUT",
- self.url + "/1234", # Token doesn't exist but that doesn't matter
- {},
- )
- self.assertEqual(401, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_update_requester_not_admin(self) -> None:
- """Try to update a token while not an admin."""
- channel = self.make_request(
- "PUT",
- self.url + "/1234", # Token doesn't exist but that doesn't matter
- {},
- access_token=self.other_user_tok,
- )
- self.assertEqual(403, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_update_non_existent(self) -> None:
- """Try to update a token that doesn't exist."""
- channel = self.make_request(
- "PUT",
- self.url + "/1234",
- {"uses_allowed": 1},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
- def test_update_uses_allowed(self) -> None:
- """Test updating just uses_allowed."""
- # Create new token using default values
- token = self._new_token()
- # Should succeed with 1
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- {"uses_allowed": 1},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["uses_allowed"], 1)
- self.assertIsNone(channel.json_body["expiry_time"])
- # Should succeed with 0 (makes token invalid)
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- {"uses_allowed": 0},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["uses_allowed"], 0)
- self.assertIsNone(channel.json_body["expiry_time"])
- # Should succeed with null
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- {"uses_allowed": None},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertIsNone(channel.json_body["uses_allowed"])
- self.assertIsNone(channel.json_body["expiry_time"])
- # Should fail with a float
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- {"uses_allowed": 1.5},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- # Should fail with a negative integer
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- {"uses_allowed": -5},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_update_expiry_time(self) -> None:
- """Test updating just expiry_time."""
- # Create new token using default values
- token = self._new_token()
- new_expiry_time = self.clock.time_msec() + 1000000
- # Should succeed with a time in the future
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- {"expiry_time": new_expiry_time},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
- self.assertIsNone(channel.json_body["uses_allowed"])
- # Should succeed with null
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- {"expiry_time": None},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertIsNone(channel.json_body["expiry_time"])
- self.assertIsNone(channel.json_body["uses_allowed"])
- # Should fail with a time in the past
- past_time = self.clock.time_msec() - 10000
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- {"expiry_time": past_time},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- # Should fail a float
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- {"expiry_time": new_expiry_time + 0.5},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- def test_update_both(self) -> None:
- """Test updating both uses_allowed and expiry_time."""
- # Create new token using default values
- token = self._new_token()
- new_expiry_time = self.clock.time_msec() + 1000000
- data = {
- "uses_allowed": 1,
- "expiry_time": new_expiry_time,
- }
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- data,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["uses_allowed"], 1)
- self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
- def test_update_invalid_type(self) -> None:
- """Test using invalid types doesn't work."""
- # Create new token using default values
- token = self._new_token()
- data = {
- "uses_allowed": False,
- "expiry_time": "1626430124000",
- }
- channel = self.make_request(
- "PUT",
- self.url + "/" + token,
- data,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
- # DELETING
- def test_delete_no_auth(self) -> None:
- """Try to delete a token without authentication."""
- channel = self.make_request(
- "DELETE",
- self.url + "/1234", # Token doesn't exist but that doesn't matter
- {},
- )
- self.assertEqual(401, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_delete_requester_not_admin(self) -> None:
- """Try to delete a token while not an admin."""
- channel = self.make_request(
- "DELETE",
- self.url + "/1234", # Token doesn't exist but that doesn't matter
- {},
- access_token=self.other_user_tok,
- )
- self.assertEqual(403, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_delete_non_existent(self) -> None:
- """Try to delete a token that doesn't exist."""
- channel = self.make_request(
- "DELETE",
- self.url + "/1234",
- {},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
- def test_delete(self) -> None:
- """Test deleting a token."""
- # Create new token using default values
- token = self._new_token()
- channel = self.make_request(
- "DELETE",
- self.url + "/" + token,
- {},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- # GETTING ONE
- def test_get_no_auth(self) -> None:
- """Try to get a token without authentication."""
- channel = self.make_request(
- "GET",
- self.url + "/1234", # Token doesn't exist but that doesn't matter
- {},
- )
- self.assertEqual(401, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_get_requester_not_admin(self) -> None:
- """Try to get a token while not an admin."""
- channel = self.make_request(
- "GET",
- self.url + "/1234", # Token doesn't exist but that doesn't matter
- {},
- access_token=self.other_user_tok,
- )
- self.assertEqual(
- 403,
- channel.code,
- msg=channel.json_body,
- )
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_get_non_existent(self) -> None:
- """Try to get a token that doesn't exist."""
- channel = self.make_request(
- "GET",
- self.url + "/1234",
- {},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
- def test_get(self) -> None:
- """Test getting a token."""
- # Create new token using default values
- token = self._new_token()
- channel = self.make_request(
- "GET",
- self.url + "/" + token,
- {},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(channel.json_body["token"], token)
- self.assertIsNone(channel.json_body["uses_allowed"])
- self.assertIsNone(channel.json_body["expiry_time"])
- self.assertEqual(channel.json_body["pending"], 0)
- self.assertEqual(channel.json_body["completed"], 0)
- # LISTING
- def test_list_no_auth(self) -> None:
- """Try to list tokens without authentication."""
- channel = self.make_request("GET", self.url, {})
- self.assertEqual(401, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_list_requester_not_admin(self) -> None:
- """Try to list tokens while not an admin."""
- channel = self.make_request(
- "GET",
- self.url,
- {},
- access_token=self.other_user_tok,
- )
- self.assertEqual(403, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_list_all(self) -> None:
- """Test listing all tokens."""
- # Create new token using default values
- token = self._new_token()
- channel = self.make_request(
- "GET",
- self.url,
- {},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
- token_info = channel.json_body["registration_tokens"][0]
- self.assertEqual(token_info["token"], token)
- self.assertIsNone(token_info["uses_allowed"])
- self.assertIsNone(token_info["expiry_time"])
- self.assertEqual(token_info["pending"], 0)
- self.assertEqual(token_info["completed"], 0)
- def test_list_invalid_query_parameter(self) -> None:
- """Test with `valid` query parameter not `true` or `false`."""
- channel = self.make_request(
- "GET",
- self.url + "?valid=x",
- {},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- def _test_list_query_parameter(self, valid: str) -> None:
- """Helper used to test both valid=true and valid=false."""
- # Create 2 valid and 2 invalid tokens.
- now = self.hs.get_clock().time_msec()
- # Create always valid token
- valid1 = self._new_token()
- # Create token that hasn't been used up
- valid2 = self._new_token(uses_allowed=1)
- # Create token that has expired
- invalid1 = self._new_token(expiry_time=now - 10000)
- # Create token that has been used up but hasn't expired
- invalid2 = self._new_token(
- uses_allowed=2,
- pending=1,
- completed=1,
- expiry_time=now + 1000000,
- )
- if valid == "true":
- tokens = [valid1, valid2]
- else:
- tokens = [invalid1, invalid2]
- channel = self.make_request(
- "GET",
- self.url + "?valid=" + valid,
- {},
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
- token_info_1 = channel.json_body["registration_tokens"][0]
- token_info_2 = channel.json_body["registration_tokens"][1]
- self.assertIn(token_info_1["token"], tokens)
- self.assertIn(token_info_2["token"], tokens)
- def test_list_valid(self) -> None:
- """Test listing just valid tokens."""
- self._test_list_query_parameter(valid="true")
- def test_list_invalid(self) -> None:
- """Test listing just invalid tokens."""
- self._test_list_query_parameter(valid="false")
|