partial_state_events_tracker.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. # Copyright 2022 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from collections import defaultdict
  16. from typing import Collection, Dict, Set
  17. from twisted.internet import defer
  18. from twisted.internet.defer import Deferred
  19. from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
  20. from synapse.storage.databases.main.events_worker import EventsWorkerStore
  21. from synapse.storage.databases.main.room import RoomWorkerStore
  22. from synapse.util import unwrapFirstError
  23. logger = logging.getLogger(__name__)
  24. class PartialStateEventsTracker:
  25. """Keeps track of which events have partial state, after a partial-state join"""
  26. def __init__(self, store: EventsWorkerStore):
  27. self._store = store
  28. # a map from event id to a set of Deferreds which are waiting for that event to be
  29. # un-partial-stated.
  30. self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
  31. def notify_un_partial_stated(self, event_id: str) -> None:
  32. """Notify that we now have full state for a given event
  33. Called by the state-resynchronization loop whenever we resynchronize the state
  34. for a particular event. Unblocks any callers to await_full_state() for that
  35. event.
  36. Args:
  37. event_id: the event that now has full state.
  38. """
  39. observers = self._observers.pop(event_id, None)
  40. if not observers:
  41. return
  42. logger.info(
  43. "Notifying %i things waiting for un-partial-stating of event %s",
  44. len(observers),
  45. event_id,
  46. )
  47. with PreserveLoggingContext():
  48. for o in observers:
  49. o.callback(None)
  50. async def await_full_state(self, event_ids: Collection[str]) -> None:
  51. """Wait for all the given events to have full state.
  52. Args:
  53. event_ids: the list of event ids that we want full state for
  54. """
  55. # first try the happy path: if there are no partial-state events, we can return
  56. # quickly
  57. partial_state_event_ids = [
  58. ev
  59. for ev, p in (await self._store.get_partial_state_events(event_ids)).items()
  60. if p
  61. ]
  62. if not partial_state_event_ids:
  63. return
  64. logger.info(
  65. "Awaiting un-partial-stating of events %s",
  66. partial_state_event_ids,
  67. stack_info=True,
  68. )
  69. # create an observer for each lazy-joined event
  70. observers: Dict[str, Deferred[None]] = {
  71. event_id: Deferred() for event_id in partial_state_event_ids
  72. }
  73. for event_id, observer in observers.items():
  74. self._observers[event_id].add(observer)
  75. try:
  76. # some of them may have been un-lazy-joined between us checking the db and
  77. # registering the observer, in which case we'd wait forever for the
  78. # notification. Call back the observers now.
  79. for event_id, partial in (
  80. await self._store.get_partial_state_events(observers.keys())
  81. ).items():
  82. # there may have been a call to notify_un_partial_stated during the
  83. # db query, so the observers may already have been called.
  84. if not partial and not observers[event_id].called:
  85. observers[event_id].callback(None)
  86. await make_deferred_yieldable(
  87. defer.gatherResults(
  88. observers.values(),
  89. consumeErrors=True,
  90. )
  91. ).addErrback(unwrapFirstError)
  92. logger.info("Events %s all un-partial-stated", observers.keys())
  93. finally:
  94. # remove any observers we created. This should happen when the notification
  95. # is received, but that might not happen for two reasons:
  96. # (a) we're bailing out early on an exception (including us being
  97. # cancelled during the await)
  98. # (b) the event got de-lazy-joined before we set up the observer.
  99. for event_id, observer in observers.items():
  100. observer_set = self._observers.get(event_id)
  101. if observer_set:
  102. observer_set.discard(observer)
  103. if not observer_set:
  104. del self._observers[event_id]
  105. class PartialCurrentStateTracker:
  106. """Keeps track of which rooms have partial state, after partial-state joins"""
  107. def __init__(self, store: RoomWorkerStore):
  108. self._store = store
  109. # a map from room id to a set of Deferreds which are waiting for that room to be
  110. # un-partial-stated.
  111. self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
  112. def notify_un_partial_stated(self, room_id: str) -> None:
  113. """Notify that we now have full current state for a given room
  114. Unblocks any callers to await_full_state() for that room.
  115. Args:
  116. room_id: the room that now has full current state.
  117. """
  118. observers = self._observers.pop(room_id, None)
  119. if not observers:
  120. return
  121. logger.info(
  122. "Notifying %i things waiting for un-partial-stating of room %s",
  123. len(observers),
  124. room_id,
  125. )
  126. with PreserveLoggingContext():
  127. for o in observers:
  128. o.callback(None)
  129. async def await_full_state(self, room_id: str) -> None:
  130. # We add the deferred immediately so that the DB call to check for
  131. # partial state doesn't race when we unpartial the room.
  132. d: Deferred[None] = Deferred()
  133. self._observers.setdefault(room_id, set()).add(d)
  134. try:
  135. # Check if the room has partial current state or not.
  136. has_partial_state = await self._store.is_partial_state_room(room_id)
  137. if not has_partial_state:
  138. return
  139. logger.info(
  140. "Awaiting un-partial-stating of room %s",
  141. room_id,
  142. )
  143. await make_deferred_yieldable(d)
  144. logger.info("Room has un-partial-stated")
  145. finally:
  146. # Remove the added observer, and remove the room entry if its empty.
  147. ds = self._observers.get(room_id)
  148. if ds is not None:
  149. ds.discard(d)
  150. if not ds:
  151. self._observers.pop(room_id, None)