test_state.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from frozendict import frozendict
  16. from twisted.test.proto_helpers import MemoryReactor
  17. from synapse.api.constants import EventTypes, Membership
  18. from synapse.api.room_versions import RoomVersions
  19. from synapse.events import EventBase
  20. from synapse.server import HomeServer
  21. from synapse.types import JsonDict, RoomID, StateMap, UserID
  22. from synapse.types.state import StateFilter
  23. from synapse.util import Clock
  24. from tests.unittest import HomeserverTestCase
  25. logger = logging.getLogger(__name__)
  26. class StateStoreTestCase(HomeserverTestCase):
  27. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  28. self.store = hs.get_datastores().main
  29. self.storage = hs.get_storage_controllers()
  30. self.state_datastore = self.storage.state.stores.state
  31. self.event_builder_factory = hs.get_event_builder_factory()
  32. self.event_creation_handler = hs.get_event_creation_handler()
  33. self.u_alice = UserID.from_string("@alice:test")
  34. self.u_bob = UserID.from_string("@bob:test")
  35. self.room = RoomID.from_string("!abc123:test")
  36. self.get_success(
  37. self.store.store_room(
  38. self.room.to_string(),
  39. room_creator_user_id="@creator:text",
  40. is_public=True,
  41. room_version=RoomVersions.V1,
  42. )
  43. )
  44. def inject_state_event(
  45. self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict
  46. ) -> EventBase:
  47. builder = self.event_builder_factory.for_room_version(
  48. RoomVersions.V1,
  49. {
  50. "type": typ,
  51. "sender": sender.to_string(),
  52. "state_key": state_key,
  53. "room_id": room.to_string(),
  54. "content": content,
  55. },
  56. )
  57. event, context = self.get_success(
  58. self.event_creation_handler.create_new_client_event(builder)
  59. )
  60. assert self.storage.persistence is not None
  61. self.get_success(self.storage.persistence.persist_event(event, context))
  62. return event
  63. def assertStateMapEqual(
  64. self, s1: StateMap[EventBase], s2: StateMap[EventBase]
  65. ) -> None:
  66. for t in s1:
  67. # just compare event IDs for simplicity
  68. self.assertEqual(s1[t].event_id, s2[t].event_id)
  69. self.assertEqual(len(s1), len(s2))
  70. def test_get_state_groups_ids(self) -> None:
  71. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  72. e2 = self.inject_state_event(
  73. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  74. )
  75. state_group_map = self.get_success(
  76. self.storage.state.get_state_groups_ids(
  77. self.room.to_string(), [e2.event_id]
  78. )
  79. )
  80. self.assertEqual(len(state_group_map), 1)
  81. state_map = list(state_group_map.values())[0]
  82. self.assertDictEqual(
  83. state_map,
  84. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  85. )
  86. def test_get_state_groups(self) -> None:
  87. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  88. e2 = self.inject_state_event(
  89. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  90. )
  91. state_group_map = self.get_success(
  92. self.storage.state.get_state_groups(self.room.to_string(), [e2.event_id])
  93. )
  94. self.assertEqual(len(state_group_map), 1)
  95. state_list = list(state_group_map.values())[0]
  96. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  97. def test_get_state_for_event(self) -> None:
  98. # this defaults to a linear DAG as each new injection defaults to whatever
  99. # forward extremities are currently in the DB for this room.
  100. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  101. e2 = self.inject_state_event(
  102. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  103. )
  104. e3 = self.inject_state_event(
  105. self.room,
  106. self.u_alice,
  107. EventTypes.Member,
  108. self.u_alice.to_string(),
  109. {"membership": Membership.JOIN},
  110. )
  111. e4 = self.inject_state_event(
  112. self.room,
  113. self.u_bob,
  114. EventTypes.Member,
  115. self.u_bob.to_string(),
  116. {"membership": Membership.JOIN},
  117. )
  118. e5 = self.inject_state_event(
  119. self.room,
  120. self.u_bob,
  121. EventTypes.Member,
  122. self.u_bob.to_string(),
  123. {"membership": Membership.LEAVE},
  124. )
  125. # check we get the full state as of the final event
  126. state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
  127. self.assertIsNotNone(e4)
  128. self.assertStateMapEqual(
  129. {
  130. (e1.type, e1.state_key): e1,
  131. (e2.type, e2.state_key): e2,
  132. (e3.type, e3.state_key): e3,
  133. # e4 is overwritten by e5
  134. (e5.type, e5.state_key): e5,
  135. },
  136. state,
  137. )
  138. # check we can filter to the m.room.name event (with a '' state key)
  139. state = self.get_success(
  140. self.storage.state.get_state_for_event(
  141. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  142. )
  143. )
  144. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  145. # check we can filter to the m.room.name event (with a wildcard None state key)
  146. state = self.get_success(
  147. self.storage.state.get_state_for_event(
  148. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  149. )
  150. )
  151. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  152. # check we can grab the m.room.member events (with a wildcard None state key)
  153. state = self.get_success(
  154. self.storage.state.get_state_for_event(
  155. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  156. )
  157. )
  158. self.assertStateMapEqual(
  159. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  160. )
  161. # check we can grab a specific room member without filtering out the
  162. # other event types
  163. state = self.get_success(
  164. self.storage.state.get_state_for_event(
  165. e5.event_id,
  166. state_filter=StateFilter(
  167. types=frozendict(
  168. {EventTypes.Member: frozenset({self.u_alice.to_string()})}
  169. ),
  170. include_others=True,
  171. ),
  172. )
  173. )
  174. self.assertStateMapEqual(
  175. {
  176. (e1.type, e1.state_key): e1,
  177. (e2.type, e2.state_key): e2,
  178. (e3.type, e3.state_key): e3,
  179. },
  180. state,
  181. )
  182. # check that we can grab everything except members
  183. state = self.get_success(
  184. self.storage.state.get_state_for_event(
  185. e5.event_id,
  186. state_filter=StateFilter(
  187. types=frozendict({EventTypes.Member: frozenset()}),
  188. include_others=True,
  189. ),
  190. )
  191. )
  192. self.assertStateMapEqual(
  193. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  194. )
  195. #######################################################
  196. # _get_state_for_group_using_cache tests against a full cache
  197. #######################################################
  198. room_id = self.room.to_string()
  199. group_ids = self.get_success(
  200. self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
  201. )
  202. group = list(group_ids.keys())[0]
  203. # test _get_state_for_group_using_cache correctly filters out members
  204. # with types=[]
  205. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  206. self.state_datastore._state_group_cache,
  207. group,
  208. state_filter=StateFilter(
  209. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  210. ),
  211. )
  212. self.assertEqual(is_all, True)
  213. self.assertDictEqual(
  214. {
  215. (e1.type, e1.state_key): e1.event_id,
  216. (e2.type, e2.state_key): e2.event_id,
  217. },
  218. state_dict,
  219. )
  220. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  221. self.state_datastore._state_group_members_cache,
  222. group,
  223. state_filter=StateFilter(
  224. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  225. ),
  226. )
  227. self.assertEqual(is_all, True)
  228. self.assertDictEqual({}, state_dict)
  229. # test _get_state_for_group_using_cache correctly filters in members
  230. # with wildcard types
  231. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  232. self.state_datastore._state_group_cache,
  233. group,
  234. state_filter=StateFilter(
  235. types=frozendict({EventTypes.Member: None}), include_others=True
  236. ),
  237. )
  238. self.assertEqual(is_all, True)
  239. self.assertDictEqual(
  240. {
  241. (e1.type, e1.state_key): e1.event_id,
  242. (e2.type, e2.state_key): e2.event_id,
  243. },
  244. state_dict,
  245. )
  246. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  247. self.state_datastore._state_group_members_cache,
  248. group,
  249. state_filter=StateFilter(
  250. types=frozendict({EventTypes.Member: None}), include_others=True
  251. ),
  252. )
  253. self.assertEqual(is_all, True)
  254. self.assertDictEqual(
  255. {
  256. (e3.type, e3.state_key): e3.event_id,
  257. # e4 is overwritten by e5
  258. (e5.type, e5.state_key): e5.event_id,
  259. },
  260. state_dict,
  261. )
  262. # test _get_state_for_group_using_cache correctly filters in members
  263. # with specific types
  264. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  265. self.state_datastore._state_group_cache,
  266. group,
  267. state_filter=StateFilter(
  268. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  269. include_others=True,
  270. ),
  271. )
  272. self.assertEqual(is_all, True)
  273. self.assertDictEqual(
  274. {
  275. (e1.type, e1.state_key): e1.event_id,
  276. (e2.type, e2.state_key): e2.event_id,
  277. },
  278. state_dict,
  279. )
  280. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  281. self.state_datastore._state_group_members_cache,
  282. group,
  283. state_filter=StateFilter(
  284. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  285. include_others=True,
  286. ),
  287. )
  288. self.assertEqual(is_all, True)
  289. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  290. # test _get_state_for_group_using_cache correctly filters in members
  291. # with specific types
  292. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  293. self.state_datastore._state_group_members_cache,
  294. group,
  295. state_filter=StateFilter(
  296. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  297. include_others=False,
  298. ),
  299. )
  300. self.assertEqual(is_all, True)
  301. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  302. #######################################################
  303. # deliberately remove e2 (room name) from the _state_group_cache
  304. cache_entry = self.state_datastore._state_group_cache.get(group)
  305. state_dict_ids = cache_entry.value
  306. self.assertEqual(cache_entry.full, True)
  307. self.assertEqual(cache_entry.known_absent, set())
  308. self.assertDictEqual(
  309. state_dict_ids,
  310. {
  311. (e1.type, e1.state_key): e1.event_id,
  312. (e2.type, e2.state_key): e2.event_id,
  313. },
  314. )
  315. state_dict_ids.pop((e2.type, e2.state_key))
  316. self.state_datastore._state_group_cache.invalidate(group)
  317. self.state_datastore._state_group_cache.update(
  318. sequence=self.state_datastore._state_group_cache.sequence,
  319. key=group,
  320. value=state_dict_ids,
  321. # list fetched keys so it knows it's partial
  322. fetched_keys=((e1.type, e1.state_key),),
  323. )
  324. cache_entry = self.state_datastore._state_group_cache.get(group)
  325. state_dict_ids = cache_entry.value
  326. self.assertEqual(cache_entry.full, False)
  327. self.assertEqual(cache_entry.known_absent, set())
  328. self.assertDictEqual(state_dict_ids, {})
  329. ############################################
  330. # test that things work with a partial cache
  331. # test _get_state_for_group_using_cache correctly filters out members
  332. # with types=[]
  333. room_id = self.room.to_string()
  334. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  335. self.state_datastore._state_group_cache,
  336. group,
  337. state_filter=StateFilter(
  338. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  339. ),
  340. )
  341. self.assertEqual(is_all, False)
  342. self.assertDictEqual({}, state_dict)
  343. room_id = self.room.to_string()
  344. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  345. self.state_datastore._state_group_members_cache,
  346. group,
  347. state_filter=StateFilter(
  348. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  349. ),
  350. )
  351. self.assertEqual(is_all, True)
  352. self.assertDictEqual({}, state_dict)
  353. # test _get_state_for_group_using_cache correctly filters in members
  354. # wildcard types
  355. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  356. self.state_datastore._state_group_cache,
  357. group,
  358. state_filter=StateFilter(
  359. types=frozendict({EventTypes.Member: None}), include_others=True
  360. ),
  361. )
  362. self.assertEqual(is_all, False)
  363. self.assertDictEqual({}, state_dict)
  364. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  365. self.state_datastore._state_group_members_cache,
  366. group,
  367. state_filter=StateFilter(
  368. types=frozendict({EventTypes.Member: None}), include_others=True
  369. ),
  370. )
  371. self.assertEqual(is_all, True)
  372. self.assertDictEqual(
  373. {
  374. (e3.type, e3.state_key): e3.event_id,
  375. (e5.type, e5.state_key): e5.event_id,
  376. },
  377. state_dict,
  378. )
  379. # test _get_state_for_group_using_cache correctly filters in members
  380. # with specific types
  381. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  382. self.state_datastore._state_group_cache,
  383. group,
  384. state_filter=StateFilter(
  385. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  386. include_others=True,
  387. ),
  388. )
  389. self.assertEqual(is_all, False)
  390. self.assertDictEqual({}, state_dict)
  391. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  392. self.state_datastore._state_group_members_cache,
  393. group,
  394. state_filter=StateFilter(
  395. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  396. include_others=True,
  397. ),
  398. )
  399. self.assertEqual(is_all, True)
  400. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  401. # test _get_state_for_group_using_cache correctly filters in members
  402. # with specific types
  403. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  404. self.state_datastore._state_group_cache,
  405. group,
  406. state_filter=StateFilter(
  407. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  408. include_others=False,
  409. ),
  410. )
  411. self.assertEqual(is_all, False)
  412. self.assertDictEqual({}, state_dict)
  413. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  414. self.state_datastore._state_group_members_cache,
  415. group,
  416. state_filter=StateFilter(
  417. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  418. include_others=False,
  419. ),
  420. )
  421. self.assertEqual(is_all, True)
  422. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)