state.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import Iterable, List, TypeVar
  17. from six import iteritems, itervalues
  18. import attr
  19. from twisted.internet import defer
  20. from synapse.api.constants import EventTypes
  21. from synapse.types import StateMap
  22. logger = logging.getLogger(__name__)
  23. # Used for generic functions below
  24. T = TypeVar("T")
  25. @attr.s(slots=True)
  26. class StateFilter(object):
  27. """A filter used when querying for state.
  28. Attributes:
  29. types (dict[str, set[str]|None]): Map from type to set of state keys (or
  30. None). This specifies which state_keys for the given type to fetch
  31. from the DB. If None then all events with that type are fetched. If
  32. the set is empty then no events with that type are fetched.
  33. include_others (bool): Whether to fetch events with types that do not
  34. appear in `types`.
  35. """
  36. types = attr.ib()
  37. include_others = attr.ib(default=False)
  38. def __attrs_post_init__(self):
  39. # If `include_others` is set we canonicalise the filter by removing
  40. # wildcards from the types dictionary
  41. if self.include_others:
  42. self.types = {k: v for k, v in iteritems(self.types) if v is not None}
  43. @staticmethod
  44. def all():
  45. """Creates a filter that fetches everything.
  46. Returns:
  47. StateFilter
  48. """
  49. return StateFilter(types={}, include_others=True)
  50. @staticmethod
  51. def none():
  52. """Creates a filter that fetches nothing.
  53. Returns:
  54. StateFilter
  55. """
  56. return StateFilter(types={}, include_others=False)
  57. @staticmethod
  58. def from_types(types):
  59. """Creates a filter that only fetches the given types
  60. Args:
  61. types (Iterable[tuple[str, str|None]]): A list of type and state
  62. keys to fetch. A state_key of None fetches everything for
  63. that type
  64. Returns:
  65. StateFilter
  66. """
  67. type_dict = {}
  68. for typ, s in types:
  69. if typ in type_dict:
  70. if type_dict[typ] is None:
  71. continue
  72. if s is None:
  73. type_dict[typ] = None
  74. continue
  75. type_dict.setdefault(typ, set()).add(s)
  76. return StateFilter(types=type_dict)
  77. @staticmethod
  78. def from_lazy_load_member_list(members):
  79. """Creates a filter that returns all non-member events, plus the member
  80. events for the given users
  81. Args:
  82. members (iterable[str]): Set of user IDs
  83. Returns:
  84. StateFilter
  85. """
  86. return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
  87. def return_expanded(self):
  88. """Creates a new StateFilter where type wild cards have been removed
  89. (except for memberships). The returned filter is a superset of the
  90. current one, i.e. anything that passes the current filter will pass
  91. the returned filter.
  92. This helps the caching as the DictionaryCache knows if it has *all* the
  93. state, but does not know if it has all of the keys of a particular type,
  94. which makes wildcard lookups expensive unless we have a complete cache.
  95. Hence, if we are doing a wildcard lookup, populate the cache fully so
  96. that we can do an efficient lookup next time.
  97. Note that since we have two caches, one for membership events and one for
  98. other events, we can be a bit more clever than simply returning
  99. `StateFilter.all()` if `has_wildcards()` is True.
  100. We return a StateFilter where:
  101. 1. the list of membership events to return is the same
  102. 2. if there is a wildcard that matches non-member events we
  103. return all non-member events
  104. Returns:
  105. StateFilter
  106. """
  107. if self.is_full():
  108. # If we're going to return everything then there's nothing to do
  109. return self
  110. if not self.has_wildcards():
  111. # If there are no wild cards, there's nothing to do
  112. return self
  113. if EventTypes.Member in self.types:
  114. get_all_members = self.types[EventTypes.Member] is None
  115. else:
  116. get_all_members = self.include_others
  117. has_non_member_wildcard = self.include_others or any(
  118. state_keys is None
  119. for t, state_keys in iteritems(self.types)
  120. if t != EventTypes.Member
  121. )
  122. if not has_non_member_wildcard:
  123. # If there are no non-member wild cards we can just return ourselves
  124. return self
  125. if get_all_members:
  126. # We want to return everything.
  127. return StateFilter.all()
  128. else:
  129. # We want to return all non-members, but only particular
  130. # memberships
  131. return StateFilter(
  132. types={EventTypes.Member: self.types[EventTypes.Member]},
  133. include_others=True,
  134. )
  135. def make_sql_filter_clause(self):
  136. """Converts the filter to an SQL clause.
  137. For example:
  138. f = StateFilter.from_types([("m.room.create", "")])
  139. clause, args = f.make_sql_filter_clause()
  140. clause == "(type = ? AND state_key = ?)"
  141. args == ['m.room.create', '']
  142. Returns:
  143. tuple[str, list]: The SQL string (may be empty) and arguments. An
  144. empty SQL string is returned when the filter matches everything
  145. (i.e. is "full").
  146. """
  147. where_clause = ""
  148. where_args = []
  149. if self.is_full():
  150. return where_clause, where_args
  151. if not self.include_others and not self.types:
  152. # i.e. this is an empty filter, so we need to return a clause that
  153. # will match nothing
  154. return "1 = 2", []
  155. # First we build up a lost of clauses for each type/state_key combo
  156. clauses = []
  157. for etype, state_keys in iteritems(self.types):
  158. if state_keys is None:
  159. clauses.append("(type = ?)")
  160. where_args.append(etype)
  161. continue
  162. for state_key in state_keys:
  163. clauses.append("(type = ? AND state_key = ?)")
  164. where_args.extend((etype, state_key))
  165. # This will match anything that appears in `self.types`
  166. where_clause = " OR ".join(clauses)
  167. # If we want to include stuff that's not in the types dict then we add
  168. # a `OR type NOT IN (...)` clause to the end.
  169. if self.include_others:
  170. if where_clause:
  171. where_clause += " OR "
  172. where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
  173. where_args.extend(self.types)
  174. return where_clause, where_args
  175. def max_entries_returned(self):
  176. """Returns the maximum number of entries this filter will return if
  177. known, otherwise returns None.
  178. For example a simple state filter asking for `("m.room.create", "")`
  179. will return 1, whereas the default state filter will return None.
  180. This is used to bail out early if the right number of entries have been
  181. fetched.
  182. """
  183. if self.has_wildcards():
  184. return None
  185. return len(self.concrete_types())
  186. def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
  187. """Returns the state filtered with by this StateFilter
  188. Args:
  189. state: The state map to filter
  190. Returns:
  191. The filtered state map
  192. """
  193. if self.is_full():
  194. return dict(state_dict)
  195. filtered_state = {}
  196. for k, v in iteritems(state_dict):
  197. typ, state_key = k
  198. if typ in self.types:
  199. state_keys = self.types[typ]
  200. if state_keys is None or state_key in state_keys:
  201. filtered_state[k] = v
  202. elif self.include_others:
  203. filtered_state[k] = v
  204. return filtered_state
  205. def is_full(self):
  206. """Whether this filter fetches everything or not
  207. Returns:
  208. bool
  209. """
  210. return self.include_others and not self.types
  211. def has_wildcards(self):
  212. """Whether the filter includes wildcards or is attempting to fetch
  213. specific state.
  214. Returns:
  215. bool
  216. """
  217. return self.include_others or any(
  218. state_keys is None for state_keys in itervalues(self.types)
  219. )
  220. def concrete_types(self):
  221. """Returns a list of concrete type/state_keys (i.e. not None) that
  222. will be fetched. This will be a complete list if `has_wildcards`
  223. returns False, but otherwise will be a subset (or even empty).
  224. Returns:
  225. list[tuple[str,str]]
  226. """
  227. return [
  228. (t, s)
  229. for t, state_keys in iteritems(self.types)
  230. if state_keys is not None
  231. for s in state_keys
  232. ]
  233. def get_member_split(self):
  234. """Return the filter split into two: one which assumes it's exclusively
  235. matching against member state, and one which assumes it's matching
  236. against non member state.
  237. This is useful due to the returned filters giving correct results for
  238. `is_full()`, `has_wildcards()`, etc, when operating against maps that
  239. either exclusively contain member events or only contain non-member
  240. events. (Which is the case when dealing with the member vs non-member
  241. state caches).
  242. Returns:
  243. tuple[StateFilter, StateFilter]: The member and non member filters
  244. """
  245. if EventTypes.Member in self.types:
  246. state_keys = self.types[EventTypes.Member]
  247. if state_keys is None:
  248. member_filter = StateFilter.all()
  249. else:
  250. member_filter = StateFilter({EventTypes.Member: state_keys})
  251. elif self.include_others:
  252. member_filter = StateFilter.all()
  253. else:
  254. member_filter = StateFilter.none()
  255. non_member_filter = StateFilter(
  256. types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
  257. include_others=self.include_others,
  258. )
  259. return member_filter, non_member_filter
  260. class StateGroupStorage(object):
  261. """High level interface to fetching state for event.
  262. """
  263. def __init__(self, hs, stores):
  264. self.stores = stores
  265. def get_state_group_delta(self, state_group: int):
  266. """Given a state group try to return a previous group and a delta between
  267. the old and the new.
  268. Returns:
  269. Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
  270. (prev_group, delta_ids)
  271. """
  272. return self.stores.state.get_state_group_delta(state_group)
  273. @defer.inlineCallbacks
  274. def get_state_groups_ids(self, _room_id, event_ids):
  275. """Get the event IDs of all the state for the state groups for the given events
  276. Args:
  277. _room_id (str): id of the room for these events
  278. event_ids (iterable[str]): ids of the events
  279. Returns:
  280. Deferred[dict[int, StateMap[str]]]:
  281. dict of state_group_id -> (dict of (type, state_key) -> event id)
  282. """
  283. if not event_ids:
  284. return {}
  285. event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
  286. groups = set(itervalues(event_to_groups))
  287. group_to_state = yield self.stores.state._get_state_for_groups(groups)
  288. return group_to_state
  289. @defer.inlineCallbacks
  290. def get_state_ids_for_group(self, state_group):
  291. """Get the event IDs of all the state in the given state group
  292. Args:
  293. state_group (int)
  294. Returns:
  295. Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
  296. """
  297. group_to_state = yield self._get_state_for_groups((state_group,))
  298. return group_to_state[state_group]
  299. @defer.inlineCallbacks
  300. def get_state_groups(self, room_id, event_ids):
  301. """ Get the state groups for the given list of event_ids
  302. Returns:
  303. Deferred[dict[int, list[EventBase]]]:
  304. dict of state_group_id -> list of state events.
  305. """
  306. if not event_ids:
  307. return {}
  308. group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
  309. state_event_map = yield self.stores.main.get_events(
  310. [
  311. ev_id
  312. for group_ids in itervalues(group_to_ids)
  313. for ev_id in itervalues(group_ids)
  314. ],
  315. get_prev_content=False,
  316. )
  317. return {
  318. group: [
  319. state_event_map[v]
  320. for v in itervalues(event_id_map)
  321. if v in state_event_map
  322. ]
  323. for group, event_id_map in iteritems(group_to_ids)
  324. }
  325. def _get_state_groups_from_groups(
  326. self, groups: List[int], state_filter: StateFilter
  327. ):
  328. """Returns the state groups for a given set of groups, filtering on
  329. types of state events.
  330. Args:
  331. groups: list of state group IDs to query
  332. state_filter: The state filter used to fetch state
  333. from the database.
  334. Returns:
  335. Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
  336. """
  337. return self.stores.state._get_state_groups_from_groups(groups, state_filter)
  338. @defer.inlineCallbacks
  339. def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
  340. """Given a list of event_ids and type tuples, return a list of state
  341. dicts for each event.
  342. Args:
  343. event_ids (list[string])
  344. state_filter (StateFilter): The state filter used to fetch state
  345. from the database.
  346. Returns:
  347. deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
  348. """
  349. event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
  350. groups = set(itervalues(event_to_groups))
  351. group_to_state = yield self.stores.state._get_state_for_groups(
  352. groups, state_filter
  353. )
  354. state_event_map = yield self.stores.main.get_events(
  355. [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
  356. get_prev_content=False,
  357. )
  358. event_to_state = {
  359. event_id: {
  360. k: state_event_map[v]
  361. for k, v in iteritems(group_to_state[group])
  362. if v in state_event_map
  363. }
  364. for event_id, group in iteritems(event_to_groups)
  365. }
  366. return {event: event_to_state[event] for event in event_ids}
  367. @defer.inlineCallbacks
  368. def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
  369. """
  370. Get the state dicts corresponding to a list of events, containing the event_ids
  371. of the state events (as opposed to the events themselves)
  372. Args:
  373. event_ids(list(str)): events whose state should be returned
  374. state_filter (StateFilter): The state filter used to fetch state
  375. from the database.
  376. Returns:
  377. A deferred dict from event_id -> (type, state_key) -> event_id
  378. """
  379. event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
  380. groups = set(itervalues(event_to_groups))
  381. group_to_state = yield self.stores.state._get_state_for_groups(
  382. groups, state_filter
  383. )
  384. event_to_state = {
  385. event_id: group_to_state[group]
  386. for event_id, group in iteritems(event_to_groups)
  387. }
  388. return {event: event_to_state[event] for event in event_ids}
  389. @defer.inlineCallbacks
  390. def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
  391. """
  392. Get the state dict corresponding to a particular event
  393. Args:
  394. event_id(str): event whose state should be returned
  395. state_filter (StateFilter): The state filter used to fetch state
  396. from the database.
  397. Returns:
  398. A deferred dict from (type, state_key) -> state_event
  399. """
  400. state_map = yield self.get_state_for_events([event_id], state_filter)
  401. return state_map[event_id]
  402. @defer.inlineCallbacks
  403. def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
  404. """
  405. Get the state dict corresponding to a particular event
  406. Args:
  407. event_id(str): event whose state should be returned
  408. state_filter (StateFilter): The state filter used to fetch state
  409. from the database.
  410. Returns:
  411. A deferred dict from (type, state_key) -> state_event
  412. """
  413. state_map = yield self.get_state_ids_for_events([event_id], state_filter)
  414. return state_map[event_id]
  415. def _get_state_for_groups(
  416. self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
  417. ):
  418. """Gets the state at each of a list of state groups, optionally
  419. filtering by type/state_key
  420. Args:
  421. groups (iterable[int]): list of state groups for which we want
  422. to get the state.
  423. state_filter (StateFilter): The state filter used to fetch state
  424. from the database.
  425. Returns:
  426. Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
  427. """
  428. return self.stores.state._get_state_for_groups(groups, state_filter)
  429. def store_state_group(
  430. self, event_id, room_id, prev_group, delta_ids, current_state_ids
  431. ):
  432. """Store a new set of state, returning a newly assigned state group.
  433. Args:
  434. event_id (str): The event ID for which the state was calculated
  435. room_id (str)
  436. prev_group (int|None): A previous state group for the room, optional.
  437. delta_ids (dict|None): The delta between state at `prev_group` and
  438. `current_state_ids`, if `prev_group` was given. Same format as
  439. `current_state_ids`.
  440. current_state_ids (dict): The state to store. Map of (type, state_key)
  441. to event_id.
  442. Returns:
  443. Deferred[int]: The state group ID
  444. """
  445. return self.stores.state.store_state_group(
  446. event_id, room_id, prev_group, delta_ids, current_state_ids
  447. )