v2.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793
  1. # Copyright 2018 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import heapq
  15. import itertools
  16. import logging
  17. from typing import (
  18. Any,
  19. Callable,
  20. Collection,
  21. Dict,
  22. Generator,
  23. Iterable,
  24. List,
  25. Optional,
  26. Sequence,
  27. Set,
  28. Tuple,
  29. overload,
  30. )
  31. from typing_extensions import Literal
  32. import synapse.state
  33. from synapse import event_auth
  34. from synapse.api.constants import EventTypes
  35. from synapse.api.errors import AuthError
  36. from synapse.api.room_versions import RoomVersion
  37. from synapse.events import EventBase
  38. from synapse.types import MutableStateMap, StateMap
  39. from synapse.util import Clock
  40. logger = logging.getLogger(__name__)
  41. # We want to await to the reactor occasionally during state res when dealing
  42. # with large data sets, so that we don't exhaust the reactor. This is done by
  43. # awaiting to reactor during loops every N iterations.
  44. _AWAIT_AFTER_ITERATIONS = 100
  45. async def resolve_events_with_store(
  46. clock: Clock,
  47. room_id: str,
  48. room_version: RoomVersion,
  49. state_sets: Sequence[StateMap[str]],
  50. event_map: Optional[Dict[str, EventBase]],
  51. state_res_store: "synapse.state.StateResolutionStore",
  52. ) -> StateMap[str]:
  53. """Resolves the state using the v2 state resolution algorithm
  54. Args:
  55. clock
  56. room_id: the room we are working in
  57. room_version: The room version
  58. state_sets: List of dicts of (type, state_key) -> event_id,
  59. which are the different state groups to resolve.
  60. event_map:
  61. a dict from event_id to event, for any events that we happen to
  62. have in flight (eg, those currently being persisted). This will be
  63. used as a starting point for finding the state we need; any missing
  64. events will be requested via state_res_store.
  65. If None, all events will be fetched via state_res_store.
  66. state_res_store:
  67. Returns:
  68. A map from (type, state_key) to event_id.
  69. """
  70. logger.debug("Computing conflicted state")
  71. # We use event_map as a cache, so if its None we need to initialize it
  72. if event_map is None:
  73. event_map = {}
  74. # First split up the un/conflicted state
  75. unconflicted_state, conflicted_state = _seperate(state_sets)
  76. if not conflicted_state:
  77. return unconflicted_state
  78. logger.debug("%d conflicted state entries", len(conflicted_state))
  79. logger.debug("Calculating auth chain difference")
  80. # Also fetch all auth events that appear in only some of the state sets'
  81. # auth chains.
  82. auth_diff = await _get_auth_chain_difference(
  83. room_id, state_sets, event_map, state_res_store
  84. )
  85. full_conflicted_set = set(
  86. itertools.chain(
  87. itertools.chain.from_iterable(conflicted_state.values()), auth_diff
  88. )
  89. )
  90. events = await state_res_store.get_events(
  91. [eid for eid in full_conflicted_set if eid not in event_map],
  92. allow_rejected=True,
  93. )
  94. event_map.update(events)
  95. # everything in the event map should be in the right room
  96. for event in event_map.values():
  97. if event.room_id != room_id:
  98. raise Exception(
  99. "Attempting to state-resolve for room %s with event %s which is in %s"
  100. % (
  101. room_id,
  102. event.event_id,
  103. event.room_id,
  104. )
  105. )
  106. full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
  107. logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
  108. # Get and sort all the power events (kicks/bans/etc)
  109. power_events = (
  110. eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
  111. )
  112. sorted_power_events = await _reverse_topological_power_sort(
  113. clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
  114. )
  115. logger.debug("sorted %d power events", len(sorted_power_events))
  116. # Now sequentially auth each one
  117. resolved_state = await _iterative_auth_checks(
  118. clock,
  119. room_id,
  120. room_version,
  121. sorted_power_events,
  122. unconflicted_state,
  123. event_map,
  124. state_res_store,
  125. )
  126. logger.debug("resolved power events")
  127. # OK, so we've now resolved the power events. Now sort the remaining
  128. # events using the mainline of the resolved power level.
  129. set_power_events = set(sorted_power_events)
  130. leftover_events = [
  131. ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events
  132. ]
  133. logger.debug("sorting %d remaining events", len(leftover_events))
  134. pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
  135. leftover_events = await _mainline_sort(
  136. clock, room_id, leftover_events, pl, event_map, state_res_store
  137. )
  138. logger.debug("resolving remaining events")
  139. resolved_state = await _iterative_auth_checks(
  140. clock,
  141. room_id,
  142. room_version,
  143. leftover_events,
  144. resolved_state,
  145. event_map,
  146. state_res_store,
  147. )
  148. logger.debug("resolved")
  149. # We make sure that unconflicted state always still applies.
  150. resolved_state.update(unconflicted_state)
  151. logger.debug("done")
  152. return resolved_state
  153. async def _get_power_level_for_sender(
  154. room_id: str,
  155. event_id: str,
  156. event_map: Dict[str, EventBase],
  157. state_res_store: "synapse.state.StateResolutionStore",
  158. ) -> int:
  159. """Return the power level of the sender of the given event according to
  160. their auth events.
  161. Args:
  162. room_id
  163. event_id
  164. event_map
  165. state_res_store
  166. Returns:
  167. The power level.
  168. """
  169. event = await _get_event(room_id, event_id, event_map, state_res_store)
  170. pl = None
  171. for aid in event.auth_event_ids():
  172. aev = await _get_event(
  173. room_id, aid, event_map, state_res_store, allow_none=True
  174. )
  175. if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
  176. pl = aev
  177. break
  178. if pl is None:
  179. # Couldn't find power level. Check if they're the creator of the room
  180. for aid in event.auth_event_ids():
  181. aev = await _get_event(
  182. room_id, aid, event_map, state_res_store, allow_none=True
  183. )
  184. if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
  185. if aev.content.get("creator") == event.sender:
  186. return 100
  187. break
  188. return 0
  189. level = pl.content.get("users", {}).get(event.sender)
  190. if level is None:
  191. level = pl.content.get("users_default", 0)
  192. if level is None:
  193. return 0
  194. else:
  195. return int(level)
  196. async def _get_auth_chain_difference(
  197. room_id: str,
  198. state_sets: Sequence[StateMap[str]],
  199. event_map: Dict[str, EventBase],
  200. state_res_store: "synapse.state.StateResolutionStore",
  201. ) -> Set[str]:
  202. """Compare the auth chains of each state set and return the set of events
  203. that only appear in some but not all of the auth chains.
  204. Args:
  205. state_sets
  206. event_map
  207. state_res_store
  208. Returns:
  209. Set of event IDs
  210. """
  211. # The `StateResolutionStore.get_auth_chain_difference` function assumes that
  212. # all events passed to it (and their auth chains) have been persisted
  213. # previously. This is not the case for any events in the `event_map`, and so
  214. # we need to manually handle those events.
  215. #
  216. # We do this by:
  217. # 1. calculating the auth chain difference for the state sets based on the
  218. # events in `event_map` alone
  219. # 2. replacing any events in the state_sets that are also in `event_map`
  220. # with their auth events (recursively), and then calling
  221. # `store.get_auth_chain_difference` as normal
  222. # 3. adding the results of 1 and 2 together.
  223. # Map from event ID in `event_map` to their auth event IDs, and their auth
  224. # event IDs if they appear in the `event_map`. This is the intersection of
  225. # the event's auth chain with the events in the `event_map` *plus* their
  226. # auth event IDs.
  227. events_to_auth_chain: Dict[str, Set[str]] = {}
  228. for event in event_map.values():
  229. chain = {event.event_id}
  230. events_to_auth_chain[event.event_id] = chain
  231. to_search = [event]
  232. while to_search:
  233. for auth_id in to_search.pop().auth_event_ids():
  234. chain.add(auth_id)
  235. auth_event = event_map.get(auth_id)
  236. if auth_event:
  237. to_search.append(auth_event)
  238. # We now a) calculate the auth chain difference for the unpersisted events
  239. # and b) work out the state sets to pass to the store.
  240. #
  241. # Note: If the `event_map` is empty (which is the common case), we can do a
  242. # much simpler calculation.
  243. if event_map:
  244. # The list of state sets to pass to the store, where each state set is a set
  245. # of the event ids making up the state. This is similar to `state_sets`,
  246. # except that (a) we only have event ids, not the complete
  247. # ((type, state_key)->event_id) mappings; and (b) we have stripped out
  248. # unpersisted events and replaced them with the persisted events in
  249. # their auth chain.
  250. state_sets_ids: List[Set[str]] = []
  251. # For each state set, the unpersisted event IDs reachable (by their auth
  252. # chain) from the events in that set.
  253. unpersisted_set_ids: List[Set[str]] = []
  254. for state_set in state_sets:
  255. set_ids: Set[str] = set()
  256. state_sets_ids.append(set_ids)
  257. unpersisted_ids: Set[str] = set()
  258. unpersisted_set_ids.append(unpersisted_ids)
  259. for event_id in state_set.values():
  260. event_chain = events_to_auth_chain.get(event_id)
  261. if event_chain is not None:
  262. # We have an event in `event_map`. We add all the auth
  263. # events that it references (that aren't also in `event_map`).
  264. set_ids.update(e for e in event_chain if e not in event_map)
  265. # We also add the full chain of unpersisted event IDs
  266. # referenced by this state set, so that we can work out the
  267. # auth chain difference of the unpersisted events.
  268. unpersisted_ids.update(e for e in event_chain if e in event_map)
  269. else:
  270. set_ids.add(event_id)
  271. # The auth chain difference of the unpersisted events of the state sets
  272. # is calculated by taking the difference between the union and
  273. # intersections.
  274. union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
  275. intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
  276. difference_from_event_map: Collection[str] = union - intersection
  277. else:
  278. difference_from_event_map = ()
  279. state_sets_ids = [set(state_set.values()) for state_set in state_sets]
  280. difference = await state_res_store.get_auth_chain_difference(
  281. room_id, state_sets_ids
  282. )
  283. difference.update(difference_from_event_map)
  284. return difference
  285. def _seperate(
  286. state_sets: Iterable[StateMap[str]],
  287. ) -> Tuple[StateMap[str], StateMap[Set[str]]]:
  288. """Return the unconflicted and conflicted state. This is different than in
  289. the original algorithm, as this defines a key to be conflicted if one of
  290. the state sets doesn't have that key.
  291. Args:
  292. state_sets
  293. Returns:
  294. A tuple of unconflicted and conflicted state. The conflicted state dict
  295. is a map from type/state_key to set of event IDs
  296. """
  297. unconflicted_state = {}
  298. conflicted_state = {}
  299. for key in set(itertools.chain.from_iterable(state_sets)):
  300. event_ids = {state_set.get(key) for state_set in state_sets}
  301. if len(event_ids) == 1:
  302. unconflicted_state[key] = event_ids.pop()
  303. else:
  304. event_ids.discard(None)
  305. conflicted_state[key] = event_ids
  306. # mypy doesn't understand that discarding None above means that conflicted
  307. # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
  308. return unconflicted_state, conflicted_state # type: ignore
  309. def _is_power_event(event: EventBase) -> bool:
  310. """Return whether or not the event is a "power event", as defined by the
  311. v2 state resolution algorithm
  312. Args:
  313. event
  314. Returns:
  315. True if the event is a power event.
  316. """
  317. if (event.type, event.state_key) in (
  318. (EventTypes.PowerLevels, ""),
  319. (EventTypes.JoinRules, ""),
  320. (EventTypes.Create, ""),
  321. ):
  322. return True
  323. if event.type == EventTypes.Member:
  324. if event.membership in ("leave", "ban"):
  325. return event.sender != event.state_key
  326. return False
  327. async def _add_event_and_auth_chain_to_graph(
  328. graph: Dict[str, Set[str]],
  329. room_id: str,
  330. event_id: str,
  331. event_map: Dict[str, EventBase],
  332. state_res_store: "synapse.state.StateResolutionStore",
  333. auth_diff: Set[str],
  334. ) -> None:
  335. """Helper function for _reverse_topological_power_sort that add the event
  336. and its auth chain (that is in the auth diff) to the graph
  337. Args:
  338. graph: A map from event ID to the events auth event IDs
  339. room_id: the room we are working in
  340. event_id: Event to add to the graph
  341. event_map
  342. state_res_store
  343. auth_diff: Set of event IDs that are in the auth difference.
  344. """
  345. state = [event_id]
  346. while state:
  347. eid = state.pop()
  348. graph.setdefault(eid, set())
  349. event = await _get_event(room_id, eid, event_map, state_res_store)
  350. for aid in event.auth_event_ids():
  351. if aid in auth_diff:
  352. if aid not in graph:
  353. state.append(aid)
  354. graph.setdefault(eid, set()).add(aid)
  355. async def _reverse_topological_power_sort(
  356. clock: Clock,
  357. room_id: str,
  358. event_ids: Iterable[str],
  359. event_map: Dict[str, EventBase],
  360. state_res_store: "synapse.state.StateResolutionStore",
  361. auth_diff: Set[str],
  362. ) -> List[str]:
  363. """Returns a list of the event_ids sorted by reverse topological ordering,
  364. and then by power level and origin_server_ts
  365. Args:
  366. clock
  367. room_id: the room we are working in
  368. event_ids: The events to sort
  369. event_map
  370. state_res_store
  371. auth_diff: Set of event IDs that are in the auth difference.
  372. Returns:
  373. The sorted list
  374. """
  375. graph: Dict[str, Set[str]] = {}
  376. for idx, event_id in enumerate(event_ids, start=1):
  377. await _add_event_and_auth_chain_to_graph(
  378. graph, room_id, event_id, event_map, state_res_store, auth_diff
  379. )
  380. # We await occasionally when we're working with large data sets to
  381. # ensure that we don't block the reactor loop for too long.
  382. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  383. await clock.sleep(0)
  384. event_to_pl = {}
  385. for idx, event_id in enumerate(graph, start=1):
  386. pl = await _get_power_level_for_sender(
  387. room_id, event_id, event_map, state_res_store
  388. )
  389. event_to_pl[event_id] = pl
  390. # We await occasionally when we're working with large data sets to
  391. # ensure that we don't block the reactor loop for too long.
  392. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  393. await clock.sleep(0)
  394. def _get_power_order(event_id: str) -> Tuple[int, int, str]:
  395. ev = event_map[event_id]
  396. pl = event_to_pl[event_id]
  397. return -pl, ev.origin_server_ts, event_id
  398. # Note: graph is modified during the sort
  399. it = lexicographical_topological_sort(graph, key=_get_power_order)
  400. sorted_events = list(it)
  401. return sorted_events
  402. async def _iterative_auth_checks(
  403. clock: Clock,
  404. room_id: str,
  405. room_version: RoomVersion,
  406. event_ids: List[str],
  407. base_state: StateMap[str],
  408. event_map: Dict[str, EventBase],
  409. state_res_store: "synapse.state.StateResolutionStore",
  410. ) -> MutableStateMap[str]:
  411. """Sequentially apply auth checks to each event in given list, updating the
  412. state as it goes along.
  413. Args:
  414. clock
  415. room_id
  416. room_version
  417. event_ids: Ordered list of events to apply auth checks to
  418. base_state: The set of state to start with
  419. event_map
  420. state_res_store
  421. Returns:
  422. Returns the final updated state
  423. """
  424. resolved_state = dict(base_state)
  425. for idx, event_id in enumerate(event_ids, start=1):
  426. event = event_map[event_id]
  427. auth_events = {}
  428. for aid in event.auth_event_ids():
  429. ev = await _get_event(
  430. room_id, aid, event_map, state_res_store, allow_none=True
  431. )
  432. if not ev:
  433. logger.warning(
  434. "auth_event id %s for event %s is missing", aid, event_id
  435. )
  436. else:
  437. if ev.rejected_reason is None:
  438. auth_events[(ev.type, ev.state_key)] = ev
  439. for key in event_auth.auth_types_for_event(room_version, event):
  440. if key in resolved_state:
  441. ev_id = resolved_state[key]
  442. ev = await _get_event(room_id, ev_id, event_map, state_res_store)
  443. if ev.rejected_reason is None:
  444. auth_events[key] = event_map[ev_id]
  445. try:
  446. event_auth.check_auth_rules_for_event(
  447. room_version,
  448. event,
  449. auth_events.values(),
  450. )
  451. resolved_state[(event.type, event.state_key)] = event_id
  452. except AuthError:
  453. pass
  454. # We await occasionally when we're working with large data sets to
  455. # ensure that we don't block the reactor loop for too long.
  456. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  457. await clock.sleep(0)
  458. return resolved_state
  459. async def _mainline_sort(
  460. clock: Clock,
  461. room_id: str,
  462. event_ids: List[str],
  463. resolved_power_event_id: Optional[str],
  464. event_map: Dict[str, EventBase],
  465. state_res_store: "synapse.state.StateResolutionStore",
  466. ) -> List[str]:
  467. """Returns a sorted list of event_ids sorted by mainline ordering based on
  468. the given event resolved_power_event_id
  469. Args:
  470. clock
  471. room_id: room we're working in
  472. event_ids: Events to sort
  473. resolved_power_event_id: The final resolved power level event ID
  474. event_map
  475. state_res_store
  476. Returns:
  477. The sorted list
  478. """
  479. if not event_ids:
  480. # It's possible for there to be no event IDs here to sort, so we can
  481. # skip calculating the mainline in that case.
  482. return []
  483. mainline = []
  484. pl = resolved_power_event_id
  485. idx = 0
  486. while pl:
  487. mainline.append(pl)
  488. pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
  489. auth_events = pl_ev.auth_event_ids()
  490. pl = None
  491. for aid in auth_events:
  492. ev = await _get_event(
  493. room_id, aid, event_map, state_res_store, allow_none=True
  494. )
  495. if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
  496. pl = aid
  497. break
  498. # We await occasionally when we're working with large data sets to
  499. # ensure that we don't block the reactor loop for too long.
  500. if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
  501. await clock.sleep(0)
  502. idx += 1
  503. mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
  504. event_ids = list(event_ids)
  505. order_map = {}
  506. for idx, ev_id in enumerate(event_ids, start=1):
  507. depth = await _get_mainline_depth_for_event(
  508. event_map[ev_id], mainline_map, event_map, state_res_store
  509. )
  510. order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
  511. # We await occasionally when we're working with large data sets to
  512. # ensure that we don't block the reactor loop for too long.
  513. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  514. await clock.sleep(0)
  515. event_ids.sort(key=lambda ev_id: order_map[ev_id])
  516. return event_ids
  517. async def _get_mainline_depth_for_event(
  518. event: EventBase,
  519. mainline_map: Dict[str, int],
  520. event_map: Dict[str, EventBase],
  521. state_res_store: "synapse.state.StateResolutionStore",
  522. ) -> int:
  523. """Get the mainline depths for the given event based on the mainline map
  524. Args:
  525. event
  526. mainline_map: Map from event_id to mainline depth for events in the mainline.
  527. event_map
  528. state_res_store
  529. Returns:
  530. The mainline depth
  531. """
  532. room_id = event.room_id
  533. tmp_event: Optional[EventBase] = event
  534. # We do an iterative search, replacing `event with the power level in its
  535. # auth events (if any)
  536. while tmp_event:
  537. depth = mainline_map.get(tmp_event.event_id)
  538. if depth is not None:
  539. return depth
  540. auth_events = tmp_event.auth_event_ids()
  541. tmp_event = None
  542. for aid in auth_events:
  543. aev = await _get_event(
  544. room_id, aid, event_map, state_res_store, allow_none=True
  545. )
  546. if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
  547. tmp_event = aev
  548. break
  549. # Didn't find a power level auth event, so we just return 0
  550. return 0
  551. @overload
  552. async def _get_event(
  553. room_id: str,
  554. event_id: str,
  555. event_map: Dict[str, EventBase],
  556. state_res_store: "synapse.state.StateResolutionStore",
  557. allow_none: Literal[False] = False,
  558. ) -> EventBase:
  559. ...
  560. @overload
  561. async def _get_event(
  562. room_id: str,
  563. event_id: str,
  564. event_map: Dict[str, EventBase],
  565. state_res_store: "synapse.state.StateResolutionStore",
  566. allow_none: Literal[True],
  567. ) -> Optional[EventBase]:
  568. ...
  569. async def _get_event(
  570. room_id: str,
  571. event_id: str,
  572. event_map: Dict[str, EventBase],
  573. state_res_store: "synapse.state.StateResolutionStore",
  574. allow_none: bool = False,
  575. ) -> Optional[EventBase]:
  576. """Helper function to look up event in event_map, falling back to looking
  577. it up in the store
  578. Args:
  579. room_id
  580. event_id
  581. event_map
  582. state_res_store
  583. allow_none: if the event is not found, return None rather than raising
  584. an exception
  585. Returns:
  586. The event, or none if the event does not exist (and allow_none is True).
  587. """
  588. if event_id not in event_map:
  589. events = await state_res_store.get_events([event_id], allow_rejected=True)
  590. event_map.update(events)
  591. event = event_map.get(event_id)
  592. if event is None:
  593. if allow_none:
  594. return None
  595. raise Exception("Unknown event %s" % (event_id,))
  596. if event.room_id != room_id:
  597. raise Exception(
  598. "In state res for room %s, event %s is in %s"
  599. % (room_id, event_id, event.room_id)
  600. )
  601. return event
  602. def lexicographical_topological_sort(
  603. graph: Dict[str, Set[str]], key: Callable[[str], Any]
  604. ) -> Generator[str, None, None]:
  605. """Performs a lexicographic reverse topological sort on the graph.
  606. This returns a reverse topological sort (i.e. if node A references B then B
  607. appears before A in the sort), with ties broken lexicographically based on
  608. return value of the `key` function.
  609. NOTE: `graph` is modified during the sort.
  610. Args:
  611. graph: A representation of the graph where each node is a key in the
  612. dict and its value are the nodes edges.
  613. key: A function that takes a node and returns a value that is comparable
  614. and used to order nodes
  615. Yields:
  616. The next node in the topological sort
  617. """
  618. # Note, this is basically Kahn's algorithm except we look at nodes with no
  619. # outgoing edges, c.f.
  620. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
  621. outdegree_map = graph
  622. reverse_graph: Dict[str, Set[str]] = {}
  623. # Lists of nodes with zero out degree. Is actually a tuple of
  624. # `(key(node), node)` so that sorting does the right thing
  625. zero_outdegree = []
  626. for node, edges in graph.items():
  627. if len(edges) == 0:
  628. zero_outdegree.append((key(node), node))
  629. reverse_graph.setdefault(node, set())
  630. for edge in edges:
  631. reverse_graph.setdefault(edge, set()).add(node)
  632. # heapq is a built in implementation of a sorted queue.
  633. heapq.heapify(zero_outdegree)
  634. while zero_outdegree:
  635. _, node = heapq.heappop(zero_outdegree)
  636. for parent in reverse_graph[node]:
  637. out = outdegree_map[parent]
  638. out.discard(node)
  639. if len(out) == 0:
  640. heapq.heappush(zero_outdegree, (key(parent), parent))
  641. yield node