test_devices.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright 2016-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import synapse.api.errors
  15. from tests.unittest import HomeserverTestCase
  16. class DeviceStoreTestCase(HomeserverTestCase):
  17. def prepare(self, reactor, clock, hs):
  18. self.store = hs.get_datastore()
  19. def test_store_new_device(self):
  20. self.get_success(
  21. self.store.store_device("user_id", "device_id", "display_name")
  22. )
  23. res = self.get_success(self.store.get_device("user_id", "device_id"))
  24. self.assertDictContainsSubset(
  25. {
  26. "user_id": "user_id",
  27. "device_id": "device_id",
  28. "display_name": "display_name",
  29. },
  30. res,
  31. )
  32. def test_get_devices_by_user(self):
  33. self.get_success(
  34. self.store.store_device("user_id", "device1", "display_name 1")
  35. )
  36. self.get_success(
  37. self.store.store_device("user_id", "device2", "display_name 2")
  38. )
  39. self.get_success(
  40. self.store.store_device("user_id2", "device3", "display_name 3")
  41. )
  42. res = self.get_success(self.store.get_devices_by_user("user_id"))
  43. self.assertEqual(2, len(res.keys()))
  44. self.assertDictContainsSubset(
  45. {
  46. "user_id": "user_id",
  47. "device_id": "device1",
  48. "display_name": "display_name 1",
  49. },
  50. res["device1"],
  51. )
  52. self.assertDictContainsSubset(
  53. {
  54. "user_id": "user_id",
  55. "device_id": "device2",
  56. "display_name": "display_name 2",
  57. },
  58. res["device2"],
  59. )
  60. def test_count_devices_by_users(self):
  61. self.get_success(
  62. self.store.store_device("user_id", "device1", "display_name 1")
  63. )
  64. self.get_success(
  65. self.store.store_device("user_id", "device2", "display_name 2")
  66. )
  67. self.get_success(
  68. self.store.store_device("user_id2", "device3", "display_name 3")
  69. )
  70. res = self.get_success(self.store.count_devices_by_users())
  71. self.assertEqual(0, res)
  72. res = self.get_success(self.store.count_devices_by_users(["unknown"]))
  73. self.assertEqual(0, res)
  74. res = self.get_success(self.store.count_devices_by_users(["user_id"]))
  75. self.assertEqual(2, res)
  76. res = self.get_success(
  77. self.store.count_devices_by_users(["user_id", "user_id2"])
  78. )
  79. self.assertEqual(3, res)
  80. def test_get_device_updates_by_remote(self):
  81. device_ids = ["device_id1", "device_id2"]
  82. # Add two device updates with sequential `stream_id`s
  83. self.get_success(
  84. self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
  85. )
  86. # Get all device updates ever meant for this remote
  87. now_stream_id, device_updates = self.get_success(
  88. self.store.get_device_updates_by_remote("somehost", -1, limit=100)
  89. )
  90. # Check original device_ids are contained within these updates
  91. self._check_devices_in_updates(device_ids, device_updates)
  92. def test_get_device_updates_by_remote_can_limit_properly(self):
  93. """
  94. Tests that `get_device_updates_by_remote` returns an appropriate
  95. stream_id to resume fetching from (without skipping any results).
  96. """
  97. # Add some device updates with sequential `stream_id`s
  98. device_ids = [
  99. "device_id1",
  100. "device_id2",
  101. "device_id3",
  102. "device_id4",
  103. "device_id5",
  104. ]
  105. self.get_success(
  106. self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
  107. )
  108. # Get device updates meant for this remote
  109. next_stream_id, device_updates = self.get_success(
  110. self.store.get_device_updates_by_remote("somehost", -1, limit=3)
  111. )
  112. # Check the first three original device_ids are contained within these updates
  113. self._check_devices_in_updates(device_ids[:3], device_updates)
  114. # Get the next batch of device updates
  115. next_stream_id, device_updates = self.get_success(
  116. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  117. )
  118. # Check the last two original device_ids are contained within these updates
  119. self._check_devices_in_updates(device_ids[3:], device_updates)
  120. # Add some more device updates to ensure it still resumes properly
  121. device_ids = ["device_id6", "device_id7"]
  122. self.get_success(
  123. self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
  124. )
  125. # Get the next batch of device updates
  126. next_stream_id, device_updates = self.get_success(
  127. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  128. )
  129. # Check the newly-added device_ids are contained within these updates
  130. self._check_devices_in_updates(device_ids, device_updates)
  131. # Check there are no more device updates left.
  132. _, device_updates = self.get_success(
  133. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  134. )
  135. self.assertEqual(device_updates, [])
  136. def test_get_device_updates_by_remote_cross_signing_key_updates(
  137. self,
  138. ) -> None:
  139. """
  140. Tests that `get_device_updates_by_remote` limits the length of the return value
  141. properly when cross-signing key updates are present.
  142. Current behaviour is that the cross-signing key updates will always come in pairs,
  143. even if that means leaving an earlier batch one EDU short of the limit.
  144. """
  145. assert self.hs.is_mine_id(
  146. "@user_id:test"
  147. ), "Test not valid: this MXID should be considered local"
  148. self.get_success(
  149. self.store.set_e2e_cross_signing_key(
  150. "@user_id:test",
  151. "master",
  152. {
  153. "keys": {
  154. "ed25519:fakeMaster": "aaafakefakefake1AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  155. },
  156. "signatures": {
  157. "@user_id:test": {
  158. "ed25519:fake2": "aaafakefakefake2AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  159. }
  160. },
  161. },
  162. )
  163. )
  164. self.get_success(
  165. self.store.set_e2e_cross_signing_key(
  166. "@user_id:test",
  167. "self_signing",
  168. {
  169. "keys": {
  170. "ed25519:fakeSelfSigning": "aaafakefakefake3AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  171. },
  172. "signatures": {
  173. "@user_id:test": {
  174. "ed25519:fake4": "aaafakefakefake4AAAAAAAAAAAAAAAAAAAAAAAAAAA="
  175. }
  176. },
  177. },
  178. )
  179. )
  180. # Add some device updates with sequential `stream_id`s
  181. # Note that the public cross-signing keys occupy the same space as device IDs,
  182. # so also notify that those have updated.
  183. device_ids = [
  184. "device_id1",
  185. "device_id2",
  186. "fakeMaster",
  187. "fakeSelfSigning",
  188. ]
  189. self.get_success(
  190. self.store.add_device_change_to_streams(
  191. "@user_id:test", device_ids, ["somehost"]
  192. )
  193. )
  194. # Get device updates meant for this remote
  195. next_stream_id, device_updates = self.get_success(
  196. self.store.get_device_updates_by_remote("somehost", -1, limit=3)
  197. )
  198. # Here we expect the device updates for `device_id1` and `device_id2`.
  199. # That means we only receive 2 updates this time around.
  200. # If we had a higher limit, we would expect to see the pair of
  201. # (unstable-prefixed & unprefixed) signing key updates for the device
  202. # represented by `fakeMaster` and `fakeSelfSigning`.
  203. # Our implementation only sends these two variants together, so we get
  204. # a short batch.
  205. self.assertEqual(len(device_updates), 2, device_updates)
  206. # Check the first two devices (device_id1, device_id2) came out.
  207. self._check_devices_in_updates(device_ids[:2], device_updates)
  208. # Get more device updates meant for this remote
  209. next_stream_id, device_updates = self.get_success(
  210. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  211. )
  212. # The next 2 updates should be a cross-signing key update
  213. # (the master key update and the self-signing key update are combined into
  214. # one 'signing key update', but the cross-signing key update is emitted
  215. # twice, once with an unprefixed type and once again with an unstable-prefixed type)
  216. # (This is a temporary arrangement for backwards compatibility!)
  217. self.assertEqual(len(device_updates), 2, device_updates)
  218. self.assertEqual(
  219. device_updates[0][0], "m.signing_key_update", device_updates[0]
  220. )
  221. self.assertEqual(
  222. device_updates[1][0], "org.matrix.signing_key_update", device_updates[1]
  223. )
  224. # Check there are no more device updates left.
  225. _, device_updates = self.get_success(
  226. self.store.get_device_updates_by_remote("somehost", next_stream_id, limit=3)
  227. )
  228. self.assertEqual(device_updates, [])
  229. def _check_devices_in_updates(self, expected_device_ids, device_updates):
  230. """Check that an specific device ids exist in a list of device update EDUs"""
  231. self.assertEqual(len(device_updates), len(expected_device_ids))
  232. received_device_ids = {
  233. update["device_id"] for edu_type, update in device_updates
  234. }
  235. self.assertEqual(received_device_ids, set(expected_device_ids))
  236. def test_update_device(self):
  237. self.get_success(
  238. self.store.store_device("user_id", "device_id", "display_name 1")
  239. )
  240. res = self.get_success(self.store.get_device("user_id", "device_id"))
  241. self.assertEqual("display_name 1", res["display_name"])
  242. # do a no-op first
  243. self.get_success(self.store.update_device("user_id", "device_id"))
  244. res = self.get_success(self.store.get_device("user_id", "device_id"))
  245. self.assertEqual("display_name 1", res["display_name"])
  246. # do the update
  247. self.get_success(
  248. self.store.update_device(
  249. "user_id", "device_id", new_display_name="display_name 2"
  250. )
  251. )
  252. # check it worked
  253. res = self.get_success(self.store.get_device("user_id", "device_id"))
  254. self.assertEqual("display_name 2", res["display_name"])
  255. def test_update_unknown_device(self):
  256. exc = self.get_failure(
  257. self.store.update_device(
  258. "user_id", "unknown_device_id", new_display_name="display_name 2"
  259. ),
  260. synapse.api.errors.StoreError,
  261. )
  262. self.assertEqual(404, exc.value.code)