1
0

test_state.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Collection, Dict, List, Optional, cast
  15. from unittest.mock import Mock
  16. from twisted.internet import defer
  17. from synapse.api.auth import Auth
  18. from synapse.api.constants import EventTypes, Membership
  19. from synapse.api.room_versions import RoomVersions
  20. from synapse.events import make_event_from_dict
  21. from synapse.events.snapshot import EventContext
  22. from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
  23. from synapse.util import Clock
  24. from synapse.util.macaroons import MacaroonGenerator
  25. from tests import unittest
  26. from .utils import MockClock, default_config
  27. _next_event_id = 1000
  28. def create_event(
  29. name=None,
  30. type=None,
  31. state_key=None,
  32. depth=2,
  33. event_id=None,
  34. prev_events: Optional[List[str]] = None,
  35. **kwargs,
  36. ):
  37. global _next_event_id
  38. if not event_id:
  39. _next_event_id += 1
  40. event_id = "$%s:test" % (_next_event_id,)
  41. if not name:
  42. if state_key is not None:
  43. name = "<%s-%s, %s>" % (type, state_key, event_id)
  44. else:
  45. name = "<%s, %s>" % (type, event_id)
  46. d = {
  47. "event_id": event_id,
  48. "type": type,
  49. "sender": "@user_id:example.com",
  50. "room_id": "!room_id:example.com",
  51. "depth": depth,
  52. "prev_events": prev_events or [],
  53. }
  54. if state_key is not None:
  55. d["state_key"] = state_key
  56. d.update(kwargs)
  57. event = make_event_from_dict(d)
  58. return event
  59. class _DummyStore:
  60. def __init__(self):
  61. self._event_to_state_group = {}
  62. self._group_to_state = {}
  63. self._event_id_to_event = {}
  64. self._next_group = 1
  65. async def get_state_groups_ids(self, room_id, event_ids):
  66. groups = {}
  67. for event_id in event_ids:
  68. group = self._event_to_state_group.get(event_id)
  69. if group:
  70. groups[group] = self._group_to_state[group]
  71. return groups
  72. async def get_state_ids_for_group(self, state_group, state_filter=None):
  73. return self._group_to_state[state_group]
  74. async def store_state_group(
  75. self, event_id, room_id, prev_group, delta_ids, current_state_ids
  76. ):
  77. state_group = self._next_group
  78. self._next_group += 1
  79. if current_state_ids is None:
  80. current_state_ids = dict(self._group_to_state[prev_group])
  81. current_state_ids.update(delta_ids)
  82. self._group_to_state[state_group] = dict(current_state_ids)
  83. return state_group
  84. async def get_events(self, event_ids, **kwargs):
  85. return {
  86. e_id: self._event_id_to_event[e_id]
  87. for e_id in event_ids
  88. if e_id in self._event_id_to_event
  89. }
  90. async def get_partial_state_events(
  91. self, event_ids: Collection[str]
  92. ) -> Dict[str, bool]:
  93. return {e: False for e in event_ids}
  94. async def get_state_group_delta(self, name):
  95. return None, None
  96. def register_events(self, events):
  97. for e in events:
  98. self._event_id_to_event[e.event_id] = e
  99. def register_event_context(self, event, context):
  100. self._event_to_state_group[event.event_id] = context.state_group
  101. def register_event_id_state_group(self, event_id, state_group):
  102. self._event_to_state_group[event_id] = state_group
  103. async def get_room_version_id(self, room_id):
  104. return RoomVersions.V1.identifier
  105. async def get_state_group_for_events(
  106. self, event_ids, await_full_state: bool = True
  107. ):
  108. res = {}
  109. for event in event_ids:
  110. res[event] = self._event_to_state_group[event]
  111. return res
  112. async def get_state_for_groups(self, groups):
  113. res = {}
  114. for group in groups:
  115. state = self._group_to_state[group]
  116. res[group] = state
  117. return res
  118. class DictObj(dict):
  119. def __init__(self, **kwargs):
  120. super().__init__(kwargs)
  121. self.__dict__ = self
  122. class Graph:
  123. def __init__(self, nodes, edges):
  124. events = {}
  125. clobbered = set(events.keys())
  126. for event_id, fields in nodes.items():
  127. refs = edges.get(event_id)
  128. if refs:
  129. clobbered.difference_update(refs)
  130. prev_events = [(r, {}) for r in refs]
  131. else:
  132. prev_events = []
  133. events[event_id] = create_event(
  134. event_id=event_id, prev_events=prev_events, **fields
  135. )
  136. self._leaves = clobbered
  137. self._events = sorted(events.values(), key=lambda e: e.depth)
  138. def walk(self):
  139. return iter(self._events)
  140. def get_leaves(self):
  141. return (self._events[i] for i in self._leaves)
  142. class StateTestCase(unittest.TestCase):
  143. def setUp(self):
  144. self.dummy_store = _DummyStore()
  145. storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
  146. hs = Mock(
  147. spec_set=[
  148. "config",
  149. "get_datastores",
  150. "get_storage_controllers",
  151. "get_auth",
  152. "get_state_handler",
  153. "get_clock",
  154. "get_state_resolution_handler",
  155. "get_account_validity_handler",
  156. "get_macaroon_generator",
  157. "get_instance_name",
  158. "get_simple_http_client",
  159. "hostname",
  160. ]
  161. )
  162. clock = cast(Clock, MockClock())
  163. hs.config = default_config("tesths", True)
  164. hs.get_datastores.return_value = Mock(main=self.dummy_store)
  165. hs.get_state_handler.return_value = None
  166. hs.get_clock.return_value = clock
  167. hs.get_macaroon_generator.return_value = MacaroonGenerator(
  168. clock, "tesths", b"verysecret"
  169. )
  170. hs.get_auth.return_value = Auth(hs)
  171. hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
  172. hs.get_storage_controllers.return_value = storage_controllers
  173. self.state = StateHandler(hs)
  174. self.event_id = 0
  175. @defer.inlineCallbacks
  176. def test_branch_no_conflict(self):
  177. graph = Graph(
  178. nodes={
  179. "START": DictObj(
  180. type=EventTypes.Create, state_key="", content={}, depth=1
  181. ),
  182. "A": DictObj(type=EventTypes.Message, depth=2),
  183. "B": DictObj(type=EventTypes.Message, depth=3),
  184. "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
  185. "D": DictObj(type=EventTypes.Message, depth=4),
  186. },
  187. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  188. )
  189. self.dummy_store.register_events(graph.walk())
  190. context_store: dict[str, EventContext] = {}
  191. for event in graph.walk():
  192. context = yield defer.ensureDeferred(
  193. self.state.compute_event_context(event)
  194. )
  195. self.dummy_store.register_event_context(event, context)
  196. context_store[event.event_id] = context
  197. ctx_c = context_store["C"]
  198. ctx_d = context_store["D"]
  199. prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
  200. self.assertEqual(2, len(prev_state_ids))
  201. self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
  202. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  203. @defer.inlineCallbacks
  204. def test_branch_basic_conflict(self):
  205. graph = Graph(
  206. nodes={
  207. "START": DictObj(
  208. type=EventTypes.Create,
  209. state_key="",
  210. content={"creator": "@user_id:example.com"},
  211. depth=1,
  212. ),
  213. "A": DictObj(
  214. type=EventTypes.Member,
  215. state_key="@user_id:example.com",
  216. content={"membership": Membership.JOIN},
  217. membership=Membership.JOIN,
  218. depth=2,
  219. ),
  220. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  221. "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
  222. "D": DictObj(type=EventTypes.Message, depth=5),
  223. },
  224. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  225. )
  226. self.dummy_store.register_events(graph.walk())
  227. context_store = {}
  228. for event in graph.walk():
  229. context = yield defer.ensureDeferred(
  230. self.state.compute_event_context(event)
  231. )
  232. self.dummy_store.register_event_context(event, context)
  233. context_store[event.event_id] = context
  234. # C ends up winning the resolution between B and C
  235. ctx_c = context_store["C"]
  236. ctx_d = context_store["D"]
  237. prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
  238. self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
  239. self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
  240. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  241. @defer.inlineCallbacks
  242. def test_branch_have_banned_conflict(self):
  243. graph = Graph(
  244. nodes={
  245. "START": DictObj(
  246. type=EventTypes.Create,
  247. state_key="",
  248. content={"creator": "@user_id:example.com"},
  249. depth=1,
  250. ),
  251. "A": DictObj(
  252. type=EventTypes.Member,
  253. state_key="@user_id:example.com",
  254. content={"membership": Membership.JOIN},
  255. membership=Membership.JOIN,
  256. depth=2,
  257. ),
  258. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  259. "C": DictObj(
  260. type=EventTypes.Member,
  261. state_key="@user_id_2:example.com",
  262. content={"membership": Membership.BAN},
  263. membership=Membership.BAN,
  264. depth=4,
  265. ),
  266. "D": DictObj(
  267. type=EventTypes.Name,
  268. state_key="",
  269. depth=4,
  270. sender="@user_id_2:example.com",
  271. ),
  272. "E": DictObj(type=EventTypes.Message, depth=5),
  273. },
  274. edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
  275. )
  276. self.dummy_store.register_events(graph.walk())
  277. context_store = {}
  278. for event in graph.walk():
  279. context = yield defer.ensureDeferred(
  280. self.state.compute_event_context(event)
  281. )
  282. self.dummy_store.register_event_context(event, context)
  283. context_store[event.event_id] = context
  284. # C ends up winning the resolution between C and D because bans win over other
  285. # changes
  286. ctx_c = context_store["C"]
  287. ctx_e = context_store["E"]
  288. prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
  289. self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
  290. self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
  291. self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
  292. @defer.inlineCallbacks
  293. def test_branch_have_perms_conflict(self):
  294. userid1 = "@user_id:example.com"
  295. userid2 = "@user_id2:example.com"
  296. nodes = {
  297. "A1": DictObj(
  298. type=EventTypes.Create,
  299. state_key="",
  300. content={"creator": userid1},
  301. depth=1,
  302. ),
  303. "A2": DictObj(
  304. type=EventTypes.Member,
  305. state_key=userid1,
  306. content={"membership": Membership.JOIN},
  307. membership=Membership.JOIN,
  308. ),
  309. "A3": DictObj(
  310. type=EventTypes.Member,
  311. state_key=userid2,
  312. content={"membership": Membership.JOIN},
  313. membership=Membership.JOIN,
  314. ),
  315. "A4": DictObj(
  316. type=EventTypes.PowerLevels,
  317. state_key="",
  318. content={
  319. "events": {"m.room.name": 50},
  320. "users": {userid1: 100, userid2: 60},
  321. },
  322. ),
  323. "A5": DictObj(type=EventTypes.Name, state_key=""),
  324. "B": DictObj(
  325. type=EventTypes.PowerLevels,
  326. state_key="",
  327. content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
  328. ),
  329. "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
  330. "D": DictObj(type=EventTypes.Message),
  331. }
  332. edges = {
  333. "A2": ["A1"],
  334. "A3": ["A2"],
  335. "A4": ["A3"],
  336. "A5": ["A4"],
  337. "B": ["A5"],
  338. "C": ["A5"],
  339. "D": ["B", "C"],
  340. }
  341. self._add_depths(nodes, edges)
  342. graph = Graph(nodes, edges)
  343. self.dummy_store.register_events(graph.walk())
  344. context_store = {}
  345. for event in graph.walk():
  346. context = yield defer.ensureDeferred(
  347. self.state.compute_event_context(event)
  348. )
  349. self.dummy_store.register_event_context(event, context)
  350. context_store[event.event_id] = context
  351. # B ends up winning the resolution between B and C because power levels
  352. # win over other changes.
  353. ctx_b = context_store["B"]
  354. ctx_d = context_store["D"]
  355. prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
  356. self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
  357. self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
  358. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  359. def _add_depths(self, nodes, edges):
  360. def _get_depth(ev):
  361. node = nodes[ev]
  362. if "depth" not in node:
  363. prevs = edges[ev]
  364. depth = max(_get_depth(prev) for prev in prevs) + 1
  365. node["depth"] = depth
  366. return node["depth"]
  367. for n in nodes:
  368. _get_depth(n)
  369. @defer.inlineCallbacks
  370. def test_annotate_with_old_message(self):
  371. event = create_event(type="test_message", name="event")
  372. old_state = [
  373. create_event(type="test1", state_key="1"),
  374. create_event(type="test1", state_key="2"),
  375. create_event(type="test2", state_key=""),
  376. ]
  377. context = yield defer.ensureDeferred(
  378. self.state.compute_event_context(
  379. event,
  380. state_ids_before_event={
  381. (e.type, e.state_key): e.event_id for e in old_state
  382. },
  383. partial_state=False,
  384. )
  385. )
  386. prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
  387. self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
  388. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  389. self.assertCountEqual(
  390. (e.event_id for e in old_state), current_state_ids.values()
  391. )
  392. self.assertIsNotNone(context.state_group_before_event)
  393. self.assertEqual(context.state_group_before_event, context.state_group)
  394. @defer.inlineCallbacks
  395. def test_annotate_with_old_state(self):
  396. event = create_event(type="state", state_key="", name="event")
  397. old_state = [
  398. create_event(type="test1", state_key="1"),
  399. create_event(type="test1", state_key="2"),
  400. create_event(type="test2", state_key=""),
  401. ]
  402. context = yield defer.ensureDeferred(
  403. self.state.compute_event_context(
  404. event,
  405. state_ids_before_event={
  406. (e.type, e.state_key): e.event_id for e in old_state
  407. },
  408. partial_state=False,
  409. )
  410. )
  411. prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
  412. self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
  413. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  414. self.assertCountEqual(
  415. (e.event_id for e in old_state + [event]), current_state_ids.values()
  416. )
  417. self.assertIsNotNone(context.state_group_before_event)
  418. self.assertNotEqual(context.state_group_before_event, context.state_group)
  419. self.assertEqual(context.state_group_before_event, context.prev_group)
  420. self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
  421. @defer.inlineCallbacks
  422. def test_trivial_annotate_message(self):
  423. prev_event_id = "prev_event_id"
  424. event = create_event(
  425. type="test_message", name="event2", prev_events=[(prev_event_id, {})]
  426. )
  427. old_state = [
  428. create_event(type="test1", state_key="1"),
  429. create_event(type="test1", state_key="2"),
  430. create_event(type="test2", state_key=""),
  431. ]
  432. group_name = yield defer.ensureDeferred(
  433. self.dummy_store.store_state_group(
  434. prev_event_id,
  435. event.room_id,
  436. None,
  437. None,
  438. {(e.type, e.state_key): e.event_id for e in old_state},
  439. )
  440. )
  441. self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
  442. context = yield defer.ensureDeferred(self.state.compute_event_context(event))
  443. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  444. self.assertEqual(
  445. {e.event_id for e in old_state}, set(current_state_ids.values())
  446. )
  447. self.assertEqual(group_name, context.state_group)
  448. @defer.inlineCallbacks
  449. def test_trivial_annotate_state(self):
  450. prev_event_id = "prev_event_id"
  451. event = create_event(
  452. type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
  453. )
  454. old_state = [
  455. create_event(type="test1", state_key="1"),
  456. create_event(type="test1", state_key="2"),
  457. create_event(type="test2", state_key=""),
  458. ]
  459. group_name = yield defer.ensureDeferred(
  460. self.dummy_store.store_state_group(
  461. prev_event_id,
  462. event.room_id,
  463. None,
  464. None,
  465. {(e.type, e.state_key): e.event_id for e in old_state},
  466. )
  467. )
  468. self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
  469. context = yield defer.ensureDeferred(self.state.compute_event_context(event))
  470. prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
  471. self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
  472. self.assertIsNotNone(context.state_group)
  473. @defer.inlineCallbacks
  474. def test_resolve_message_conflict(self):
  475. prev_event_id1 = "event_id1"
  476. prev_event_id2 = "event_id2"
  477. event = create_event(
  478. type="test_message",
  479. name="event3",
  480. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  481. )
  482. creation = create_event(type=EventTypes.Create, state_key="")
  483. old_state_1 = [
  484. creation,
  485. create_event(type="test1", state_key="1"),
  486. create_event(type="test1", state_key="2"),
  487. create_event(type="test2", state_key=""),
  488. ]
  489. old_state_2 = [
  490. creation,
  491. create_event(type="test1", state_key="1"),
  492. create_event(type="test3", state_key="2"),
  493. create_event(type="test4", state_key=""),
  494. ]
  495. self.dummy_store.register_events(old_state_1)
  496. self.dummy_store.register_events(old_state_2)
  497. context = yield self._get_context(
  498. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  499. )
  500. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  501. self.assertEqual(len(current_state_ids), 6)
  502. self.assertIsNotNone(context.state_group)
  503. @defer.inlineCallbacks
  504. def test_resolve_state_conflict(self):
  505. prev_event_id1 = "event_id1"
  506. prev_event_id2 = "event_id2"
  507. event = create_event(
  508. type="test4",
  509. state_key="",
  510. name="event",
  511. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  512. )
  513. creation = create_event(type=EventTypes.Create, state_key="")
  514. old_state_1 = [
  515. creation,
  516. create_event(type="test1", state_key="1"),
  517. create_event(type="test1", state_key="2"),
  518. create_event(type="test2", state_key=""),
  519. ]
  520. old_state_2 = [
  521. creation,
  522. create_event(type="test1", state_key="1"),
  523. create_event(type="test3", state_key="2"),
  524. create_event(type="test4", state_key=""),
  525. ]
  526. store = _DummyStore()
  527. store.register_events(old_state_1)
  528. store.register_events(old_state_2)
  529. self.dummy_store.get_events = store.get_events
  530. context = yield self._get_context(
  531. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  532. )
  533. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  534. self.assertEqual(len(current_state_ids), 6)
  535. self.assertIsNotNone(context.state_group)
  536. @defer.inlineCallbacks
  537. def test_standard_depth_conflict(self):
  538. prev_event_id1 = "event_id1"
  539. prev_event_id2 = "event_id2"
  540. event = create_event(
  541. type="test4",
  542. name="event",
  543. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  544. )
  545. member_event = create_event(
  546. type=EventTypes.Member,
  547. state_key="@user_id:example.com",
  548. content={"membership": Membership.JOIN},
  549. )
  550. power_levels = create_event(
  551. type=EventTypes.PowerLevels,
  552. state_key="",
  553. content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
  554. )
  555. creation = create_event(
  556. type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
  557. )
  558. old_state_1 = [
  559. creation,
  560. power_levels,
  561. member_event,
  562. create_event(type="test1", state_key="1", depth=1),
  563. ]
  564. old_state_2 = [
  565. creation,
  566. power_levels,
  567. member_event,
  568. create_event(type="test1", state_key="1", depth=2),
  569. ]
  570. store = _DummyStore()
  571. store.register_events(old_state_1)
  572. store.register_events(old_state_2)
  573. self.dummy_store.get_events = store.get_events
  574. context = yield self._get_context(
  575. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  576. )
  577. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  578. self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
  579. # Reverse the depth to make sure we are actually using the depths
  580. # during state resolution.
  581. old_state_1 = [
  582. creation,
  583. power_levels,
  584. member_event,
  585. create_event(type="test1", state_key="1", depth=2),
  586. ]
  587. old_state_2 = [
  588. creation,
  589. power_levels,
  590. member_event,
  591. create_event(type="test1", state_key="1", depth=1),
  592. ]
  593. store.register_events(old_state_1)
  594. store.register_events(old_state_2)
  595. context = yield self._get_context(
  596. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  597. )
  598. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  599. self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
  600. @defer.inlineCallbacks
  601. def _get_context(
  602. self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
  603. ):
  604. sg1 = yield defer.ensureDeferred(
  605. self.dummy_store.store_state_group(
  606. prev_event_id_1,
  607. event.room_id,
  608. None,
  609. None,
  610. {(e.type, e.state_key): e.event_id for e in old_state_1},
  611. )
  612. )
  613. self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
  614. sg2 = yield defer.ensureDeferred(
  615. self.dummy_store.store_state_group(
  616. prev_event_id_2,
  617. event.room_id,
  618. None,
  619. None,
  620. {(e.type, e.state_key): e.event_id for e in old_state_2},
  621. )
  622. )
  623. self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2)
  624. result = yield defer.ensureDeferred(self.state.compute_event_context(event))
  625. return result
  626. def test_make_state_cache_entry(self):
  627. "Test that calculating a prev_group and delta is correct"
  628. new_state = {
  629. ("a", ""): "E",
  630. ("b", ""): "E",
  631. ("c", ""): "E",
  632. ("d", ""): "E",
  633. }
  634. # old_state_1 has fewer differences to new_state than old_state_2, but
  635. # the delta involves deleting a key, which isn't allowed in the deltas,
  636. # so we should pick old_state_2 as the prev_group.
  637. # `old_state_1` has two differences: `a` and `e`
  638. old_state_1 = {
  639. ("a", ""): "F",
  640. ("b", ""): "E",
  641. ("c", ""): "E",
  642. ("d", ""): "E",
  643. ("e", ""): "E",
  644. }
  645. # `old_state_2` has three differences: `a`, `c` and `d`
  646. old_state_2 = {
  647. ("a", ""): "F",
  648. ("b", ""): "E",
  649. ("c", ""): "F",
  650. ("d", ""): "F",
  651. }
  652. entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2})
  653. self.assertEqual(entry.prev_group, 2)
  654. # There are three changes from `old_state_2` to `new_state`
  655. self.assertEqual(
  656. entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"}
  657. )