test_state.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. # Copyright 2018-2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from synapse.api.constants import EventTypes, Membership
  16. from synapse.api.room_versions import RoomVersions
  17. from synapse.storage.state import StateFilter
  18. from synapse.types import RoomID, UserID
  19. from tests.unittest import HomeserverTestCase
  20. logger = logging.getLogger(__name__)
  21. class StateStoreTestCase(HomeserverTestCase):
  22. def prepare(self, reactor, clock, hs):
  23. self.store = hs.get_datastore()
  24. self.storage = hs.get_storage()
  25. self.state_datastore = self.storage.state.stores.state
  26. self.event_builder_factory = hs.get_event_builder_factory()
  27. self.event_creation_handler = hs.get_event_creation_handler()
  28. self.u_alice = UserID.from_string("@alice:test")
  29. self.u_bob = UserID.from_string("@bob:test")
  30. self.room = RoomID.from_string("!abc123:test")
  31. self.get_success(
  32. self.store.store_room(
  33. self.room.to_string(),
  34. room_creator_user_id="@creator:text",
  35. is_public=True,
  36. room_version=RoomVersions.V1,
  37. )
  38. )
  39. def inject_state_event(self, room, sender, typ, state_key, content):
  40. builder = self.event_builder_factory.for_room_version(
  41. RoomVersions.V1,
  42. {
  43. "type": typ,
  44. "sender": sender.to_string(),
  45. "state_key": state_key,
  46. "room_id": room.to_string(),
  47. "content": content,
  48. },
  49. )
  50. event, context = self.get_success(
  51. self.event_creation_handler.create_new_client_event(builder)
  52. )
  53. self.get_success(self.storage.persistence.persist_event(event, context))
  54. return event
  55. def assertStateMapEqual(self, s1, s2):
  56. for t in s1:
  57. # just compare event IDs for simplicity
  58. self.assertEqual(s1[t].event_id, s2[t].event_id)
  59. self.assertEqual(len(s1), len(s2))
  60. def test_get_state_groups_ids(self):
  61. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  62. e2 = self.inject_state_event(
  63. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  64. )
  65. state_group_map = self.get_success(
  66. self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
  67. )
  68. self.assertEqual(len(state_group_map), 1)
  69. state_map = list(state_group_map.values())[0]
  70. self.assertDictEqual(
  71. state_map,
  72. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  73. )
  74. def test_get_state_groups(self):
  75. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  76. e2 = self.inject_state_event(
  77. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  78. )
  79. state_group_map = self.get_success(
  80. self.storage.state.get_state_groups(self.room, [e2.event_id])
  81. )
  82. self.assertEqual(len(state_group_map), 1)
  83. state_list = list(state_group_map.values())[0]
  84. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  85. def test_get_state_for_event(self):
  86. # this defaults to a linear DAG as each new injection defaults to whatever
  87. # forward extremities are currently in the DB for this room.
  88. e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
  89. e2 = self.inject_state_event(
  90. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  91. )
  92. e3 = self.inject_state_event(
  93. self.room,
  94. self.u_alice,
  95. EventTypes.Member,
  96. self.u_alice.to_string(),
  97. {"membership": Membership.JOIN},
  98. )
  99. e4 = self.inject_state_event(
  100. self.room,
  101. self.u_bob,
  102. EventTypes.Member,
  103. self.u_bob.to_string(),
  104. {"membership": Membership.JOIN},
  105. )
  106. e5 = self.inject_state_event(
  107. self.room,
  108. self.u_bob,
  109. EventTypes.Member,
  110. self.u_bob.to_string(),
  111. {"membership": Membership.LEAVE},
  112. )
  113. # check we get the full state as of the final event
  114. state = self.get_success(self.storage.state.get_state_for_event(e5.event_id))
  115. self.assertIsNotNone(e4)
  116. self.assertStateMapEqual(
  117. {
  118. (e1.type, e1.state_key): e1,
  119. (e2.type, e2.state_key): e2,
  120. (e3.type, e3.state_key): e3,
  121. # e4 is overwritten by e5
  122. (e5.type, e5.state_key): e5,
  123. },
  124. state,
  125. )
  126. # check we can filter to the m.room.name event (with a '' state key)
  127. state = self.get_success(
  128. self.storage.state.get_state_for_event(
  129. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  130. )
  131. )
  132. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  133. # check we can filter to the m.room.name event (with a wildcard None state key)
  134. state = self.get_success(
  135. self.storage.state.get_state_for_event(
  136. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  137. )
  138. )
  139. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  140. # check we can grab the m.room.member events (with a wildcard None state key)
  141. state = self.get_success(
  142. self.storage.state.get_state_for_event(
  143. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  144. )
  145. )
  146. self.assertStateMapEqual(
  147. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  148. )
  149. # check we can grab a specific room member without filtering out the
  150. # other event types
  151. state = self.get_success(
  152. self.storage.state.get_state_for_event(
  153. e5.event_id,
  154. state_filter=StateFilter(
  155. types={EventTypes.Member: {self.u_alice.to_string()}},
  156. include_others=True,
  157. ),
  158. )
  159. )
  160. self.assertStateMapEqual(
  161. {
  162. (e1.type, e1.state_key): e1,
  163. (e2.type, e2.state_key): e2,
  164. (e3.type, e3.state_key): e3,
  165. },
  166. state,
  167. )
  168. # check that we can grab everything except members
  169. state = self.get_success(
  170. self.storage.state.get_state_for_event(
  171. e5.event_id,
  172. state_filter=StateFilter(
  173. types={EventTypes.Member: set()}, include_others=True
  174. ),
  175. )
  176. )
  177. self.assertStateMapEqual(
  178. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  179. )
  180. #######################################################
  181. # _get_state_for_group_using_cache tests against a full cache
  182. #######################################################
  183. room_id = self.room.to_string()
  184. group_ids = self.get_success(
  185. self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
  186. )
  187. group = list(group_ids.keys())[0]
  188. # test _get_state_for_group_using_cache correctly filters out members
  189. # with types=[]
  190. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  191. self.state_datastore._state_group_cache,
  192. group,
  193. state_filter=StateFilter(
  194. types={EventTypes.Member: set()}, include_others=True
  195. ),
  196. )
  197. self.assertEqual(is_all, True)
  198. self.assertDictEqual(
  199. {
  200. (e1.type, e1.state_key): e1.event_id,
  201. (e2.type, e2.state_key): e2.event_id,
  202. },
  203. state_dict,
  204. )
  205. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  206. self.state_datastore._state_group_members_cache,
  207. group,
  208. state_filter=StateFilter(
  209. types={EventTypes.Member: set()}, include_others=True
  210. ),
  211. )
  212. self.assertEqual(is_all, True)
  213. self.assertDictEqual({}, state_dict)
  214. # test _get_state_for_group_using_cache correctly filters in members
  215. # with wildcard types
  216. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  217. self.state_datastore._state_group_cache,
  218. group,
  219. state_filter=StateFilter(
  220. types={EventTypes.Member: None}, include_others=True
  221. ),
  222. )
  223. self.assertEqual(is_all, True)
  224. self.assertDictEqual(
  225. {
  226. (e1.type, e1.state_key): e1.event_id,
  227. (e2.type, e2.state_key): e2.event_id,
  228. },
  229. state_dict,
  230. )
  231. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  232. self.state_datastore._state_group_members_cache,
  233. group,
  234. state_filter=StateFilter(
  235. types={EventTypes.Member: None}, include_others=True
  236. ),
  237. )
  238. self.assertEqual(is_all, True)
  239. self.assertDictEqual(
  240. {
  241. (e3.type, e3.state_key): e3.event_id,
  242. # e4 is overwritten by e5
  243. (e5.type, e5.state_key): e5.event_id,
  244. },
  245. state_dict,
  246. )
  247. # test _get_state_for_group_using_cache correctly filters in members
  248. # with specific types
  249. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  250. self.state_datastore._state_group_cache,
  251. group,
  252. state_filter=StateFilter(
  253. types={EventTypes.Member: {e5.state_key}}, include_others=True
  254. ),
  255. )
  256. self.assertEqual(is_all, True)
  257. self.assertDictEqual(
  258. {
  259. (e1.type, e1.state_key): e1.event_id,
  260. (e2.type, e2.state_key): e2.event_id,
  261. },
  262. state_dict,
  263. )
  264. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  265. self.state_datastore._state_group_members_cache,
  266. group,
  267. state_filter=StateFilter(
  268. types={EventTypes.Member: {e5.state_key}}, include_others=True
  269. ),
  270. )
  271. self.assertEqual(is_all, True)
  272. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  273. # test _get_state_for_group_using_cache correctly filters in members
  274. # with specific types
  275. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  276. self.state_datastore._state_group_members_cache,
  277. group,
  278. state_filter=StateFilter(
  279. types={EventTypes.Member: {e5.state_key}}, include_others=False
  280. ),
  281. )
  282. self.assertEqual(is_all, True)
  283. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  284. #######################################################
  285. # deliberately remove e2 (room name) from the _state_group_cache
  286. cache_entry = self.state_datastore._state_group_cache.get(group)
  287. state_dict_ids = cache_entry.value
  288. self.assertEqual(cache_entry.full, True)
  289. self.assertEqual(cache_entry.known_absent, set())
  290. self.assertDictEqual(
  291. state_dict_ids,
  292. {
  293. (e1.type, e1.state_key): e1.event_id,
  294. (e2.type, e2.state_key): e2.event_id,
  295. },
  296. )
  297. state_dict_ids.pop((e2.type, e2.state_key))
  298. self.state_datastore._state_group_cache.invalidate(group)
  299. self.state_datastore._state_group_cache.update(
  300. sequence=self.state_datastore._state_group_cache.sequence,
  301. key=group,
  302. value=state_dict_ids,
  303. # list fetched keys so it knows it's partial
  304. fetched_keys=((e1.type, e1.state_key),),
  305. )
  306. cache_entry = self.state_datastore._state_group_cache.get(group)
  307. state_dict_ids = cache_entry.value
  308. self.assertEqual(cache_entry.full, False)
  309. self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
  310. self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
  311. ############################################
  312. # test that things work with a partial cache
  313. # test _get_state_for_group_using_cache correctly filters out members
  314. # with types=[]
  315. room_id = self.room.to_string()
  316. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  317. self.state_datastore._state_group_cache,
  318. group,
  319. state_filter=StateFilter(
  320. types={EventTypes.Member: set()}, include_others=True
  321. ),
  322. )
  323. self.assertEqual(is_all, False)
  324. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  325. room_id = self.room.to_string()
  326. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  327. self.state_datastore._state_group_members_cache,
  328. group,
  329. state_filter=StateFilter(
  330. types={EventTypes.Member: set()}, include_others=True
  331. ),
  332. )
  333. self.assertEqual(is_all, True)
  334. self.assertDictEqual({}, state_dict)
  335. # test _get_state_for_group_using_cache correctly filters in members
  336. # wildcard types
  337. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  338. self.state_datastore._state_group_cache,
  339. group,
  340. state_filter=StateFilter(
  341. types={EventTypes.Member: None}, include_others=True
  342. ),
  343. )
  344. self.assertEqual(is_all, False)
  345. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  346. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  347. self.state_datastore._state_group_members_cache,
  348. group,
  349. state_filter=StateFilter(
  350. types={EventTypes.Member: None}, include_others=True
  351. ),
  352. )
  353. self.assertEqual(is_all, True)
  354. self.assertDictEqual(
  355. {
  356. (e3.type, e3.state_key): e3.event_id,
  357. (e5.type, e5.state_key): e5.event_id,
  358. },
  359. state_dict,
  360. )
  361. # test _get_state_for_group_using_cache correctly filters in members
  362. # with specific types
  363. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  364. self.state_datastore._state_group_cache,
  365. group,
  366. state_filter=StateFilter(
  367. types={EventTypes.Member: {e5.state_key}}, include_others=True
  368. ),
  369. )
  370. self.assertEqual(is_all, False)
  371. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  372. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  373. self.state_datastore._state_group_members_cache,
  374. group,
  375. state_filter=StateFilter(
  376. types={EventTypes.Member: {e5.state_key}}, include_others=True
  377. ),
  378. )
  379. self.assertEqual(is_all, True)
  380. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  381. # test _get_state_for_group_using_cache correctly filters in members
  382. # with specific types
  383. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  384. self.state_datastore._state_group_cache,
  385. group,
  386. state_filter=StateFilter(
  387. types={EventTypes.Member: {e5.state_key}}, include_others=False
  388. ),
  389. )
  390. self.assertEqual(is_all, False)
  391. self.assertDictEqual({}, state_dict)
  392. (state_dict, is_all,) = self.state_datastore._get_state_for_group_using_cache(
  393. self.state_datastore._state_group_members_cache,
  394. group,
  395. state_filter=StateFilter(
  396. types={EventTypes.Member: {e5.state_key}}, include_others=False
  397. ),
  398. )
  399. self.assertEqual(is_all, True)
  400. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)