state.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014, 2015 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet import defer
  16. from synapse.util.logutils import log_function
  17. from synapse.util.async import run_on_reactor
  18. from synapse.util.caches.expiringcache import ExpiringCache
  19. from synapse.api.constants import EventTypes
  20. from synapse.api.errors import AuthError
  21. from synapse.api.auth import AuthEventTypes
  22. from synapse.events.snapshot import EventContext
  23. from collections import namedtuple
  24. import logging
  25. import hashlib
  26. logger = logging.getLogger(__name__)
  27. def _get_state_key_from_event(event):
  28. return event.state_key
  29. KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
  30. SIZE_OF_CACHE = 1000
  31. EVICTION_TIMEOUT_SECONDS = 20
  32. class _StateCacheEntry(object):
  33. def __init__(self, state, state_group, ts):
  34. self.state = state
  35. self.state_group = state_group
  36. class StateHandler(object):
  37. """ Responsible for doing state conflict resolution.
  38. """
  39. def __init__(self, hs):
  40. self.clock = hs.get_clock()
  41. self.store = hs.get_datastore()
  42. self.hs = hs
  43. # dict of set of event_ids -> _StateCacheEntry.
  44. self._state_cache = None
  45. def start_caching(self):
  46. logger.debug("start_caching")
  47. self._state_cache = ExpiringCache(
  48. cache_name="state_cache",
  49. clock=self.clock,
  50. max_len=SIZE_OF_CACHE,
  51. expiry_ms=EVICTION_TIMEOUT_SECONDS*1000,
  52. reset_expiry_on_get=True,
  53. )
  54. self._state_cache.start()
  55. @defer.inlineCallbacks
  56. def get_current_state(self, room_id, event_type=None, state_key=""):
  57. """ Returns the current state for the room as a list. This is done by
  58. calling `get_latest_events_in_room` to get the leading edges of the
  59. event graph and then resolving any of the state conflicts.
  60. This is equivalent to getting the state of an event that were to send
  61. next before receiving any new events.
  62. If `event_type` is specified, then the method returns only the one
  63. event (or None) with that `event_type` and `state_key`.
  64. """
  65. event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
  66. cache = None
  67. if self._state_cache is not None:
  68. cache = self._state_cache.get(frozenset(event_ids), None)
  69. if cache:
  70. cache.ts = self.clock.time_msec()
  71. state = cache.state
  72. else:
  73. res = yield self.resolve_state_groups(room_id, event_ids)
  74. state = res[1]
  75. if event_type:
  76. defer.returnValue(state.get((event_type, state_key)))
  77. return
  78. defer.returnValue(state)
  79. @defer.inlineCallbacks
  80. def compute_event_context(self, event, old_state=None, outlier=False):
  81. """ Fills out the context with the `current state` of the graph. The
  82. `current state` here is defined to be the state of the event graph
  83. just before the event - i.e. it never includes `event`
  84. If `event` has `auth_events` then this will also fill out the
  85. `auth_events` field on `context` from the `current_state`.
  86. Args:
  87. event (EventBase)
  88. Returns:
  89. an EventContext
  90. """
  91. yield run_on_reactor()
  92. context = EventContext()
  93. if outlier:
  94. # If this is an outlier, then we know it shouldn't have any current
  95. # state. Certainly store.get_current_state won't return any, and
  96. # persisting the event won't store the state group.
  97. if old_state:
  98. context.current_state = {
  99. (s.type, s.state_key): s for s in old_state
  100. }
  101. else:
  102. context.current_state = {}
  103. context.prev_state_events = []
  104. context.state_group = None
  105. defer.returnValue(context)
  106. if old_state:
  107. context.current_state = {
  108. (s.type, s.state_key): s for s in old_state
  109. }
  110. context.state_group = None
  111. if event.is_state():
  112. key = (event.type, event.state_key)
  113. if key in context.current_state:
  114. replaces = context.current_state[key]
  115. if replaces.event_id != event.event_id: # Paranoia check
  116. event.unsigned["replaces_state"] = replaces.event_id
  117. context.prev_state_events = []
  118. defer.returnValue(context)
  119. if event.is_state():
  120. ret = yield self.resolve_state_groups(
  121. event.room_id, [e for e, _ in event.prev_events],
  122. event_type=event.type,
  123. state_key=event.state_key,
  124. )
  125. else:
  126. ret = yield self.resolve_state_groups(
  127. event.room_id, [e for e, _ in event.prev_events],
  128. )
  129. group, curr_state, prev_state = ret
  130. context.current_state = curr_state
  131. context.state_group = group if not event.is_state() else None
  132. if event.is_state():
  133. key = (event.type, event.state_key)
  134. if key in context.current_state:
  135. replaces = context.current_state[key]
  136. event.unsigned["replaces_state"] = replaces.event_id
  137. context.prev_state_events = prev_state
  138. defer.returnValue(context)
  139. @defer.inlineCallbacks
  140. @log_function
  141. def resolve_state_groups(self, room_id, event_ids, event_type=None, state_key=""):
  142. """ Given a list of event_ids this method fetches the state at each
  143. event, resolves conflicts between them and returns them.
  144. Return format is a tuple: (`state_group`, `state_events`), where the
  145. first is the name of a state group if one and only one is involved,
  146. otherwise `None`.
  147. """
  148. logger.debug("resolve_state_groups event_ids %s", event_ids)
  149. if self._state_cache is not None:
  150. cache = self._state_cache.get(frozenset(event_ids), None)
  151. if cache and cache.state_group:
  152. cache.ts = self.clock.time_msec()
  153. prev_state = cache.state.get((event_type, state_key), None)
  154. if prev_state:
  155. prev_state = prev_state.event_id
  156. prev_states = [prev_state]
  157. else:
  158. prev_states = []
  159. defer.returnValue(
  160. (cache.state_group, cache.state, prev_states)
  161. )
  162. state_groups = yield self.store.get_state_groups(
  163. room_id, event_ids
  164. )
  165. logger.debug(
  166. "resolve_state_groups state_groups %s",
  167. state_groups.keys()
  168. )
  169. group_names = set(state_groups.keys())
  170. if len(group_names) == 1:
  171. name, state_list = state_groups.items().pop()
  172. state = {
  173. (e.type, e.state_key): e
  174. for e in state_list
  175. }
  176. prev_state = state.get((event_type, state_key), None)
  177. if prev_state:
  178. prev_state = prev_state.event_id
  179. prev_states = [prev_state]
  180. else:
  181. prev_states = []
  182. if self._state_cache is not None:
  183. cache = _StateCacheEntry(
  184. state=state,
  185. state_group=name,
  186. ts=self.clock.time_msec()
  187. )
  188. self._state_cache[frozenset(event_ids)] = cache
  189. defer.returnValue((name, state, prev_states))
  190. new_state, prev_states = self._resolve_events(
  191. state_groups.values(), event_type, state_key
  192. )
  193. if self._state_cache is not None:
  194. cache = _StateCacheEntry(
  195. state=new_state,
  196. state_group=None,
  197. ts=self.clock.time_msec()
  198. )
  199. self._state_cache[frozenset(event_ids)] = cache
  200. defer.returnValue((None, new_state, prev_states))
  201. def resolve_events(self, state_sets, event):
  202. if event.is_state():
  203. return self._resolve_events(
  204. state_sets, event.type, event.state_key
  205. )
  206. else:
  207. return self._resolve_events(state_sets)
  208. def _resolve_events(self, state_sets, event_type=None, state_key=""):
  209. state = {}
  210. for st in state_sets:
  211. for e in st:
  212. state.setdefault(
  213. (e.type, e.state_key),
  214. {}
  215. )[e.event_id] = e
  216. unconflicted_state = {
  217. k: v.values()[0] for k, v in state.items()
  218. if len(v.values()) == 1
  219. }
  220. conflicted_state = {
  221. k: v.values()
  222. for k, v in state.items()
  223. if len(v.values()) > 1
  224. }
  225. if event_type:
  226. prev_states_events = conflicted_state.get(
  227. (event_type, state_key), []
  228. )
  229. prev_states = [s.event_id for s in prev_states_events]
  230. else:
  231. prev_states = []
  232. auth_events = {
  233. k: e for k, e in unconflicted_state.items()
  234. if k[0] in AuthEventTypes
  235. }
  236. try:
  237. resolved_state = self._resolve_state_events(
  238. conflicted_state, auth_events
  239. )
  240. except:
  241. logger.exception("Failed to resolve state")
  242. raise
  243. new_state = unconflicted_state
  244. new_state.update(resolved_state)
  245. return new_state, prev_states
  246. @log_function
  247. def _resolve_state_events(self, conflicted_state, auth_events):
  248. """ This is where we actually decide which of the conflicted state to
  249. use.
  250. We resolve conflicts in the following order:
  251. 1. power levels
  252. 2. memberships
  253. 3. other events.
  254. """
  255. resolved_state = {}
  256. power_key = (EventTypes.PowerLevels, "")
  257. if power_key in conflicted_state.items():
  258. power_levels = conflicted_state[power_key]
  259. resolved_state[power_key] = self._resolve_auth_events(power_levels)
  260. auth_events.update(resolved_state)
  261. for key, events in conflicted_state.items():
  262. if key[0] == EventTypes.JoinRules:
  263. resolved_state[key] = self._resolve_auth_events(
  264. events,
  265. auth_events
  266. )
  267. auth_events.update(resolved_state)
  268. for key, events in conflicted_state.items():
  269. if key[0] == EventTypes.Member:
  270. resolved_state[key] = self._resolve_auth_events(
  271. events,
  272. auth_events
  273. )
  274. auth_events.update(resolved_state)
  275. for key, events in conflicted_state.items():
  276. if key not in resolved_state:
  277. resolved_state[key] = self._resolve_normal_events(
  278. events, auth_events
  279. )
  280. return resolved_state
  281. def _resolve_auth_events(self, events, auth_events):
  282. reverse = [i for i in reversed(self._ordered_events(events))]
  283. auth_events = dict(auth_events)
  284. prev_event = reverse[0]
  285. for event in reverse[1:]:
  286. auth_events[(prev_event.type, prev_event.state_key)] = prev_event
  287. try:
  288. # FIXME: hs.get_auth() is bad style, but we need to do it to
  289. # get around circular deps.
  290. self.hs.get_auth().check(event, auth_events)
  291. prev_event = event
  292. except AuthError:
  293. return prev_event
  294. return event
  295. def _resolve_normal_events(self, events, auth_events):
  296. for event in self._ordered_events(events):
  297. try:
  298. # FIXME: hs.get_auth() is bad style, but we need to do it to
  299. # get around circular deps.
  300. self.hs.get_auth().check(event, auth_events)
  301. return event
  302. except AuthError:
  303. pass
  304. # Use the last event (the one with the least depth) if they all fail
  305. # the auth check.
  306. return event
  307. def _ordered_events(self, events):
  308. def key_func(e):
  309. return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
  310. return sorted(events, key=key_func)