test_state.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from tests import unittest
  16. from twisted.internet import defer
  17. from synapse.events import FrozenEvent
  18. from synapse.api.auth import Auth
  19. from synapse.api.constants import EventTypes, Membership
  20. from synapse.state import StateHandler
  21. from .utils import MockClock
  22. from mock import Mock
  23. _next_event_id = 1000
  24. def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
  25. prev_events=[], **kwargs):
  26. global _next_event_id
  27. if not event_id:
  28. _next_event_id += 1
  29. event_id = "$%s:test" % (_next_event_id,)
  30. if not name:
  31. if state_key is not None:
  32. name = "<%s-%s, %s>" % (type, state_key, event_id,)
  33. else:
  34. name = "<%s, %s>" % (type, event_id,)
  35. d = {
  36. "event_id": event_id,
  37. "type": type,
  38. "sender": "@user_id:example.com",
  39. "room_id": "!room_id:example.com",
  40. "depth": depth,
  41. "prev_events": prev_events,
  42. }
  43. if state_key is not None:
  44. d["state_key"] = state_key
  45. d.update(kwargs)
  46. event = FrozenEvent(d)
  47. return event
  48. class StateGroupStore(object):
  49. def __init__(self):
  50. self._event_to_state_group = {}
  51. self._group_to_state = {}
  52. self._event_id_to_event = {}
  53. self._next_group = 1
  54. def get_state_groups_ids(self, room_id, event_ids):
  55. groups = {}
  56. for event_id in event_ids:
  57. group = self._event_to_state_group.get(event_id)
  58. if group:
  59. groups[group] = self._group_to_state[group]
  60. return defer.succeed(groups)
  61. def store_state_groups(self, event, context):
  62. if context.current_state_ids is None:
  63. return
  64. state_events = dict(context.current_state_ids)
  65. self._group_to_state[context.state_group] = state_events
  66. self._event_to_state_group[event.event_id] = context.state_group
  67. def get_events(self, event_ids, **kwargs):
  68. return {
  69. e_id: self._event_id_to_event[e_id] for e_id in event_ids
  70. if e_id in self._event_id_to_event
  71. }
  72. def register_events(self, events):
  73. for e in events:
  74. self._event_id_to_event[e.event_id] = e
  75. class DictObj(dict):
  76. def __init__(self, **kwargs):
  77. super(DictObj, self).__init__(kwargs)
  78. self.__dict__ = self
  79. class Graph(object):
  80. def __init__(self, nodes, edges):
  81. events = {}
  82. clobbered = set(events.keys())
  83. for event_id, fields in nodes.items():
  84. refs = edges.get(event_id)
  85. if refs:
  86. clobbered.difference_update(refs)
  87. prev_events = [(r, {}) for r in refs]
  88. else:
  89. prev_events = []
  90. events[event_id] = create_event(
  91. event_id=event_id,
  92. prev_events=prev_events,
  93. **fields
  94. )
  95. self._leaves = clobbered
  96. self._events = sorted(events.values(), key=lambda e: e.depth)
  97. def walk(self):
  98. return iter(self._events)
  99. def get_leaves(self):
  100. return (self._events[i] for i in self._leaves)
  101. class StateTestCase(unittest.TestCase):
  102. def setUp(self):
  103. self.store = Mock(
  104. spec_set=[
  105. "get_state_groups_ids",
  106. "add_event_hashes",
  107. "get_events",
  108. "get_next_state_group",
  109. "get_state_group_delta",
  110. ]
  111. )
  112. hs = Mock(spec_set=[
  113. "get_datastore", "get_auth", "get_state_handler", "get_clock",
  114. ])
  115. hs.get_datastore.return_value = self.store
  116. hs.get_state_handler.return_value = None
  117. hs.get_clock.return_value = MockClock()
  118. hs.get_auth.return_value = Auth(hs)
  119. self.store.get_next_state_group.side_effect = Mock
  120. self.store.get_state_group_delta.return_value = (None, None)
  121. self.state = StateHandler(hs)
  122. self.event_id = 0
  123. @defer.inlineCallbacks
  124. def test_branch_no_conflict(self):
  125. graph = Graph(
  126. nodes={
  127. "START": DictObj(
  128. type=EventTypes.Create,
  129. state_key="",
  130. depth=1,
  131. ),
  132. "A": DictObj(
  133. type=EventTypes.Message,
  134. depth=2,
  135. ),
  136. "B": DictObj(
  137. type=EventTypes.Message,
  138. depth=3,
  139. ),
  140. "C": DictObj(
  141. type=EventTypes.Name,
  142. state_key="",
  143. depth=3,
  144. ),
  145. "D": DictObj(
  146. type=EventTypes.Message,
  147. depth=4,
  148. ),
  149. },
  150. edges={
  151. "A": ["START"],
  152. "B": ["A"],
  153. "C": ["A"],
  154. "D": ["B", "C"]
  155. }
  156. )
  157. store = StateGroupStore()
  158. self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
  159. context_store = {}
  160. for event in graph.walk():
  161. context = yield self.state.compute_event_context(event)
  162. store.store_state_groups(event, context)
  163. context_store[event.event_id] = context
  164. self.assertEqual(2, len(context_store["D"].prev_state_ids))
  165. @defer.inlineCallbacks
  166. def test_branch_basic_conflict(self):
  167. graph = Graph(
  168. nodes={
  169. "START": DictObj(
  170. type=EventTypes.Create,
  171. state_key="",
  172. content={"creator": "@user_id:example.com"},
  173. depth=1,
  174. ),
  175. "A": DictObj(
  176. type=EventTypes.Member,
  177. state_key="@user_id:example.com",
  178. content={"membership": Membership.JOIN},
  179. membership=Membership.JOIN,
  180. depth=2,
  181. ),
  182. "B": DictObj(
  183. type=EventTypes.Name,
  184. state_key="",
  185. depth=3,
  186. ),
  187. "C": DictObj(
  188. type=EventTypes.Name,
  189. state_key="",
  190. depth=4,
  191. ),
  192. "D": DictObj(
  193. type=EventTypes.Message,
  194. depth=5,
  195. ),
  196. },
  197. edges={
  198. "A": ["START"],
  199. "B": ["A"],
  200. "C": ["A"],
  201. "D": ["B", "C"]
  202. }
  203. )
  204. store = StateGroupStore()
  205. self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
  206. self.store.get_events = store.get_events
  207. store.register_events(graph.walk())
  208. context_store = {}
  209. for event in graph.walk():
  210. context = yield self.state.compute_event_context(event)
  211. store.store_state_groups(event, context)
  212. context_store[event.event_id] = context
  213. self.assertSetEqual(
  214. {"START", "A", "C"},
  215. {e_id for e_id in context_store["D"].prev_state_ids.values()}
  216. )
  217. @defer.inlineCallbacks
  218. def test_branch_have_banned_conflict(self):
  219. graph = Graph(
  220. nodes={
  221. "START": DictObj(
  222. type=EventTypes.Create,
  223. state_key="",
  224. content={"creator": "@user_id:example.com"},
  225. depth=1,
  226. ),
  227. "A": DictObj(
  228. type=EventTypes.Member,
  229. state_key="@user_id:example.com",
  230. content={"membership": Membership.JOIN},
  231. membership=Membership.JOIN,
  232. depth=2,
  233. ),
  234. "B": DictObj(
  235. type=EventTypes.Name,
  236. state_key="",
  237. depth=3,
  238. ),
  239. "C": DictObj(
  240. type=EventTypes.Member,
  241. state_key="@user_id_2:example.com",
  242. content={"membership": Membership.BAN},
  243. membership=Membership.BAN,
  244. depth=4,
  245. ),
  246. "D": DictObj(
  247. type=EventTypes.Name,
  248. state_key="",
  249. depth=4,
  250. sender="@user_id_2:example.com",
  251. ),
  252. "E": DictObj(
  253. type=EventTypes.Message,
  254. depth=5,
  255. ),
  256. },
  257. edges={
  258. "A": ["START"],
  259. "B": ["A"],
  260. "C": ["B"],
  261. "D": ["B"],
  262. "E": ["C", "D"]
  263. }
  264. )
  265. store = StateGroupStore()
  266. self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
  267. self.store.get_events = store.get_events
  268. store.register_events(graph.walk())
  269. context_store = {}
  270. for event in graph.walk():
  271. context = yield self.state.compute_event_context(event)
  272. store.store_state_groups(event, context)
  273. context_store[event.event_id] = context
  274. self.assertSetEqual(
  275. {"START", "A", "B", "C"},
  276. {e for e in context_store["E"].prev_state_ids.values()}
  277. )
  278. @defer.inlineCallbacks
  279. def test_branch_have_perms_conflict(self):
  280. userid1 = "@user_id:example.com"
  281. userid2 = "@user_id2:example.com"
  282. nodes = {
  283. "A1": DictObj(
  284. type=EventTypes.Create,
  285. state_key="",
  286. content={"creator": userid1},
  287. depth=1,
  288. ),
  289. "A2": DictObj(
  290. type=EventTypes.Member,
  291. state_key=userid1,
  292. content={"membership": Membership.JOIN},
  293. membership=Membership.JOIN,
  294. ),
  295. "A3": DictObj(
  296. type=EventTypes.Member,
  297. state_key=userid2,
  298. content={"membership": Membership.JOIN},
  299. membership=Membership.JOIN,
  300. ),
  301. "A4": DictObj(
  302. type=EventTypes.PowerLevels,
  303. state_key="",
  304. content={
  305. "events": {"m.room.name": 50},
  306. "users": {userid1: 100,
  307. userid2: 60},
  308. },
  309. ),
  310. "A5": DictObj(
  311. type=EventTypes.Name,
  312. state_key="",
  313. ),
  314. "B": DictObj(
  315. type=EventTypes.PowerLevels,
  316. state_key="",
  317. content={
  318. "events": {"m.room.name": 50},
  319. "users": {userid2: 30},
  320. },
  321. ),
  322. "C": DictObj(
  323. type=EventTypes.Name,
  324. state_key="",
  325. sender=userid2,
  326. ),
  327. "D": DictObj(
  328. type=EventTypes.Message,
  329. ),
  330. }
  331. edges = {
  332. "A2": ["A1"],
  333. "A3": ["A2"],
  334. "A4": ["A3"],
  335. "A5": ["A4"],
  336. "B": ["A5"],
  337. "C": ["A5"],
  338. "D": ["B", "C"]
  339. }
  340. self._add_depths(nodes, edges)
  341. graph = Graph(nodes, edges)
  342. store = StateGroupStore()
  343. self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
  344. self.store.get_events = store.get_events
  345. store.register_events(graph.walk())
  346. context_store = {}
  347. for event in graph.walk():
  348. context = yield self.state.compute_event_context(event)
  349. store.store_state_groups(event, context)
  350. context_store[event.event_id] = context
  351. self.assertSetEqual(
  352. {"A1", "A2", "A3", "A5", "B"},
  353. {e for e in context_store["D"].prev_state_ids.values()}
  354. )
  355. def _add_depths(self, nodes, edges):
  356. def _get_depth(ev):
  357. node = nodes[ev]
  358. if 'depth' not in node:
  359. prevs = edges[ev]
  360. depth = max(_get_depth(prev) for prev in prevs) + 1
  361. node['depth'] = depth
  362. return node['depth']
  363. for n in nodes:
  364. _get_depth(n)
  365. @defer.inlineCallbacks
  366. def test_annotate_with_old_message(self):
  367. event = create_event(type="test_message", name="event")
  368. old_state = [
  369. create_event(type="test1", state_key="1"),
  370. create_event(type="test1", state_key="2"),
  371. create_event(type="test2", state_key=""),
  372. ]
  373. context = yield self.state.compute_event_context(
  374. event, old_state=old_state
  375. )
  376. self.assertEqual(
  377. set(e.event_id for e in old_state), set(context.current_state_ids.values())
  378. )
  379. self.assertIsNotNone(context.state_group)
  380. @defer.inlineCallbacks
  381. def test_annotate_with_old_state(self):
  382. event = create_event(type="state", state_key="", name="event")
  383. old_state = [
  384. create_event(type="test1", state_key="1"),
  385. create_event(type="test1", state_key="2"),
  386. create_event(type="test2", state_key=""),
  387. ]
  388. context = yield self.state.compute_event_context(
  389. event, old_state=old_state
  390. )
  391. self.assertEqual(
  392. set(e.event_id for e in old_state), set(context.prev_state_ids.values())
  393. )
  394. @defer.inlineCallbacks
  395. def test_trivial_annotate_message(self):
  396. event = create_event(type="test_message", name="event")
  397. old_state = [
  398. create_event(type="test1", state_key="1"),
  399. create_event(type="test1", state_key="2"),
  400. create_event(type="test2", state_key=""),
  401. ]
  402. group_name = "group_name_1"
  403. self.store.get_state_groups_ids.return_value = {
  404. group_name: {(e.type, e.state_key): e.event_id for e in old_state},
  405. }
  406. context = yield self.state.compute_event_context(event)
  407. self.assertEqual(
  408. set([e.event_id for e in old_state]),
  409. set(context.current_state_ids.values())
  410. )
  411. self.assertEqual(group_name, context.state_group)
  412. @defer.inlineCallbacks
  413. def test_trivial_annotate_state(self):
  414. event = create_event(type="state", state_key="", name="event")
  415. old_state = [
  416. create_event(type="test1", state_key="1"),
  417. create_event(type="test1", state_key="2"),
  418. create_event(type="test2", state_key=""),
  419. ]
  420. group_name = "group_name_1"
  421. self.store.get_state_groups_ids.return_value = {
  422. group_name: {(e.type, e.state_key): e.event_id for e in old_state},
  423. }
  424. context = yield self.state.compute_event_context(event)
  425. self.assertEqual(
  426. set([e.event_id for e in old_state]),
  427. set(context.prev_state_ids.values())
  428. )
  429. self.assertIsNotNone(context.state_group)
  430. @defer.inlineCallbacks
  431. def test_resolve_message_conflict(self):
  432. event = create_event(type="test_message", name="event")
  433. creation = create_event(
  434. type=EventTypes.Create, state_key=""
  435. )
  436. old_state_1 = [
  437. creation,
  438. create_event(type="test1", state_key="1"),
  439. create_event(type="test1", state_key="2"),
  440. create_event(type="test2", state_key=""),
  441. ]
  442. old_state_2 = [
  443. creation,
  444. create_event(type="test1", state_key="1"),
  445. create_event(type="test3", state_key="2"),
  446. create_event(type="test4", state_key=""),
  447. ]
  448. store = StateGroupStore()
  449. store.register_events(old_state_1)
  450. store.register_events(old_state_2)
  451. self.store.get_events = store.get_events
  452. context = yield self._get_context(event, old_state_1, old_state_2)
  453. self.assertEqual(len(context.current_state_ids), 6)
  454. self.assertIsNotNone(context.state_group)
  455. @defer.inlineCallbacks
  456. def test_resolve_state_conflict(self):
  457. event = create_event(type="test4", state_key="", name="event")
  458. creation = create_event(
  459. type=EventTypes.Create, state_key=""
  460. )
  461. old_state_1 = [
  462. creation,
  463. create_event(type="test1", state_key="1"),
  464. create_event(type="test1", state_key="2"),
  465. create_event(type="test2", state_key=""),
  466. ]
  467. old_state_2 = [
  468. creation,
  469. create_event(type="test1", state_key="1"),
  470. create_event(type="test3", state_key="2"),
  471. create_event(type="test4", state_key=""),
  472. ]
  473. store = StateGroupStore()
  474. store.register_events(old_state_1)
  475. store.register_events(old_state_2)
  476. self.store.get_events = store.get_events
  477. context = yield self._get_context(event, old_state_1, old_state_2)
  478. self.assertEqual(len(context.current_state_ids), 6)
  479. self.assertIsNotNone(context.state_group)
  480. @defer.inlineCallbacks
  481. def test_standard_depth_conflict(self):
  482. event = create_event(type="test4", name="event")
  483. member_event = create_event(
  484. type=EventTypes.Member,
  485. state_key="@user_id:example.com",
  486. content={
  487. "membership": Membership.JOIN,
  488. }
  489. )
  490. creation = create_event(
  491. type=EventTypes.Create, state_key="",
  492. content={"creator": "@foo:bar"}
  493. )
  494. old_state_1 = [
  495. creation,
  496. member_event,
  497. create_event(type="test1", state_key="1", depth=1),
  498. ]
  499. old_state_2 = [
  500. creation,
  501. member_event,
  502. create_event(type="test1", state_key="1", depth=2),
  503. ]
  504. store = StateGroupStore()
  505. store.register_events(old_state_1)
  506. store.register_events(old_state_2)
  507. self.store.get_events = store.get_events
  508. context = yield self._get_context(event, old_state_1, old_state_2)
  509. self.assertEqual(
  510. old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
  511. )
  512. # Reverse the depth to make sure we are actually using the depths
  513. # during state resolution.
  514. old_state_1 = [
  515. creation,
  516. member_event,
  517. create_event(type="test1", state_key="1", depth=2),
  518. ]
  519. old_state_2 = [
  520. creation,
  521. member_event,
  522. create_event(type="test1", state_key="1", depth=1),
  523. ]
  524. store.register_events(old_state_1)
  525. store.register_events(old_state_2)
  526. context = yield self._get_context(event, old_state_1, old_state_2)
  527. self.assertEqual(
  528. old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
  529. )
  530. def _get_context(self, event, old_state_1, old_state_2):
  531. group_name_1 = "group_name_1"
  532. group_name_2 = "group_name_2"
  533. self.store.get_state_groups_ids.return_value = {
  534. group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
  535. group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
  536. }
  537. return self.state.compute_event_context(event)