1
0

test_devices.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet import defer
  16. import synapse.api.errors
  17. import tests.unittest
  18. import tests.utils
  19. class DeviceStoreTestCase(tests.unittest.TestCase):
  20. def __init__(self, *args, **kwargs):
  21. super().__init__(*args, **kwargs)
  22. self.store = None # type: synapse.storage.DataStore
  23. @defer.inlineCallbacks
  24. def setUp(self):
  25. hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
  26. self.store = hs.get_datastore()
  27. @defer.inlineCallbacks
  28. def test_store_new_device(self):
  29. yield defer.ensureDeferred(
  30. self.store.store_device("user_id", "device_id", "display_name")
  31. )
  32. res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
  33. self.assertDictContainsSubset(
  34. {
  35. "user_id": "user_id",
  36. "device_id": "device_id",
  37. "display_name": "display_name",
  38. },
  39. res,
  40. )
  41. @defer.inlineCallbacks
  42. def test_get_devices_by_user(self):
  43. yield defer.ensureDeferred(
  44. self.store.store_device("user_id", "device1", "display_name 1")
  45. )
  46. yield defer.ensureDeferred(
  47. self.store.store_device("user_id", "device2", "display_name 2")
  48. )
  49. yield defer.ensureDeferred(
  50. self.store.store_device("user_id2", "device3", "display_name 3")
  51. )
  52. res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
  53. self.assertEqual(2, len(res.keys()))
  54. self.assertDictContainsSubset(
  55. {
  56. "user_id": "user_id",
  57. "device_id": "device1",
  58. "display_name": "display_name 1",
  59. },
  60. res["device1"],
  61. )
  62. self.assertDictContainsSubset(
  63. {
  64. "user_id": "user_id",
  65. "device_id": "device2",
  66. "display_name": "display_name 2",
  67. },
  68. res["device2"],
  69. )
  70. @defer.inlineCallbacks
  71. def test_count_devices_by_users(self):
  72. yield defer.ensureDeferred(
  73. self.store.store_device("user_id", "device1", "display_name 1")
  74. )
  75. yield defer.ensureDeferred(
  76. self.store.store_device("user_id", "device2", "display_name 2")
  77. )
  78. yield defer.ensureDeferred(
  79. self.store.store_device("user_id2", "device3", "display_name 3")
  80. )
  81. res = yield defer.ensureDeferred(self.store.count_devices_by_users())
  82. self.assertEqual(0, res)
  83. res = yield defer.ensureDeferred(self.store.count_devices_by_users(["unknown"]))
  84. self.assertEqual(0, res)
  85. res = yield defer.ensureDeferred(self.store.count_devices_by_users(["user_id"]))
  86. self.assertEqual(2, res)
  87. res = yield defer.ensureDeferred(
  88. self.store.count_devices_by_users(["user_id", "user_id2"])
  89. )
  90. self.assertEqual(3, res)
  91. @defer.inlineCallbacks
  92. def test_get_device_updates_by_remote(self):
  93. device_ids = ["device_id1", "device_id2"]
  94. # Add two device updates with a single stream_id
  95. yield defer.ensureDeferred(
  96. self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
  97. )
  98. # Get all device updates ever meant for this remote
  99. now_stream_id, device_updates = yield defer.ensureDeferred(
  100. self.store.get_device_updates_by_remote("somehost", -1, limit=100)
  101. )
  102. # Check original device_ids are contained within these updates
  103. self._check_devices_in_updates(device_ids, device_updates)
  104. def _check_devices_in_updates(self, expected_device_ids, device_updates):
  105. """Check that an specific device ids exist in a list of device update EDUs"""
  106. self.assertEqual(len(device_updates), len(expected_device_ids))
  107. received_device_ids = {
  108. update["device_id"] for edu_type, update in device_updates
  109. }
  110. self.assertEqual(received_device_ids, set(expected_device_ids))
  111. @defer.inlineCallbacks
  112. def test_update_device(self):
  113. yield defer.ensureDeferred(
  114. self.store.store_device("user_id", "device_id", "display_name 1")
  115. )
  116. res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
  117. self.assertEqual("display_name 1", res["display_name"])
  118. # do a no-op first
  119. yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
  120. res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
  121. self.assertEqual("display_name 1", res["display_name"])
  122. # do the update
  123. yield defer.ensureDeferred(
  124. self.store.update_device(
  125. "user_id", "device_id", new_display_name="display_name 2"
  126. )
  127. )
  128. # check it worked
  129. res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
  130. self.assertEqual("display_name 2", res["display_name"])
  131. @defer.inlineCallbacks
  132. def test_update_unknown_device(self):
  133. with self.assertRaises(synapse.api.errors.StoreError) as cm:
  134. yield defer.ensureDeferred(
  135. self.store.update_device(
  136. "user_id", "unknown_device_id", new_display_name="display_name 2"
  137. )
  138. )
  139. self.assertEqual(404, cm.exception.code)