state.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2022 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import (
  17. TYPE_CHECKING,
  18. Callable,
  19. Collection,
  20. Dict,
  21. Iterable,
  22. List,
  23. Mapping,
  24. Optional,
  25. Set,
  26. Tuple,
  27. TypeVar,
  28. )
  29. import attr
  30. from immutabledict import immutabledict
  31. from synapse.api.constants import EventTypes
  32. from synapse.types import MutableStateMap, StateKey, StateMap
  33. if TYPE_CHECKING:
  34. from typing import FrozenSet # noqa: used within quoted type hint; flake8 sad
  35. logger = logging.getLogger(__name__)
  36. # Used for generic functions below
  37. T = TypeVar("T")
  38. @attr.s(slots=True, frozen=True, auto_attribs=True)
  39. class StateFilter:
  40. """A filter used when querying for state.
  41. Attributes:
  42. types: Map from type to set of state keys (or None). This specifies
  43. which state_keys for the given type to fetch from the DB. If None
  44. then all events with that type are fetched. If the set is empty
  45. then no events with that type are fetched.
  46. include_others: Whether to fetch events with types that do not
  47. appear in `types`.
  48. """
  49. types: "immutabledict[str, Optional[FrozenSet[str]]]"
  50. include_others: bool = False
  51. def __attrs_post_init__(self) -> None:
  52. # If `include_others` is set we canonicalise the filter by removing
  53. # wildcards from the types dictionary
  54. if self.include_others:
  55. # this is needed to work around the fact that StateFilter is frozen
  56. object.__setattr__(
  57. self,
  58. "types",
  59. immutabledict({k: v for k, v in self.types.items() if v is not None}),
  60. )
  61. @staticmethod
  62. def all() -> "StateFilter":
  63. """Returns a filter that fetches everything.
  64. Returns:
  65. The state filter.
  66. """
  67. return _ALL_STATE_FILTER
  68. @staticmethod
  69. def none() -> "StateFilter":
  70. """Returns a filter that fetches nothing.
  71. Returns:
  72. The new state filter.
  73. """
  74. return _NONE_STATE_FILTER
  75. @staticmethod
  76. def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
  77. """Creates a filter that only fetches the given types
  78. Args:
  79. types: A list of type and state keys to fetch. A state_key of None
  80. fetches everything for that type
  81. Returns:
  82. The new state filter.
  83. """
  84. type_dict: Dict[str, Optional[Set[str]]] = {}
  85. for typ, s in types:
  86. if typ in type_dict:
  87. if type_dict[typ] is None:
  88. continue
  89. if s is None:
  90. type_dict[typ] = None
  91. continue
  92. type_dict.setdefault(typ, set()).add(s) # type: ignore
  93. return StateFilter(
  94. types=immutabledict(
  95. (k, frozenset(v) if v is not None else None)
  96. for k, v in type_dict.items()
  97. )
  98. )
  99. def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
  100. """The inverse to `from_types`."""
  101. for event_type, state_keys in self.types.items():
  102. if state_keys is None:
  103. yield event_type, None
  104. else:
  105. for state_key in state_keys:
  106. yield event_type, state_key
  107. @staticmethod
  108. def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
  109. """Creates a filter that returns all non-member events, plus the member
  110. events for the given users
  111. Args:
  112. members: Set of user IDs
  113. Returns:
  114. The new state filter
  115. """
  116. return StateFilter(
  117. types=immutabledict({EventTypes.Member: frozenset(members)}),
  118. include_others=True,
  119. )
  120. @staticmethod
  121. def freeze(
  122. types: Mapping[str, Optional[Collection[str]]], include_others: bool
  123. ) -> "StateFilter":
  124. """
  125. Returns a (frozen) StateFilter with the same contents as the parameters
  126. specified here, which can be made of mutable types.
  127. """
  128. types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
  129. for state_types, state_keys in types.items():
  130. if state_keys is not None:
  131. types_with_frozen_values[state_types] = frozenset(state_keys)
  132. else:
  133. types_with_frozen_values[state_types] = None
  134. return StateFilter(
  135. immutabledict(types_with_frozen_values), include_others=include_others
  136. )
  137. def return_expanded(self) -> "StateFilter":
  138. """Creates a new StateFilter where type wild cards have been removed
  139. (except for memberships). The returned filter is a superset of the
  140. current one, i.e. anything that passes the current filter will pass
  141. the returned filter.
  142. This helps the caching as the DictionaryCache knows if it has *all* the
  143. state, but does not know if it has all of the keys of a particular type,
  144. which makes wildcard lookups expensive unless we have a complete cache.
  145. Hence, if we are doing a wildcard lookup, populate the cache fully so
  146. that we can do an efficient lookup next time.
  147. Note that since we have two caches, one for membership events and one for
  148. other events, we can be a bit more clever than simply returning
  149. `StateFilter.all()` if `has_wildcards()` is True.
  150. We return a StateFilter where:
  151. 1. the list of membership events to return is the same
  152. 2. if there is a wildcard that matches non-member events we
  153. return all non-member events
  154. Returns:
  155. The new state filter.
  156. """
  157. if self.is_full():
  158. # If we're going to return everything then there's nothing to do
  159. return self
  160. if not self.has_wildcards():
  161. # If there are no wild cards, there's nothing to do
  162. return self
  163. if EventTypes.Member in self.types:
  164. get_all_members = self.types[EventTypes.Member] is None
  165. else:
  166. get_all_members = self.include_others
  167. has_non_member_wildcard = self.include_others or any(
  168. state_keys is None
  169. for t, state_keys in self.types.items()
  170. if t != EventTypes.Member
  171. )
  172. if not has_non_member_wildcard:
  173. # If there are no non-member wild cards we can just return ourselves
  174. return self
  175. if get_all_members:
  176. # We want to return everything.
  177. return StateFilter.all()
  178. elif EventTypes.Member in self.types:
  179. # We want to return all non-members, but only particular
  180. # memberships
  181. return StateFilter(
  182. types=immutabledict({EventTypes.Member: self.types[EventTypes.Member]}),
  183. include_others=True,
  184. )
  185. else:
  186. # We want to return all non-members
  187. return _ALL_NON_MEMBER_STATE_FILTER
  188. def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
  189. """Converts the filter to an SQL clause.
  190. For example:
  191. f = StateFilter.from_types([("m.room.create", "")])
  192. clause, args = f.make_sql_filter_clause()
  193. clause == "(type = ? AND state_key = ?)"
  194. args == ['m.room.create', '']
  195. Returns:
  196. The SQL string (may be empty) and arguments. An empty SQL string is
  197. returned when the filter matches everything (i.e. is "full").
  198. """
  199. where_clause = ""
  200. where_args: List[str] = []
  201. if self.is_full():
  202. return where_clause, where_args
  203. if not self.include_others and not self.types:
  204. # i.e. this is an empty filter, so we need to return a clause that
  205. # will match nothing
  206. return "1 = 2", []
  207. # First we build up a lost of clauses for each type/state_key combo
  208. clauses = []
  209. for etype, state_keys in self.types.items():
  210. if state_keys is None:
  211. clauses.append("(type = ?)")
  212. where_args.append(etype)
  213. continue
  214. for state_key in state_keys:
  215. clauses.append("(type = ? AND state_key = ?)")
  216. where_args.extend((etype, state_key))
  217. # This will match anything that appears in `self.types`
  218. where_clause = " OR ".join(clauses)
  219. # If we want to include stuff that's not in the types dict then we add
  220. # a `OR type NOT IN (...)` clause to the end.
  221. if self.include_others:
  222. if where_clause:
  223. where_clause += " OR "
  224. where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
  225. where_args.extend(self.types)
  226. return where_clause, where_args
  227. def max_entries_returned(self) -> Optional[int]:
  228. """Returns the maximum number of entries this filter will return if
  229. known, otherwise returns None.
  230. For example a simple state filter asking for `("m.room.create", "")`
  231. will return 1, whereas the default state filter will return None.
  232. This is used to bail out early if the right number of entries have been
  233. fetched.
  234. """
  235. if self.has_wildcards():
  236. return None
  237. return len(self.concrete_types())
  238. def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
  239. """Returns the state filtered with by this StateFilter.
  240. Args:
  241. state: The state map to filter
  242. Returns:
  243. The filtered state map.
  244. This is a copy, so it's safe to mutate.
  245. """
  246. if self.is_full():
  247. return dict(state_dict)
  248. filtered_state = {}
  249. for k, v in state_dict.items():
  250. typ, state_key = k
  251. if typ in self.types:
  252. state_keys = self.types[typ]
  253. if state_keys is None or state_key in state_keys:
  254. filtered_state[k] = v
  255. elif self.include_others:
  256. filtered_state[k] = v
  257. return filtered_state
  258. def is_full(self) -> bool:
  259. """Whether this filter fetches everything or not
  260. Returns:
  261. True if the filter fetches everything.
  262. """
  263. return self.include_others and not self.types
  264. def has_wildcards(self) -> bool:
  265. """Whether the filter includes wildcards or is attempting to fetch
  266. specific state.
  267. Returns:
  268. True if the filter includes wildcards.
  269. """
  270. return self.include_others or any(
  271. state_keys is None for state_keys in self.types.values()
  272. )
  273. def concrete_types(self) -> List[Tuple[str, str]]:
  274. """Returns a list of concrete type/state_keys (i.e. not None) that
  275. will be fetched. This will be a complete list if `has_wildcards`
  276. returns False, but otherwise will be a subset (or even empty).
  277. Returns:
  278. A list of type/state_keys tuples.
  279. """
  280. return [
  281. (t, s)
  282. for t, state_keys in self.types.items()
  283. if state_keys is not None
  284. for s in state_keys
  285. ]
  286. def wildcard_types(self) -> List[str]:
  287. """Returns a list of event types which require us to fetch all state keys.
  288. This will be empty unless `has_wildcards` returns True.
  289. Returns:
  290. A list of event types.
  291. """
  292. return [t for t, state_keys in self.types.items() if state_keys is None]
  293. def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
  294. """Return the filter split into two: one which assumes it's exclusively
  295. matching against member state, and one which assumes it's matching
  296. against non member state.
  297. This is useful due to the returned filters giving correct results for
  298. `is_full()`, `has_wildcards()`, etc, when operating against maps that
  299. either exclusively contain member events or only contain non-member
  300. events. (Which is the case when dealing with the member vs non-member
  301. state caches).
  302. Returns:
  303. The member and non member filters
  304. """
  305. if EventTypes.Member in self.types:
  306. state_keys = self.types[EventTypes.Member]
  307. if state_keys is None:
  308. member_filter = StateFilter.all()
  309. else:
  310. member_filter = StateFilter(
  311. immutabledict({EventTypes.Member: state_keys})
  312. )
  313. elif self.include_others:
  314. member_filter = StateFilter.all()
  315. else:
  316. member_filter = StateFilter.none()
  317. non_member_filter = StateFilter(
  318. types=immutabledict(
  319. {k: v for k, v in self.types.items() if k != EventTypes.Member}
  320. ),
  321. include_others=self.include_others,
  322. )
  323. return member_filter, non_member_filter
  324. def _decompose_into_four_parts(
  325. self,
  326. ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
  327. """
  328. Decomposes this state filter into 4 constituent parts, which can be
  329. thought of as this:
  330. all? - minus_wildcards + plus_wildcards + plus_state_keys
  331. where
  332. * all represents ALL state
  333. * minus_wildcards represents entire state types to remove
  334. * plus_wildcards represents entire state types to add
  335. * plus_state_keys represents individual state keys to add
  336. See `recompose_from_four_parts` for the other direction of this
  337. correspondence.
  338. """
  339. is_all = self.include_others
  340. excluded_types: Set[str] = {t for t in self.types if is_all}
  341. wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
  342. concrete_keys: Set[StateKey] = set(self.concrete_types())
  343. return (is_all, excluded_types), (wildcard_types, concrete_keys)
  344. @staticmethod
  345. def _recompose_from_four_parts(
  346. all_part: bool,
  347. minus_wildcards: Set[str],
  348. plus_wildcards: Set[str],
  349. plus_state_keys: Set[StateKey],
  350. ) -> "StateFilter":
  351. """
  352. Recomposes a state filter from 4 parts.
  353. See `decompose_into_four_parts` (the other direction of this
  354. correspondence) for descriptions on each of the parts.
  355. """
  356. # {state type -> set of state keys OR None for wildcard}
  357. # (The same structure as that of a StateFilter.)
  358. new_types: Dict[str, Optional[Set[str]]] = {}
  359. # if we start with all, insert the excluded statetypes as empty sets
  360. # to prevent them from being included
  361. if all_part:
  362. new_types.update({state_type: set() for state_type in minus_wildcards})
  363. # insert the plus wildcards
  364. new_types.update({state_type: None for state_type in plus_wildcards})
  365. # insert the specific state keys
  366. for state_type, state_key in plus_state_keys:
  367. if state_type in new_types:
  368. entry = new_types[state_type]
  369. if entry is not None:
  370. entry.add(state_key)
  371. elif not all_part:
  372. # don't insert if the entire type is already included by
  373. # include_others as this would actually shrink the state allowed
  374. # by this filter.
  375. new_types[state_type] = {state_key}
  376. return StateFilter.freeze(new_types, include_others=all_part)
  377. def approx_difference(self, other: "StateFilter") -> "StateFilter":
  378. """
  379. Returns a state filter which represents `self - other`.
  380. This is useful for determining what state remains to be pulled out of the
  381. database if we want the state included by `self` but already have the state
  382. included by `other`.
  383. The returned state filter
  384. - MUST include all state events that are included by this filter (`self`)
  385. unless they are included by `other`;
  386. - MUST NOT include state events not included by this filter (`self`); and
  387. - MAY be an over-approximation: the returned state filter
  388. MAY additionally include some state events from `other`.
  389. This implementation attempts to return the narrowest such state filter.
  390. In the case that `self` contains wildcards for state types where
  391. `other` contains specific state keys, an approximation must be made:
  392. the returned state filter keeps the wildcard, as state filters are not
  393. able to express 'all state keys except some given examples'.
  394. e.g.
  395. StateFilter(m.room.member -> None (wildcard))
  396. minus
  397. StateFilter(m.room.member -> {'@wombat:example.org'})
  398. is approximated as
  399. StateFilter(m.room.member -> None (wildcard))
  400. """
  401. # We first transform self and other into an alternative representation:
  402. # - whether or not they include all events to begin with ('all')
  403. # - if so, which event types are excluded? ('excludes')
  404. # - which entire event types to include ('wildcards')
  405. # - which concrete state keys to include ('concrete state keys')
  406. (self_all, self_excludes), (
  407. self_wildcards,
  408. self_concrete_keys,
  409. ) = self._decompose_into_four_parts()
  410. (other_all, other_excludes), (
  411. other_wildcards,
  412. other_concrete_keys,
  413. ) = other._decompose_into_four_parts()
  414. # Start with an estimate of the difference based on self
  415. new_all = self_all
  416. # Wildcards from the other can be added to the exclusion filter
  417. new_excludes = self_excludes | other_wildcards
  418. # We remove wildcards that appeared as wildcards in the other
  419. new_wildcards = self_wildcards - other_wildcards
  420. # We filter out the concrete state keys that appear in the other
  421. # as wildcards or concrete state keys.
  422. new_concrete_keys = {
  423. (state_type, state_key)
  424. for (state_type, state_key) in self_concrete_keys
  425. if state_type not in other_wildcards
  426. } - other_concrete_keys
  427. if other_all:
  428. if self_all:
  429. # If self starts with all, then we add as wildcards any
  430. # types which appear in the other's exclusion filter (but
  431. # aren't in the self exclusion filter). This is as the other
  432. # filter will return everything BUT the types in its exclusion, so
  433. # we need to add those excluded types that also match the self
  434. # filter as wildcard types in the new filter.
  435. new_wildcards |= other_excludes.difference(self_excludes)
  436. # If other is an `include_others` then the difference isn't.
  437. new_all = False
  438. # (We have no need for excludes when we don't start with all, as there
  439. # is nothing to exclude.)
  440. new_excludes = set()
  441. # We also filter out all state types that aren't in the exclusion
  442. # list of the other.
  443. new_wildcards &= other_excludes
  444. new_concrete_keys = {
  445. (state_type, state_key)
  446. for (state_type, state_key) in new_concrete_keys
  447. if state_type in other_excludes
  448. }
  449. # Transform our newly-constructed state filter from the alternative
  450. # representation back into the normal StateFilter representation.
  451. return StateFilter._recompose_from_four_parts(
  452. new_all, new_excludes, new_wildcards, new_concrete_keys
  453. )
  454. def must_await_full_state(self, is_mine_id: Callable[[str], bool]) -> bool:
  455. """Check if we need to wait for full state to complete to calculate this state
  456. If we have a state filter which is completely satisfied even with partial
  457. state, then we don't need to await_full_state before we can return it.
  458. Args:
  459. is_mine_id: a callable which confirms if a given state_key matches a mxid
  460. of a local user
  461. """
  462. # if we haven't requested membership events, then it depends on the value of
  463. # 'include_others'
  464. if EventTypes.Member not in self.types:
  465. return self.include_others
  466. # if we're looking for *all* membership events, then we have to wait
  467. member_state_keys = self.types[EventTypes.Member]
  468. if member_state_keys is None:
  469. return True
  470. # otherwise, consider whose membership we are looking for. If it's entirely
  471. # local users, then we don't need to wait.
  472. for state_key in member_state_keys:
  473. if not is_mine_id(state_key):
  474. # remote user
  475. return True
  476. # local users only
  477. return False
  478. _ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
  479. _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
  480. types=immutabledict({EventTypes.Member: frozenset()}), include_others=True
  481. )
  482. _NONE_STATE_FILTER = StateFilter(types=immutabledict(), include_others=False)