test_state.py 29 KB


  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import (
  15. Any,
  16. Collection,
  17. Dict,
  18. Generator,
  19. Iterable,
  20. Iterator,
  21. List,
  22. Optional,
  23. Set,
  24. Tuple,
  25. cast,
  26. )
  27. from unittest.mock import Mock
  28. from twisted.internet import defer
  29. from synapse.api.auth.internal import InternalAuth
  30. from synapse.api.constants import EventTypes, Membership
  31. from synapse.api.room_versions import RoomVersions
  32. from synapse.events import EventBase, make_event_from_dict
  33. from synapse.events.snapshot import EventContext
  34. from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
  35. from synapse.types import MutableStateMap, StateMap
  36. from synapse.types.state import StateFilter
  37. from synapse.util import Clock
  38. from synapse.util.macaroons import MacaroonGenerator
  39. from tests import unittest
  40. from .utils import MockClock, default_config
  41. _next_event_id = 1000
  42. def create_event(
  43. name: Optional[str] = None,
  44. type: Optional[str] = None,
  45. state_key: Optional[str] = None,
  46. depth: int = 2,
  47. event_id: Optional[str] = None,
  48. prev_events: Optional[List[Tuple[str, dict]]] = None,
  49. **kwargs: Any,
  50. ) -> EventBase:
  51. global _next_event_id
  52. if not event_id:
  53. _next_event_id += 1
  54. event_id = "$%s:test" % (_next_event_id,)
  55. if not name:
  56. if state_key is not None:
  57. name = "<%s-%s, %s>" % (type, state_key, event_id)
  58. else:
  59. name = "<%s, %s>" % (type, event_id)
  60. d = {
  61. "event_id": event_id,
  62. "type": type,
  63. "sender": "@user_id:example.com",
  64. "room_id": "!room_id:example.com",
  65. "depth": depth,
  66. "prev_events": prev_events or [],
  67. }
  68. if state_key is not None:
  69. d["state_key"] = state_key
  70. d.update(kwargs)
  71. return make_event_from_dict(d)
  72. class _DummyStore:
  73. def __init__(self) -> None:
  74. self._event_to_state_group: Dict[str, int] = {}
  75. self._group_to_state: Dict[int, MutableStateMap[str]] = {}
  76. self._event_id_to_event: Dict[str, EventBase] = {}
  77. self._next_group = 1
  78. async def get_state_groups_ids(
  79. self, room_id: str, event_ids: Collection[str]
  80. ) -> Dict[int, MutableStateMap[str]]:
  81. groups = {}
  82. for event_id in event_ids:
  83. group = self._event_to_state_group.get(event_id)
  84. if group:
  85. groups[group] = self._group_to_state[group]
  86. return groups
  87. async def get_state_ids_for_group(
  88. self, state_group: int, state_filter: Optional[StateFilter] = None
  89. ) -> MutableStateMap[str]:
  90. return self._group_to_state[state_group]
  91. async def store_state_group(
  92. self,
  93. event_id: str,
  94. room_id: str,
  95. prev_group: Optional[int],
  96. delta_ids: Optional[StateMap[str]],
  97. current_state_ids: Optional[StateMap[str]],
  98. ) -> int:
  99. state_group = self._next_group
  100. self._next_group += 1
  101. if current_state_ids is None:
  102. assert prev_group is not None
  103. assert delta_ids is not None
  104. current_state_ids = dict(self._group_to_state[prev_group])
  105. current_state_ids.update(delta_ids)
  106. self._group_to_state[state_group] = dict(current_state_ids)
  107. return state_group
  108. async def get_events(
  109. self, event_ids: Collection[str], **kwargs: Any
  110. ) -> Dict[str, EventBase]:
  111. return {
  112. e_id: self._event_id_to_event[e_id]
  113. for e_id in event_ids
  114. if e_id in self._event_id_to_event
  115. }
  116. async def get_partial_state_events(
  117. self, event_ids: Collection[str]
  118. ) -> Dict[str, bool]:
  119. return {e: False for e in event_ids}
  120. async def get_state_group_delta(
  121. self, name: str
  122. ) -> Tuple[Optional[int], Optional[StateMap[str]]]:
  123. return None, None
  124. def register_events(self, events: Iterable[EventBase]) -> None:
  125. for e in events:
  126. self._event_id_to_event[e.event_id] = e
  127. def register_event_context(self, event: EventBase, context: EventContext) -> None:
  128. assert context.state_group is not None
  129. self._event_to_state_group[event.event_id] = context.state_group
  130. def register_event_id_state_group(self, event_id: str, state_group: int) -> None:
  131. self._event_to_state_group[event_id] = state_group
  132. async def get_room_version_id(self, room_id: str) -> str:
  133. return RoomVersions.V1.identifier
  134. async def get_state_group_for_events(
  135. self, event_ids: Collection[str], await_full_state: bool = True
  136. ) -> Dict[str, int]:
  137. res = {}
  138. for event in event_ids:
  139. res[event] = self._event_to_state_group[event]
  140. return res
  141. async def get_state_for_groups(
  142. self, groups: Collection[int]
  143. ) -> Dict[int, MutableStateMap[str]]:
  144. res = {}
  145. for group in groups:
  146. state = self._group_to_state[group]
  147. res[group] = state
  148. return res
  149. class DictObj(dict):
  150. def __init__(self, **kwargs: Any) -> None:
  151. super().__init__(kwargs)
  152. self.__dict__ = self
  153. class Graph:
  154. def __init__(self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]]):
  155. events: Dict[str, EventBase] = {}
  156. clobbered: Set[str] = set()
  157. for event_id, fields in nodes.items():
  158. refs = edges.get(event_id)
  159. if refs:
  160. clobbered.difference_update(refs)
  161. prev_events: List[Tuple[str, dict]] = [(r, {}) for r in refs]
  162. else:
  163. prev_events = []
  164. events[event_id] = create_event(
  165. event_id=event_id, prev_events=prev_events, **fields
  166. )
  167. self._leaves = clobbered
  168. self._events = sorted(events.values(), key=lambda e: e.depth)
  169. def walk(self) -> Iterator[EventBase]:
  170. return iter(self._events)
  171. class StateTestCase(unittest.TestCase):
  172. def setUp(self) -> None:
  173. self.dummy_store = _DummyStore()
  174. storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store)
  175. hs = Mock(
  176. spec_set=[
  177. "config",
  178. "get_datastores",
  179. "get_storage_controllers",
  180. "get_auth",
  181. "get_state_handler",
  182. "get_clock",
  183. "get_state_resolution_handler",
  184. "get_account_validity_handler",
  185. "get_macaroon_generator",
  186. "get_instance_name",
  187. "get_simple_http_client",
  188. "get_replication_client",
  189. "hostname",
  190. ]
  191. )
  192. clock = cast(Clock, MockClock())
  193. hs.config = default_config("tesths", True)
  194. hs.get_datastores.return_value = Mock(main=self.dummy_store)
  195. hs.get_state_handler.return_value = None
  196. hs.get_clock.return_value = clock
  197. hs.get_macaroon_generator.return_value = MacaroonGenerator(
  198. clock, "tesths", b"verysecret"
  199. )
  200. hs.get_auth.return_value = InternalAuth(hs)
  201. hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
  202. hs.get_storage_controllers.return_value = storage_controllers
  203. self.state = StateHandler(hs)
  204. self.event_id = 0
  205. @defer.inlineCallbacks
  206. def test_branch_no_conflict(self) -> Generator[defer.Deferred, Any, None]:
  207. graph = Graph(
  208. nodes={
  209. "START": DictObj(
  210. type=EventTypes.Create, state_key="", content={}, depth=1
  211. ),
  212. "A": DictObj(type=EventTypes.Message, depth=2),
  213. "B": DictObj(type=EventTypes.Message, depth=3),
  214. "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
  215. "D": DictObj(type=EventTypes.Message, depth=4),
  216. },
  217. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  218. )
  219. self.dummy_store.register_events(graph.walk())
  220. context_store: Dict[str, EventContext] = {}
  221. for event in graph.walk():
  222. context = yield defer.ensureDeferred(
  223. self.state.compute_event_context(event)
  224. )
  225. self.dummy_store.register_event_context(event, context)
  226. context_store[event.event_id] = context
  227. ctx_c = context_store["C"]
  228. ctx_d = context_store["D"]
  229. prev_state_ids: StateMap[str]
  230. prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
  231. self.assertEqual(2, len(prev_state_ids))
  232. self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
  233. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  234. @defer.inlineCallbacks
  235. def test_branch_basic_conflict(
  236. self,
  237. ) -> Generator["defer.Deferred[object]", Any, None]:
  238. graph = Graph(
  239. nodes={
  240. "START": DictObj(
  241. type=EventTypes.Create,
  242. state_key="",
  243. content={"creator": "@user_id:example.com"},
  244. depth=1,
  245. ),
  246. "A": DictObj(
  247. type=EventTypes.Member,
  248. state_key="@user_id:example.com",
  249. content={"membership": Membership.JOIN},
  250. membership=Membership.JOIN,
  251. depth=2,
  252. ),
  253. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  254. "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
  255. "D": DictObj(type=EventTypes.Message, depth=5),
  256. },
  257. edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
  258. )
  259. self.dummy_store.register_events(graph.walk())
  260. context_store: Dict[str, EventContext] = {}
  261. for event in graph.walk():
  262. context = yield defer.ensureDeferred(
  263. self.state.compute_event_context(event)
  264. )
  265. self.dummy_store.register_event_context(event, context)
  266. context_store[event.event_id] = context
  267. # C ends up winning the resolution between B and C
  268. ctx_c = context_store["C"]
  269. ctx_d = context_store["D"]
  270. prev_state_ids: StateMap[str]
  271. prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
  272. self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
  273. self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
  274. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  275. @defer.inlineCallbacks
  276. def test_branch_have_banned_conflict(
  277. self,
  278. ) -> Generator["defer.Deferred[object]", Any, None]:
  279. graph = Graph(
  280. nodes={
  281. "START": DictObj(
  282. type=EventTypes.Create,
  283. state_key="",
  284. content={"creator": "@user_id:example.com"},
  285. depth=1,
  286. ),
  287. "A": DictObj(
  288. type=EventTypes.Member,
  289. state_key="@user_id:example.com",
  290. content={"membership": Membership.JOIN},
  291. membership=Membership.JOIN,
  292. depth=2,
  293. ),
  294. "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
  295. "C": DictObj(
  296. type=EventTypes.Member,
  297. state_key="@user_id_2:example.com",
  298. content={"membership": Membership.BAN},
  299. membership=Membership.BAN,
  300. depth=4,
  301. ),
  302. "D": DictObj(
  303. type=EventTypes.Name,
  304. state_key="",
  305. depth=4,
  306. sender="@user_id_2:example.com",
  307. ),
  308. "E": DictObj(type=EventTypes.Message, depth=5),
  309. },
  310. edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
  311. )
  312. self.dummy_store.register_events(graph.walk())
  313. context_store: Dict[str, EventContext] = {}
  314. for event in graph.walk():
  315. context = yield defer.ensureDeferred(
  316. self.state.compute_event_context(event)
  317. )
  318. self.dummy_store.register_event_context(event, context)
  319. context_store[event.event_id] = context
  320. # C ends up winning the resolution between C and D because bans win over other
  321. # changes
  322. ctx_c = context_store["C"]
  323. ctx_e = context_store["E"]
  324. prev_state_ids: StateMap[str]
  325. prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
  326. self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
  327. self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
  328. self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
  329. @defer.inlineCallbacks
  330. def test_branch_have_perms_conflict(
  331. self,
  332. ) -> Generator["defer.Deferred[object]", Any, None]:
  333. userid1 = "@user_id:example.com"
  334. userid2 = "@user_id2:example.com"
  335. nodes = {
  336. "A1": DictObj(
  337. type=EventTypes.Create,
  338. state_key="",
  339. content={"creator": userid1},
  340. depth=1,
  341. ),
  342. "A2": DictObj(
  343. type=EventTypes.Member,
  344. state_key=userid1,
  345. content={"membership": Membership.JOIN},
  346. membership=Membership.JOIN,
  347. ),
  348. "A3": DictObj(
  349. type=EventTypes.Member,
  350. state_key=userid2,
  351. content={"membership": Membership.JOIN},
  352. membership=Membership.JOIN,
  353. ),
  354. "A4": DictObj(
  355. type=EventTypes.PowerLevels,
  356. state_key="",
  357. content={
  358. "events": {"m.room.name": 50},
  359. "users": {userid1: 100, userid2: 60},
  360. },
  361. ),
  362. "A5": DictObj(type=EventTypes.Name, state_key=""),
  363. "B": DictObj(
  364. type=EventTypes.PowerLevels,
  365. state_key="",
  366. content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
  367. ),
  368. "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
  369. "D": DictObj(type=EventTypes.Message),
  370. }
  371. edges = {
  372. "A2": ["A1"],
  373. "A3": ["A2"],
  374. "A4": ["A3"],
  375. "A5": ["A4"],
  376. "B": ["A5"],
  377. "C": ["A5"],
  378. "D": ["B", "C"],
  379. }
  380. self._add_depths(nodes, edges)
  381. graph = Graph(nodes, edges)
  382. self.dummy_store.register_events(graph.walk())
  383. context_store: Dict[str, EventContext] = {}
  384. for event in graph.walk():
  385. context = yield defer.ensureDeferred(
  386. self.state.compute_event_context(event)
  387. )
  388. self.dummy_store.register_event_context(event, context)
  389. context_store[event.event_id] = context
  390. # B ends up winning the resolution between B and C because power levels
  391. # win over other changes.
  392. ctx_b = context_store["B"]
  393. ctx_d = context_store["D"]
  394. prev_state_ids: StateMap[str]
  395. prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
  396. self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
  397. self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
  398. self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
  399. def _add_depths(
  400. self, nodes: Dict[str, DictObj], edges: Dict[str, List[str]]
  401. ) -> None:
  402. def _get_depth(ev: str) -> int:
  403. node = nodes[ev]
  404. if "depth" not in node:
  405. prevs = edges[ev]
  406. depth = max(_get_depth(prev) for prev in prevs) + 1
  407. node["depth"] = depth
  408. return node["depth"]
  409. for n in nodes:
  410. _get_depth(n)
  411. @defer.inlineCallbacks
  412. def test_annotate_with_old_message(
  413. self,
  414. ) -> Generator["defer.Deferred[object]", Any, None]:
  415. event = create_event(type="test_message", name="event")
  416. old_state = [
  417. create_event(type="test1", state_key="1"),
  418. create_event(type="test1", state_key="2"),
  419. create_event(type="test2", state_key=""),
  420. ]
  421. context: EventContext
  422. context = yield defer.ensureDeferred(
  423. self.state.compute_event_context(
  424. event,
  425. state_ids_before_event={
  426. (e.type, e.state_key): e.event_id for e in old_state
  427. },
  428. partial_state=False,
  429. )
  430. )
  431. prev_state_ids: StateMap[str]
  432. prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
  433. self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
  434. current_state_ids: StateMap[str]
  435. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  436. self.assertCountEqual(
  437. (e.event_id for e in old_state), current_state_ids.values()
  438. )
  439. self.assertIsNotNone(context.state_group_before_event)
  440. self.assertEqual(context.state_group_before_event, context.state_group)
  441. @defer.inlineCallbacks
  442. def test_annotate_with_old_state(
  443. self,
  444. ) -> Generator["defer.Deferred[object]", Any, None]:
  445. event = create_event(type="state", state_key="", name="event")
  446. old_state = [
  447. create_event(type="test1", state_key="1"),
  448. create_event(type="test1", state_key="2"),
  449. create_event(type="test2", state_key=""),
  450. ]
  451. context: EventContext
  452. context = yield defer.ensureDeferred(
  453. self.state.compute_event_context(
  454. event,
  455. state_ids_before_event={
  456. (e.type, e.state_key): e.event_id for e in old_state
  457. },
  458. partial_state=False,
  459. )
  460. )
  461. prev_state_ids: StateMap[str]
  462. prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
  463. self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
  464. current_state_ids: StateMap[str]
  465. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  466. self.assertCountEqual(
  467. (e.event_id for e in old_state + [event]), current_state_ids.values()
  468. )
  469. assert context.state_group_before_event is not None
  470. assert context.state_group is not None
  471. self.assertEqual(
  472. context.state_group_deltas.get(
  473. (context.state_group_before_event, context.state_group)
  474. ),
  475. {(event.type, event.state_key): event.event_id},
  476. )
  477. self.assertNotEqual(context.state_group_before_event, context.state_group)
  478. @defer.inlineCallbacks
  479. def test_trivial_annotate_message(
  480. self,
  481. ) -> Generator["defer.Deferred[object]", Any, None]:
  482. prev_event_id = "prev_event_id"
  483. event = create_event(
  484. type="test_message", name="event2", prev_events=[(prev_event_id, {})]
  485. )
  486. old_state = [
  487. create_event(type="test1", state_key="1"),
  488. create_event(type="test1", state_key="2"),
  489. create_event(type="test2", state_key=""),
  490. ]
  491. group_name = yield defer.ensureDeferred(
  492. self.dummy_store.store_state_group(
  493. prev_event_id,
  494. event.room_id,
  495. None,
  496. None,
  497. {(e.type, e.state_key): e.event_id for e in old_state},
  498. )
  499. )
  500. self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
  501. context: EventContext
  502. context = yield defer.ensureDeferred(self.state.compute_event_context(event))
  503. current_state_ids: StateMap[str]
  504. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  505. self.assertEqual(
  506. {e.event_id for e in old_state}, set(current_state_ids.values())
  507. )
  508. self.assertEqual(group_name, context.state_group)
  509. @defer.inlineCallbacks
  510. def test_trivial_annotate_state(
  511. self,
  512. ) -> Generator["defer.Deferred[object]", Any, None]:
  513. prev_event_id = "prev_event_id"
  514. event = create_event(
  515. type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
  516. )
  517. old_state = [
  518. create_event(type="test1", state_key="1"),
  519. create_event(type="test1", state_key="2"),
  520. create_event(type="test2", state_key=""),
  521. ]
  522. group_name = yield defer.ensureDeferred(
  523. self.dummy_store.store_state_group(
  524. prev_event_id,
  525. event.room_id,
  526. None,
  527. None,
  528. {(e.type, e.state_key): e.event_id for e in old_state},
  529. )
  530. )
  531. self.dummy_store.register_event_id_state_group(prev_event_id, group_name)
  532. context: EventContext
  533. context = yield defer.ensureDeferred(self.state.compute_event_context(event))
  534. prev_state_ids: StateMap[str]
  535. prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
  536. self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
  537. self.assertIsNotNone(context.state_group)
  538. @defer.inlineCallbacks
  539. def test_resolve_message_conflict(
  540. self,
  541. ) -> Generator["defer.Deferred[Any]", Any, None]:
  542. prev_event_id1 = "event_id1"
  543. prev_event_id2 = "event_id2"
  544. event = create_event(
  545. type="test_message",
  546. name="event3",
  547. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  548. )
  549. creation = create_event(type=EventTypes.Create, state_key="")
  550. old_state_1 = [
  551. creation,
  552. create_event(type="test1", state_key="1"),
  553. create_event(type="test1", state_key="2"),
  554. create_event(type="test2", state_key=""),
  555. ]
  556. old_state_2 = [
  557. creation,
  558. create_event(type="test1", state_key="1"),
  559. create_event(type="test3", state_key="2"),
  560. create_event(type="test4", state_key=""),
  561. ]
  562. self.dummy_store.register_events(old_state_1)
  563. self.dummy_store.register_events(old_state_2)
  564. context: EventContext
  565. context = yield self._get_context(
  566. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  567. )
  568. current_state_ids: StateMap[str]
  569. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  570. self.assertEqual(len(current_state_ids), 6)
  571. self.assertIsNotNone(context.state_group)
  572. @defer.inlineCallbacks
  573. def test_resolve_state_conflict(
  574. self,
  575. ) -> Generator["defer.Deferred[Any]", Any, None]:
  576. prev_event_id1 = "event_id1"
  577. prev_event_id2 = "event_id2"
  578. event = create_event(
  579. type="test4",
  580. state_key="",
  581. name="event",
  582. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  583. )
  584. creation = create_event(type=EventTypes.Create, state_key="")
  585. old_state_1 = [
  586. creation,
  587. create_event(type="test1", state_key="1"),
  588. create_event(type="test1", state_key="2"),
  589. create_event(type="test2", state_key=""),
  590. ]
  591. old_state_2 = [
  592. creation,
  593. create_event(type="test1", state_key="1"),
  594. create_event(type="test3", state_key="2"),
  595. create_event(type="test4", state_key=""),
  596. ]
  597. store = _DummyStore()
  598. store.register_events(old_state_1)
  599. store.register_events(old_state_2)
  600. self.dummy_store.get_events = store.get_events # type: ignore[assignment]
  601. context: EventContext
  602. context = yield self._get_context(
  603. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  604. )
  605. current_state_ids: StateMap[str]
  606. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  607. self.assertEqual(len(current_state_ids), 6)
  608. self.assertIsNotNone(context.state_group)
  609. @defer.inlineCallbacks
  610. def test_standard_depth_conflict(
  611. self,
  612. ) -> Generator["defer.Deferred[Any]", Any, None]:
  613. prev_event_id1 = "event_id1"
  614. prev_event_id2 = "event_id2"
  615. event = create_event(
  616. type="test4",
  617. name="event",
  618. prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
  619. )
  620. member_event = create_event(
  621. type=EventTypes.Member,
  622. state_key="@user_id:example.com",
  623. content={"membership": Membership.JOIN},
  624. )
  625. power_levels = create_event(
  626. type=EventTypes.PowerLevels,
  627. state_key="",
  628. content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
  629. )
  630. creation = create_event(
  631. type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
  632. )
  633. old_state_1 = [
  634. creation,
  635. power_levels,
  636. member_event,
  637. create_event(type="test1", state_key="1", depth=1),
  638. ]
  639. old_state_2 = [
  640. creation,
  641. power_levels,
  642. member_event,
  643. create_event(type="test1", state_key="1", depth=2),
  644. ]
  645. store = _DummyStore()
  646. store.register_events(old_state_1)
  647. store.register_events(old_state_2)
  648. self.dummy_store.get_events = store.get_events # type: ignore[assignment]
  649. context: EventContext
  650. context = yield self._get_context(
  651. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  652. )
  653. current_state_ids: StateMap[str]
  654. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  655. self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
  656. # Reverse the depth to make sure we are actually using the depths
  657. # during state resolution.
  658. old_state_1 = [
  659. creation,
  660. power_levels,
  661. member_event,
  662. create_event(type="test1", state_key="1", depth=2),
  663. ]
  664. old_state_2 = [
  665. creation,
  666. power_levels,
  667. member_event,
  668. create_event(type="test1", state_key="1", depth=1),
  669. ]
  670. store.register_events(old_state_1)
  671. store.register_events(old_state_2)
  672. context = yield self._get_context(
  673. event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
  674. )
  675. current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
  676. self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
  677. @defer.inlineCallbacks
  678. def _get_context(
  679. self,
  680. event: EventBase,
  681. prev_event_id_1: str,
  682. old_state_1: Collection[EventBase],
  683. prev_event_id_2: str,
  684. old_state_2: Collection[EventBase],
  685. ) -> Generator["defer.Deferred[object]", Any, EventContext]:
  686. sg1: int
  687. sg1 = yield defer.ensureDeferred(
  688. self.dummy_store.store_state_group(
  689. prev_event_id_1,
  690. event.room_id,
  691. None,
  692. None,
  693. {(e.type, e.state_key): e.event_id for e in old_state_1},
  694. )
  695. )
  696. self.dummy_store.register_event_id_state_group(prev_event_id_1, sg1)
  697. sg2: int
  698. sg2 = yield defer.ensureDeferred(
  699. self.dummy_store.store_state_group(
  700. prev_event_id_2,
  701. event.room_id,
  702. None,
  703. None,
  704. {(e.type, e.state_key): e.event_id for e in old_state_2},
  705. )
  706. )
  707. self.dummy_store.register_event_id_state_group(prev_event_id_2, sg2)
  708. result = yield defer.ensureDeferred(self.state.compute_event_context(event))
  709. return result
  710. def test_make_state_cache_entry(self) -> None:
  711. "Test that calculating a prev_group and delta is correct"
  712. new_state = {
  713. ("a", ""): "E",
  714. ("b", ""): "E",
  715. ("c", ""): "E",
  716. ("d", ""): "E",
  717. }
  718. # old_state_1 has fewer differences to new_state than old_state_2, but
  719. # the delta involves deleting a key, which isn't allowed in the deltas,
  720. # so we should pick old_state_2 as the prev_group.
  721. # `old_state_1` has two differences: `a` and `e`
  722. old_state_1 = {
  723. ("a", ""): "F",
  724. ("b", ""): "E",
  725. ("c", ""): "E",
  726. ("d", ""): "E",
  727. ("e", ""): "E",
  728. }
  729. # `old_state_2` has three differences: `a`, `c` and `d`
  730. old_state_2 = {
  731. ("a", ""): "F",
  732. ("b", ""): "E",
  733. ("c", ""): "F",
  734. ("d", ""): "F",
  735. }
  736. entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2})
  737. self.assertEqual(entry.prev_group, 2)
  738. # There are three changes from `old_state_2` to `new_state`
  739. self.assertEqual(
  740. entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"}
  741. )