test_state.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from mock import Mock
  16. from twisted.internet import defer
  17. from synapse.api.auth import Auth
  18. from synapse.api.constants import EventTypes, Membership, RoomVersions
  19. from synapse.events import FrozenEvent
  20. from synapse.state import StateHandler, StateResolutionHandler
  21. from tests import unittest
  22. from .utils import MockClock
  23. _next_event_id = 1000
  24. def create_event(
  25. name=None,
  26. type=None,
  27. state_key=None,
  28. depth=2,
  29. event_id=None,
  30. prev_events=[],
  31. **kwargs
  32. ):
  33. global _next_event_id
  34. if not event_id:
  35. _next_event_id += 1
  36. event_id = "$%s:test" % (_next_event_id,)
  37. if not name:
  38. if state_key is not None:
  39. name = "<%s-%s, %s>" % (type, state_key, event_id)
  40. else:
  41. name = "<%s, %s>" % (type, event_id)
  42. d = {
  43. "event_id": event_id,
  44. "type": type,
  45. "sender": "@user_id:example.com",
  46. "room_id": "!room_id:example.com",
  47. "depth": depth,
  48. "prev_events": prev_events,
  49. }
  50. if state_key is not None:
  51. d["state_key"] = state_key
  52. d.update(kwargs)
  53. event = FrozenEvent(d)
  54. return event
  55. class StateGroupStore(object):
  56. def __init__(self):
  57. self._event_to_state_group = {}
  58. self._group_to_state = {}
  59. self._event_id_to_event = {}
  60. self._next_group = 1
  61. def get_state_groups_ids(self, room_id, event_ids):
  62. groups = {}
  63. for event_id in event_ids:
  64. group = self._event_to_state_group.get(event_id)
  65. if group:
  66. groups[group] = self._group_to_state[group]
  67. return defer.succeed(groups)
  68. def store_state_group(
  69. self, event_id, room_id, prev_group, delta_ids, current_state_ids
  70. ):
  71. state_group = self._next_group
  72. self._next_group += 1
  73. self._group_to_state[state_group] = dict(current_state_ids)
  74. return state_group
  75. def get_events(self, event_ids, **kwargs):
  76. return {
  77. e_id: self._event_id_to_event[e_id]
  78. for e_id in event_ids
  79. if e_id in self._event_id_to_event
  80. }
  81. def get_state_group_delta(self, name):
  82. return (None, None)
  83. def register_events(self, events):
  84. for e in events:
  85. self._event_id_to_event[e.event_id] = e
  86. def register_event_context(self, event, context):
  87. self._event_to_state_group[event.event_id] = context.state_group
  88. def register_event_id_state_group(self, event_id, state_group):
  89. self._event_to_state_group[event_id] = state_group
  90. def get_room_version(self, room_id):
  91. return RoomVersions.V1
  92. class DictObj(dict):
  93. def __init__(self, **kwargs):
  94. super(DictObj, self).__init__(kwargs)
  95. self.__dict__ = self
  96. class Graph(object):
  97. def __init__(self, nodes, edges):
  98. events = {}
  99. clobbered = set(events.keys())
  100. for event_id, fields in nodes.items():
  101. refs = edges.get(event_id)
  102. if refs:
  103. clobbered.difference_update(refs)
  104. prev_events = [(r, {}) for r in refs]
  105. else:
  106. prev_events = []
  107. events[event_id] = create_event(
  108. event_id=event_id, prev_events=prev_events, **fields
  109. )
  110. self._leaves = clobbered
  111. self._events = sorted(events.values(), key=lambda e: e.depth)
  112. def walk(self):
  113. return iter(self._events)
  114. def get_leaves(self):
  115. return (self._events[i] for i in self._leaves)
  116. class StateTestCase(unittest.TestCase):
  117. def setUp(self):
  118. self.store = StateGroupStore()
  119. hs = Mock(
  120. spec_set=[
  121. "get_datastore",
  122. "get_auth",
  123. "get_state_handler",
  124. "get_clock",
  125. "get_state_resolution_handler",
  126. ]
  127. )
  128. hs.get_datastore.return_value = self.store
  129. hs.get_state_handler.return_value = None
  130. hs.get_clock.return_value = MockClock()
  131. hs.get_auth.return_value = Auth(hs)
  132. hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
  133. self.state = StateHandler(hs)
  134. self.event_id = 0
  135. @defer.inlineCallbacks
  136. def test_branch_no_conflict(self):
  137. graph = Graph(
  138. nodes={
  139. "START": DictObj(
  140. type=EventTypes.Create, state_key="", content={}, depth=1,
  141. ),
  142. "A": DictObj(type=EventTypes.Message, depth=2),
  143. "B": DictObj(type=EventTypes.Message, depth=3),
  144. "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
  145. "D": DictObj(type=EventTypes.Message, depth=4),
  146. },
  147. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  148. )
  149. self.store.register_events(graph.walk())
  150. context_store = {}
  151. for event in graph.walk():
  152. context = yield self.state.compute_event_context(event)
  153. self.store.register_event_context(event, context)
  154. context_store[event.event_id] = context
  155. prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
  156. self.assertEqual(2, len(prev_state_ids))
  157. @defer.inlineCallbacks
  158. def test_branch_basic_conflict(self):
  159. graph = Graph(
  160. nodes={
  161. "START": DictObj(
  162. type=EventTypes.Create,
  163. state_key="",
  164. content={"creator": "@user_id:example.com"},
  165. depth=1,
  166. ),
  167. "A": DictObj(
  168. type=EventTypes.Member,
  169. state_key="@user_id:example.com",
  170. content={"membership": Membership.JOIN},
  171. membership=Membership.JOIN,
  172. depth=2,
  173. ),
  174. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  175. "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
  176. "D": DictObj(type=EventTypes.Message, depth=5),
  177. },
  178. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  179. )
  180. self.store.register_events(graph.walk())
  181. context_store = {}
  182. for event in graph.walk():
  183. context = yield self.state.compute_event_context(event)
  184. self.store.register_event_context(event, context)
  185. context_store[event.event_id] = context
  186. prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
  187. self.assertSetEqual(
  188. {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
  189. )
  190. @defer.inlineCallbacks
  191. def test_branch_have_banned_conflict(self):
  192. graph = Graph(
  193. nodes={
  194. "START": DictObj(
  195. type=EventTypes.Create,
  196. state_key="",
  197. content={"creator": "@user_id:example.com"},
  198. depth=1,
  199. ),
  200. "A": DictObj(
  201. type=EventTypes.Member,
  202. state_key="@user_id:example.com",
  203. content={"membership": Membership.JOIN},
  204. membership=Membership.JOIN,
  205. depth=2,
  206. ),
  207. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  208. "C": DictObj(
  209. type=EventTypes.Member,
  210. state_key="@user_id_2:example.com",
  211. content={"membership": Membership.BAN},
  212. membership=Membership.BAN,
  213. depth=4,
  214. ),
  215. "D": DictObj(
  216. type=EventTypes.Name,
  217. state_key="",
  218. depth=4,
  219. sender="@user_id_2:example.com",
  220. ),
  221. "E": DictObj(type=EventTypes.Message, depth=5),
  222. },
  223. edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
  224. )
  225. self.store.register_events(graph.walk())
  226. context_store = {}
  227. for event in graph.walk():
  228. context = yield self.state.compute_event_context(event)
  229. self.store.register_event_context(event, context)
  230. context_store[event.event_id] = context
  231. prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
  232. self.assertSetEqual(
  233. {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
  234. )
  235. @defer.inlineCallbacks
  236. def test_branch_have_perms_conflict(self):
  237. userid1 = "@user_id:example.com"
  238. userid2 = "@user_id2:example.com"
  239. nodes = {
  240. "A1": DictObj(
  241. type=EventTypes.Create,
  242. state_key="",
  243. content={"creator": userid1},
  244. depth=1,
  245. ),
  246. "A2": DictObj(
  247. type=EventTypes.Member,
  248. state_key=userid1,
  249. content={"membership": Membership.JOIN},
  250. membership=Membership.JOIN,
  251. ),
  252. "A3": DictObj(
  253. type=EventTypes.Member,
  254. state_key=userid2,
  255. content={"membership": Membership.JOIN},
  256. membership=Membership.JOIN,
  257. ),
  258. "A4": DictObj(
  259. type=EventTypes.PowerLevels,
  260. state_key="",
  261. content={
  262. "events": {"m.room.name": 50},
  263. "users": {userid1: 100, userid2: 60},
  264. },
  265. ),
  266. "A5": DictObj(type=EventTypes.Name, state_key=""),
  267. "B": DictObj(
  268. type=EventTypes.PowerLevels,
  269. state_key="",
  270. content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
  271. ),
  272. "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
  273. "D": DictObj(type=EventTypes.Message),
  274. }
  275. edges = {
  276. "A2": ["A1"],
  277. "A3": ["A2"],
  278. "A4": ["A3"],
  279. "A5": ["A4"],
  280. "B": ["A5"],
  281. "C": ["A5"],
  282. "D": ["B", "C"],
  283. }
  284. self._add_depths(nodes, edges)
  285. graph = Graph(nodes, edges)
  286. self.store.register_events(graph.walk())
  287. context_store = {}
  288. for event in graph.walk():
  289. context = yield self.state.compute_event_context(event)
  290. self.store.register_event_context(event, context)
  291. context_store[event.event_id] = context
  292. prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
  293. self.assertSetEqual(
  294. {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
  295. )
  296. def _add_depths(self, nodes, edges):
  297. def _get_depth(ev):
  298. node = nodes[ev]
  299. if 'depth' not in node:
  300. prevs = edges[ev]
  301. depth = max(_get_depth(prev) for prev in prevs) + 1
  302. node['depth'] = depth
  303. return node['depth']
  304. for n in nodes:
  305. _get_depth(n)
  306. @defer.inlineCallbacks
  307. def test_annotate_with_old_message(self):
  308. event = create_event(type="test_message", name="event")
  309. old_state = [
  310. create_event(type="test1", state_key="1"),
  311. create_event(type="test1", state_key="2"),
  312. create_event(type="test2", state_key=""),
  313. ]
  314. context = yield self.state.compute_event_context(event, old_state=old_state)
  315. current_state_ids = yield context.get_current_state_ids(self.store)
  316. self.assertEqual(
  317. set(e.event_id for e in old_state), set(current_state_ids.values())
  318. )
  319. self.assertIsNotNone(context.state_group)
  320. @defer.inlineCallbacks
  321. def test_annotate_with_old_state(self):
  322. event = create_event(type="state", state_key="", name="event")
  323. old_state = [
  324. create_event(type="test1", state_key="1"),
  325. create_event(type="test1", state_key="2"),
  326. create_event(type="test2", state_key=""),
  327. ]
  328. context = yield self.state.compute_event_context(event, old_state=old_state)
  329. prev_state_ids = yield context.get_prev_state_ids(self.store)
  330. self.assertEqual(
  331. set(e.event_id for e in old_state), set(prev_state_ids.values())
  332. )
  333. @defer.inlineCallbacks
  334. def test_trivial_annotate_message(self):
  335. prev_event_id = "prev_event_id"
  336. event = create_event(
  337. type="test_message", name="event2", prev_events=[(prev_event_id, {})]
  338. )
  339. old_state = [
  340. create_event(type="test1", state_key="1"),
  341. create_event(type="test1", state_key="2"),
  342. create_event(type="test2", state_key=""),
  343. ]
  344. group_name = self.store.store_state_group(
  345. prev_event_id,
  346. event.room_id,
  347. None,
  348. None,
  349. {(e.type, e.state_key): e.event_id for e in old_state},
  350. )
  351. self.store.register_event_id_state_group(prev_event_id, group_name)
  352. context = yield self.state.compute_event_context(event)
  353. current_state_ids = yield context.get_current_state_ids(self.store)
  354. self.assertEqual(
  355. set([e.event_id for e in old_state]), set(current_state_ids.values())
  356. )
  357. self.assertEqual(group_name, context.state_group)
  358. @defer.inlineCallbacks
  359. def test_trivial_annotate_state(self):
  360. prev_event_id = "prev_event_id"
  361. event = create_event(
  362. type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
  363. )
  364. old_state = [
  365. create_event(type="test1", state_key="1"),
  366. create_event(type="test1", state_key="2"),
  367. create_event(type="test2", state_key=""),
  368. ]
  369. group_name = self.store.store_state_group(
  370. prev_event_id,
  371. event.room_id,
  372. None,
  373. None,
  374. {(e.type, e.state_key): e.event_id for e in old_state},
  375. )
  376. self.store.register_event_id_state_group(prev_event_id, group_name)
  377. context = yield self.state.compute_event_context(event)
  378. prev_state_ids = yield context.get_prev_state_ids(self.store)
  379. self.assertEqual(
  380. set([e.event_id for e in old_state]), set(prev_state_ids.values())
  381. )
  382. self.assertIsNotNone(context.state_group)
  383. @defer.inlineCallbacks
  384. def test_resolve_message_conflict(self):
  385. prev_event_id1 = "event_id1"
  386. prev_event_id2 = "event_id2"
  387. event = create_event(
  388. type="test_message",
  389. name="event3",
  390. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  391. )
  392. creation = create_event(type=EventTypes.Create, state_key="")
  393. old_state_1 = [
  394. creation,
  395. create_event(type="test1", state_key="1"),
  396. create_event(type="test1", state_key="2"),
  397. create_event(type="test2", state_key=""),
  398. ]
  399. old_state_2 = [
  400. creation,
  401. create_event(type="test1", state_key="1"),
  402. create_event(type="test3", state_key="2"),
  403. create_event(type="test4", state_key=""),
  404. ]
  405. self.store.register_events(old_state_1)
  406. self.store.register_events(old_state_2)
  407. context = yield self._get_context(
  408. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  409. )
  410. current_state_ids = yield context.get_current_state_ids(self.store)
  411. self.assertEqual(len(current_state_ids), 6)
  412. self.assertIsNotNone(context.state_group)
  413. @defer.inlineCallbacks
  414. def test_resolve_state_conflict(self):
  415. prev_event_id1 = "event_id1"
  416. prev_event_id2 = "event_id2"
  417. event = create_event(
  418. type="test4",
  419. state_key="",
  420. name="event",
  421. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  422. )
  423. creation = create_event(type=EventTypes.Create, state_key="")
  424. old_state_1 = [
  425. creation,
  426. create_event(type="test1", state_key="1"),
  427. create_event(type="test1", state_key="2"),
  428. create_event(type="test2", state_key=""),
  429. ]
  430. old_state_2 = [
  431. creation,
  432. create_event(type="test1", state_key="1"),
  433. create_event(type="test3", state_key="2"),
  434. create_event(type="test4", state_key=""),
  435. ]
  436. store = StateGroupStore()
  437. store.register_events(old_state_1)
  438. store.register_events(old_state_2)
  439. self.store.get_events = store.get_events
  440. context = yield self._get_context(
  441. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  442. )
  443. current_state_ids = yield context.get_current_state_ids(self.store)
  444. self.assertEqual(len(current_state_ids), 6)
  445. self.assertIsNotNone(context.state_group)
  446. @defer.inlineCallbacks
  447. def test_standard_depth_conflict(self):
  448. prev_event_id1 = "event_id1"
  449. prev_event_id2 = "event_id2"
  450. event = create_event(
  451. type="test4",
  452. name="event",
  453. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  454. )
  455. member_event = create_event(
  456. type=EventTypes.Member,
  457. state_key="@user_id:example.com",
  458. content={"membership": Membership.JOIN},
  459. )
  460. power_levels = create_event(
  461. type=EventTypes.PowerLevels,
  462. state_key="",
  463. content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
  464. )
  465. creation = create_event(
  466. type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
  467. )
  468. old_state_1 = [
  469. creation,
  470. power_levels,
  471. member_event,
  472. create_event(type="test1", state_key="1", depth=1),
  473. ]
  474. old_state_2 = [
  475. creation,
  476. power_levels,
  477. member_event,
  478. create_event(type="test1", state_key="1", depth=2),
  479. ]
  480. store = StateGroupStore()
  481. store.register_events(old_state_1)
  482. store.register_events(old_state_2)
  483. self.store.get_events = store.get_events
  484. context = yield self._get_context(
  485. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  486. )
  487. current_state_ids = yield context.get_current_state_ids(self.store)
  488. self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
  489. # Reverse the depth to make sure we are actually using the depths
  490. # during state resolution.
  491. old_state_1 = [
  492. creation,
  493. power_levels,
  494. member_event,
  495. create_event(type="test1", state_key="1", depth=2),
  496. ]
  497. old_state_2 = [
  498. creation,
  499. power_levels,
  500. member_event,
  501. create_event(type="test1", state_key="1", depth=1),
  502. ]
  503. store.register_events(old_state_1)
  504. store.register_events(old_state_2)
  505. context = yield self._get_context(
  506. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  507. )
  508. current_state_ids = yield context.get_current_state_ids(self.store)
  509. self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
  510. def _get_context(
  511. self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
  512. ):
  513. sg1 = self.store.store_state_group(
  514. prev_event_id_1,
  515. event.room_id,
  516. None,
  517. None,
  518. {(e.type, e.state_key): e.event_id for e in old_state_1},
  519. )
  520. self.store.register_event_id_state_group(prev_event_id_1, sg1)
  521. sg2 = self.store.store_state_group(
  522. prev_event_id_2,
  523. event.room_id,
  524. None,
  525. None,
  526. {(e.type, e.state_key): e.event_id for e in old_state_2},
  527. )
  528. self.store.register_event_id_state_group(prev_event_id_2, sg2)
  529. return self.state.compute_event_context(event)