state.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import (
  16. TYPE_CHECKING,
  17. Awaitable,
  18. Collection,
  19. Dict,
  20. Iterable,
  21. List,
  22. Mapping,
  23. Optional,
  24. Set,
  25. Tuple,
  26. TypeVar,
  27. )
  28. import attr
  29. from frozendict import frozendict
  30. from synapse.api.constants import EventTypes
  31. from synapse.events import EventBase
  32. from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
  33. from synapse.types import MutableStateMap, StateKey, StateMap
  34. if TYPE_CHECKING:
  35. from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
  36. from synapse.server import HomeServer
  37. from synapse.storage.databases import Databases
  38. logger = logging.getLogger(__name__)
  39. # Used for generic functions below
  40. T = TypeVar("T")
  41. @attr.s(slots=True, frozen=True, auto_attribs=True)
  42. class StateFilter:
  43. """A filter used when querying for state.
  44. Attributes:
  45. types: Map from type to set of state keys (or None). This specifies
  46. which state_keys for the given type to fetch from the DB. If None
  47. then all events with that type are fetched. If the set is empty
  48. then no events with that type are fetched.
  49. include_others: Whether to fetch events with types that do not
  50. appear in `types`.
  51. """
  52. types: "frozendict[str, Optional[FrozenSet[str]]]"
  53. include_others: bool = False
  54. def __attrs_post_init__(self) -> None:
  55. # If `include_others` is set we canonicalise the filter by removing
  56. # wildcards from the types dictionary
  57. if self.include_others:
  58. # this is needed to work around the fact that StateFilter is frozen
  59. object.__setattr__(
  60. self,
  61. "types",
  62. frozendict({k: v for k, v in self.types.items() if v is not None}),
  63. )
  64. @staticmethod
  65. def all() -> "StateFilter":
  66. """Returns a filter that fetches everything.
  67. Returns:
  68. The state filter.
  69. """
  70. return _ALL_STATE_FILTER
  71. @staticmethod
  72. def none() -> "StateFilter":
  73. """Returns a filter that fetches nothing.
  74. Returns:
  75. The new state filter.
  76. """
  77. return _NONE_STATE_FILTER
  78. @staticmethod
  79. def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
  80. """Creates a filter that only fetches the given types
  81. Args:
  82. types: A list of type and state keys to fetch. A state_key of None
  83. fetches everything for that type
  84. Returns:
  85. The new state filter.
  86. """
  87. type_dict: Dict[str, Optional[Set[str]]] = {}
  88. for typ, s in types:
  89. if typ in type_dict:
  90. if type_dict[typ] is None:
  91. continue
  92. if s is None:
  93. type_dict[typ] = None
  94. continue
  95. type_dict.setdefault(typ, set()).add(s) # type: ignore
  96. return StateFilter(
  97. types=frozendict(
  98. (k, frozenset(v) if v is not None else None)
  99. for k, v in type_dict.items()
  100. )
  101. )
  102. @staticmethod
  103. def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
  104. """Creates a filter that returns all non-member events, plus the member
  105. events for the given users
  106. Args:
  107. members: Set of user IDs
  108. Returns:
  109. The new state filter
  110. """
  111. return StateFilter(
  112. types=frozendict({EventTypes.Member: frozenset(members)}),
  113. include_others=True,
  114. )
  115. @staticmethod
  116. def freeze(
  117. types: Mapping[str, Optional[Collection[str]]], include_others: bool
  118. ) -> "StateFilter":
  119. """
  120. Returns a (frozen) StateFilter with the same contents as the parameters
  121. specified here, which can be made of mutable types.
  122. """
  123. types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
  124. for state_types, state_keys in types.items():
  125. if state_keys is not None:
  126. types_with_frozen_values[state_types] = frozenset(state_keys)
  127. else:
  128. types_with_frozen_values[state_types] = None
  129. return StateFilter(
  130. frozendict(types_with_frozen_values), include_others=include_others
  131. )
  132. def return_expanded(self) -> "StateFilter":
  133. """Creates a new StateFilter where type wild cards have been removed
  134. (except for memberships). The returned filter is a superset of the
  135. current one, i.e. anything that passes the current filter will pass
  136. the returned filter.
  137. This helps the caching as the DictionaryCache knows if it has *all* the
  138. state, but does not know if it has all of the keys of a particular type,
  139. which makes wildcard lookups expensive unless we have a complete cache.
  140. Hence, if we are doing a wildcard lookup, populate the cache fully so
  141. that we can do an efficient lookup next time.
  142. Note that since we have two caches, one for membership events and one for
  143. other events, we can be a bit more clever than simply returning
  144. `StateFilter.all()` if `has_wildcards()` is True.
  145. We return a StateFilter where:
  146. 1. the list of membership events to return is the same
  147. 2. if there is a wildcard that matches non-member events we
  148. return all non-member events
  149. Returns:
  150. The new state filter.
  151. """
  152. if self.is_full():
  153. # If we're going to return everything then there's nothing to do
  154. return self
  155. if not self.has_wildcards():
  156. # If there are no wild cards, there's nothing to do
  157. return self
  158. if EventTypes.Member in self.types:
  159. get_all_members = self.types[EventTypes.Member] is None
  160. else:
  161. get_all_members = self.include_others
  162. has_non_member_wildcard = self.include_others or any(
  163. state_keys is None
  164. for t, state_keys in self.types.items()
  165. if t != EventTypes.Member
  166. )
  167. if not has_non_member_wildcard:
  168. # If there are no non-member wild cards we can just return ourselves
  169. return self
  170. if get_all_members:
  171. # We want to return everything.
  172. return StateFilter.all()
  173. elif EventTypes.Member in self.types:
  174. # We want to return all non-members, but only particular
  175. # memberships
  176. return StateFilter(
  177. types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
  178. include_others=True,
  179. )
  180. else:
  181. # We want to return all non-members
  182. return _ALL_NON_MEMBER_STATE_FILTER
  183. def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
  184. """Converts the filter to an SQL clause.
  185. For example:
  186. f = StateFilter.from_types([("m.room.create", "")])
  187. clause, args = f.make_sql_filter_clause()
  188. clause == "(type = ? AND state_key = ?)"
  189. args == ['m.room.create', '']
  190. Returns:
  191. The SQL string (may be empty) and arguments. An empty SQL string is
  192. returned when the filter matches everything (i.e. is "full").
  193. """
  194. where_clause = ""
  195. where_args: List[str] = []
  196. if self.is_full():
  197. return where_clause, where_args
  198. if not self.include_others and not self.types:
  199. # i.e. this is an empty filter, so we need to return a clause that
  200. # will match nothing
  201. return "1 = 2", []
  202. # First we build up a lost of clauses for each type/state_key combo
  203. clauses = []
  204. for etype, state_keys in self.types.items():
  205. if state_keys is None:
  206. clauses.append("(type = ?)")
  207. where_args.append(etype)
  208. continue
  209. for state_key in state_keys:
  210. clauses.append("(type = ? AND state_key = ?)")
  211. where_args.extend((etype, state_key))
  212. # This will match anything that appears in `self.types`
  213. where_clause = " OR ".join(clauses)
  214. # If we want to include stuff that's not in the types dict then we add
  215. # a `OR type NOT IN (...)` clause to the end.
  216. if self.include_others:
  217. if where_clause:
  218. where_clause += " OR "
  219. where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
  220. where_args.extend(self.types)
  221. return where_clause, where_args
  222. def max_entries_returned(self) -> Optional[int]:
  223. """Returns the maximum number of entries this filter will return if
  224. known, otherwise returns None.
  225. For example a simple state filter asking for `("m.room.create", "")`
  226. will return 1, whereas the default state filter will return None.
  227. This is used to bail out early if the right number of entries have been
  228. fetched.
  229. """
  230. if self.has_wildcards():
  231. return None
  232. return len(self.concrete_types())
  233. def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
  234. """Returns the state filtered with by this StateFilter.
  235. Args:
  236. state: The state map to filter
  237. Returns:
  238. The filtered state map.
  239. This is a copy, so it's safe to mutate.
  240. """
  241. if self.is_full():
  242. return dict(state_dict)
  243. filtered_state = {}
  244. for k, v in state_dict.items():
  245. typ, state_key = k
  246. if typ in self.types:
  247. state_keys = self.types[typ]
  248. if state_keys is None or state_key in state_keys:
  249. filtered_state[k] = v
  250. elif self.include_others:
  251. filtered_state[k] = v
  252. return filtered_state
  253. def is_full(self) -> bool:
  254. """Whether this filter fetches everything or not
  255. Returns:
  256. True if the filter fetches everything.
  257. """
  258. return self.include_others and not self.types
  259. def has_wildcards(self) -> bool:
  260. """Whether the filter includes wildcards or is attempting to fetch
  261. specific state.
  262. Returns:
  263. True if the filter includes wildcards.
  264. """
  265. return self.include_others or any(
  266. state_keys is None for state_keys in self.types.values()
  267. )
  268. def concrete_types(self) -> List[Tuple[str, str]]:
  269. """Returns a list of concrete type/state_keys (i.e. not None) that
  270. will be fetched. This will be a complete list if `has_wildcards`
  271. returns False, but otherwise will be a subset (or even empty).
  272. Returns:
  273. A list of type/state_keys tuples.
  274. """
  275. return [
  276. (t, s)
  277. for t, state_keys in self.types.items()
  278. if state_keys is not None
  279. for s in state_keys
  280. ]
  281. def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
  282. """Return the filter split into two: one which assumes it's exclusively
  283. matching against member state, and one which assumes it's matching
  284. against non member state.
  285. This is useful due to the returned filters giving correct results for
  286. `is_full()`, `has_wildcards()`, etc, when operating against maps that
  287. either exclusively contain member events or only contain non-member
  288. events. (Which is the case when dealing with the member vs non-member
  289. state caches).
  290. Returns:
  291. The member and non member filters
  292. """
  293. if EventTypes.Member in self.types:
  294. state_keys = self.types[EventTypes.Member]
  295. if state_keys is None:
  296. member_filter = StateFilter.all()
  297. else:
  298. member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
  299. elif self.include_others:
  300. member_filter = StateFilter.all()
  301. else:
  302. member_filter = StateFilter.none()
  303. non_member_filter = StateFilter(
  304. types=frozendict(
  305. {k: v for k, v in self.types.items() if k != EventTypes.Member}
  306. ),
  307. include_others=self.include_others,
  308. )
  309. return member_filter, non_member_filter
  310. def _decompose_into_four_parts(
  311. self,
  312. ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
  313. """
  314. Decomposes this state filter into 4 constituent parts, which can be
  315. thought of as this:
  316. all? - minus_wildcards + plus_wildcards + plus_state_keys
  317. where
  318. * all represents ALL state
  319. * minus_wildcards represents entire state types to remove
  320. * plus_wildcards represents entire state types to add
  321. * plus_state_keys represents individual state keys to add
  322. See `recompose_from_four_parts` for the other direction of this
  323. correspondence.
  324. """
  325. is_all = self.include_others
  326. excluded_types: Set[str] = {t for t in self.types if is_all}
  327. wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
  328. concrete_keys: Set[StateKey] = set(self.concrete_types())
  329. return (is_all, excluded_types), (wildcard_types, concrete_keys)
  330. @staticmethod
  331. def _recompose_from_four_parts(
  332. all_part: bool,
  333. minus_wildcards: Set[str],
  334. plus_wildcards: Set[str],
  335. plus_state_keys: Set[StateKey],
  336. ) -> "StateFilter":
  337. """
  338. Recomposes a state filter from 4 parts.
  339. See `decompose_into_four_parts` (the other direction of this
  340. correspondence) for descriptions on each of the parts.
  341. """
  342. # {state type -> set of state keys OR None for wildcard}
  343. # (The same structure as that of a StateFilter.)
  344. new_types: Dict[str, Optional[Set[str]]] = {}
  345. # if we start with all, insert the excluded statetypes as empty sets
  346. # to prevent them from being included
  347. if all_part:
  348. new_types.update({state_type: set() for state_type in minus_wildcards})
  349. # insert the plus wildcards
  350. new_types.update({state_type: None for state_type in plus_wildcards})
  351. # insert the specific state keys
  352. for state_type, state_key in plus_state_keys:
  353. if state_type in new_types:
  354. entry = new_types[state_type]
  355. if entry is not None:
  356. entry.add(state_key)
  357. elif not all_part:
  358. # don't insert if the entire type is already included by
  359. # include_others as this would actually shrink the state allowed
  360. # by this filter.
  361. new_types[state_type] = {state_key}
  362. return StateFilter.freeze(new_types, include_others=all_part)
  363. def approx_difference(self, other: "StateFilter") -> "StateFilter":
  364. """
  365. Returns a state filter which represents `self - other`.
  366. This is useful for determining what state remains to be pulled out of the
  367. database if we want the state included by `self` but already have the state
  368. included by `other`.
  369. The returned state filter
  370. - MUST include all state events that are included by this filter (`self`)
  371. unless they are included by `other`;
  372. - MUST NOT include state events not included by this filter (`self`); and
  373. - MAY be an over-approximation: the returned state filter
  374. MAY additionally include some state events from `other`.
  375. This implementation attempts to return the narrowest such state filter.
  376. In the case that `self` contains wildcards for state types where
  377. `other` contains specific state keys, an approximation must be made:
  378. the returned state filter keeps the wildcard, as state filters are not
  379. able to express 'all state keys except some given examples'.
  380. e.g.
  381. StateFilter(m.room.member -> None (wildcard))
  382. minus
  383. StateFilter(m.room.member -> {'@wombat:example.org'})
  384. is approximated as
  385. StateFilter(m.room.member -> None (wildcard))
  386. """
  387. # We first transform self and other into an alternative representation:
  388. # - whether or not they include all events to begin with ('all')
  389. # - if so, which event types are excluded? ('excludes')
  390. # - which entire event types to include ('wildcards')
  391. # - which concrete state keys to include ('concrete state keys')
  392. (self_all, self_excludes), (
  393. self_wildcards,
  394. self_concrete_keys,
  395. ) = self._decompose_into_four_parts()
  396. (other_all, other_excludes), (
  397. other_wildcards,
  398. other_concrete_keys,
  399. ) = other._decompose_into_four_parts()
  400. # Start with an estimate of the difference based on self
  401. new_all = self_all
  402. # Wildcards from the other can be added to the exclusion filter
  403. new_excludes = self_excludes | other_wildcards
  404. # We remove wildcards that appeared as wildcards in the other
  405. new_wildcards = self_wildcards - other_wildcards
  406. # We filter out the concrete state keys that appear in the other
  407. # as wildcards or concrete state keys.
  408. new_concrete_keys = {
  409. (state_type, state_key)
  410. for (state_type, state_key) in self_concrete_keys
  411. if state_type not in other_wildcards
  412. } - other_concrete_keys
  413. if other_all:
  414. if self_all:
  415. # If self starts with all, then we add as wildcards any
  416. # types which appear in the other's exclusion filter (but
  417. # aren't in the self exclusion filter). This is as the other
  418. # filter will return everything BUT the types in its exclusion, so
  419. # we need to add those excluded types that also match the self
  420. # filter as wildcard types in the new filter.
  421. new_wildcards |= other_excludes.difference(self_excludes)
  422. # If other is an `include_others` then the difference isn't.
  423. new_all = False
  424. # (We have no need for excludes when we don't start with all, as there
  425. # is nothing to exclude.)
  426. new_excludes = set()
  427. # We also filter out all state types that aren't in the exclusion
  428. # list of the other.
  429. new_wildcards &= other_excludes
  430. new_concrete_keys = {
  431. (state_type, state_key)
  432. for (state_type, state_key) in new_concrete_keys
  433. if state_type in other_excludes
  434. }
  435. # Transform our newly-constructed state filter from the alternative
  436. # representation back into the normal StateFilter representation.
  437. return StateFilter._recompose_from_four_parts(
  438. new_all, new_excludes, new_wildcards, new_concrete_keys
  439. )
  440. _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
  441. _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
  442. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  443. )
  444. _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
  445. class StateGroupStorage:
  446. """High level interface to fetching state for event."""
  447. def __init__(self, hs: "HomeServer", stores: "Databases"):
  448. self.stores = stores
  449. self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
  450. def notify_event_un_partial_stated(self, event_id: str) -> None:
  451. self._partial_state_events_tracker.notify_un_partial_stated(event_id)
  452. async def get_state_group_delta(
  453. self, state_group: int
  454. ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
  455. """Given a state group try to return a previous group and a delta between
  456. the old and the new.
  457. Args:
  458. state_group: The state group used to retrieve state deltas.
  459. Returns:
  460. A tuple of the previous group and a state map of the event IDs which
  461. make up the delta between the old and new state groups.
  462. """
  463. state_group_delta = await self.stores.state.get_state_group_delta(state_group)
  464. return state_group_delta.prev_group, state_group_delta.delta_ids
  465. async def get_state_groups_ids(
  466. self, _room_id: str, event_ids: Collection[str]
  467. ) -> Dict[int, MutableStateMap[str]]:
  468. """Get the event IDs of all the state for the state groups for the given events
  469. Args:
  470. _room_id: id of the room for these events
  471. event_ids: ids of the events
  472. Returns:
  473. dict of state_group_id -> (dict of (type, state_key) -> event id)
  474. Raises:
  475. RuntimeError if we don't have a state group for one or more of the events
  476. (ie they are outliers or unknown)
  477. """
  478. if not event_ids:
  479. return {}
  480. event_to_groups = await self.get_state_group_for_events(event_ids)
  481. groups = set(event_to_groups.values())
  482. group_to_state = await self.stores.state._get_state_for_groups(groups)
  483. return group_to_state
  484. async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
  485. """Get the event IDs of all the state in the given state group
  486. Args:
  487. state_group: A state group for which we want to get the state IDs.
  488. Returns:
  489. Resolves to a map of (type, state_key) -> event_id
  490. """
  491. group_to_state = await self.get_state_for_groups((state_group,))
  492. return group_to_state[state_group]
  493. async def get_state_groups(
  494. self, room_id: str, event_ids: Collection[str]
  495. ) -> Dict[int, List[EventBase]]:
  496. """Get the state groups for the given list of event_ids
  497. Args:
  498. room_id: ID of the room for these events.
  499. event_ids: The event IDs to retrieve state for.
  500. Returns:
  501. dict of state_group_id -> list of state events.
  502. """
  503. if not event_ids:
  504. return {}
  505. group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
  506. state_event_map = await self.stores.main.get_events(
  507. [
  508. ev_id
  509. for group_ids in group_to_ids.values()
  510. for ev_id in group_ids.values()
  511. ],
  512. get_prev_content=False,
  513. )
  514. return {
  515. group: [
  516. state_event_map[v]
  517. for v in event_id_map.values()
  518. if v in state_event_map
  519. ]
  520. for group, event_id_map in group_to_ids.items()
  521. }
  522. def _get_state_groups_from_groups(
  523. self, groups: List[int], state_filter: StateFilter
  524. ) -> Awaitable[Dict[int, StateMap[str]]]:
  525. """Returns the state groups for a given set of groups, filtering on
  526. types of state events.
  527. Args:
  528. groups: list of state group IDs to query
  529. state_filter: The state filter used to fetch state
  530. from the database.
  531. Returns:
  532. Dict of state group to state map.
  533. """
  534. return self.stores.state._get_state_groups_from_groups(groups, state_filter)
  535. async def get_state_for_events(
  536. self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
  537. ) -> Dict[str, StateMap[EventBase]]:
  538. """Given a list of event_ids and type tuples, return a list of state
  539. dicts for each event.
  540. Args:
  541. event_ids: The events to fetch the state of.
  542. state_filter: The state filter used to fetch state.
  543. Returns:
  544. A dict of (event_id) -> (type, state_key) -> [state_events]
  545. Raises:
  546. RuntimeError if we don't have a state group for one or more of the events
  547. (ie they are outliers or unknown)
  548. """
  549. event_to_groups = await self.get_state_group_for_events(event_ids)
  550. groups = set(event_to_groups.values())
  551. group_to_state = await self.stores.state._get_state_for_groups(
  552. groups, state_filter or StateFilter.all()
  553. )
  554. state_event_map = await self.stores.main.get_events(
  555. [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
  556. get_prev_content=False,
  557. )
  558. event_to_state = {
  559. event_id: {
  560. k: state_event_map[v]
  561. for k, v in group_to_state[group].items()
  562. if v in state_event_map
  563. }
  564. for event_id, group in event_to_groups.items()
  565. }
  566. return {event: event_to_state[event] for event in event_ids}
  567. async def get_state_ids_for_events(
  568. self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
  569. ) -> Dict[str, StateMap[str]]:
  570. """
  571. Get the state dicts corresponding to a list of events, containing the event_ids
  572. of the state events (as opposed to the events themselves)
  573. Args:
  574. event_ids: events whose state should be returned
  575. state_filter: The state filter used to fetch state from the database.
  576. Returns:
  577. A dict from event_id -> (type, state_key) -> event_id
  578. Raises:
  579. RuntimeError if we don't have a state group for one or more of the events
  580. (ie they are outliers or unknown)
  581. """
  582. event_to_groups = await self.get_state_group_for_events(event_ids)
  583. groups = set(event_to_groups.values())
  584. group_to_state = await self.stores.state._get_state_for_groups(
  585. groups, state_filter or StateFilter.all()
  586. )
  587. event_to_state = {
  588. event_id: group_to_state[group]
  589. for event_id, group in event_to_groups.items()
  590. }
  591. return {event: event_to_state[event] for event in event_ids}
  592. async def get_state_for_event(
  593. self, event_id: str, state_filter: Optional[StateFilter] = None
  594. ) -> StateMap[EventBase]:
  595. """
  596. Get the state dict corresponding to a particular event
  597. Args:
  598. event_id: event whose state should be returned
  599. state_filter: The state filter used to fetch state from the database.
  600. Returns:
  601. A dict from (type, state_key) -> state_event
  602. Raises:
  603. RuntimeError if we don't have a state group for the event (ie it is an
  604. outlier or is unknown)
  605. """
  606. state_map = await self.get_state_for_events(
  607. [event_id], state_filter or StateFilter.all()
  608. )
  609. return state_map[event_id]
  610. async def get_state_ids_for_event(
  611. self, event_id: str, state_filter: Optional[StateFilter] = None
  612. ) -> StateMap[str]:
  613. """
  614. Get the state dict corresponding to a particular event
  615. Args:
  616. event_id: event whose state should be returned
  617. state_filter: The state filter used to fetch state from the database.
  618. Returns:
  619. A dict from (type, state_key) -> state_event_id
  620. Raises:
  621. RuntimeError if we don't have a state group for the event (ie it is an
  622. outlier or is unknown)
  623. """
  624. state_map = await self.get_state_ids_for_events(
  625. [event_id], state_filter or StateFilter.all()
  626. )
  627. return state_map[event_id]
  628. def get_state_for_groups(
  629. self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
  630. ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
  631. """Gets the state at each of a list of state groups, optionally
  632. filtering by type/state_key
  633. Args:
  634. groups: list of state groups for which we want to get the state.
  635. state_filter: The state filter used to fetch state.
  636. from the database.
  637. Returns:
  638. Dict of state group to state map.
  639. """
  640. return self.stores.state._get_state_for_groups(
  641. groups, state_filter or StateFilter.all()
  642. )
  643. async def get_state_group_for_events(
  644. self,
  645. event_ids: Collection[str],
  646. await_full_state: bool = True,
  647. ) -> Mapping[str, int]:
  648. """Returns mapping event_id -> state_group
  649. Args:
  650. event_ids: events to get state groups for
  651. await_full_state: if true, will block if we do not yet have complete
  652. state at this event.
  653. """
  654. if await_full_state:
  655. await self._partial_state_events_tracker.await_full_state(event_ids)
  656. return await self.stores.main._get_state_group_for_events(event_ids)
  657. async def store_state_group(
  658. self,
  659. event_id: str,
  660. room_id: str,
  661. prev_group: Optional[int],
  662. delta_ids: Optional[StateMap[str]],
  663. current_state_ids: StateMap[str],
  664. ) -> int:
  665. """Store a new set of state, returning a newly assigned state group.
  666. Args:
  667. event_id: The event ID for which the state was calculated.
  668. room_id: ID of the room for which the state was calculated.
  669. prev_group: A previous state group for the room, optional.
  670. delta_ids: The delta between state at `prev_group` and
  671. `current_state_ids`, if `prev_group` was given. Same format as
  672. `current_state_ids`.
  673. current_state_ids: The state to store. Map of (type, state_key)
  674. to event_id.
  675. Returns:
  676. The state group ID
  677. """
  678. return await self.stores.state.store_state_group(
  679. event_id, room_id, prev_group, delta_ids, current_state_ids
  680. )