test_state.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from mock import Mock
  16. from twisted.internet import defer
  17. from synapse.api.auth import Auth
  18. from synapse.api.constants import EventTypes, Membership
  19. from synapse.api.room_versions import RoomVersions
  20. from synapse.events import FrozenEvent
  21. from synapse.events.snapshot import EventContext
  22. from synapse.state import StateHandler, StateResolutionHandler
  23. from tests import unittest
  24. from .utils import MockClock, default_config
  25. _next_event_id = 1000
  26. def create_event(
  27. name=None,
  28. type=None,
  29. state_key=None,
  30. depth=2,
  31. event_id=None,
  32. prev_events=[],
  33. **kwargs
  34. ):
  35. global _next_event_id
  36. if not event_id:
  37. _next_event_id += 1
  38. event_id = "$%s:test" % (_next_event_id,)
  39. if not name:
  40. if state_key is not None:
  41. name = "<%s-%s, %s>" % (type, state_key, event_id)
  42. else:
  43. name = "<%s, %s>" % (type, event_id)
  44. d = {
  45. "event_id": event_id,
  46. "type": type,
  47. "sender": "@user_id:example.com",
  48. "room_id": "!room_id:example.com",
  49. "depth": depth,
  50. "prev_events": prev_events,
  51. }
  52. if state_key is not None:
  53. d["state_key"] = state_key
  54. d.update(kwargs)
  55. event = FrozenEvent(d)
  56. return event
  57. class StateGroupStore(object):
  58. def __init__(self):
  59. self._event_to_state_group = {}
  60. self._group_to_state = {}
  61. self._event_id_to_event = {}
  62. self._next_group = 1
  63. def get_state_groups_ids(self, room_id, event_ids):
  64. groups = {}
  65. for event_id in event_ids:
  66. group = self._event_to_state_group.get(event_id)
  67. if group:
  68. groups[group] = self._group_to_state[group]
  69. return defer.succeed(groups)
  70. def store_state_group(
  71. self, event_id, room_id, prev_group, delta_ids, current_state_ids
  72. ):
  73. state_group = self._next_group
  74. self._next_group += 1
  75. self._group_to_state[state_group] = dict(current_state_ids)
  76. return state_group
  77. def get_events(self, event_ids, **kwargs):
  78. return {
  79. e_id: self._event_id_to_event[e_id]
  80. for e_id in event_ids
  81. if e_id in self._event_id_to_event
  82. }
  83. def get_state_group_delta(self, name):
  84. return None, None
  85. def register_events(self, events):
  86. for e in events:
  87. self._event_id_to_event[e.event_id] = e
  88. def register_event_context(self, event, context):
  89. self._event_to_state_group[event.event_id] = context.state_group
  90. def register_event_id_state_group(self, event_id, state_group):
  91. self._event_to_state_group[event_id] = state_group
  92. def get_room_version(self, room_id):
  93. return RoomVersions.V1.identifier
  94. class DictObj(dict):
  95. def __init__(self, **kwargs):
  96. super(DictObj, self).__init__(kwargs)
  97. self.__dict__ = self
  98. class Graph(object):
  99. def __init__(self, nodes, edges):
  100. events = {}
  101. clobbered = set(events.keys())
  102. for event_id, fields in nodes.items():
  103. refs = edges.get(event_id)
  104. if refs:
  105. clobbered.difference_update(refs)
  106. prev_events = [(r, {}) for r in refs]
  107. else:
  108. prev_events = []
  109. events[event_id] = create_event(
  110. event_id=event_id, prev_events=prev_events, **fields
  111. )
  112. self._leaves = clobbered
  113. self._events = sorted(events.values(), key=lambda e: e.depth)
  114. def walk(self):
  115. return iter(self._events)
  116. def get_leaves(self):
  117. return (self._events[i] for i in self._leaves)
  118. class StateTestCase(unittest.TestCase):
  119. def setUp(self):
  120. self.store = StateGroupStore()
  121. storage = Mock(main=self.store, state=self.store)
  122. hs = Mock(
  123. spec_set=[
  124. "config",
  125. "get_datastore",
  126. "get_storage",
  127. "get_auth",
  128. "get_state_handler",
  129. "get_clock",
  130. "get_state_resolution_handler",
  131. ]
  132. )
  133. hs.config = default_config("tesths", True)
  134. hs.get_datastore.return_value = self.store
  135. hs.get_state_handler.return_value = None
  136. hs.get_clock.return_value = MockClock()
  137. hs.get_auth.return_value = Auth(hs)
  138. hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
  139. hs.get_storage.return_value = storage
  140. self.state = StateHandler(hs)
  141. self.event_id = 0
  142. @defer.inlineCallbacks
  143. def test_branch_no_conflict(self):
  144. graph = Graph(
  145. nodes={
  146. "START": DictObj(
  147. type=EventTypes.Create, state_key="", content={}, depth=1
  148. ),
  149. "A": DictObj(type=EventTypes.Message, depth=2),
  150. "B": DictObj(type=EventTypes.Message, depth=3),
  151. "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
  152. "D": DictObj(type=EventTypes.Message, depth=4),
  153. },
  154. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  155. )
  156. self.store.register_events(graph.walk())
  157. context_store = {} # type: dict[str, EventContext]
  158. for event in graph.walk():
  159. context = yield self.state.compute_event_context(event)
  160. self.store.register_event_context(event, context)
  161. context_store[event.event_id] = context
  162. ctx_c = context_store["C"]
  163. ctx_d = context_store["D"]
  164. prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
  165. self.assertEqual(2, len(prev_state_ids))
  166. self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
  167. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  168. @defer.inlineCallbacks
  169. def test_branch_basic_conflict(self):
  170. graph = Graph(
  171. nodes={
  172. "START": DictObj(
  173. type=EventTypes.Create,
  174. state_key="",
  175. content={"creator": "@user_id:example.com"},
  176. depth=1,
  177. ),
  178. "A": DictObj(
  179. type=EventTypes.Member,
  180. state_key="@user_id:example.com",
  181. content={"membership": Membership.JOIN},
  182. membership=Membership.JOIN,
  183. depth=2,
  184. ),
  185. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  186. "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
  187. "D": DictObj(type=EventTypes.Message, depth=5),
  188. },
  189. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  190. )
  191. self.store.register_events(graph.walk())
  192. context_store = {}
  193. for event in graph.walk():
  194. context = yield self.state.compute_event_context(event)
  195. self.store.register_event_context(event, context)
  196. context_store[event.event_id] = context
  197. # C ends up winning the resolution between B and C
  198. ctx_c = context_store["C"]
  199. ctx_d = context_store["D"]
  200. prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
  201. self.assertSetEqual(
  202. {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
  203. )
  204. self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
  205. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  206. @defer.inlineCallbacks
  207. def test_branch_have_banned_conflict(self):
  208. graph = Graph(
  209. nodes={
  210. "START": DictObj(
  211. type=EventTypes.Create,
  212. state_key="",
  213. content={"creator": "@user_id:example.com"},
  214. depth=1,
  215. ),
  216. "A": DictObj(
  217. type=EventTypes.Member,
  218. state_key="@user_id:example.com",
  219. content={"membership": Membership.JOIN},
  220. membership=Membership.JOIN,
  221. depth=2,
  222. ),
  223. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  224. "C": DictObj(
  225. type=EventTypes.Member,
  226. state_key="@user_id_2:example.com",
  227. content={"membership": Membership.BAN},
  228. membership=Membership.BAN,
  229. depth=4,
  230. ),
  231. "D": DictObj(
  232. type=EventTypes.Name,
  233. state_key="",
  234. depth=4,
  235. sender="@user_id_2:example.com",
  236. ),
  237. "E": DictObj(type=EventTypes.Message, depth=5),
  238. },
  239. edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
  240. )
  241. self.store.register_events(graph.walk())
  242. context_store = {}
  243. for event in graph.walk():
  244. context = yield self.state.compute_event_context(event)
  245. self.store.register_event_context(event, context)
  246. context_store[event.event_id] = context
  247. # C ends up winning the resolution between C and D because bans win over other
  248. # changes
  249. ctx_c = context_store["C"]
  250. ctx_e = context_store["E"]
  251. prev_state_ids = yield ctx_e.get_prev_state_ids(self.store)
  252. self.assertSetEqual(
  253. {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
  254. )
  255. self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
  256. self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
  257. @defer.inlineCallbacks
  258. def test_branch_have_perms_conflict(self):
  259. userid1 = "@user_id:example.com"
  260. userid2 = "@user_id2:example.com"
  261. nodes = {
  262. "A1": DictObj(
  263. type=EventTypes.Create,
  264. state_key="",
  265. content={"creator": userid1},
  266. depth=1,
  267. ),
  268. "A2": DictObj(
  269. type=EventTypes.Member,
  270. state_key=userid1,
  271. content={"membership": Membership.JOIN},
  272. membership=Membership.JOIN,
  273. ),
  274. "A3": DictObj(
  275. type=EventTypes.Member,
  276. state_key=userid2,
  277. content={"membership": Membership.JOIN},
  278. membership=Membership.JOIN,
  279. ),
  280. "A4": DictObj(
  281. type=EventTypes.PowerLevels,
  282. state_key="",
  283. content={
  284. "events": {"m.room.name": 50},
  285. "users": {userid1: 100, userid2: 60},
  286. },
  287. ),
  288. "A5": DictObj(type=EventTypes.Name, state_key=""),
  289. "B": DictObj(
  290. type=EventTypes.PowerLevels,
  291. state_key="",
  292. content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
  293. ),
  294. "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
  295. "D": DictObj(type=EventTypes.Message),
  296. }
  297. edges = {
  298. "A2": ["A1"],
  299. "A3": ["A2"],
  300. "A4": ["A3"],
  301. "A5": ["A4"],
  302. "B": ["A5"],
  303. "C": ["A5"],
  304. "D": ["B", "C"],
  305. }
  306. self._add_depths(nodes, edges)
  307. graph = Graph(nodes, edges)
  308. self.store.register_events(graph.walk())
  309. context_store = {}
  310. for event in graph.walk():
  311. context = yield self.state.compute_event_context(event)
  312. self.store.register_event_context(event, context)
  313. context_store[event.event_id] = context
  314. # B ends up winning the resolution between B and C because power levels
  315. # win over other changes.
  316. ctx_b = context_store["B"]
  317. ctx_d = context_store["D"]
  318. prev_state_ids = yield ctx_d.get_prev_state_ids(self.store)
  319. self.assertSetEqual(
  320. {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
  321. )
  322. self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
  323. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  324. def _add_depths(self, nodes, edges):
  325. def _get_depth(ev):
  326. node = nodes[ev]
  327. if "depth" not in node:
  328. prevs = edges[ev]
  329. depth = max(_get_depth(prev) for prev in prevs) + 1
  330. node["depth"] = depth
  331. return node["depth"]
  332. for n in nodes:
  333. _get_depth(n)
  334. @defer.inlineCallbacks
  335. def test_annotate_with_old_message(self):
  336. event = create_event(type="test_message", name="event")
  337. old_state = [
  338. create_event(type="test1", state_key="1"),
  339. create_event(type="test1", state_key="2"),
  340. create_event(type="test2", state_key=""),
  341. ]
  342. context = yield self.state.compute_event_context(event, old_state=old_state)
  343. prev_state_ids = yield context.get_prev_state_ids(self.store)
  344. self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
  345. current_state_ids = yield context.get_current_state_ids(self.store)
  346. self.assertCountEqual(
  347. (e.event_id for e in old_state), current_state_ids.values()
  348. )
  349. self.assertIsNotNone(context.state_group_before_event)
  350. self.assertEqual(context.state_group_before_event, context.state_group)
  351. @defer.inlineCallbacks
  352. def test_annotate_with_old_state(self):
  353. event = create_event(type="state", state_key="", name="event")
  354. old_state = [
  355. create_event(type="test1", state_key="1"),
  356. create_event(type="test1", state_key="2"),
  357. create_event(type="test2", state_key=""),
  358. ]
  359. context = yield self.state.compute_event_context(event, old_state=old_state)
  360. prev_state_ids = yield context.get_prev_state_ids(self.store)
  361. self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
  362. current_state_ids = yield context.get_current_state_ids(self.store)
  363. self.assertCountEqual(
  364. (e.event_id for e in old_state + [event]), current_state_ids.values()
  365. )
  366. self.assertIsNotNone(context.state_group_before_event)
  367. self.assertNotEqual(context.state_group_before_event, context.state_group)
  368. self.assertEqual(context.state_group_before_event, context.prev_group)
  369. self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
  370. @defer.inlineCallbacks
  371. def test_trivial_annotate_message(self):
  372. prev_event_id = "prev_event_id"
  373. event = create_event(
  374. type="test_message", name="event2", prev_events=[(prev_event_id, {})]
  375. )
  376. old_state = [
  377. create_event(type="test1", state_key="1"),
  378. create_event(type="test1", state_key="2"),
  379. create_event(type="test2", state_key=""),
  380. ]
  381. group_name = self.store.store_state_group(
  382. prev_event_id,
  383. event.room_id,
  384. None,
  385. None,
  386. {(e.type, e.state_key): e.event_id for e in old_state},
  387. )
  388. self.store.register_event_id_state_group(prev_event_id, group_name)
  389. context = yield self.state.compute_event_context(event)
  390. current_state_ids = yield context.get_current_state_ids(self.store)
  391. self.assertEqual(
  392. set([e.event_id for e in old_state]), set(current_state_ids.values())
  393. )
  394. self.assertEqual(group_name, context.state_group)
  395. @defer.inlineCallbacks
  396. def test_trivial_annotate_state(self):
  397. prev_event_id = "prev_event_id"
  398. event = create_event(
  399. type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
  400. )
  401. old_state = [
  402. create_event(type="test1", state_key="1"),
  403. create_event(type="test1", state_key="2"),
  404. create_event(type="test2", state_key=""),
  405. ]
  406. group_name = self.store.store_state_group(
  407. prev_event_id,
  408. event.room_id,
  409. None,
  410. None,
  411. {(e.type, e.state_key): e.event_id for e in old_state},
  412. )
  413. self.store.register_event_id_state_group(prev_event_id, group_name)
  414. context = yield self.state.compute_event_context(event)
  415. prev_state_ids = yield context.get_prev_state_ids(self.store)
  416. self.assertEqual(
  417. set([e.event_id for e in old_state]), set(prev_state_ids.values())
  418. )
  419. self.assertIsNotNone(context.state_group)
  420. @defer.inlineCallbacks
  421. def test_resolve_message_conflict(self):
  422. prev_event_id1 = "event_id1"
  423. prev_event_id2 = "event_id2"
  424. event = create_event(
  425. type="test_message",
  426. name="event3",
  427. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  428. )
  429. creation = create_event(type=EventTypes.Create, state_key="")
  430. old_state_1 = [
  431. creation,
  432. create_event(type="test1", state_key="1"),
  433. create_event(type="test1", state_key="2"),
  434. create_event(type="test2", state_key=""),
  435. ]
  436. old_state_2 = [
  437. creation,
  438. create_event(type="test1", state_key="1"),
  439. create_event(type="test3", state_key="2"),
  440. create_event(type="test4", state_key=""),
  441. ]
  442. self.store.register_events(old_state_1)
  443. self.store.register_events(old_state_2)
  444. context = yield self._get_context(
  445. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  446. )
  447. current_state_ids = yield context.get_current_state_ids(self.store)
  448. self.assertEqual(len(current_state_ids), 6)
  449. self.assertIsNotNone(context.state_group)
  450. @defer.inlineCallbacks
  451. def test_resolve_state_conflict(self):
  452. prev_event_id1 = "event_id1"
  453. prev_event_id2 = "event_id2"
  454. event = create_event(
  455. type="test4",
  456. state_key="",
  457. name="event",
  458. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  459. )
  460. creation = create_event(type=EventTypes.Create, state_key="")
  461. old_state_1 = [
  462. creation,
  463. create_event(type="test1", state_key="1"),
  464. create_event(type="test1", state_key="2"),
  465. create_event(type="test2", state_key=""),
  466. ]
  467. old_state_2 = [
  468. creation,
  469. create_event(type="test1", state_key="1"),
  470. create_event(type="test3", state_key="2"),
  471. create_event(type="test4", state_key=""),
  472. ]
  473. store = StateGroupStore()
  474. store.register_events(old_state_1)
  475. store.register_events(old_state_2)
  476. self.store.get_events = store.get_events
  477. context = yield self._get_context(
  478. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  479. )
  480. current_state_ids = yield context.get_current_state_ids(self.store)
  481. self.assertEqual(len(current_state_ids), 6)
  482. self.assertIsNotNone(context.state_group)
  483. @defer.inlineCallbacks
  484. def test_standard_depth_conflict(self):
  485. prev_event_id1 = "event_id1"
  486. prev_event_id2 = "event_id2"
  487. event = create_event(
  488. type="test4",
  489. name="event",
  490. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  491. )
  492. member_event = create_event(
  493. type=EventTypes.Member,
  494. state_key="@user_id:example.com",
  495. content={"membership": Membership.JOIN},
  496. )
  497. power_levels = create_event(
  498. type=EventTypes.PowerLevels,
  499. state_key="",
  500. content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
  501. )
  502. creation = create_event(
  503. type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
  504. )
  505. old_state_1 = [
  506. creation,
  507. power_levels,
  508. member_event,
  509. create_event(type="test1", state_key="1", depth=1),
  510. ]
  511. old_state_2 = [
  512. creation,
  513. power_levels,
  514. member_event,
  515. create_event(type="test1", state_key="1", depth=2),
  516. ]
  517. store = StateGroupStore()
  518. store.register_events(old_state_1)
  519. store.register_events(old_state_2)
  520. self.store.get_events = store.get_events
  521. context = yield self._get_context(
  522. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  523. )
  524. current_state_ids = yield context.get_current_state_ids(self.store)
  525. self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
  526. # Reverse the depth to make sure we are actually using the depths
  527. # during state resolution.
  528. old_state_1 = [
  529. creation,
  530. power_levels,
  531. member_event,
  532. create_event(type="test1", state_key="1", depth=2),
  533. ]
  534. old_state_2 = [
  535. creation,
  536. power_levels,
  537. member_event,
  538. create_event(type="test1", state_key="1", depth=1),
  539. ]
  540. store.register_events(old_state_1)
  541. store.register_events(old_state_2)
  542. context = yield self._get_context(
  543. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  544. )
  545. current_state_ids = yield context.get_current_state_ids(self.store)
  546. self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
  547. def _get_context(
  548. self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
  549. ):
  550. sg1 = self.store.store_state_group(
  551. prev_event_id_1,
  552. event.room_id,
  553. None,
  554. None,
  555. {(e.type, e.state_key): e.event_id for e in old_state_1},
  556. )
  557. self.store.register_event_id_state_group(prev_event_id_1, sg1)
  558. sg2 = self.store.store_state_group(
  559. prev_event_id_2,
  560. event.room_id,
  561. None,
  562. None,
  563. {(e.type, e.state_key): e.event_id for e in old_state_2},
  564. )
  565. self.store.register_event_id_state_group(prev_event_id_2, sg2)
  566. return self.state.compute_event_context(event)