test_device.py 12 KB


  1. # Copyright 2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. # Copyright 2020 The Matrix.org Foundation C.I.C.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from typing import Optional
  17. from twisted.test.proto_helpers import MemoryReactor
  18. from synapse.api.errors import NotFoundError, SynapseError
  19. from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
  20. from synapse.server import HomeServer
  21. from synapse.util import Clock
  22. from tests import unittest
  23. user1 = "@boris:aaa"
  24. user2 = "@theresa:bbb"
  25. class DeviceTestCase(unittest.HomeserverTestCase):
  26. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  27. hs = self.setup_test_homeserver("server", federation_http_client=None)
  28. handler = hs.get_device_handler()
  29. assert isinstance(handler, DeviceHandler)
  30. self.handler = handler
  31. self.store = hs.get_datastores().main
  32. return hs
  33. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  34. # These tests assume that it starts 1000 seconds in.
  35. self.reactor.advance(1000)
  36. def test_device_is_created_with_invalid_name(self) -> None:
  37. self.get_failure(
  38. self.handler.check_device_registered(
  39. user_id="@boris:foo",
  40. device_id="foo",
  41. initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
  42. ),
  43. SynapseError,
  44. )
  45. def test_device_is_created_if_doesnt_exist(self) -> None:
  46. res = self.get_success(
  47. self.handler.check_device_registered(
  48. user_id="@boris:foo",
  49. device_id="fco",
  50. initial_device_display_name="display name",
  51. )
  52. )
  53. self.assertEqual(res, "fco")
  54. dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
  55. assert dev is not None
  56. self.assertEqual(dev["display_name"], "display name")
  57. def test_device_is_preserved_if_exists(self) -> None:
  58. res1 = self.get_success(
  59. self.handler.check_device_registered(
  60. user_id="@boris:foo",
  61. device_id="fco",
  62. initial_device_display_name="display name",
  63. )
  64. )
  65. self.assertEqual(res1, "fco")
  66. res2 = self.get_success(
  67. self.handler.check_device_registered(
  68. user_id="@boris:foo",
  69. device_id="fco",
  70. initial_device_display_name="new display name",
  71. )
  72. )
  73. self.assertEqual(res2, "fco")
  74. dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
  75. assert dev is not None
  76. self.assertEqual(dev["display_name"], "display name")
  77. def test_device_id_is_made_up_if_unspecified(self) -> None:
  78. device_id = self.get_success(
  79. self.handler.check_device_registered(
  80. user_id="@theresa:foo",
  81. device_id=None,
  82. initial_device_display_name="display",
  83. )
  84. )
  85. dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
  86. assert dev is not None
  87. self.assertEqual(dev["display_name"], "display")
  88. def test_get_devices_by_user(self) -> None:
  89. self._record_users()
  90. res = self.get_success(self.handler.get_devices_by_user(user1))
  91. self.assertEqual(3, len(res))
  92. device_map = {d["device_id"]: d for d in res}
  93. self.assertDictContainsSubset(
  94. {
  95. "user_id": user1,
  96. "device_id": "xyz",
  97. "display_name": "display 0",
  98. "last_seen_ip": None,
  99. "last_seen_ts": None,
  100. },
  101. device_map["xyz"],
  102. )
  103. self.assertDictContainsSubset(
  104. {
  105. "user_id": user1,
  106. "device_id": "fco",
  107. "display_name": "display 1",
  108. "last_seen_ip": "ip1",
  109. "last_seen_ts": 1000000,
  110. },
  111. device_map["fco"],
  112. )
  113. self.assertDictContainsSubset(
  114. {
  115. "user_id": user1,
  116. "device_id": "abc",
  117. "display_name": "display 2",
  118. "last_seen_ip": "ip3",
  119. "last_seen_ts": 3000000,
  120. },
  121. device_map["abc"],
  122. )
  123. def test_get_device(self) -> None:
  124. self._record_users()
  125. res = self.get_success(self.handler.get_device(user1, "abc"))
  126. self.assertDictContainsSubset(
  127. {
  128. "user_id": user1,
  129. "device_id": "abc",
  130. "display_name": "display 2",
  131. "last_seen_ip": "ip3",
  132. "last_seen_ts": 3000000,
  133. },
  134. res,
  135. )
  136. def test_delete_device(self) -> None:
  137. self._record_users()
  138. # delete the device
  139. self.get_success(self.handler.delete_devices(user1, ["abc"]))
  140. # check the device was deleted
  141. self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
  142. # we'd like to check the access token was invalidated, but that's a
  143. # bit of a PITA.
  144. def test_delete_device_and_device_inbox(self) -> None:
  145. self._record_users()
  146. # add an device_inbox
  147. self.get_success(
  148. self.store.db_pool.simple_insert(
  149. "device_inbox",
  150. {
  151. "user_id": user1,
  152. "device_id": "abc",
  153. "stream_id": 1,
  154. "message_json": "{}",
  155. },
  156. )
  157. )
  158. # delete the device
  159. self.get_success(self.handler.delete_devices(user1, ["abc"]))
  160. # check that the device_inbox was deleted
  161. res = self.get_success(
  162. self.store.db_pool.simple_select_one(
  163. table="device_inbox",
  164. keyvalues={"user_id": user1, "device_id": "abc"},
  165. retcols=("user_id", "device_id"),
  166. allow_none=True,
  167. desc="get_device_id_from_device_inbox",
  168. )
  169. )
  170. self.assertIsNone(res)
  171. def test_update_device(self) -> None:
  172. self._record_users()
  173. update = {"display_name": "new display"}
  174. self.get_success(self.handler.update_device(user1, "abc", update))
  175. res = self.get_success(self.handler.get_device(user1, "abc"))
  176. self.assertEqual(res["display_name"], "new display")
  177. def test_update_device_too_long_display_name(self) -> None:
  178. """Update a device with a display name that is invalid (too long)."""
  179. self._record_users()
  180. # Request to update a device display name with a new value that is longer than allowed.
  181. update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
  182. self.get_failure(
  183. self.handler.update_device(user1, "abc", update),
  184. SynapseError,
  185. )
  186. # Ensure the display name was not updated.
  187. res = self.get_success(self.handler.get_device(user1, "abc"))
  188. self.assertEqual(res["display_name"], "display 2")
  189. def test_update_unknown_device(self) -> None:
  190. update = {"display_name": "new_display"}
  191. self.get_failure(
  192. self.handler.update_device("user_id", "unknown_device_id", update),
  193. NotFoundError,
  194. )
  195. def _record_users(self) -> None:
  196. # check this works for both devices which have a recorded client_ip,
  197. # and those which don't.
  198. self._record_user(user1, "xyz", "display 0")
  199. self._record_user(user1, "fco", "display 1", "token1", "ip1")
  200. self._record_user(user1, "abc", "display 2", "token2", "ip2")
  201. self._record_user(user1, "abc", "display 2", "token3", "ip3")
  202. self._record_user(user2, "def", "dispkay", "token4", "ip4")
  203. self.reactor.advance(10000)
  204. def _record_user(
  205. self,
  206. user_id: str,
  207. device_id: str,
  208. display_name: str,
  209. access_token: Optional[str] = None,
  210. ip: Optional[str] = None,
  211. ) -> None:
  212. device_id = self.get_success(
  213. self.handler.check_device_registered(
  214. user_id=user_id,
  215. device_id=device_id,
  216. initial_device_display_name=display_name,
  217. )
  218. )
  219. if access_token is not None and ip is not None:
  220. self.get_success(
  221. self.store.insert_client_ip(
  222. user_id, access_token, ip, "user_agent", device_id
  223. )
  224. )
  225. self.reactor.advance(1000)
  226. class DehydrationTestCase(unittest.HomeserverTestCase):
  227. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  228. hs = self.setup_test_homeserver("server", federation_http_client=None)
  229. handler = hs.get_device_handler()
  230. assert isinstance(handler, DeviceHandler)
  231. self.handler = handler
  232. self.registration = hs.get_registration_handler()
  233. self.auth = hs.get_auth()
  234. self.store = hs.get_datastores().main
  235. return hs
  236. def test_dehydrate_and_rehydrate_device(self) -> None:
  237. user_id = "@boris:dehydration"
  238. self.get_success(self.store.register_user(user_id, "foobar"))
  239. # First check if we can store and fetch a dehydrated device
  240. stored_dehydrated_device_id = self.get_success(
  241. self.handler.store_dehydrated_device(
  242. user_id=user_id,
  243. device_data={"device_data": {"foo": "bar"}},
  244. initial_device_display_name="dehydrated device",
  245. )
  246. )
  247. result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
  248. assert result is not None
  249. retrieved_device_id, device_data = result
  250. self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
  251. self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
  252. # Create a new login for the user and dehydrated the device
  253. device_id, access_token, _expiration_time, _refresh_token = self.get_success(
  254. self.registration.register_device(
  255. user_id=user_id,
  256. device_id=None,
  257. initial_display_name="new device",
  258. )
  259. )
  260. # Trying to claim a nonexistent device should throw an error
  261. self.get_failure(
  262. self.handler.rehydrate_device(
  263. user_id=user_id,
  264. access_token=access_token,
  265. device_id="not the right device ID",
  266. ),
  267. NotFoundError,
  268. )
  269. # dehydrating the right devices should succeed and change our device ID
  270. # to the dehydrated device's ID
  271. res = self.get_success(
  272. self.handler.rehydrate_device(
  273. user_id=user_id,
  274. access_token=access_token,
  275. device_id=retrieved_device_id,
  276. )
  277. )
  278. self.assertEqual(res, {"success": True})
  279. # make sure that our device ID has changed
  280. user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
  281. self.assertEqual(user_info.device_id, retrieved_device_id)
  282. # make sure the device has the display name that was set from the login
  283. res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
  284. self.assertEqual(res["display_name"], "new device")
  285. # make sure that the device ID that we were initially assigned no longer exists
  286. self.get_failure(
  287. self.handler.get_device(user_id, device_id),
  288. NotFoundError,
  289. )
  290. # make sure that there's no device available for dehydrating now
  291. ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
  292. self.assertIsNone(ret)