test_state.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from frozendict import frozendict
  16. from synapse.api.constants import EventTypes, Membership
  17. from synapse.api.room_versions import RoomVersions
  18. from synapse.storage.state import StateFilter
  19. from synapse.types import RoomID, UserID
  20. from tests.unittest import HomeserverTestCase
  21. logger = logging.getLogger(__name__)
  22. class StateStoreTestCase(HomeserverTestCase):
  23. def prepare(self, reactor, clock, hs):
  24. self.store = hs.get_datastore()
  25. self.storage = hs.get_storage()
  26. self.state_datastore = self.storage.state.stores.state
  27. self.event_builder_factory = hs.get_event_builder_factory()
  28. self.event_creation_handler = hs.get_event_creation_handler()
  29. self.u_alice = UserID.from_string("@alice:test")
  30. self.u_bob = UserID.from_string("@bob:test")
  31. self.room = RoomID.from_string("!abc123:test")
  32. self.get_success(
  33. self.store.store_room(
  34. self.room.to_string(),
  35. room_creator_user_id="@creator:text",
  36. is_public=True,
  37. room_version=RoomVersions.V1,
  38. )
  39. )
  40. def inject_state_event(self, room, sender, typ, state_key, content):
  41. builder = self.event_builder_factory.for_room_version(
  42. RoomVersions.V1,
  43. {
  44. "type": typ,
  45. "sender": sender.to_string(),
  46. "state_key": state_key,
  47. "room_id": room.to_string(),
  48. "content": content,
  49. },
  50. )
  51. event, context = self.get_success(
  52. self.event_creation_handler.create_new_client_event(builder)
  53. )
  54. self.get_success(self.storage.persistence.persist_event(event, context))
  55. return event
  56. def assertStateMapEqual(self, s1, s2):
  57. for t in s1:
  58. # just compare event IDs for simplicity
  59. self.assertEqual(s1[t].event_id, s2[t].event_id)
  60. self.assertEqual(len(s1), len(s2))
  61. def test_get_state_groups_ids(self):
  62. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  63. e2 = self.inject_state_event(
  64. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  65. )
  66. state_group_map = self.get_success(
  67. self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
  68. )
  69. self.assertEqual(len(state_group_map), 1)
  70. state_map = list(state_group_map.values())[0]
  71. self.assertDictEqual(
  72. state_map,
  73. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  74. )
  75. def test_get_state_groups(self):
  76. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  77. e2 = self.inject_state_event(
  78. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  79. )
  80. state_group_map = self.get_success(
  81. self.storage.state.get_state_groups(self.room, [e2.event_id])
  82. )
  83. self.assertEqual(len(state_group_map), 1)
  84. state_list = list(state_group_map.values())[0]
  85. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  86. def test_get_state_for_event(self):
  87. # this defaults to a linear DAG as each new injection defaults to whatever
  88. # forward extremities are currently in the DB for this room.
  89. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  90. e2 = self.inject_state_event(
  91. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  92. )
  93. e3 = self.inject_state_event(
  94. self.room,
  95. self.u_alice,
  96. EventTypes.Member,
  97. self.u_alice.to_string(),
  98. {"membership": Membership.JOIN},
  99. )
  100. e4 = self.inject_state_event(
  101. self.room,
  102. self.u_bob,
  103. EventTypes.Member,
  104. self.u_bob.to_string(),
  105. {"membership": Membership.JOIN},
  106. )
  107. e5 = self.inject_state_event(
  108. self.room,
  109. self.u_bob,
  110. EventTypes.Member,
  111. self.u_bob.to_string(),
  112. {"membership": Membership.LEAVE},
  113. )
  114. # check we get the full state as of the final event
  115. state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
  116. self.assertIsNotNone(e4)
  117. self.assertStateMapEqual(
  118. {
  119. (e1.type, e1.state_key): e1,
  120. (e2.type, e2.state_key): e2,
  121. (e3.type, e3.state_key): e3,
  122. # e4 is overwritten by e5
  123. (e5.type, e5.state_key): e5,
  124. },
  125. state,
  126. )
  127. # check we can filter to the m.room.name event (with a '' state key)
  128. state = self.get_success(
  129. self.storage.state.get_state_for_event(
  130. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  131. )
  132. )
  133. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  134. # check we can filter to the m.room.name event (with a wildcard None state key)
  135. state = self.get_success(
  136. self.storage.state.get_state_for_event(
  137. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  138. )
  139. )
  140. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  141. # check we can grab the m.room.member events (with a wildcard None state key)
  142. state = self.get_success(
  143. self.storage.state.get_state_for_event(
  144. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  145. )
  146. )
  147. self.assertStateMapEqual(
  148. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  149. )
  150. # check we can grab a specific room member without filtering out the
  151. # other event types
  152. state = self.get_success(
  153. self.storage.state.get_state_for_event(
  154. e5.event_id,
  155. state_filter=StateFilter(
  156. types=frozendict(
  157. {EventTypes.Member: frozenset({self.u_alice.to_string()})}
  158. ),
  159. include_others=True,
  160. ),
  161. )
  162. )
  163. self.assertStateMapEqual(
  164. {
  165. (e1.type, e1.state_key): e1,
  166. (e2.type, e2.state_key): e2,
  167. (e3.type, e3.state_key): e3,
  168. },
  169. state,
  170. )
  171. # check that we can grab everything except members
  172. state = self.get_success(
  173. self.storage.state.get_state_for_event(
  174. e5.event_id,
  175. state_filter=StateFilter(
  176. types=frozendict({EventTypes.Member: frozenset()}),
  177. include_others=True,
  178. ),
  179. )
  180. )
  181. self.assertStateMapEqual(
  182. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  183. )
  184. #######################################################
  185. # _get_state_for_group_using_cache tests against a full cache
  186. #######################################################
  187. room_id = self.room.to_string()
  188. group_ids = self.get_success(
  189. self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
  190. )
  191. group = list(group_ids.keys())[0]
  192. # test _get_state_for_group_using_cache correctly filters out members
  193. # with types=[]
  194. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  195. self.state_datastore._state_group_cache,
  196. group,
  197. state_filter=StateFilter(
  198. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  199. ),
  200. )
  201. self.assertEqual(is_all, True)
  202. self.assertDictEqual(
  203. {
  204. (e1.type, e1.state_key): e1.event_id,
  205. (e2.type, e2.state_key): e2.event_id,
  206. },
  207. state_dict,
  208. )
  209. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  210. self.state_datastore._state_group_members_cache,
  211. group,
  212. state_filter=StateFilter(
  213. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  214. ),
  215. )
  216. self.assertEqual(is_all, True)
  217. self.assertDictEqual({}, state_dict)
  218. # test _get_state_for_group_using_cache correctly filters in members
  219. # with wildcard types
  220. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  221. self.state_datastore._state_group_cache,
  222. group,
  223. state_filter=StateFilter(
  224. types=frozendict({EventTypes.Member: None}), include_others=True
  225. ),
  226. )
  227. self.assertEqual(is_all, True)
  228. self.assertDictEqual(
  229. {
  230. (e1.type, e1.state_key): e1.event_id,
  231. (e2.type, e2.state_key): e2.event_id,
  232. },
  233. state_dict,
  234. )
  235. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  236. self.state_datastore._state_group_members_cache,
  237. group,
  238. state_filter=StateFilter(
  239. types=frozendict({EventTypes.Member: None}), include_others=True
  240. ),
  241. )
  242. self.assertEqual(is_all, True)
  243. self.assertDictEqual(
  244. {
  245. (e3.type, e3.state_key): e3.event_id,
  246. # e4 is overwritten by e5
  247. (e5.type, e5.state_key): e5.event_id,
  248. },
  249. state_dict,
  250. )
  251. # test _get_state_for_group_using_cache correctly filters in members
  252. # with specific types
  253. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  254. self.state_datastore._state_group_cache,
  255. group,
  256. state_filter=StateFilter(
  257. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  258. include_others=True,
  259. ),
  260. )
  261. self.assertEqual(is_all, True)
  262. self.assertDictEqual(
  263. {
  264. (e1.type, e1.state_key): e1.event_id,
  265. (e2.type, e2.state_key): e2.event_id,
  266. },
  267. state_dict,
  268. )
  269. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  270. self.state_datastore._state_group_members_cache,
  271. group,
  272. state_filter=StateFilter(
  273. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  274. include_others=True,
  275. ),
  276. )
  277. self.assertEqual(is_all, True)
  278. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  279. # test _get_state_for_group_using_cache correctly filters in members
  280. # with specific types
  281. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  282. self.state_datastore._state_group_members_cache,
  283. group,
  284. state_filter=StateFilter(
  285. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  286. include_others=False,
  287. ),
  288. )
  289. self.assertEqual(is_all, True)
  290. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  291. #######################################################
  292. # deliberately remove e2 (room name) from the _state_group_cache
  293. cache_entry = self.state_datastore._state_group_cache.get(group)
  294. state_dict_ids = cache_entry.value
  295. self.assertEqual(cache_entry.full, True)
  296. self.assertEqual(cache_entry.known_absent, set())
  297. self.assertDictEqual(
  298. state_dict_ids,
  299. {
  300. (e1.type, e1.state_key): e1.event_id,
  301. (e2.type, e2.state_key): e2.event_id,
  302. },
  303. )
  304. state_dict_ids.pop((e2.type, e2.state_key))
  305. self.state_datastore._state_group_cache.invalidate(group)
  306. self.state_datastore._state_group_cache.update(
  307. sequence=self.state_datastore._state_group_cache.sequence,
  308. key=group,
  309. value=state_dict_ids,
  310. # list fetched keys so it knows it's partial
  311. fetched_keys=((e1.type, e1.state_key),),
  312. )
  313. cache_entry = self.state_datastore._state_group_cache.get(group)
  314. state_dict_ids = cache_entry.value
  315. self.assertEqual(cache_entry.full, False)
  316. self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
  317. self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
  318. ############################################
  319. # test that things work with a partial cache
  320. # test _get_state_for_group_using_cache correctly filters out members
  321. # with types=[]
  322. room_id = self.room.to_string()
  323. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  324. self.state_datastore._state_group_cache,
  325. group,
  326. state_filter=StateFilter(
  327. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  328. ),
  329. )
  330. self.assertEqual(is_all, False)
  331. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  332. room_id = self.room.to_string()
  333. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  334. self.state_datastore._state_group_members_cache,
  335. group,
  336. state_filter=StateFilter(
  337. types=frozendict({EventTypes.Member: frozenset()}), include_others=True
  338. ),
  339. )
  340. self.assertEqual(is_all, True)
  341. self.assertDictEqual({}, state_dict)
  342. # test _get_state_for_group_using_cache correctly filters in members
  343. # wildcard types
  344. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  345. self.state_datastore._state_group_cache,
  346. group,
  347. state_filter=StateFilter(
  348. types=frozendict({EventTypes.Member: None}), include_others=True
  349. ),
  350. )
  351. self.assertEqual(is_all, False)
  352. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  353. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  354. self.state_datastore._state_group_members_cache,
  355. group,
  356. state_filter=StateFilter(
  357. types=frozendict({EventTypes.Member: None}), include_others=True
  358. ),
  359. )
  360. self.assertEqual(is_all, True)
  361. self.assertDictEqual(
  362. {
  363. (e3.type, e3.state_key): e3.event_id,
  364. (e5.type, e5.state_key): e5.event_id,
  365. },
  366. state_dict,
  367. )
  368. # test _get_state_for_group_using_cache correctly filters in members
  369. # with specific types
  370. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  371. self.state_datastore._state_group_cache,
  372. group,
  373. state_filter=StateFilter(
  374. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  375. include_others=True,
  376. ),
  377. )
  378. self.assertEqual(is_all, False)
  379. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  380. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  381. self.state_datastore._state_group_members_cache,
  382. group,
  383. state_filter=StateFilter(
  384. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  385. include_others=True,
  386. ),
  387. )
  388. self.assertEqual(is_all, True)
  389. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  390. # test _get_state_for_group_using_cache correctly filters in members
  391. # with specific types
  392. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  393. self.state_datastore._state_group_cache,
  394. group,
  395. state_filter=StateFilter(
  396. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  397. include_others=False,
  398. ),
  399. )
  400. self.assertEqual(is_all, False)
  401. self.assertDictEqual({}, state_dict)
  402. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  403. self.state_datastore._state_group_members_cache,
  404. group,
  405. state_filter=StateFilter(
  406. types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
  407. include_others=False,
  408. ),
  409. )
  410. self.assertEqual(is_all, True)
  411. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)