v2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819
  1. # Copyright 2018 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import heapq
  15. import itertools
  16. import logging
  17. from typing import (
  18. Any,
  19. Awaitable,
  20. Callable,
  21. Collection,
  22. Dict,
  23. Generator,
  24. Iterable,
  25. List,
  26. Mapping,
  27. Optional,
  28. Sequence,
  29. Set,
  30. Tuple,
  31. overload,
  32. )
  33. from typing_extensions import Literal, Protocol
  34. from synapse import event_auth
  35. from synapse.api.constants import EventTypes
  36. from synapse.api.errors import AuthError
  37. from synapse.api.room_versions import RoomVersion
  38. from synapse.events import EventBase
  39. from synapse.types import MutableStateMap, StateMap
  40. logger = logging.getLogger(__name__)
  41. class Clock(Protocol):
  42. # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
  43. # We only ever sleep(0) though, so that other async functions can make forward
  44. # progress without waiting for stateres to complete.
  45. def sleep(self, duration_ms: float) -> Awaitable[None]:
  46. ...
  47. class StateResolutionStore(Protocol):
  48. # This is usually synapse.state.StateResolutionStore, but it's replaced with a
  49. # TestStateResolutionStore in tests.
  50. def get_events(
  51. self, event_ids: Collection[str], allow_rejected: bool = False
  52. ) -> Awaitable[Dict[str, EventBase]]:
  53. ...
  54. def get_auth_chain_difference(
  55. self, room_id: str, state_sets: List[Set[str]]
  56. ) -> Awaitable[Set[str]]:
  57. ...
  58. # We want to await to the reactor occasionally during state res when dealing
  59. # with large data sets, so that we don't exhaust the reactor. This is done by
  60. # awaiting to reactor during loops every N iterations.
  61. _AWAIT_AFTER_ITERATIONS = 100
  62. __all__ = [
  63. "resolve_events_with_store",
  64. ]
  65. async def resolve_events_with_store(
  66. clock: Clock,
  67. room_id: str,
  68. room_version: RoomVersion,
  69. state_sets: Sequence[StateMap[str]],
  70. event_map: Optional[Dict[str, EventBase]],
  71. state_res_store: StateResolutionStore,
  72. ) -> StateMap[str]:
  73. """Resolves the state using the v2 state resolution algorithm
  74. Args:
  75. clock
  76. room_id: the room we are working in
  77. room_version: The room version
  78. state_sets: List of dicts of (type, state_key) -> event_id,
  79. which are the different state groups to resolve.
  80. event_map:
  81. a dict from event_id to event, for any events that we happen to
  82. have in flight (eg, those currently being persisted). This will be
  83. used as a starting point for finding the state we need; any missing
  84. events will be requested via state_res_store.
  85. If None, all events will be fetched via state_res_store.
  86. state_res_store:
  87. Returns:
  88. A map from (type, state_key) to event_id.
  89. """
  90. logger.debug("Computing conflicted state")
  91. # We use event_map as a cache, so if its None we need to initialize it
  92. if event_map is None:
  93. event_map = {}
  94. # First split up the un/conflicted state
  95. unconflicted_state, conflicted_state = _seperate(state_sets)
  96. if not conflicted_state:
  97. return unconflicted_state
  98. logger.debug("%d conflicted state entries", len(conflicted_state))
  99. logger.debug("Calculating auth chain difference")
  100. # Also fetch all auth events that appear in only some of the state sets'
  101. # auth chains.
  102. auth_diff = await _get_auth_chain_difference(
  103. room_id, state_sets, event_map, state_res_store
  104. )
  105. full_conflicted_set = set(
  106. itertools.chain(
  107. itertools.chain.from_iterable(conflicted_state.values()), auth_diff
  108. )
  109. )
  110. events = await state_res_store.get_events(
  111. [eid for eid in full_conflicted_set if eid not in event_map],
  112. allow_rejected=True,
  113. )
  114. event_map.update(events)
  115. # everything in the event map should be in the right room
  116. for event in event_map.values():
  117. if event.room_id != room_id:
  118. raise Exception(
  119. "Attempting to state-resolve for room %s with event %s which is in %s"
  120. % (
  121. room_id,
  122. event.event_id,
  123. event.room_id,
  124. )
  125. )
  126. full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
  127. logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
  128. # Get and sort all the power events (kicks/bans/etc)
  129. power_events = (
  130. eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
  131. )
  132. sorted_power_events = await _reverse_topological_power_sort(
  133. clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
  134. )
  135. logger.debug("sorted %d power events", len(sorted_power_events))
  136. # Now sequentially auth each one
  137. resolved_state = await _iterative_auth_checks(
  138. clock,
  139. room_id,
  140. room_version,
  141. sorted_power_events,
  142. unconflicted_state,
  143. event_map,
  144. state_res_store,
  145. )
  146. logger.debug("resolved power events")
  147. # OK, so we've now resolved the power events. Now sort the remaining
  148. # events using the mainline of the resolved power level.
  149. set_power_events = set(sorted_power_events)
  150. leftover_events = [
  151. ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events
  152. ]
  153. logger.debug("sorting %d remaining events", len(leftover_events))
  154. pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
  155. leftover_events = await _mainline_sort(
  156. clock, room_id, leftover_events, pl, event_map, state_res_store
  157. )
  158. logger.debug("resolving remaining events")
  159. resolved_state = await _iterative_auth_checks(
  160. clock,
  161. room_id,
  162. room_version,
  163. leftover_events,
  164. resolved_state,
  165. event_map,
  166. state_res_store,
  167. )
  168. logger.debug("resolved")
  169. # We make sure that unconflicted state always still applies.
  170. resolved_state.update(unconflicted_state)
  171. logger.debug("done")
  172. return resolved_state
  173. async def _get_power_level_for_sender(
  174. room_id: str,
  175. event_id: str,
  176. event_map: Dict[str, EventBase],
  177. state_res_store: StateResolutionStore,
  178. ) -> int:
  179. """Return the power level of the sender of the given event according to
  180. their auth events.
  181. Args:
  182. room_id
  183. event_id
  184. event_map
  185. state_res_store
  186. Returns:
  187. The power level.
  188. """
  189. event = await _get_event(room_id, event_id, event_map, state_res_store)
  190. pl = None
  191. for aid in event.auth_event_ids():
  192. aev = await _get_event(
  193. room_id, aid, event_map, state_res_store, allow_none=True
  194. )
  195. if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
  196. pl = aev
  197. break
  198. if pl is None:
  199. # Couldn't find power level. Check if they're the creator of the room
  200. for aid in event.auth_event_ids():
  201. aev = await _get_event(
  202. room_id, aid, event_map, state_res_store, allow_none=True
  203. )
  204. if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
  205. if aev.content.get("creator") == event.sender:
  206. return 100
  207. break
  208. return 0
  209. level = pl.content.get("users", {}).get(event.sender)
  210. if level is None:
  211. level = pl.content.get("users_default", 0)
  212. if level is None:
  213. return 0
  214. else:
  215. return int(level)
  216. async def _get_auth_chain_difference(
  217. room_id: str,
  218. state_sets: Sequence[Mapping[Any, str]],
  219. event_map: Dict[str, EventBase],
  220. state_res_store: StateResolutionStore,
  221. ) -> Set[str]:
  222. """Compare the auth chains of each state set and return the set of events
  223. that only appear in some but not all of the auth chains.
  224. Args:
  225. state_sets
  226. event_map
  227. state_res_store
  228. Returns:
  229. Set of event IDs
  230. """
  231. # The `StateResolutionStore.get_auth_chain_difference` function assumes that
  232. # all events passed to it (and their auth chains) have been persisted
  233. # previously. This is not the case for any events in the `event_map`, and so
  234. # we need to manually handle those events.
  235. #
  236. # We do this by:
  237. # 1. calculating the auth chain difference for the state sets based on the
  238. # events in `event_map` alone
  239. # 2. replacing any events in the state_sets that are also in `event_map`
  240. # with their auth events (recursively), and then calling
  241. # `store.get_auth_chain_difference` as normal
  242. # 3. adding the results of 1 and 2 together.
  243. # Map from event ID in `event_map` to their auth event IDs, and their auth
  244. # event IDs if they appear in the `event_map`. This is the intersection of
  245. # the event's auth chain with the events in the `event_map` *plus* their
  246. # auth event IDs.
  247. events_to_auth_chain: Dict[str, Set[str]] = {}
  248. for event in event_map.values():
  249. chain = {event.event_id}
  250. events_to_auth_chain[event.event_id] = chain
  251. to_search = [event]
  252. while to_search:
  253. for auth_id in to_search.pop().auth_event_ids():
  254. chain.add(auth_id)
  255. auth_event = event_map.get(auth_id)
  256. if auth_event:
  257. to_search.append(auth_event)
  258. # We now a) calculate the auth chain difference for the unpersisted events
  259. # and b) work out the state sets to pass to the store.
  260. #
  261. # Note: If the `event_map` is empty (which is the common case), we can do a
  262. # much simpler calculation.
  263. if event_map:
  264. # The list of state sets to pass to the store, where each state set is a set
  265. # of the event ids making up the state. This is similar to `state_sets`,
  266. # except that (a) we only have event ids, not the complete
  267. # ((type, state_key)->event_id) mappings; and (b) we have stripped out
  268. # unpersisted events and replaced them with the persisted events in
  269. # their auth chain.
  270. state_sets_ids: List[Set[str]] = []
  271. # For each state set, the unpersisted event IDs reachable (by their auth
  272. # chain) from the events in that set.
  273. unpersisted_set_ids: List[Set[str]] = []
  274. for state_set in state_sets:
  275. set_ids: Set[str] = set()
  276. state_sets_ids.append(set_ids)
  277. unpersisted_ids: Set[str] = set()
  278. unpersisted_set_ids.append(unpersisted_ids)
  279. for event_id in state_set.values():
  280. event_chain = events_to_auth_chain.get(event_id)
  281. if event_chain is not None:
  282. # We have an event in `event_map`. We add all the auth
  283. # events that it references (that aren't also in `event_map`).
  284. set_ids.update(e for e in event_chain if e not in event_map)
  285. # We also add the full chain of unpersisted event IDs
  286. # referenced by this state set, so that we can work out the
  287. # auth chain difference of the unpersisted events.
  288. unpersisted_ids.update(e for e in event_chain if e in event_map)
  289. else:
  290. set_ids.add(event_id)
  291. # The auth chain difference of the unpersisted events of the state sets
  292. # is calculated by taking the difference between the union and
  293. # intersections.
  294. union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
  295. intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
  296. difference_from_event_map: Collection[str] = union - intersection
  297. else:
  298. difference_from_event_map = ()
  299. state_sets_ids = [set(state_set.values()) for state_set in state_sets]
  300. difference = await state_res_store.get_auth_chain_difference(
  301. room_id, state_sets_ids
  302. )
  303. difference.update(difference_from_event_map)
  304. return difference
  305. def _seperate(
  306. state_sets: Iterable[StateMap[str]],
  307. ) -> Tuple[StateMap[str], StateMap[Set[str]]]:
  308. """Return the unconflicted and conflicted state. This is different than in
  309. the original algorithm, as this defines a key to be conflicted if one of
  310. the state sets doesn't have that key.
  311. Args:
  312. state_sets
  313. Returns:
  314. A tuple of unconflicted and conflicted state. The conflicted state dict
  315. is a map from type/state_key to set of event IDs
  316. """
  317. unconflicted_state = {}
  318. conflicted_state = {}
  319. for key in set(itertools.chain.from_iterable(state_sets)):
  320. event_ids = {state_set.get(key) for state_set in state_sets}
  321. if len(event_ids) == 1:
  322. unconflicted_state[key] = event_ids.pop()
  323. else:
  324. event_ids.discard(None)
  325. conflicted_state[key] = event_ids
  326. # mypy doesn't understand that discarding None above means that conflicted
  327. # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
  328. return unconflicted_state, conflicted_state # type: ignore
  329. def _is_power_event(event: EventBase) -> bool:
  330. """Return whether or not the event is a "power event", as defined by the
  331. v2 state resolution algorithm
  332. Args:
  333. event
  334. Returns:
  335. True if the event is a power event.
  336. """
  337. if (event.type, event.state_key) in (
  338. (EventTypes.PowerLevels, ""),
  339. (EventTypes.JoinRules, ""),
  340. (EventTypes.Create, ""),
  341. ):
  342. return True
  343. if event.type == EventTypes.Member:
  344. if event.membership in ("leave", "ban"):
  345. return event.sender != event.state_key
  346. return False
  347. async def _add_event_and_auth_chain_to_graph(
  348. graph: Dict[str, Set[str]],
  349. room_id: str,
  350. event_id: str,
  351. event_map: Dict[str, EventBase],
  352. state_res_store: StateResolutionStore,
  353. auth_diff: Set[str],
  354. ) -> None:
  355. """Helper function for _reverse_topological_power_sort that add the event
  356. and its auth chain (that is in the auth diff) to the graph
  357. Args:
  358. graph: A map from event ID to the events auth event IDs
  359. room_id: the room we are working in
  360. event_id: Event to add to the graph
  361. event_map
  362. state_res_store
  363. auth_diff: Set of event IDs that are in the auth difference.
  364. """
  365. state = [event_id]
  366. while state:
  367. eid = state.pop()
  368. graph.setdefault(eid, set())
  369. event = await _get_event(room_id, eid, event_map, state_res_store)
  370. for aid in event.auth_event_ids():
  371. if aid in auth_diff:
  372. if aid not in graph:
  373. state.append(aid)
  374. graph.setdefault(eid, set()).add(aid)
  375. async def _reverse_topological_power_sort(
  376. clock: Clock,
  377. room_id: str,
  378. event_ids: Iterable[str],
  379. event_map: Dict[str, EventBase],
  380. state_res_store: StateResolutionStore,
  381. auth_diff: Set[str],
  382. ) -> List[str]:
  383. """Returns a list of the event_ids sorted by reverse topological ordering,
  384. and then by power level and origin_server_ts
  385. Args:
  386. clock
  387. room_id: the room we are working in
  388. event_ids: The events to sort
  389. event_map
  390. state_res_store
  391. auth_diff: Set of event IDs that are in the auth difference.
  392. Returns:
  393. The sorted list
  394. """
  395. graph: Dict[str, Set[str]] = {}
  396. for idx, event_id in enumerate(event_ids, start=1):
  397. await _add_event_and_auth_chain_to_graph(
  398. graph, room_id, event_id, event_map, state_res_store, auth_diff
  399. )
  400. # We await occasionally when we're working with large data sets to
  401. # ensure that we don't block the reactor loop for too long.
  402. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  403. await clock.sleep(0)
  404. event_to_pl = {}
  405. for idx, event_id in enumerate(graph, start=1):
  406. pl = await _get_power_level_for_sender(
  407. room_id, event_id, event_map, state_res_store
  408. )
  409. event_to_pl[event_id] = pl
  410. # We await occasionally when we're working with large data sets to
  411. # ensure that we don't block the reactor loop for too long.
  412. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  413. await clock.sleep(0)
  414. def _get_power_order(event_id: str) -> Tuple[int, int, str]:
  415. ev = event_map[event_id]
  416. pl = event_to_pl[event_id]
  417. return -pl, ev.origin_server_ts, event_id
  418. # Note: graph is modified during the sort
  419. it = lexicographical_topological_sort(graph, key=_get_power_order)
  420. sorted_events = list(it)
  421. return sorted_events
  422. async def _iterative_auth_checks(
  423. clock: Clock,
  424. room_id: str,
  425. room_version: RoomVersion,
  426. event_ids: List[str],
  427. base_state: StateMap[str],
  428. event_map: Dict[str, EventBase],
  429. state_res_store: StateResolutionStore,
  430. ) -> MutableStateMap[str]:
  431. """Sequentially apply auth checks to each event in given list, updating the
  432. state as it goes along.
  433. Args:
  434. clock
  435. room_id
  436. room_version
  437. event_ids: Ordered list of events to apply auth checks to
  438. base_state: The set of state to start with
  439. event_map
  440. state_res_store
  441. Returns:
  442. Returns the final updated state
  443. """
  444. resolved_state = dict(base_state)
  445. for idx, event_id in enumerate(event_ids, start=1):
  446. event = event_map[event_id]
  447. auth_events = {}
  448. for aid in event.auth_event_ids():
  449. ev = await _get_event(
  450. room_id, aid, event_map, state_res_store, allow_none=True
  451. )
  452. if not ev:
  453. logger.warning(
  454. "auth_event id %s for event %s is missing", aid, event_id
  455. )
  456. else:
  457. if ev.rejected_reason is None:
  458. auth_events[(ev.type, ev.state_key)] = ev
  459. for key in event_auth.auth_types_for_event(room_version, event):
  460. if key in resolved_state:
  461. ev_id = resolved_state[key]
  462. ev = await _get_event(room_id, ev_id, event_map, state_res_store)
  463. if ev.rejected_reason is None:
  464. auth_events[key] = event_map[ev_id]
  465. try:
  466. event_auth.check_state_dependent_auth_rules(
  467. event,
  468. auth_events.values(),
  469. )
  470. resolved_state[(event.type, event.state_key)] = event_id
  471. except AuthError:
  472. pass
  473. # We await occasionally when we're working with large data sets to
  474. # ensure that we don't block the reactor loop for too long.
  475. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  476. await clock.sleep(0)
  477. return resolved_state
  478. async def _mainline_sort(
  479. clock: Clock,
  480. room_id: str,
  481. event_ids: List[str],
  482. resolved_power_event_id: Optional[str],
  483. event_map: Dict[str, EventBase],
  484. state_res_store: StateResolutionStore,
  485. ) -> List[str]:
  486. """Returns a sorted list of event_ids sorted by mainline ordering based on
  487. the given event resolved_power_event_id
  488. Args:
  489. clock
  490. room_id: room we're working in
  491. event_ids: Events to sort
  492. resolved_power_event_id: The final resolved power level event ID
  493. event_map
  494. state_res_store
  495. Returns:
  496. The sorted list
  497. """
  498. if not event_ids:
  499. # It's possible for there to be no event IDs here to sort, so we can
  500. # skip calculating the mainline in that case.
  501. return []
  502. mainline = []
  503. pl = resolved_power_event_id
  504. idx = 0
  505. while pl:
  506. mainline.append(pl)
  507. pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
  508. auth_events = pl_ev.auth_event_ids()
  509. pl = None
  510. for aid in auth_events:
  511. ev = await _get_event(
  512. room_id, aid, event_map, state_res_store, allow_none=True
  513. )
  514. if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
  515. pl = aid
  516. break
  517. # We await occasionally when we're working with large data sets to
  518. # ensure that we don't block the reactor loop for too long.
  519. if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
  520. await clock.sleep(0)
  521. idx += 1
  522. mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
  523. event_ids = list(event_ids)
  524. order_map = {}
  525. for idx, ev_id in enumerate(event_ids, start=1):
  526. depth = await _get_mainline_depth_for_event(
  527. event_map[ev_id], mainline_map, event_map, state_res_store
  528. )
  529. order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
  530. # We await occasionally when we're working with large data sets to
  531. # ensure that we don't block the reactor loop for too long.
  532. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  533. await clock.sleep(0)
  534. event_ids.sort(key=lambda ev_id: order_map[ev_id])
  535. return event_ids
  536. async def _get_mainline_depth_for_event(
  537. event: EventBase,
  538. mainline_map: Dict[str, int],
  539. event_map: Dict[str, EventBase],
  540. state_res_store: StateResolutionStore,
  541. ) -> int:
  542. """Get the mainline depths for the given event based on the mainline map
  543. Args:
  544. event
  545. mainline_map: Map from event_id to mainline depth for events in the mainline.
  546. event_map
  547. state_res_store
  548. Returns:
  549. The mainline depth
  550. """
  551. room_id = event.room_id
  552. tmp_event: Optional[EventBase] = event
  553. # We do an iterative search, replacing `event with the power level in its
  554. # auth events (if any)
  555. while tmp_event:
  556. depth = mainline_map.get(tmp_event.event_id)
  557. if depth is not None:
  558. return depth
  559. auth_events = tmp_event.auth_event_ids()
  560. tmp_event = None
  561. for aid in auth_events:
  562. aev = await _get_event(
  563. room_id, aid, event_map, state_res_store, allow_none=True
  564. )
  565. if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
  566. tmp_event = aev
  567. break
  568. # Didn't find a power level auth event, so we just return 0
  569. return 0
  570. @overload
  571. async def _get_event(
  572. room_id: str,
  573. event_id: str,
  574. event_map: Dict[str, EventBase],
  575. state_res_store: StateResolutionStore,
  576. allow_none: Literal[False] = False,
  577. ) -> EventBase:
  578. ...
  579. @overload
  580. async def _get_event(
  581. room_id: str,
  582. event_id: str,
  583. event_map: Dict[str, EventBase],
  584. state_res_store: StateResolutionStore,
  585. allow_none: Literal[True],
  586. ) -> Optional[EventBase]:
  587. ...
  588. async def _get_event(
  589. room_id: str,
  590. event_id: str,
  591. event_map: Dict[str, EventBase],
  592. state_res_store: StateResolutionStore,
  593. allow_none: bool = False,
  594. ) -> Optional[EventBase]:
  595. """Helper function to look up event in event_map, falling back to looking
  596. it up in the store
  597. Args:
  598. room_id
  599. event_id
  600. event_map
  601. state_res_store
  602. allow_none: if the event is not found, return None rather than raising
  603. an exception
  604. Returns:
  605. The event, or none if the event does not exist (and allow_none is True).
  606. """
  607. if event_id not in event_map:
  608. events = await state_res_store.get_events([event_id], allow_rejected=True)
  609. event_map.update(events)
  610. event = event_map.get(event_id)
  611. if event is None:
  612. if allow_none:
  613. return None
  614. raise Exception("Unknown event %s" % (event_id,))
  615. if event.room_id != room_id:
  616. raise Exception(
  617. "In state res for room %s, event %s is in %s"
  618. % (room_id, event_id, event.room_id)
  619. )
  620. return event
  621. def lexicographical_topological_sort(
  622. graph: Dict[str, Set[str]], key: Callable[[str], Any]
  623. ) -> Generator[str, None, None]:
  624. """Performs a lexicographic reverse topological sort on the graph.
  625. This returns a reverse topological sort (i.e. if node A references B then B
  626. appears before A in the sort), with ties broken lexicographically based on
  627. return value of the `key` function.
  628. NOTE: `graph` is modified during the sort.
  629. Args:
  630. graph: A representation of the graph where each node is a key in the
  631. dict and its value are the nodes edges.
  632. key: A function that takes a node and returns a value that is comparable
  633. and used to order nodes
  634. Yields:
  635. The next node in the topological sort
  636. """
  637. # Note, this is basically Kahn's algorithm except we look at nodes with no
  638. # outgoing edges, c.f.
  639. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
  640. outdegree_map = graph
  641. reverse_graph: Dict[str, Set[str]] = {}
  642. # Lists of nodes with zero out degree. Is actually a tuple of
  643. # `(key(node), node)` so that sorting does the right thing
  644. zero_outdegree = []
  645. for node, edges in graph.items():
  646. if len(edges) == 0:
  647. zero_outdegree.append((key(node), node))
  648. reverse_graph.setdefault(node, set())
  649. for edge in edges:
  650. reverse_graph.setdefault(edge, set()).add(node)
  651. # heapq is a built in implementation of a sorted queue.
  652. heapq.heapify(zero_outdegree)
  653. while zero_outdegree:
  654. _, node = heapq.heappop(zero_outdegree)
  655. for parent in reverse_graph[node]:
  656. out = outdegree_map[parent]
  657. out.discard(node)
  658. if len(out) == 0:
  659. heapq.heappush(zero_outdegree, (key(parent), parent))
  660. yield node