test_devices.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338
  1. # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Collection, List, Tuple
  15. from twisted.test.proto_helpers import MemoryReactor
  16. import synapse.api.errors
  17. from synapse.api.constants import EduTypes
  18. from synapse.server import HomeServer
  19. from synapse.types import JsonDict
  20. from synapse.util import Clock
  21. from tests.unittest import HomeserverTestCase
  22. class DeviceStoreTestCase(HomeserverTestCase):
  23. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  24. self.store = hs.get_datastores().main
  25. def add_device_change(self, user_id: str, device_ids: List[str], host: str) -> None:
  26. """Add a device list change for the given device to
  27. `device_lists_outbound_pokes` table.
  28. """
  29. for device_id in device_ids:
  30. self.get_success(
  31. self.store.add_device_change_to_streams(
  32. user_id, [device_id], ["!some:room"]
  33. )
  34. )
  35. self.get_success(
  36. self.store.add_device_list_outbound_pokes(
  37. user_id=user_id,
  38. device_id=device_id,
  39. room_id="!some:room",
  40. hosts=[host],
  41. context={},
  42. )
  43. )
  44. def test_store_new_device(self) -> None:
  45. self.get_success(
  46. self.store.store_device("user_id", "device_id", "display_name")
  47. )
  48. res = self.get_success(self.store.get_device("user_id", "device_id"))
  49. assert res is not None
  50. self.assertDictContainsSubset(
  51. {
  52. "user_id": "user_id",
  53. "device_id": "device_id",
  54. "display_name": "display_name",
  55. },
  56. res,
  57. )
  58. def test_get_devices_by_user(self) -> None:
  59. self.get_success(
  60. self.store.store_device("user_id", "device1", "display_name 1")
  61. )
  62. self.get_success(
  63. self.store.store_device("user_id", "device2", "display_name 2")
  64. )
  65. self.get_success(
  66. self.store.store_device("user_id2", "device3", "display_name 3")
  67. )
  68. res = self.get_success(self.store.get_devices_by_user("user_id"))
  69. self.assertEqual(2, len(res.keys()))
  70. self.assertDictContainsSubset(
  71. {
  72. "user_id": "user_id",
  73. "device_id": "device1",
  74. "display_name": "display_name 1",
  75. },
  76. res["device1"],
  77. )
  78. self.assertDictContainsSubset(
  79. {
  80. "user_id": "user_id",
  81. "device_id": "device2",
  82. "display_name": "display_name 2",
  83. },
  84. res["device2"],
  85. )
  86. def test_count_devices_by_users(self) -> None:
  87. self.get_success(
  88. self.store.store_device("user_id", "device1", "display_name 1")
  89. )
  90. self.get_success(
  91. self.store.store_device("user_id", "device2", "display_name 2")
  92. )
  93. self.get_success(
  94. self.store.store_device("user_id2", "device3", "display_name 3")
  95. )
  96. res = self.get_success(self.store.count_devices_by_users())
  97. self.assertEqual(0, res)
  98. res = self.get_success(self.store.count_devices_by_users(["unknown"]))
  99. self.assertEqual(0, res)
  100. res = self.get_success(self.store.count_devices_by_users(["user_id"]))
  101. self.assertEqual(2, res)
  102. res = self.get_success(
  103. self.store.count_devices_by_users(["user_id", "user_id2"])
  104. )
  105. self.assertEqual(3, res)
  106. def test_get_device_updates_by_remote(self) -> None:
  107. device_ids = ["device_id1", "device_id2"]
  108. # Add two device updates with sequential `stream_id`s
  109. self.add_device_change("@user_id:test", device_ids, "somehost")
  110. # Get all device updates ever meant for this remote
  111. now_stream_id, device_updates = self.get_success(
  112. self.store.get_device_updates_by_remote("somehost", -1, limit=100)
  113. )
  114. # Check original device_ids are contained within these updates
  115. self._check_devices_in_updates(device_ids, device_updates)
  116. def test_get_device_updates_by_remote_can_limit_properly(self) -> None:
  117. """
  118. Tests that `get_device_updates_by_remote` returns an appropriate
  119. stream_id to resume fetching from (without skipping any results).
  120. """
  121. # Add some device updates with sequential `stream_id`s
  122. device_ids = [
  123. "device_id1",
  124. "device_id2",
  125. "device_id3",
  126. "device_id4",
  127. "device_id5",
  128. ]
  129. self.add_device_change("@user_id:test", device_ids, "somehost")
  130. # Get device updates meant for this remote
  131. next_stream_id, device_updates = self.get_success(
  132. self.store.get_device_updates_by_remote("somehost", -1, limit=3)
  133. )
  134. # Check the first three original device_ids are contained within these updates
  135. self._check_devices_in_updates(device_ids[:3], device_updates)
  136. # Get the next batch of device updates
  137. next_stream_id, device_updates = self.get_success(
  138. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  139. )
  140. # Check the last two original device_ids are contained within these updates
  141. self._check_devices_in_updates(device_ids[3:], device_updates)
  142. # Add some more device updates to ensure it still resumes properly
  143. device_ids = ["device_id6", "device_id7"]
  144. self.add_device_change("@user_id:test", device_ids, "somehost")
  145. # Get the next batch of device updates
  146. next_stream_id, device_updates = self.get_success(
  147. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  148. )
  149. # Check the newly-added device_ids are contained within these updates
  150. self._check_devices_in_updates(device_ids, device_updates)
  151. # Check there are no more device updates left.
  152. _, device_updates = self.get_success(
  153. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  154. )
  155. self.assertEqual(device_updates, [])
  156. def test_get_device_updates_by_remote_cross_signing_key_updates(
  157. self,
  158. ) -> None:
  159. """
  160. Tests that `get_device_updates_by_remote` limits the length of the return value
  161. properly when cross-signing key updates are present.
  162. Current behaviour is that the cross-signing key updates will always come in pairs,
  163. even if that means leaving an earlier batch one EDU short of the limit.
  164. """
  165. assert self.hs.is_mine_id(
  166. "@user_id:test"
  167. ), "Test not valid: this MXID should be considered local"
  168. self.get_success(
  169. self.store.set_e2e_cross_signing_key(
  170. "@user_id:test",
  171. "master",
  172. {
  173. "keys": {
  174. "ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  175. },
  176. "signatures": {
  177. "@user_id:test": {
  178. "ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  179. }
  180. },
  181. },
  182. )
  183. )
  184. self.get_success(
  185. self.store.set_e2e_cross_signing_key(
  186. "@user_id:test",
  187. "self_signing",
  188. {
  189. "keys": {
  190. "ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  191. },
  192. "signatures": {
  193. "@user_id:test": {
  194. "ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  195. }
  196. },
  197. },
  198. )
  199. )
  200. # Add some device updates with sequential `stream_id`s
  201. # Note that the public cross-signing keys occupy the same space as device IDs,
  202. # so also notify that those have updated.
  203. device_ids = [
  204. "device_id1",
  205. "device_id2",
  206. "fakeMaster",
  207. "fakeSelfSigning",
  208. ]
  209. self.add_device_change("@user_id:test", device_ids, "somehost")
  210. # Get device updates meant for this remote
  211. next_stream_id, device_updates = self.get_success(
  212. self.store.get_device_updates_by_remote("somehost", -1, limit=3)
  213. )
  214. # Here we expect the device updates for `device_id1` and `device_id2`.
  215. # That means we only receive 2 updates this time around.
  216. # If we had a higher limit, we would expect to see the pair of
  217. # (unstable-prefixed & unprefixed) signing key updates for the device
  218. # represented by `fakeMaster` and `fakeSelfSigning`.
  219. # Our implementation only sends these two variants together, so we get
  220. # a short batch.
  221. self.assertEqual(len(device_updates), 2, device_updates)
  222. # Check the first two devices (device_id1, device_id2) came out.
  223. self._check_devices_in_updates(device_ids[:2], device_updates)
  224. # Get more device updates meant for this remote
  225. next_stream_id, device_updates = self.get_success(
  226. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  227. )
  228. # The next 2 updates should be a cross-signing key update
  229. # (the master key update and the self-signing key update are combined into
  230. # one 'signing key update', but the cross-signing key update is emitted
  231. # twice, once with an unprefixed type and once again with an unstable-prefixed type)
  232. # (This is a temporary arrangement for backwards compatibility!)
  233. self.assertEqual(len(device_updates), 2, device_updates)
  234. self.assertEqual(
  235. device_updates[0][0], EduTypes.SIGNING_KEY_UPDATE, device_updates[0]
  236. )
  237. self.assertEqual(
  238. device_updates[1][0],
  239. EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
  240. device_updates[1],
  241. )
  242. # Check there are no more device updates left.
  243. _, device_updates = self.get_success(
  244. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  245. )
  246. self.assertEqual(device_updates, [])
  247. def _check_devices_in_updates(
  248. self,
  249. expected_device_ids: Collection[str],
  250. device_updates: List[Tuple[str, JsonDict]],
  251. ) -> None:
  252. """Check that an specific device ids exist in a list of device update EDUs"""
  253. self.assertEqual(len(device_updates), len(expected_device_ids))
  254. received_device_ids = {
  255. update["device_id"] for edu_type, update in device_updates
  256. }
  257. self.assertEqual(received_device_ids, set(expected_device_ids))
  258. def test_update_device(self) -> None:
  259. self.get_success(
  260. self.store.store_device("user_id", "device_id", "display_name 1")
  261. )
  262. res = self.get_success(self.store.get_device("user_id", "device_id"))
  263. assert res is not None
  264. self.assertEqual("display_name 1", res["display_name"])
  265. # do a no-op first
  266. self.get_success(self.store.update_device("user_id", "device_id"))
  267. res = self.get_success(self.store.get_device("user_id", "device_id"))
  268. assert res is not None
  269. self.assertEqual("display_name 1", res["display_name"])
  270. # do the update
  271. self.get_success(
  272. self.store.update_device(
  273. "user_id", "device_id", new_display_name="display_name 2"
  274. )
  275. )
  276. # check it worked
  277. res = self.get_success(self.store.get_device("user_id", "device_id"))
  278. assert res is not None
  279. self.assertEqual("display_name 2", res["display_name"])
  280. def test_update_unknown_device(self) -> None:
  281. exc = self.get_failure(
  282. self.store.update_device(
  283. "user_id", "unknown_device_id", new_display_name="display_name 2"
  284. ),
  285. synapse.api.errors.StoreError,
  286. )
  287. self.assertEqual(404, exc.value.code)