test_devices.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import synapse.api.errors
  15. from tests.unittest import HomeserverTestCase
  16. class DeviceStoreTestCase(HomeserverTestCase):
  17. def prepare(self, reactor, clock, hs):
  18. self.store = hs.get_datastores().main
  19. def add_device_change(self, user_id, device_ids, host):
  20. """Add a device list change for the given device to
  21. `device_lists_outbound_pokes` table.
  22. """
  23. for device_id in device_ids:
  24. stream_id = self.get_success(
  25. self.store.add_device_change_to_streams(
  26. user_id, [device_id], ["!some:room"]
  27. )
  28. )
  29. self.get_success(
  30. self.store.add_device_list_outbound_pokes(
  31. user_id=user_id,
  32. device_id=device_id,
  33. room_id="!some:room",
  34. stream_id=stream_id,
  35. hosts=[host],
  36. context={},
  37. )
  38. )
  39. def test_store_new_device(self):
  40. self.get_success(
  41. self.store.store_device("user_id", "device_id", "display_name")
  42. )
  43. res = self.get_success(self.store.get_device("user_id", "device_id"))
  44. self.assertDictContainsSubset(
  45. {
  46. "user_id": "user_id",
  47. "device_id": "device_id",
  48. "display_name": "display_name",
  49. },
  50. res,
  51. )
  52. def test_get_devices_by_user(self):
  53. self.get_success(
  54. self.store.store_device("user_id", "device1", "display_name 1")
  55. )
  56. self.get_success(
  57. self.store.store_device("user_id", "device2", "display_name 2")
  58. )
  59. self.get_success(
  60. self.store.store_device("user_id2", "device3", "display_name 3")
  61. )
  62. res = self.get_success(self.store.get_devices_by_user("user_id"))
  63. self.assertEqual(2, len(res.keys()))
  64. self.assertDictContainsSubset(
  65. {
  66. "user_id": "user_id",
  67. "device_id": "device1",
  68. "display_name": "display_name 1",
  69. },
  70. res["device1"],
  71. )
  72. self.assertDictContainsSubset(
  73. {
  74. "user_id": "user_id",
  75. "device_id": "device2",
  76. "display_name": "display_name 2",
  77. },
  78. res["device2"],
  79. )
  80. def test_count_devices_by_users(self):
  81. self.get_success(
  82. self.store.store_device("user_id", "device1", "display_name 1")
  83. )
  84. self.get_success(
  85. self.store.store_device("user_id", "device2", "display_name 2")
  86. )
  87. self.get_success(
  88. self.store.store_device("user_id2", "device3", "display_name 3")
  89. )
  90. res = self.get_success(self.store.count_devices_by_users())
  91. self.assertEqual(0, res)
  92. res = self.get_success(self.store.count_devices_by_users(["unknown"]))
  93. self.assertEqual(0, res)
  94. res = self.get_success(self.store.count_devices_by_users(["user_id"]))
  95. self.assertEqual(2, res)
  96. res = self.get_success(
  97. self.store.count_devices_by_users(["user_id", "user_id2"])
  98. )
  99. self.assertEqual(3, res)
  100. def test_get_device_updates_by_remote(self):
  101. device_ids = ["device_id1", "device_id2"]
  102. # Add two device updates with sequential `stream_id`s
  103. self.add_device_change("@user_id:test", device_ids, "somehost")
  104. # Get all device updates ever meant for this remote
  105. now_stream_id, device_updates = self.get_success(
  106. self.store.get_device_updates_by_remote("somehost", -1, limit=100)
  107. )
  108. # Check original device_ids are contained within these updates
  109. self._check_devices_in_updates(device_ids, device_updates)
  110. def test_get_device_updates_by_remote_can_limit_properly(self):
  111. """
  112. Tests that `get_device_updates_by_remote` returns an appropriate
  113. stream_id to resume fetching from (without skipping any results).
  114. """
  115. # Add some device updates with sequential `stream_id`s
  116. device_ids = [
  117. "device_id1",
  118. "device_id2",
  119. "device_id3",
  120. "device_id4",
  121. "device_id5",
  122. ]
  123. self.add_device_change("@user_id:test", device_ids, "somehost")
  124. # Get device updates meant for this remote
  125. next_stream_id, device_updates = self.get_success(
  126. self.store.get_device_updates_by_remote("somehost", -1, limit=3)
  127. )
  128. # Check the first three original device_ids are contained within these updates
  129. self._check_devices_in_updates(device_ids[:3], device_updates)
  130. # Get the next batch of device updates
  131. next_stream_id, device_updates = self.get_success(
  132. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  133. )
  134. # Check the last two original device_ids are contained within these updates
  135. self._check_devices_in_updates(device_ids[3:], device_updates)
  136. # Add some more device updates to ensure it still resumes properly
  137. device_ids = ["device_id6", "device_id7"]
  138. self.add_device_change("@user_id:test", device_ids, "somehost")
  139. # Get the next batch of device updates
  140. next_stream_id, device_updates = self.get_success(
  141. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  142. )
  143. # Check the newly-added device_ids are contained within these updates
  144. self._check_devices_in_updates(device_ids, device_updates)
  145. # Check there are no more device updates left.
  146. _, device_updates = self.get_success(
  147. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  148. )
  149. self.assertEqual(device_updates, [])
  150. def test_get_device_updates_by_remote_cross_signing_key_updates(
  151. self,
  152. ) -> None:
  153. """
  154. Tests that `get_device_updates_by_remote` limits the length of the return value
  155. properly when cross-signing key updates are present.
  156. Current behaviour is that the cross-signing key updates will always come in pairs,
  157. even if that means leaving an earlier batch one EDU short of the limit.
  158. """
  159. assert self.hs.is_mine_id(
  160. "@user_id:test"
  161. ), "Test not valid: this MXID should be considered local"
  162. self.get_success(
  163. self.store.set_e2e_cross_signing_key(
  164. "@user_id:test",
  165. "master",
  166. {
  167. "keys": {
  168. "ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  169. },
  170. "signatures": {
  171. "@user_id:test": {
  172. "ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  173. }
  174. },
  175. },
  176. )
  177. )
  178. self.get_success(
  179. self.store.set_e2e_cross_signing_key(
  180. "@user_id:test",
  181. "self_signing",
  182. {
  183. "keys": {
  184. "ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  185. },
  186. "signatures": {
  187. "@user_id:test": {
  188. "ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  189. }
  190. },
  191. },
  192. )
  193. )
  194. # Add some device updates with sequential `stream_id`s
  195. # Note that the public cross-signing keys occupy the same space as device IDs,
  196. # so also notify that those have updated.
  197. device_ids = [
  198. "device_id1",
  199. "device_id2",
  200. "fakeMaster",
  201. "fakeSelfSigning",
  202. ]
  203. self.add_device_change("@user_id:test", device_ids, "somehost")
  204. # Get device updates meant for this remote
  205. next_stream_id, device_updates = self.get_success(
  206. self.store.get_device_updates_by_remote("somehost", -1, limit=3)
  207. )
  208. # Here we expect the device updates for `device_id1` and `device_id2`.
  209. # That means we only receive 2 updates this time around.
  210. # If we had a higher limit, we would expect to see the pair of
  211. # (unstable-prefixed & unprefixed) signing key updates for the device
  212. # represented by `fakeMaster` and `fakeSelfSigning`.
  213. # Our implementation only sends these two variants together, so we get
  214. # a short batch.
  215. self.assertEqual(len(device_updates), 2, device_updates)
  216. # Check the first two devices (device_id1, device_id2) came out.
  217. self._check_devices_in_updates(device_ids[:2], device_updates)
  218. # Get more device updates meant for this remote
  219. next_stream_id, device_updates = self.get_success(
  220. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  221. )
  222. # The next 2 updates should be a cross-signing key update
  223. # (the master key update and the self-signing key update are combined into
  224. # one 'signing key update', but the cross-signing key update is emitted
  225. # twice, once with an unprefixed type and once again with an unstable-prefixed type)
  226. # (This is a temporary arrangement for backwards compatibility!)
  227. self.assertEqual(len(device_updates), 2, device_updates)
  228. self.assertEqual(
  229. device_updates[0][0], "m.signing_key_update", device_updates[0]
  230. )
  231. self.assertEqual(
  232. device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
  233. )
  234. # Check there are no more device updates left.
  235. _, device_updates = self.get_success(
  236. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  237. )
  238. self.assertEqual(device_updates, [])
  239. def _check_devices_in_updates(self, expected_device_ids, device_updates):
  240. """Check that an specific device ids exist in a list of device update EDUs"""
  241. self.assertEqual(len(device_updates), len(expected_device_ids))
  242. received_device_ids = {
  243. update["device_id"] for edu_type, update in device_updates
  244. }
  245. self.assertEqual(received_device_ids, set(expected_device_ids))
  246. def test_update_device(self):
  247. self.get_success(
  248. self.store.store_device("user_id", "device_id", "display_name 1")
  249. )
  250. res = self.get_success(self.store.get_device("user_id", "device_id"))
  251. self.assertEqual("display_name 1", res["display_name"])
  252. # do a no-op first
  253. self.get_success(self.store.update_device("user_id", "device_id"))
  254. res = self.get_success(self.store.get_device("user_id", "device_id"))
  255. self.assertEqual("display_name 1", res["display_name"])
  256. # do the update
  257. self.get_success(
  258. self.store.update_device(
  259. "user_id", "device_id", new_display_name="display_name 2"
  260. )
  261. )
  262. # check it worked
  263. res = self.get_success(self.store.get_device("user_id", "device_id"))
  264. self.assertEqual("display_name 2", res["display_name"])
  265. def test_update_unknown_device(self):
  266. exc = self.get_failure(
  267. self.store.update_device(
  268. "user_id", "unknown_device_id", new_display_name="display_name 2"
  269. ),
  270. synapse.api.errors.StoreError,
  271. )
  272. self.assertEqual(404, exc.value.code)