test_state.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from mock import Mock
  16. from twisted.internet import defer
  17. from synapse.api.auth import Auth
  18. from synapse.api.constants import EventTypes, Membership
  19. from synapse.api.room_versions import RoomVersions
  20. from synapse.events import make_event_from_dict
  21. from synapse.events.snapshot import EventContext
  22. from synapse.state import StateHandler, StateResolutionHandler
  23. from tests import unittest
  24. from .utils import MockClock, default_config
  25. _next_event_id = 1000
  26. def create_event(
  27. name=None,
  28. type=None,
  29. state_key=None,
  30. depth=2,
  31. event_id=None,
  32. prev_events=[],
  33. **kwargs
  34. ):
  35. global _next_event_id
  36. if not event_id:
  37. _next_event_id += 1
  38. event_id = "$%s:test" % (_next_event_id,)
  39. if not name:
  40. if state_key is not None:
  41. name = "<%s-%s, %s>" % (type, state_key, event_id)
  42. else:
  43. name = "<%s, %s>" % (type, event_id)
  44. d = {
  45. "event_id": event_id,
  46. "type": type,
  47. "sender": "@user_id:example.com",
  48. "room_id": "!room_id:example.com",
  49. "depth": depth,
  50. "prev_events": prev_events,
  51. }
  52. if state_key is not None:
  53. d["state_key"] = state_key
  54. d.update(kwargs)
  55. event = make_event_from_dict(d)
  56. return event
  57. class StateGroupStore(object):
  58. def __init__(self):
  59. self._event_to_state_group = {}
  60. self._group_to_state = {}
  61. self._event_id_to_event = {}
  62. self._next_group = 1
  63. def get_state_groups_ids(self, room_id, event_ids):
  64. groups = {}
  65. for event_id in event_ids:
  66. group = self._event_to_state_group.get(event_id)
  67. if group:
  68. groups[group] = self._group_to_state[group]
  69. return defer.succeed(groups)
  70. def store_state_group(
  71. self, event_id, room_id, prev_group, delta_ids, current_state_ids
  72. ):
  73. state_group = self._next_group
  74. self._next_group += 1
  75. self._group_to_state[state_group] = dict(current_state_ids)
  76. return state_group
  77. def get_events(self, event_ids, **kwargs):
  78. return {
  79. e_id: self._event_id_to_event[e_id]
  80. for e_id in event_ids
  81. if e_id in self._event_id_to_event
  82. }
  83. def get_state_group_delta(self, name):
  84. return None, None
  85. def register_events(self, events):
  86. for e in events:
  87. self._event_id_to_event[e.event_id] = e
  88. def register_event_context(self, event, context):
  89. self._event_to_state_group[event.event_id] = context.state_group
  90. def register_event_id_state_group(self, event_id, state_group):
  91. self._event_to_state_group[event_id] = state_group
  92. def get_room_version_id(self, room_id):
  93. return RoomVersions.V1.identifier
  94. class DictObj(dict):
  95. def __init__(self, **kwargs):
  96. super(DictObj, self).__init__(kwargs)
  97. self.__dict__ = self
  98. class Graph(object):
  99. def __init__(self, nodes, edges):
  100. events = {}
  101. clobbered = set(events.keys())
  102. for event_id, fields in nodes.items():
  103. refs = edges.get(event_id)
  104. if refs:
  105. clobbered.difference_update(refs)
  106. prev_events = [(r, {}) for r in refs]
  107. else:
  108. prev_events = []
  109. events[event_id] = create_event(
  110. event_id=event_id, prev_events=prev_events, **fields
  111. )
  112. self._leaves = clobbered
  113. self._events = sorted(events.values(), key=lambda e: e.depth)
  114. def walk(self):
  115. return iter(self._events)
  116. def get_leaves(self):
  117. return (self._events[i] for i in self._leaves)
  118. class StateTestCase(unittest.TestCase):
  119. def setUp(self):
  120. self.store = StateGroupStore()
  121. storage = Mock(main=self.store, state=self.store)
  122. hs = Mock(
  123. spec_set=[
  124. "config",
  125. "get_datastore",
  126. "get_storage",
  127. "get_auth",
  128. "get_state_handler",
  129. "get_clock",
  130. "get_state_resolution_handler",
  131. ]
  132. )
  133. hs.config = default_config("tesths", True)
  134. hs.get_datastore.return_value = self.store
  135. hs.get_state_handler.return_value = None
  136. hs.get_clock.return_value = MockClock()
  137. hs.get_auth.return_value = Auth(hs)
  138. hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
  139. hs.get_storage.return_value = storage
  140. self.state = StateHandler(hs)
  141. self.event_id = 0
  142. @defer.inlineCallbacks
  143. def test_branch_no_conflict(self):
  144. graph = Graph(
  145. nodes={
  146. "START": DictObj(
  147. type=EventTypes.Create, state_key="", content={}, depth=1
  148. ),
  149. "A": DictObj(type=EventTypes.Message, depth=2),
  150. "B": DictObj(type=EventTypes.Message, depth=3),
  151. "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
  152. "D": DictObj(type=EventTypes.Message, depth=4),
  153. },
  154. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  155. )
  156. self.store.register_events(graph.walk())
  157. context_store = {} # type: dict[str, EventContext]
  158. for event in graph.walk():
  159. context = yield self.state.compute_event_context(event)
  160. self.store.register_event_context(event, context)
  161. context_store[event.event_id] = context
  162. ctx_c = context_store["C"]
  163. ctx_d = context_store["D"]
  164. prev_state_ids = yield ctx_d.get_prev_state_ids()
  165. self.assertEqual(2, len(prev_state_ids))
  166. self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
  167. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  168. @defer.inlineCallbacks
  169. def test_branch_basic_conflict(self):
  170. graph = Graph(
  171. nodes={
  172. "START": DictObj(
  173. type=EventTypes.Create,
  174. state_key="",
  175. content={"creator": "@user_id:example.com"},
  176. depth=1,
  177. ),
  178. "A": DictObj(
  179. type=EventTypes.Member,
  180. state_key="@user_id:example.com",
  181. content={"membership": Membership.JOIN},
  182. membership=Membership.JOIN,
  183. depth=2,
  184. ),
  185. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  186. "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
  187. "D": DictObj(type=EventTypes.Message, depth=5),
  188. },
  189. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  190. )
  191. self.store.register_events(graph.walk())
  192. context_store = {}
  193. for event in graph.walk():
  194. context = yield self.state.compute_event_context(event)
  195. self.store.register_event_context(event, context)
  196. context_store[event.event_id] = context
  197. # C ends up winning the resolution between B and C
  198. ctx_c = context_store["C"]
  199. ctx_d = context_store["D"]
  200. prev_state_ids = yield ctx_d.get_prev_state_ids()
  201. self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
  202. self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
  203. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  204. @defer.inlineCallbacks
  205. def test_branch_have_banned_conflict(self):
  206. graph = Graph(
  207. nodes={
  208. "START": DictObj(
  209. type=EventTypes.Create,
  210. state_key="",
  211. content={"creator": "@user_id:example.com"},
  212. depth=1,
  213. ),
  214. "A": DictObj(
  215. type=EventTypes.Member,
  216. state_key="@user_id:example.com",
  217. content={"membership": Membership.JOIN},
  218. membership=Membership.JOIN,
  219. depth=2,
  220. ),
  221. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  222. "C": DictObj(
  223. type=EventTypes.Member,
  224. state_key="@user_id_2:example.com",
  225. content={"membership": Membership.BAN},
  226. membership=Membership.BAN,
  227. depth=4,
  228. ),
  229. "D": DictObj(
  230. type=EventTypes.Name,
  231. state_key="",
  232. depth=4,
  233. sender="@user_id_2:example.com",
  234. ),
  235. "E": DictObj(type=EventTypes.Message, depth=5),
  236. },
  237. edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
  238. )
  239. self.store.register_events(graph.walk())
  240. context_store = {}
  241. for event in graph.walk():
  242. context = yield self.state.compute_event_context(event)
  243. self.store.register_event_context(event, context)
  244. context_store[event.event_id] = context
  245. # C ends up winning the resolution between C and D because bans win over other
  246. # changes
  247. ctx_c = context_store["C"]
  248. ctx_e = context_store["E"]
  249. prev_state_ids = yield ctx_e.get_prev_state_ids()
  250. self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
  251. self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
  252. self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
  253. @defer.inlineCallbacks
  254. def test_branch_have_perms_conflict(self):
  255. userid1 = "@user_id:example.com"
  256. userid2 = "@user_id2:example.com"
  257. nodes = {
  258. "A1": DictObj(
  259. type=EventTypes.Create,
  260. state_key="",
  261. content={"creator": userid1},
  262. depth=1,
  263. ),
  264. "A2": DictObj(
  265. type=EventTypes.Member,
  266. state_key=userid1,
  267. content={"membership": Membership.JOIN},
  268. membership=Membership.JOIN,
  269. ),
  270. "A3": DictObj(
  271. type=EventTypes.Member,
  272. state_key=userid2,
  273. content={"membership": Membership.JOIN},
  274. membership=Membership.JOIN,
  275. ),
  276. "A4": DictObj(
  277. type=EventTypes.PowerLevels,
  278. state_key="",
  279. content={
  280. "events": {"m.room.name": 50},
  281. "users": {userid1: 100, userid2: 60},
  282. },
  283. ),
  284. "A5": DictObj(type=EventTypes.Name, state_key=""),
  285. "B": DictObj(
  286. type=EventTypes.PowerLevels,
  287. state_key="",
  288. content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
  289. ),
  290. "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
  291. "D": DictObj(type=EventTypes.Message),
  292. }
  293. edges = {
  294. "A2": ["A1"],
  295. "A3": ["A2"],
  296. "A4": ["A3"],
  297. "A5": ["A4"],
  298. "B": ["A5"],
  299. "C": ["A5"],
  300. "D": ["B", "C"],
  301. }
  302. self._add_depths(nodes, edges)
  303. graph = Graph(nodes, edges)
  304. self.store.register_events(graph.walk())
  305. context_store = {}
  306. for event in graph.walk():
  307. context = yield self.state.compute_event_context(event)
  308. self.store.register_event_context(event, context)
  309. context_store[event.event_id] = context
  310. # B ends up winning the resolution between B and C because power levels
  311. # win over other changes.
  312. ctx_b = context_store["B"]
  313. ctx_d = context_store["D"]
  314. prev_state_ids = yield ctx_d.get_prev_state_ids()
  315. self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
  316. self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
  317. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  318. def _add_depths(self, nodes, edges):
  319. def _get_depth(ev):
  320. node = nodes[ev]
  321. if "depth" not in node:
  322. prevs = edges[ev]
  323. depth = max(_get_depth(prev) for prev in prevs) + 1
  324. node["depth"] = depth
  325. return node["depth"]
  326. for n in nodes:
  327. _get_depth(n)
  328. @defer.inlineCallbacks
  329. def test_annotate_with_old_message(self):
  330. event = create_event(type="test_message", name="event")
  331. old_state = [
  332. create_event(type="test1", state_key="1"),
  333. create_event(type="test1", state_key="2"),
  334. create_event(type="test2", state_key=""),
  335. ]
  336. context = yield self.state.compute_event_context(event, old_state=old_state)
  337. prev_state_ids = yield context.get_prev_state_ids()
  338. self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
  339. current_state_ids = yield context.get_current_state_ids()
  340. self.assertCountEqual(
  341. (e.event_id for e in old_state), current_state_ids.values()
  342. )
  343. self.assertIsNotNone(context.state_group_before_event)
  344. self.assertEqual(context.state_group_before_event, context.state_group)
  345. @defer.inlineCallbacks
  346. def test_annotate_with_old_state(self):
  347. event = create_event(type="state", state_key="", name="event")
  348. old_state = [
  349. create_event(type="test1", state_key="1"),
  350. create_event(type="test1", state_key="2"),
  351. create_event(type="test2", state_key=""),
  352. ]
  353. context = yield self.state.compute_event_context(event, old_state=old_state)
  354. prev_state_ids = yield context.get_prev_state_ids()
  355. self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
  356. current_state_ids = yield context.get_current_state_ids()
  357. self.assertCountEqual(
  358. (e.event_id for e in old_state + [event]), current_state_ids.values()
  359. )
  360. self.assertIsNotNone(context.state_group_before_event)
  361. self.assertNotEqual(context.state_group_before_event, context.state_group)
  362. self.assertEqual(context.state_group_before_event, context.prev_group)
  363. self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
  364. @defer.inlineCallbacks
  365. def test_trivial_annotate_message(self):
  366. prev_event_id = "prev_event_id"
  367. event = create_event(
  368. type="test_message", name="event2", prev_events=[(prev_event_id, {})]
  369. )
  370. old_state = [
  371. create_event(type="test1", state_key="1"),
  372. create_event(type="test1", state_key="2"),
  373. create_event(type="test2", state_key=""),
  374. ]
  375. group_name = self.store.store_state_group(
  376. prev_event_id,
  377. event.room_id,
  378. None,
  379. None,
  380. {(e.type, e.state_key): e.event_id for e in old_state},
  381. )
  382. self.store.register_event_id_state_group(prev_event_id, group_name)
  383. context = yield self.state.compute_event_context(event)
  384. current_state_ids = yield context.get_current_state_ids()
  385. self.assertEqual(
  386. {e.event_id for e in old_state}, set(current_state_ids.values())
  387. )
  388. self.assertEqual(group_name, context.state_group)
  389. @defer.inlineCallbacks
  390. def test_trivial_annotate_state(self):
  391. prev_event_id = "prev_event_id"
  392. event = create_event(
  393. type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
  394. )
  395. old_state = [
  396. create_event(type="test1", state_key="1"),
  397. create_event(type="test1", state_key="2"),
  398. create_event(type="test2", state_key=""),
  399. ]
  400. group_name = self.store.store_state_group(
  401. prev_event_id,
  402. event.room_id,
  403. None,
  404. None,
  405. {(e.type, e.state_key): e.event_id for e in old_state},
  406. )
  407. self.store.register_event_id_state_group(prev_event_id, group_name)
  408. context = yield self.state.compute_event_context(event)
  409. prev_state_ids = yield context.get_prev_state_ids()
  410. self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
  411. self.assertIsNotNone(context.state_group)
  412. @defer.inlineCallbacks
  413. def test_resolve_message_conflict(self):
  414. prev_event_id1 = "event_id1"
  415. prev_event_id2 = "event_id2"
  416. event = create_event(
  417. type="test_message",
  418. name="event3",
  419. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  420. )
  421. creation = create_event(type=EventTypes.Create, state_key="")
  422. old_state_1 = [
  423. creation,
  424. create_event(type="test1", state_key="1"),
  425. create_event(type="test1", state_key="2"),
  426. create_event(type="test2", state_key=""),
  427. ]
  428. old_state_2 = [
  429. creation,
  430. create_event(type="test1", state_key="1"),
  431. create_event(type="test3", state_key="2"),
  432. create_event(type="test4", state_key=""),
  433. ]
  434. self.store.register_events(old_state_1)
  435. self.store.register_events(old_state_2)
  436. context = yield self._get_context(
  437. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  438. )
  439. current_state_ids = yield context.get_current_state_ids()
  440. self.assertEqual(len(current_state_ids), 6)
  441. self.assertIsNotNone(context.state_group)
  442. @defer.inlineCallbacks
  443. def test_resolve_state_conflict(self):
  444. prev_event_id1 = "event_id1"
  445. prev_event_id2 = "event_id2"
  446. event = create_event(
  447. type="test4",
  448. state_key="",
  449. name="event",
  450. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  451. )
  452. creation = create_event(type=EventTypes.Create, state_key="")
  453. old_state_1 = [
  454. creation,
  455. create_event(type="test1", state_key="1"),
  456. create_event(type="test1", state_key="2"),
  457. create_event(type="test2", state_key=""),
  458. ]
  459. old_state_2 = [
  460. creation,
  461. create_event(type="test1", state_key="1"),
  462. create_event(type="test3", state_key="2"),
  463. create_event(type="test4", state_key=""),
  464. ]
  465. store = StateGroupStore()
  466. store.register_events(old_state_1)
  467. store.register_events(old_state_2)
  468. self.store.get_events = store.get_events
  469. context = yield self._get_context(
  470. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  471. )
  472. current_state_ids = yield context.get_current_state_ids()
  473. self.assertEqual(len(current_state_ids), 6)
  474. self.assertIsNotNone(context.state_group)
  475. @defer.inlineCallbacks
  476. def test_standard_depth_conflict(self):
  477. prev_event_id1 = "event_id1"
  478. prev_event_id2 = "event_id2"
  479. event = create_event(
  480. type="test4",
  481. name="event",
  482. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  483. )
  484. member_event = create_event(
  485. type=EventTypes.Member,
  486. state_key="@user_id:example.com",
  487. content={"membership": Membership.JOIN},
  488. )
  489. power_levels = create_event(
  490. type=EventTypes.PowerLevels,
  491. state_key="",
  492. content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
  493. )
  494. creation = create_event(
  495. type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
  496. )
  497. old_state_1 = [
  498. creation,
  499. power_levels,
  500. member_event,
  501. create_event(type="test1", state_key="1", depth=1),
  502. ]
  503. old_state_2 = [
  504. creation,
  505. power_levels,
  506. member_event,
  507. create_event(type="test1", state_key="1", depth=2),
  508. ]
  509. store = StateGroupStore()
  510. store.register_events(old_state_1)
  511. store.register_events(old_state_2)
  512. self.store.get_events = store.get_events
  513. context = yield self._get_context(
  514. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  515. )
  516. current_state_ids = yield context.get_current_state_ids()
  517. self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
  518. # Reverse the depth to make sure we are actually using the depths
  519. # during state resolution.
  520. old_state_1 = [
  521. creation,
  522. power_levels,
  523. member_event,
  524. create_event(type="test1", state_key="1", depth=2),
  525. ]
  526. old_state_2 = [
  527. creation,
  528. power_levels,
  529. member_event,
  530. create_event(type="test1", state_key="1", depth=1),
  531. ]
  532. store.register_events(old_state_1)
  533. store.register_events(old_state_2)
  534. context = yield self._get_context(
  535. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  536. )
  537. current_state_ids = yield context.get_current_state_ids()
  538. self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
  539. def _get_context(
  540. self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
  541. ):
  542. sg1 = self.store.store_state_group(
  543. prev_event_id_1,
  544. event.room_id,
  545. None,
  546. None,
  547. {(e.type, e.state_key): e.event_id for e in old_state_1},
  548. )
  549. self.store.register_event_id_state_group(prev_event_id_1, sg1)
  550. sg2 = self.store.store_state_group(
  551. prev_event_id_2,
  552. event.room_id,
  553. None,
  554. None,
  555. {(e.type, e.state_key): e.event_id for e in old_state_2},
  556. )
  557. self.store.register_event_id_state_group(prev_event_id_2, sg2)
  558. return self.state.compute_event_context(event)