test_partial_state_events_tracker.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Copyright 2022 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict
  15. from unittest import mock
  16. from twisted.internet.defer import CancelledError, ensureDeferred
  17. from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
  18. from tests.unittest import TestCase
  19. class PartialStateEventsTrackerTestCase(TestCase):
  20. def setUp(self) -> None:
  21. # the results to be returned by the mocked get_partial_state_events
  22. self._events_dict: Dict[str, bool] = {}
  23. async def get_partial_state_events(events):
  24. return {e: self._events_dict[e] for e in events}
  25. self.mock_store = mock.Mock(spec_set=["get_partial_state_events"])
  26. self.mock_store.get_partial_state_events.side_effect = get_partial_state_events
  27. self.tracker = PartialStateEventsTracker(self.mock_store)
  28. def test_does_not_block_for_full_state_events(self):
  29. self._events_dict = {"event1": False, "event2": False}
  30. self.successResultOf(
  31. ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
  32. )
  33. self.mock_store.get_partial_state_events.assert_called_once_with(
  34. ["event1", "event2"]
  35. )
  36. def test_blocks_for_partial_state_events(self):
  37. self._events_dict = {"event1": True, "event2": False}
  38. d = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
  39. # there should be no result yet
  40. self.assertNoResult(d)
  41. # notifying that the event has been de-partial-stated should unblock
  42. self.tracker.notify_un_partial_stated("event1")
  43. self.successResultOf(d)
  44. def test_un_partial_state_race(self):
  45. # if the event is un-partial-stated between the initial check and the
  46. # registration of the listener, it should not block.
  47. self._events_dict = {"event1": True, "event2": False}
  48. async def get_partial_state_events(events):
  49. res = {e: self._events_dict[e] for e in events}
  50. # change the result for next time
  51. self._events_dict = {"event1": False, "event2": False}
  52. return res
  53. self.mock_store.get_partial_state_events.side_effect = get_partial_state_events
  54. self.successResultOf(
  55. ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
  56. )
  57. def test_un_partial_state_during_get_partial_state_events(self):
  58. # we should correctly handle a call to notify_un_partial_stated during the
  59. # second call to get_partial_state_events.
  60. self._events_dict = {"event1": True, "event2": False}
  61. async def get_partial_state_events1(events):
  62. self.mock_store.get_partial_state_events.side_effect = (
  63. get_partial_state_events2
  64. )
  65. return {e: self._events_dict[e] for e in events}
  66. async def get_partial_state_events2(events):
  67. self.tracker.notify_un_partial_stated("event1")
  68. self._events_dict["event1"] = False
  69. return {e: self._events_dict[e] for e in events}
  70. self.mock_store.get_partial_state_events.side_effect = get_partial_state_events1
  71. self.successResultOf(
  72. ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
  73. )
  74. def test_cancellation(self):
  75. self._events_dict = {"event1": True, "event2": False}
  76. d1 = ensureDeferred(self.tracker.await_full_state(["event1", "event2"]))
  77. self.assertNoResult(d1)
  78. d2 = ensureDeferred(self.tracker.await_full_state(["event1"]))
  79. self.assertNoResult(d2)
  80. d1.cancel()
  81. self.assertFailure(d1, CancelledError)
  82. # d2 should still be waiting!
  83. self.assertNoResult(d2)
  84. self.tracker.notify_un_partial_stated("event1")
  85. self.successResultOf(d2)