test_state.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from frozendict import frozendict
  16. from twisted.test.proto_helpers import MemoryReactor
  17. from synapse.api.constants import EventTypes, Membership
  18. from synapse.api.room_versions import RoomVersions
  19. from synapse.events import EventBase
  20. from synapse.server import HomeServer
  21. from synapse.types import JsonDict, RoomID, StateMap, UserID
  22. from synapse.types.state import StateFilter
  23. from synapse.util import Clock
  24. from tests.unittest import HomeserverTestCase
  25. logger = logging.getLogger(__name__)
  26. class StateStoreTestCase(HomeserverTestCase):
  27. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  28. self.store = hs.get_datastores().main
  29. self.storage = hs.get_storage_controllers()
  30. self.state_datastore = self.storage.state.stores.state
  31. self.event_builder_factory = hs.get_event_builder_factory()
  32. self.event_creation_handler = hs.get_event_creation_handler()
  33. self.u_alice = UserID.from_string("@alice:test")
  34. self.u_bob = UserID.from_string("@bob:test")
  35. self.room = RoomID.from_string("!abc123:test")
  36. self.get_success(
  37. self.store.store_room(
  38. self.room.to_string(),
  39. room_creator_user_id="@creator:text",
  40. is_public=True,
  41. room_version=RoomVersions.V1,
  42. )
  43. )
  44. def inject_state_event(
  45. self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
  46. ) -> EventBase:
  47. builder = self.event_builder_factory.for_room_version(
  48. RoomVersions.V1,
  49. {
  50. "type": typ,
  51. "sender": sender.to_string(),
  52. "state_key": state_key,
  53. "room_id": room.to_string(),
  54. "content": content,
  55. },
  56. )
  57. event, unpersisted_context = self.get_success(
  58. self.event_creation_handler.create_new_client_event(builder)
  59. )
  60. context = self.get_success(unpersisted_context.persist(event))
  61. assert self.storage.persistence is not None
  62. self.get_success(self.storage.persistence.persist_event(event, context))
  63. return event
  64. def assertStateMapEqual(
  65. self, s1: StateMap[EventBase], s2: StateMap[EventBase]
  66. ) -> None:
  67. for t in s1:
  68. # just compare event IDs for simplicity
  69. self.assertEqual(s1[t].event_id, s2[t].event_id)
  70. self.assertEqual(len(s1), len(s2))
  71. def test_get_state_groups_ids(self) -> None:
  72. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  73. e2 = self.inject_state_event(
  74. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  75. )
  76. state_group_map = self.get_success(
  77. self.storage.state.get_state_groups_ids(
  78. self.room.to_string(), [e2.event_id]
  79. )
  80. )
  81. self.assertEqual(len(state_group_map), 1)
  82. state_map = list(state_group_map.values())[0]
  83. self.assertDictEqual(
  84. state_map,
  85. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  86. )
  87. def test_get_state_groups(self) -> None:
  88. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  89. e2 = self.inject_state_event(
  90. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  91. )
  92. state_group_map = self.get_success(
  93. self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
  94. )
  95. self.assertEqual(len(state_group_map), 1)
  96. state_list = list(state_group_map.values())[0]
  97. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  98. def test_get_state_for_event(self) -> None:
  99. # this defaults to a linear DAG as each new injection defaults to whatever
  100. # forward extremities are currently in the DB for this room.
  101. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  102. e2 = self.inject_state_event(
  103. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  104. )
  105. e3 = self.inject_state_event(
  106. self.room,
  107. self.u_alice,
  108. EventTypes.Member,
  109. self.u_alice.to_string(),
  110. {"membership": Membership.JOIN},
  111. )
  112. e4 = self.inject_state_event(
  113. self.room,
  114. self.u_bob,
  115. EventTypes.Member,
  116. self.u_bob.to_string(),
  117. {"membership": Membership.JOIN},
  118. )
  119. e5 = self.inject_state_event(
  120. self.room,
  121. self.u_bob,
  122. EventTypes.Member,
  123. self.u_bob.to_string(),
  124. {"membership": Membership.LEAVE},
  125. )
  126. # check we get the full state as of the final event
  127. state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
  128. self.assertIsNotNone(e4)
  129. self.assertStateMapEqual(
  130. {
  131. (e1.type, e1.state_key): e1,
  132. (e2.type, e2.state_key): e2,
  133. (e3.type, e3.state_key): e3,
  134. # e4 is overwritten by e5
  135. (e5.type, e5.state_key): e5,
  136. },
  137. state,
  138. )
  139. # check we can filter to the m.room.name event (with a '' state key)
  140. state = self.get_success(
  141. self.storage.state.get_state_for_event(
  142. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  143. )
  144. )
  145. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  146. # check we can filter to the m.room.name event (with a wildcard None state key)
  147. state = self.get_success(
  148. self.storage.state.get_state_for_event(
  149. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  150. )
  151. )
  152. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  153. # check we can grab the m.room.member events (with a wildcard None state key)
  154. state = self.get_success(
  155. self.storage.state.get_state_for_event(
  156. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  157. )
  158. )
  159. self.assertStateMapEqual(
  160. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  161. )
  162. # check we can grab a specific room member without filtering out the
  163. # other event types
  164. state = self.get_success(
  165. self.storage.state.get_state_for_event(
  166. e5.event_id,
  167. state_filter=StateFilter(
  168. types=frozendict(
  169. {EventTypes.Member: frozenset({self.u_alice.to_string()})}
  170. ),
  171. include_others=True,
  172. ),
  173. )
  174. )
  175. self.assertStateMapEqual(
  176. {
  177. (e1.type, e1.state_key): e1,
  178. (e2.type, e2.state_key): e2,
  179. (e3.type, e3.state_key): e3,
  180. },
  181. state,
  182. )
  183. # check that we can grab everything except members
  184. state = self.get_success(
  185. self.storage.state.get_state_for_event(
  186. e5.event_id,
  187. state_filter=StateFilter(
  188. types=frozendict({EventTypes.Member: frozenset()}),
  189. include_others=True,
  190. ),
  191. )
  192. )
  193. self.assertStateMapEqual(
  194. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  195. )
  196. #######################################################
  197. # _get_state_for_group_using_cache tests against a full cache
  198. #######################################################
  199. room_id = self.room.to_string()
  200. group_ids = self.get_success(
  201. self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
  202. )
  203. group = list(group_ids.keys())[0]
  204. # test _get_state_for_group_using_cache correctly filters out members
  205. # with types=[]
  206. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  207. self.state_datastore._state_group_cache,
  208. group,
  209. state_filter=StateFilter(
  210. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  211. ),
  212. )
  213. self.assertEqual(is_all, True)
  214. self.assertDictEqual(
  215. {
  216. (e1.type, e1.state_key): e1.event_id,
  217. (e2.type, e2.state_key): e2.event_id,
  218. },
  219. state_dict,
  220. )
  221. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  222. self.state_datastore._state_group_members_cache,
  223. group,
  224. state_filter=StateFilter(
  225. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  226. ),
  227. )
  228. self.assertEqual(is_all, True)
  229. self.assertDictEqual({}, state_dict)
  230. # test _get_state_for_group_using_cache correctly filters in members
  231. # with wildcard types
  232. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  233. self.state_datastore._state_group_cache,
  234. group,
  235. state_filter=StateFilter(
  236. types=frozendict({EventTypes.Member: None}), include_others=True
  237. ),
  238. )
  239. self.assertEqual(is_all, True)
  240. self.assertDictEqual(
  241. {
  242. (e1.type, e1.state_key): e1.event_id,
  243. (e2.type, e2.state_key): e2.event_id,
  244. },
  245. state_dict,
  246. )
  247. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  248. self.state_datastore._state_group_members_cache,
  249. group,
  250. state_filter=StateFilter(
  251. types=frozendict({EventTypes.Member: None}), include_others=True
  252. ),
  253. )
  254. self.assertEqual(is_all, True)
  255. self.assertDictEqual(
  256. {
  257. (e3.type, e3.state_key): e3.event_id,
  258. # e4 is overwritten by e5
  259. (e5.type, e5.state_key): e5.event_id,
  260. },
  261. state_dict,
  262. )
  263. # test _get_state_for_group_using_cache correctly filters in members
  264. # with specific types
  265. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  266. self.state_datastore._state_group_cache,
  267. group,
  268. state_filter=StateFilter(
  269. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  270. include_others=True,
  271. ),
  272. )
  273. self.assertEqual(is_all, True)
  274. self.assertDictEqual(
  275. {
  276. (e1.type, e1.state_key): e1.event_id,
  277. (e2.type, e2.state_key): e2.event_id,
  278. },
  279. state_dict,
  280. )
  281. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  282. self.state_datastore._state_group_members_cache,
  283. group,
  284. state_filter=StateFilter(
  285. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  286. include_others=True,
  287. ),
  288. )
  289. self.assertEqual(is_all, True)
  290. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  291. # test _get_state_for_group_using_cache correctly filters in members
  292. # with specific types
  293. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  294. self.state_datastore._state_group_members_cache,
  295. group,
  296. state_filter=StateFilter(
  297. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  298. include_others=False,
  299. ),
  300. )
  301. self.assertEqual(is_all, True)
  302. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  303. #######################################################
  304. # deliberately remove e2 (room name) from the _state_group_cache
  305. cache_entry = self.state_datastore._state_group_cache.get(group)
  306. state_dict_ids = cache_entry.value
  307. self.assertEqual(cache_entry.full, True)
  308. self.assertEqual(cache_entry.known_absent, set())
  309. self.assertDictEqual(
  310. state_dict_ids,
  311. {
  312. (e1.type, e1.state_key): e1.event_id,
  313. (e2.type, e2.state_key): e2.event_id,
  314. },
  315. )
  316. state_dict_ids.pop((e2.type, e2.state_key))
  317. self.state_datastore._state_group_cache.invalidate(group)
  318. self.state_datastore._state_group_cache.update(
  319. sequence=self.state_datastore._state_group_cache.sequence,
  320. key=group,
  321. value=state_dict_ids,
  322. # list fetched keys so it knows it's partial
  323. fetched_keys=((e1.type, e1.state_key),),
  324. )
  325. cache_entry = self.state_datastore._state_group_cache.get(group)
  326. state_dict_ids = cache_entry.value
  327. self.assertEqual(cache_entry.full, False)
  328. self.assertEqual(cache_entry.known_absent, set())
  329. self.assertDictEqual(state_dict_ids, {})
  330. ############################################
  331. # test that things work with a partial cache
  332. # test _get_state_for_group_using_cache correctly filters out members
  333. # with types=[]
  334. room_id = self.room.to_string()
  335. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  336. self.state_datastore._state_group_cache,
  337. group,
  338. state_filter=StateFilter(
  339. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  340. ),
  341. )
  342. self.assertEqual(is_all, False)
  343. self.assertDictEqual({}, state_dict)
  344. room_id = self.room.to_string()
  345. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  346. self.state_datastore._state_group_members_cache,
  347. group,
  348. state_filter=StateFilter(
  349. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  350. ),
  351. )
  352. self.assertEqual(is_all, True)
  353. self.assertDictEqual({}, state_dict)
  354. # test _get_state_for_group_using_cache correctly filters in members
  355. # wildcard types
  356. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  357. self.state_datastore._state_group_cache,
  358. group,
  359. state_filter=StateFilter(
  360. types=frozendict({EventTypes.Member: None}), include_others=True
  361. ),
  362. )
  363. self.assertEqual(is_all, False)
  364. self.assertDictEqual({}, state_dict)
  365. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  366. self.state_datastore._state_group_members_cache,
  367. group,
  368. state_filter=StateFilter(
  369. types=frozendict({EventTypes.Member: None}), include_others=True
  370. ),
  371. )
  372. self.assertEqual(is_all, True)
  373. self.assertDictEqual(
  374. {
  375. (e3.type, e3.state_key): e3.event_id,
  376. (e5.type, e5.state_key): e5.event_id,
  377. },
  378. state_dict,
  379. )
  380. # test _get_state_for_group_using_cache correctly filters in members
  381. # with specific types
  382. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  383. self.state_datastore._state_group_cache,
  384. group,
  385. state_filter=StateFilter(
  386. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  387. include_others=True,
  388. ),
  389. )
  390. self.assertEqual(is_all, False)
  391. self.assertDictEqual({}, state_dict)
  392. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  393. self.state_datastore._state_group_members_cache,
  394. group,
  395. state_filter=StateFilter(
  396. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  397. include_others=True,
  398. ),
  399. )
  400. self.assertEqual(is_all, True)
  401. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  402. # test _get_state_for_group_using_cache correctly filters in members
  403. # with specific types
  404. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  405. self.state_datastore._state_group_cache,
  406. group,
  407. state_filter=StateFilter(
  408. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  409. include_others=False,
  410. ),
  411. )
  412. self.assertEqual(is_all, False)
  413. self.assertDictEqual({}, state_dict)
  414. state_dict, is_all = self.state_datastore._get_state_for_group_using_cache(
  415. self.state_datastore._state_group_members_cache,
  416. group,
  417. state_filter=StateFilter(
  418. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  419. include_others=False,
  420. ),
  421. )
  422. self.assertEqual(is_all, True)
  423. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  424. def test_batched_state_group_storing(self) -> None:
  425. creation_event = self.inject_state_event(
  426. self.room, self.u_alice, EventTypes.Create, "", {}
  427. )
  428. state_to_event = self.get_success(
  429. self.storage.state.get_state_groups(
  430. self.room.to_string(), [creation_event.event_id]
  431. )
  432. )
  433. current_state_group = list(state_to_event.keys())[0]
  434. # create some unpersisted events and event contexts to store against room
  435. events_and_context = []
  436. builder = self.event_builder_factory.for_room_version(
  437. RoomVersions.V1,
  438. {
  439. "type": EventTypes.Name,
  440. "sender": self.u_alice.to_string(),
  441. "state_key": "",
  442. "room_id": self.room.to_string(),
  443. "content": {"name": "first rename of room"},
  444. },
  445. )
  446. event1, unpersisted_context1 = self.get_success(
  447. self.event_creation_handler.create_new_client_event(builder)
  448. )
  449. events_and_context.append((event1, unpersisted_context1))
  450. builder2 = self.event_builder_factory.for_room_version(
  451. RoomVersions.V1,
  452. {
  453. "type": EventTypes.JoinRules,
  454. "sender": self.u_alice.to_string(),
  455. "state_key": "",
  456. "room_id": self.room.to_string(),
  457. "content": {"join_rule": "private"},
  458. },
  459. )
  460. event2, unpersisted_context2 = self.get_success(
  461. self.event_creation_handler.create_new_client_event(builder2)
  462. )
  463. events_and_context.append((event2, unpersisted_context2))
  464. builder3 = self.event_builder_factory.for_room_version(
  465. RoomVersions.V1,
  466. {
  467. "type": EventTypes.Message,
  468. "sender": self.u_alice.to_string(),
  469. "room_id": self.room.to_string(),
  470. "content": {"body": "hello from event 3", "msgtype": "m.text"},
  471. },
  472. )
  473. event3, unpersisted_context3 = self.get_success(
  474. self.event_creation_handler.create_new_client_event(builder3)
  475. )
  476. events_and_context.append((event3, unpersisted_context3))
  477. builder4 = self.event_builder_factory.for_room_version(
  478. RoomVersions.V1,
  479. {
  480. "type": EventTypes.JoinRules,
  481. "sender": self.u_alice.to_string(),
  482. "state_key": "",
  483. "room_id": self.room.to_string(),
  484. "content": {"join_rule": "public"},
  485. },
  486. )
  487. event4, unpersisted_context4 = self.get_success(
  488. self.event_creation_handler.create_new_client_event(builder4)
  489. )
  490. events_and_context.append((event4, unpersisted_context4))
  491. processed_events_and_context = self.get_success(
  492. self.hs.get_datastores().state.store_state_deltas_for_batched(
  493. events_and_context, self.room.to_string(), current_state_group
  494. )
  495. )
  496. # check that only state events are in state_groups, and all state events are in state_groups
  497. res = self.get_success(
  498. self.store.db_pool.simple_select_list(
  499. table="state_groups",
  500. keyvalues=None,
  501. retcols=("event_id",),
  502. )
  503. )
  504. events = []
  505. for result in res:
  506. self.assertNotIn(event3.event_id, result)
  507. events.append(result.get("event_id"))
  508. for event, _ in processed_events_and_context:
  509. if event.is_state():
  510. self.assertIn(event.event_id, events)
  511. # check that each unique state has state group in state_groups_state and that the
  512. # type/state key is correct, and check that each state event's state group
  513. # has an entry and prev event in state_group_edges
  514. for event, context in processed_events_and_context:
  515. if event.is_state():
  516. state = self.get_success(
  517. self.store.db_pool.simple_select_list(
  518. table="state_groups_state",
  519. keyvalues={"state_group": context.state_group_after_event},
  520. retcols=("type", "state_key"),
  521. )
  522. )
  523. self.assertEqual(event.type, state[0].get("type"))
  524. self.assertEqual(event.state_key, state[0].get("state_key"))
  525. groups = self.get_success(
  526. self.store.db_pool.simple_select_list(
  527. table="state_group_edges",
  528. keyvalues={"state_group": str(context.state_group_after_event)},
  529. retcols=("*",),
  530. )
  531. )
  532. self.assertEqual(
  533. context.state_group_before_event, groups[0].get("prev_state_group")
  534. )