v2.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839
  1. # Copyright 2018 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import heapq
  15. import itertools
  16. import logging
  17. from typing import (
  18. Any,
  19. Awaitable,
  20. Callable,
  21. Collection,
  22. Dict,
  23. Generator,
  24. Iterable,
  25. List,
  26. Mapping,
  27. Optional,
  28. Sequence,
  29. Set,
  30. Tuple,
  31. overload,
  32. )
  33. from typing_extensions import Literal, Protocol
  34. from synapse import event_auth
  35. from synapse.api.constants import EventTypes
  36. from synapse.api.errors import AuthError
  37. from synapse.api.room_versions import RoomVersion
  38. from synapse.events import EventBase
  39. from synapse.types import MutableStateMap, StateMap
  40. logger = logging.getLogger(__name__)
  41. class Clock(Protocol):
  42. # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
  43. # We only ever sleep(0) though, so that other async functions can make forward
  44. # progress without waiting for stateres to complete.
  45. def sleep(self, duration_ms: float) -> Awaitable[None]:
  46. ...
  47. class StateResolutionStore(Protocol):
  48. # This is usually synapse.state.StateResolutionStore, but it's replaced with a
  49. # TestStateResolutionStore in tests.
  50. def get_events(
  51. self, event_ids: Collection[str], allow_rejected: bool = False
  52. ) -> Awaitable[Dict[str, EventBase]]:
  53. ...
  54. def get_auth_chain_difference(
  55. self, room_id: str, state_sets: List[Set[str]]
  56. ) -> Awaitable[Set[str]]:
  57. ...
  58. # We want to await to the reactor occasionally during state res when dealing
  59. # with large data sets, so that we don't exhaust the reactor. This is done by
  60. # awaiting to reactor during loops every N iterations.
  61. _AWAIT_AFTER_ITERATIONS = 100
  62. __all__ = [
  63. "resolve_events_with_store",
  64. ]
  65. async def resolve_events_with_store(
  66. clock: Clock,
  67. room_id: str,
  68. room_version: RoomVersion,
  69. state_sets: Sequence[StateMap[str]],
  70. event_map: Optional[Dict[str, EventBase]],
  71. state_res_store: StateResolutionStore,
  72. ) -> StateMap[str]:
  73. """Resolves the state using the v2 state resolution algorithm
  74. Args:
  75. clock
  76. room_id: the room we are working in
  77. room_version: The room version
  78. state_sets: List of dicts of (type, state_key) -> event_id,
  79. which are the different state groups to resolve.
  80. event_map:
  81. a dict from event_id to event, for any events that we happen to
  82. have in flight (eg, those currently being persisted). This will be
  83. used as a starting point for finding the state we need; any missing
  84. events will be requested via state_res_store.
  85. If None, all events will be fetched via state_res_store.
  86. state_res_store:
  87. Returns:
  88. A map from (type, state_key) to event_id.
  89. """
  90. logger.debug("Computing conflicted state")
  91. # We use event_map as a cache, so if its None we need to initialize it
  92. if event_map is None:
  93. event_map = {}
  94. # First split up the un/conflicted state
  95. unconflicted_state, conflicted_state = _seperate(state_sets)
  96. if not conflicted_state:
  97. return unconflicted_state
  98. logger.debug("%d conflicted state entries", len(conflicted_state))
  99. logger.debug("Calculating auth chain difference")
  100. # Also fetch all auth events that appear in only some of the state sets'
  101. # auth chains.
  102. auth_diff = await _get_auth_chain_difference(
  103. room_id, state_sets, event_map, state_res_store
  104. )
  105. full_conflicted_set = set(
  106. itertools.chain(
  107. itertools.chain.from_iterable(conflicted_state.values()), auth_diff
  108. )
  109. )
  110. events = await state_res_store.get_events(
  111. [eid for eid in full_conflicted_set if eid not in event_map],
  112. allow_rejected=True,
  113. )
  114. event_map.update(events)
  115. # everything in the event map should be in the right room
  116. for event in event_map.values():
  117. if event.room_id != room_id:
  118. raise Exception(
  119. "Attempting to state-resolve for room %s with event %s which is in %s"
  120. % (
  121. room_id,
  122. event.event_id,
  123. event.room_id,
  124. )
  125. )
  126. full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
  127. logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
  128. # Get and sort all the power events (kicks/bans/etc)
  129. power_events = (
  130. eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
  131. )
  132. sorted_power_events = await _reverse_topological_power_sort(
  133. clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
  134. )
  135. logger.debug("sorted %d power events", len(sorted_power_events))
  136. # Now sequentially auth each one
  137. resolved_state = await _iterative_auth_checks(
  138. clock,
  139. room_id,
  140. room_version,
  141. sorted_power_events,
  142. unconflicted_state,
  143. event_map,
  144. state_res_store,
  145. )
  146. logger.debug("resolved power events")
  147. # OK, so we've now resolved the power events. Now sort the remaining
  148. # events using the mainline of the resolved power level.
  149. set_power_events = set(sorted_power_events)
  150. leftover_events = [
  151. ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events
  152. ]
  153. logger.debug("sorting %d remaining events", len(leftover_events))
  154. pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
  155. leftover_events = await _mainline_sort(
  156. clock, room_id, leftover_events, pl, event_map, state_res_store
  157. )
  158. logger.debug("resolving remaining events")
  159. resolved_state = await _iterative_auth_checks(
  160. clock,
  161. room_id,
  162. room_version,
  163. leftover_events,
  164. resolved_state,
  165. event_map,
  166. state_res_store,
  167. )
  168. logger.debug("resolved")
  169. # We make sure that unconflicted state always still applies.
  170. resolved_state.update(unconflicted_state)
  171. logger.debug("done")
  172. return resolved_state
  173. async def _get_power_level_for_sender(
  174. room_id: str,
  175. event_id: str,
  176. event_map: Dict[str, EventBase],
  177. state_res_store: StateResolutionStore,
  178. ) -> int:
  179. """Return the power level of the sender of the given event according to
  180. their auth events.
  181. Args:
  182. room_id
  183. event_id
  184. event_map
  185. state_res_store
  186. Returns:
  187. The power level.
  188. """
  189. event = await _get_event(room_id, event_id, event_map, state_res_store)
  190. pl = None
  191. for aid in event.auth_event_ids():
  192. aev = await _get_event(
  193. room_id, aid, event_map, state_res_store, allow_none=True
  194. )
  195. if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
  196. pl = aev
  197. break
  198. if pl is None:
  199. # Couldn't find power level. Check if they're the creator of the room
  200. for aid in event.auth_event_ids():
  201. aev = await _get_event(
  202. room_id, aid, event_map, state_res_store, allow_none=True
  203. )
  204. if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
  205. if aev.content.get("creator") == event.sender:
  206. return 100
  207. break
  208. return 0
  209. level = pl.content.get("users", {}).get(event.sender)
  210. if level is None:
  211. level = pl.content.get("users_default", 0)
  212. if level is None:
  213. return 0
  214. else:
  215. return int(level)
  216. async def _get_auth_chain_difference(
  217. room_id: str,
  218. state_sets: Sequence[Mapping[Any, str]],
  219. unpersisted_events: Dict[str, EventBase],
  220. state_res_store: StateResolutionStore,
  221. ) -> Set[str]:
  222. """Compare the auth chains of each state set and return the set of events
  223. that only appear in some, but not all of the auth chains.
  224. Args:
  225. state_sets: The input state sets we are trying to resolve across.
  226. unpersisted_events: A map from event ID to EventBase containing all unpersisted
  227. events involved in this resolution.
  228. state_res_store:
  229. Returns:
  230. The auth difference of the given state sets, as a set of event IDs.
  231. """
  232. # The `StateResolutionStore.get_auth_chain_difference` function assumes that
  233. # all events passed to it (and their auth chains) have been persisted
  234. # previously. We need to manually handle any other events that are yet to be
  235. # persisted.
  236. #
  237. # We do this in three steps:
  238. # 1. Compute the set of unpersisted events belonging to the auth difference.
  239. # 2. Replacing any unpersisted events in the state_sets with their auth events,
  240. # recursively, until the state_sets contain only persisted events.
  241. # Then we call `store.get_auth_chain_difference` as normal, which computes
  242. # the set of persisted events belonging to the auth difference.
  243. # 3. Adding the results of 1 and 2 together.
  244. # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth
  245. # event IDs if they appear in the `unpersisted_events`. This is the intersection of
  246. # the event's auth chain with the events in `unpersisted_events` *plus* their
  247. # auth event IDs.
  248. events_to_auth_chain: Dict[str, Set[str]] = {}
  249. for event in unpersisted_events.values():
  250. chain = {event.event_id}
  251. events_to_auth_chain[event.event_id] = chain
  252. to_search = [event]
  253. while to_search:
  254. for auth_id in to_search.pop().auth_event_ids():
  255. chain.add(auth_id)
  256. auth_event = unpersisted_events.get(auth_id)
  257. if auth_event:
  258. to_search.append(auth_event)
  259. # We now 1) calculate the auth chain difference for the unpersisted events
  260. # and 2) work out the state sets to pass to the store.
  261. #
  262. # Note: If there are no `unpersisted_events` (which is the common case), we can do a
  263. # much simpler calculation.
  264. if unpersisted_events:
  265. # The list of state sets to pass to the store, where each state set is a set
  266. # of the event ids making up the state. This is similar to `state_sets`,
  267. # except that (a) we only have event ids, not the complete
  268. # ((type, state_key)->event_id) mappings; and (b) we have stripped out
  269. # unpersisted events and replaced them with the persisted events in
  270. # their auth chain.
  271. state_sets_ids: List[Set[str]] = []
  272. # For each state set, the unpersisted event IDs reachable (by their auth
  273. # chain) from the events in that set.
  274. unpersisted_set_ids: List[Set[str]] = []
  275. for state_set in state_sets:
  276. set_ids: Set[str] = set()
  277. state_sets_ids.append(set_ids)
  278. unpersisted_ids: Set[str] = set()
  279. unpersisted_set_ids.append(unpersisted_ids)
  280. for event_id in state_set.values():
  281. event_chain = events_to_auth_chain.get(event_id)
  282. if event_chain is not None:
  283. # We have an unpersisted event. We add all the auth
  284. # events that it references which are also unpersisted.
  285. set_ids.update(
  286. e for e in event_chain if e not in unpersisted_events
  287. )
  288. # We also add the full chain of unpersisted event IDs
  289. # referenced by this state set, so that we can work out the
  290. # auth chain difference of the unpersisted events.
  291. unpersisted_ids.update(
  292. e for e in event_chain if e in unpersisted_events
  293. )
  294. else:
  295. set_ids.add(event_id)
  296. # The auth chain difference of the unpersisted events of the state sets
  297. # is calculated by taking the difference between the union and
  298. # intersections.
  299. union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
  300. intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
  301. auth_difference_unpersisted_part: Collection[str] = union - intersection
  302. else:
  303. auth_difference_unpersisted_part = ()
  304. state_sets_ids = [set(state_set.values()) for state_set in state_sets]
  305. difference = await state_res_store.get_auth_chain_difference(
  306. room_id, state_sets_ids
  307. )
  308. difference.update(auth_difference_unpersisted_part)
  309. return difference
  310. def _seperate(
  311. state_sets: Iterable[StateMap[str]],
  312. ) -> Tuple[StateMap[str], StateMap[Set[str]]]:
  313. """Return the unconflicted and conflicted state. This is different than in
  314. the original algorithm, as this defines a key to be conflicted if one of
  315. the state sets doesn't have that key.
  316. Args:
  317. state_sets
  318. Returns:
  319. A tuple of unconflicted and conflicted state. The conflicted state dict
  320. is a map from type/state_key to set of event IDs
  321. """
  322. unconflicted_state = {}
  323. conflicted_state = {}
  324. for key in set(itertools.chain.from_iterable(state_sets)):
  325. event_ids = {state_set.get(key) for state_set in state_sets}
  326. if len(event_ids) == 1:
  327. unconflicted_state[key] = event_ids.pop()
  328. else:
  329. event_ids.discard(None)
  330. conflicted_state[key] = event_ids
  331. # mypy doesn't understand that discarding None above means that conflicted
  332. # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
  333. return unconflicted_state, conflicted_state # type: ignore
  334. def _is_power_event(event: EventBase) -> bool:
  335. """Return whether or not the event is a "power event", as defined by the
  336. v2 state resolution algorithm
  337. Args:
  338. event
  339. Returns:
  340. True if the event is a power event.
  341. """
  342. if (event.type, event.state_key) in (
  343. (EventTypes.PowerLevels, ""),
  344. (EventTypes.JoinRules, ""),
  345. (EventTypes.Create, ""),
  346. ):
  347. return True
  348. if event.type == EventTypes.Member:
  349. if event.membership in ("leave", "ban"):
  350. return event.sender != event.state_key
  351. return False
  352. async def _add_event_and_auth_chain_to_graph(
  353. graph: Dict[str, Set[str]],
  354. room_id: str,
  355. event_id: str,
  356. event_map: Dict[str, EventBase],
  357. state_res_store: StateResolutionStore,
  358. full_conflicted_set: Set[str],
  359. ) -> None:
  360. """Helper function for _reverse_topological_power_sort that add the event
  361. and its auth chain (that is in the auth diff) to the graph
  362. Args:
  363. graph: A map from event ID to the events auth event IDs
  364. room_id: the room we are working in
  365. event_id: Event to add to the graph
  366. event_map
  367. state_res_store
  368. full_conflicted_set: Set of event IDs that are in the full conflicted set.
  369. """
  370. state = [event_id]
  371. while state:
  372. eid = state.pop()
  373. graph.setdefault(eid, set())
  374. event = await _get_event(room_id, eid, event_map, state_res_store)
  375. for aid in event.auth_event_ids():
  376. if aid in full_conflicted_set:
  377. if aid not in graph:
  378. state.append(aid)
  379. graph.setdefault(eid, set()).add(aid)
  380. async def _reverse_topological_power_sort(
  381. clock: Clock,
  382. room_id: str,
  383. event_ids: Iterable[str],
  384. event_map: Dict[str, EventBase],
  385. state_res_store: StateResolutionStore,
  386. full_conflicted_set: Set[str],
  387. ) -> List[str]:
  388. """Returns a list of the event_ids sorted by reverse topological ordering,
  389. and then by power level and origin_server_ts
  390. Args:
  391. clock
  392. room_id: the room we are working in
  393. event_ids: The events to sort
  394. event_map
  395. state_res_store
  396. full_conflicted_set: Set of event IDs that are in the full conflicted set.
  397. Returns:
  398. The sorted list
  399. """
  400. graph: Dict[str, Set[str]] = {}
  401. for idx, event_id in enumerate(event_ids, start=1):
  402. await _add_event_and_auth_chain_to_graph(
  403. graph, room_id, event_id, event_map, state_res_store, full_conflicted_set
  404. )
  405. # We await occasionally when we're working with large data sets to
  406. # ensure that we don't block the reactor loop for too long.
  407. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  408. await clock.sleep(0)
  409. event_to_pl = {}
  410. for idx, event_id in enumerate(graph, start=1):
  411. pl = await _get_power_level_for_sender(
  412. room_id, event_id, event_map, state_res_store
  413. )
  414. event_to_pl[event_id] = pl
  415. # We await occasionally when we're working with large data sets to
  416. # ensure that we don't block the reactor loop for too long.
  417. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  418. await clock.sleep(0)
  419. def _get_power_order(event_id: str) -> Tuple[int, int, str]:
  420. ev = event_map[event_id]
  421. pl = event_to_pl[event_id]
  422. return -pl, ev.origin_server_ts, event_id
  423. # Note: graph is modified during the sort
  424. it = lexicographical_topological_sort(graph, key=_get_power_order)
  425. sorted_events = list(it)
  426. return sorted_events
  427. async def _iterative_auth_checks(
  428. clock: Clock,
  429. room_id: str,
  430. room_version: RoomVersion,
  431. event_ids: List[str],
  432. base_state: StateMap[str],
  433. event_map: Dict[str, EventBase],
  434. state_res_store: StateResolutionStore,
  435. ) -> MutableStateMap[str]:
  436. """Sequentially apply auth checks to each event in given list, updating the
  437. state as it goes along.
  438. Args:
  439. clock
  440. room_id
  441. room_version
  442. event_ids: Ordered list of events to apply auth checks to
  443. base_state: The set of state to start with
  444. event_map
  445. state_res_store
  446. Returns:
  447. Returns the final updated state
  448. """
  449. resolved_state = dict(base_state)
  450. for idx, event_id in enumerate(event_ids, start=1):
  451. event = event_map[event_id]
  452. auth_events = {}
  453. for aid in event.auth_event_ids():
  454. ev = await _get_event(
  455. room_id, aid, event_map, state_res_store, allow_none=True
  456. )
  457. if not ev:
  458. logger.warning(
  459. "auth_event id %s for event %s is missing", aid, event_id
  460. )
  461. else:
  462. if ev.rejected_reason is None:
  463. auth_events[(ev.type, ev.state_key)] = ev
  464. for key in event_auth.auth_types_for_event(room_version, event):
  465. if key in resolved_state:
  466. ev_id = resolved_state[key]
  467. ev = await _get_event(room_id, ev_id, event_map, state_res_store)
  468. if ev.rejected_reason is None:
  469. auth_events[key] = event_map[ev_id]
  470. if event.rejected_reason is not None:
  471. # Do not admit previously rejected events into state.
  472. # TODO: This isn't spec compliant. Events that were previously rejected due
  473. # to failing auth checks at their state, but pass auth checks during
  474. # state resolution should be accepted. Synapse does not handle the
  475. # change of rejection status well, so we preserve the previous
  476. # rejection status for now.
  477. #
  478. # Note that events rejected for non-state reasons, such as having the
  479. # wrong auth events, should remain rejected.
  480. #
  481. # https://spec.matrix.org/v1.2/rooms/v9/#rejected-events
  482. # https://github.com/matrix-org/synapse/issues/13797
  483. continue
  484. try:
  485. event_auth.check_state_dependent_auth_rules(
  486. event,
  487. auth_events.values(),
  488. )
  489. resolved_state[(event.type, event.state_key)] = event_id
  490. except AuthError:
  491. pass
  492. # We await occasionally when we're working with large data sets to
  493. # ensure that we don't block the reactor loop for too long.
  494. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  495. await clock.sleep(0)
  496. return resolved_state
  497. async def _mainline_sort(
  498. clock: Clock,
  499. room_id: str,
  500. event_ids: List[str],
  501. resolved_power_event_id: Optional[str],
  502. event_map: Dict[str, EventBase],
  503. state_res_store: StateResolutionStore,
  504. ) -> List[str]:
  505. """Returns a sorted list of event_ids sorted by mainline ordering based on
  506. the given event resolved_power_event_id
  507. Args:
  508. clock
  509. room_id: room we're working in
  510. event_ids: Events to sort
  511. resolved_power_event_id: The final resolved power level event ID
  512. event_map
  513. state_res_store
  514. Returns:
  515. The sorted list
  516. """
  517. if not event_ids:
  518. # It's possible for there to be no event IDs here to sort, so we can
  519. # skip calculating the mainline in that case.
  520. return []
  521. mainline = []
  522. pl = resolved_power_event_id
  523. idx = 0
  524. while pl:
  525. mainline.append(pl)
  526. pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
  527. auth_events = pl_ev.auth_event_ids()
  528. pl = None
  529. for aid in auth_events:
  530. ev = await _get_event(
  531. room_id, aid, event_map, state_res_store, allow_none=True
  532. )
  533. if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
  534. pl = aid
  535. break
  536. # We await occasionally when we're working with large data sets to
  537. # ensure that we don't block the reactor loop for too long.
  538. if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
  539. await clock.sleep(0)
  540. idx += 1
  541. mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
  542. event_ids = list(event_ids)
  543. order_map = {}
  544. for idx, ev_id in enumerate(event_ids, start=1):
  545. depth = await _get_mainline_depth_for_event(
  546. event_map[ev_id], mainline_map, event_map, state_res_store
  547. )
  548. order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
  549. # We await occasionally when we're working with large data sets to
  550. # ensure that we don't block the reactor loop for too long.
  551. if idx % _AWAIT_AFTER_ITERATIONS == 0:
  552. await clock.sleep(0)
  553. event_ids.sort(key=lambda ev_id: order_map[ev_id])
  554. return event_ids
  555. async def _get_mainline_depth_for_event(
  556. event: EventBase,
  557. mainline_map: Dict[str, int],
  558. event_map: Dict[str, EventBase],
  559. state_res_store: StateResolutionStore,
  560. ) -> int:
  561. """Get the mainline depths for the given event based on the mainline map
  562. Args:
  563. event
  564. mainline_map: Map from event_id to mainline depth for events in the mainline.
  565. event_map
  566. state_res_store
  567. Returns:
  568. The mainline depth
  569. """
  570. room_id = event.room_id
  571. tmp_event: Optional[EventBase] = event
  572. # We do an iterative search, replacing `event with the power level in its
  573. # auth events (if any)
  574. while tmp_event:
  575. depth = mainline_map.get(tmp_event.event_id)
  576. if depth is not None:
  577. return depth
  578. auth_events = tmp_event.auth_event_ids()
  579. tmp_event = None
  580. for aid in auth_events:
  581. aev = await _get_event(
  582. room_id, aid, event_map, state_res_store, allow_none=True
  583. )
  584. if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
  585. tmp_event = aev
  586. break
  587. # Didn't find a power level auth event, so we just return 0
  588. return 0
  589. @overload
  590. async def _get_event(
  591. room_id: str,
  592. event_id: str,
  593. event_map: Dict[str, EventBase],
  594. state_res_store: StateResolutionStore,
  595. allow_none: Literal[False] = False,
  596. ) -> EventBase:
  597. ...
  598. @overload
  599. async def _get_event(
  600. room_id: str,
  601. event_id: str,
  602. event_map: Dict[str, EventBase],
  603. state_res_store: StateResolutionStore,
  604. allow_none: Literal[True],
  605. ) -> Optional[EventBase]:
  606. ...
  607. async def _get_event(
  608. room_id: str,
  609. event_id: str,
  610. event_map: Dict[str, EventBase],
  611. state_res_store: StateResolutionStore,
  612. allow_none: bool = False,
  613. ) -> Optional[EventBase]:
  614. """Helper function to look up event in event_map, falling back to looking
  615. it up in the store
  616. Args:
  617. room_id
  618. event_id
  619. event_map
  620. state_res_store
  621. allow_none: if the event is not found, return None rather than raising
  622. an exception
  623. Returns:
  624. The event, or none if the event does not exist (and allow_none is True).
  625. """
  626. if event_id not in event_map:
  627. events = await state_res_store.get_events([event_id], allow_rejected=True)
  628. event_map.update(events)
  629. event = event_map.get(event_id)
  630. if event is None:
  631. if allow_none:
  632. return None
  633. raise Exception("Unknown event %s" % (event_id,))
  634. if event.room_id != room_id:
  635. raise Exception(
  636. "In state res for room %s, event %s is in %s"
  637. % (room_id, event_id, event.room_id)
  638. )
  639. return event
  640. def lexicographical_topological_sort(
  641. graph: Dict[str, Set[str]], key: Callable[[str], Any]
  642. ) -> Generator[str, None, None]:
  643. """Performs a lexicographic reverse topological sort on the graph.
  644. This returns a reverse topological sort (i.e. if node A references B then B
  645. appears before A in the sort), with ties broken lexicographically based on
  646. return value of the `key` function.
  647. NOTE: `graph` is modified during the sort.
  648. Args:
  649. graph: A representation of the graph where each node is a key in the
  650. dict and its value are the nodes edges.
  651. key: A function that takes a node and returns a value that is comparable
  652. and used to order nodes
  653. Yields:
  654. The next node in the topological sort
  655. """
  656. # Note, this is basically Kahn's algorithm except we look at nodes with no
  657. # outgoing edges, c.f.
  658. # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
  659. outdegree_map = graph
  660. reverse_graph: Dict[str, Set[str]] = {}
  661. # Lists of nodes with zero out degree. Is actually a tuple of
  662. # `(key(node), node)` so that sorting does the right thing
  663. zero_outdegree = []
  664. for node, edges in graph.items():
  665. if len(edges) == 0:
  666. zero_outdegree.append((key(node), node))
  667. reverse_graph.setdefault(node, set())
  668. for edge in edges:
  669. reverse_graph.setdefault(edge, set()).add(node)
  670. # heapq is a built in implementation of a sorted queue.
  671. heapq.heapify(zero_outdegree)
  672. while zero_outdegree:
  673. _, node = heapq.heappop(zero_outdegree)
  674. for parent in reverse_graph[node]:
  675. out = outdegree_map[parent]
  676. out.discard(node)
  677. if len(out) == 0:
  678. heapq.heappush(zero_outdegree, (key(parent), parent))
  679. yield node