test_state.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from twisted.internet import defer
  17. from synapse.api.constants import EventTypes, Membership
  18. from synapse.api.room_versions import RoomVersions
  19. from synapse.storage.state import StateFilter
  20. from synapse.types import RoomID, UserID
  21. import tests.unittest
  22. import tests.utils
  23. logger = logging.getLogger(__name__)
  24. class StateStoreTestCase(tests.unittest.TestCase):
  25. @defer.inlineCallbacks
  26. def setUp(self):
  27. hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
  28. self.store = hs.get_datastore()
  29. self.event_builder_factory = hs.get_event_builder_factory()
  30. self.event_creation_handler = hs.get_event_creation_handler()
  31. self.u_alice = UserID.from_string("@alice:test")
  32. self.u_bob = UserID.from_string("@bob:test")
  33. self.room = RoomID.from_string("!abc123:test")
  34. yield self.store.store_room(
  35. self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
  36. )
  37. @defer.inlineCallbacks
  38. def inject_state_event(self, room, sender, typ, state_key, content):
  39. builder = self.event_builder_factory.for_room_version(
  40. RoomVersions.V1,
  41. {
  42. "type": typ,
  43. "sender": sender.to_string(),
  44. "state_key": state_key,
  45. "room_id": room.to_string(),
  46. "content": content,
  47. },
  48. )
  49. event, context = yield self.event_creation_handler.create_new_client_event(
  50. builder
  51. )
  52. yield self.store.persist_event(event, context)
  53. return event
  54. def assertStateMapEqual(self, s1, s2):
  55. for t in s1:
  56. # just compare event IDs for simplicity
  57. self.assertEqual(s1[t].event_id, s2[t].event_id)
  58. self.assertEqual(len(s1), len(s2))
  59. @defer.inlineCallbacks
  60. def test_get_state_groups_ids(self):
  61. e1 = yield self.inject_state_event(
  62. self.room, self.u_alice, EventTypes.Create, "", {}
  63. )
  64. e2 = yield self.inject_state_event(
  65. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  66. )
  67. state_group_map = yield self.store.get_state_groups_ids(
  68. self.room, [e2.event_id]
  69. )
  70. self.assertEqual(len(state_group_map), 1)
  71. state_map = list(state_group_map.values())[0]
  72. self.assertDictEqual(
  73. state_map,
  74. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  75. )
  76. @defer.inlineCallbacks
  77. def test_get_state_groups(self):
  78. e1 = yield self.inject_state_event(
  79. self.room, self.u_alice, EventTypes.Create, "", {}
  80. )
  81. e2 = yield self.inject_state_event(
  82. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  83. )
  84. state_group_map = yield self.store.get_state_groups(self.room, [e2.event_id])
  85. self.assertEqual(len(state_group_map), 1)
  86. state_list = list(state_group_map.values())[0]
  87. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  88. @defer.inlineCallbacks
  89. def test_get_state_for_event(self):
  90. # this defaults to a linear DAG as each new injection defaults to whatever
  91. # forward extremities are currently in the DB for this room.
  92. e1 = yield self.inject_state_event(
  93. self.room, self.u_alice, EventTypes.Create, "", {}
  94. )
  95. e2 = yield self.inject_state_event(
  96. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  97. )
  98. e3 = yield self.inject_state_event(
  99. self.room,
  100. self.u_alice,
  101. EventTypes.Member,
  102. self.u_alice.to_string(),
  103. {"membership": Membership.JOIN},
  104. )
  105. e4 = yield self.inject_state_event(
  106. self.room,
  107. self.u_bob,
  108. EventTypes.Member,
  109. self.u_bob.to_string(),
  110. {"membership": Membership.JOIN},
  111. )
  112. e5 = yield self.inject_state_event(
  113. self.room,
  114. self.u_bob,
  115. EventTypes.Member,
  116. self.u_bob.to_string(),
  117. {"membership": Membership.LEAVE},
  118. )
  119. # check we get the full state as of the final event
  120. state = yield self.store.get_state_for_event(e5.event_id)
  121. self.assertIsNotNone(e4)
  122. self.assertStateMapEqual(
  123. {
  124. (e1.type, e1.state_key): e1,
  125. (e2.type, e2.state_key): e2,
  126. (e3.type, e3.state_key): e3,
  127. # e4 is overwritten by e5
  128. (e5.type, e5.state_key): e5,
  129. },
  130. state,
  131. )
  132. # check we can filter to the m.room.name event (with a '' state key)
  133. state = yield self.store.get_state_for_event(
  134. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  135. )
  136. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  137. # check we can filter to the m.room.name event (with a wildcard None state key)
  138. state = yield self.store.get_state_for_event(
  139. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  140. )
  141. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  142. # check we can grab the m.room.member events (with a wildcard None state key)
  143. state = yield self.store.get_state_for_event(
  144. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  145. )
  146. self.assertStateMapEqual(
  147. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  148. )
  149. # check we can grab a specific room member without filtering out the
  150. # other event types
  151. state = yield self.store.get_state_for_event(
  152. e5.event_id,
  153. state_filter=StateFilter(
  154. types={EventTypes.Member: {self.u_alice.to_string()}},
  155. include_others=True,
  156. ),
  157. )
  158. self.assertStateMapEqual(
  159. {
  160. (e1.type, e1.state_key): e1,
  161. (e2.type, e2.state_key): e2,
  162. (e3.type, e3.state_key): e3,
  163. },
  164. state,
  165. )
  166. # check that we can grab everything except members
  167. state = yield self.store.get_state_for_event(
  168. e5.event_id,
  169. state_filter=StateFilter(
  170. types={EventTypes.Member: set()}, include_others=True
  171. ),
  172. )
  173. self.assertStateMapEqual(
  174. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  175. )
  176. #######################################################
  177. # _get_state_for_group_using_cache tests against a full cache
  178. #######################################################
  179. room_id = self.room.to_string()
  180. group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
  181. group = list(group_ids.keys())[0]
  182. # test _get_state_for_group_using_cache correctly filters out members
  183. # with types=[]
  184. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  185. self.store._state_group_cache,
  186. group,
  187. state_filter=StateFilter(
  188. types={EventTypes.Member: set()}, include_others=True
  189. ),
  190. )
  191. self.assertEqual(is_all, True)
  192. self.assertDictEqual(
  193. {
  194. (e1.type, e1.state_key): e1.event_id,
  195. (e2.type, e2.state_key): e2.event_id,
  196. },
  197. state_dict,
  198. )
  199. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  200. self.store._state_group_members_cache,
  201. group,
  202. state_filter=StateFilter(
  203. types={EventTypes.Member: set()}, include_others=True
  204. ),
  205. )
  206. self.assertEqual(is_all, True)
  207. self.assertDictEqual({}, state_dict)
  208. # test _get_state_for_group_using_cache correctly filters in members
  209. # with wildcard types
  210. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  211. self.store._state_group_cache,
  212. group,
  213. state_filter=StateFilter(
  214. types={EventTypes.Member: None}, include_others=True
  215. ),
  216. )
  217. self.assertEqual(is_all, True)
  218. self.assertDictEqual(
  219. {
  220. (e1.type, e1.state_key): e1.event_id,
  221. (e2.type, e2.state_key): e2.event_id,
  222. },
  223. state_dict,
  224. )
  225. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  226. self.store._state_group_members_cache,
  227. group,
  228. state_filter=StateFilter(
  229. types={EventTypes.Member: None}, include_others=True
  230. ),
  231. )
  232. self.assertEqual(is_all, True)
  233. self.assertDictEqual(
  234. {
  235. (e3.type, e3.state_key): e3.event_id,
  236. # e4 is overwritten by e5
  237. (e5.type, e5.state_key): e5.event_id,
  238. },
  239. state_dict,
  240. )
  241. # test _get_state_for_group_using_cache correctly filters in members
  242. # with specific types
  243. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  244. self.store._state_group_cache,
  245. group,
  246. state_filter=StateFilter(
  247. types={EventTypes.Member: {e5.state_key}}, include_others=True
  248. ),
  249. )
  250. self.assertEqual(is_all, True)
  251. self.assertDictEqual(
  252. {
  253. (e1.type, e1.state_key): e1.event_id,
  254. (e2.type, e2.state_key): e2.event_id,
  255. },
  256. state_dict,
  257. )
  258. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  259. self.store._state_group_members_cache,
  260. group,
  261. state_filter=StateFilter(
  262. types={EventTypes.Member: {e5.state_key}}, include_others=True
  263. ),
  264. )
  265. self.assertEqual(is_all, True)
  266. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  267. # test _get_state_for_group_using_cache correctly filters in members
  268. # with specific types
  269. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  270. self.store._state_group_members_cache,
  271. group,
  272. state_filter=StateFilter(
  273. types={EventTypes.Member: {e5.state_key}}, include_others=False
  274. ),
  275. )
  276. self.assertEqual(is_all, True)
  277. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  278. #######################################################
  279. # deliberately remove e2 (room name) from the _state_group_cache
  280. (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
  281. group
  282. )
  283. self.assertEqual(is_all, True)
  284. self.assertEqual(known_absent, set())
  285. self.assertDictEqual(
  286. state_dict_ids,
  287. {
  288. (e1.type, e1.state_key): e1.event_id,
  289. (e2.type, e2.state_key): e2.event_id,
  290. },
  291. )
  292. state_dict_ids.pop((e2.type, e2.state_key))
  293. self.store._state_group_cache.invalidate(group)
  294. self.store._state_group_cache.update(
  295. sequence=self.store._state_group_cache.sequence,
  296. key=group,
  297. value=state_dict_ids,
  298. # list fetched keys so it knows it's partial
  299. fetched_keys=((e1.type, e1.state_key),),
  300. )
  301. (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
  302. group
  303. )
  304. self.assertEqual(is_all, False)
  305. self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
  306. self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
  307. ############################################
  308. # test that things work with a partial cache
  309. # test _get_state_for_group_using_cache correctly filters out members
  310. # with types=[]
  311. room_id = self.room.to_string()
  312. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  313. self.store._state_group_cache,
  314. group,
  315. state_filter=StateFilter(
  316. types={EventTypes.Member: set()}, include_others=True
  317. ),
  318. )
  319. self.assertEqual(is_all, False)
  320. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  321. room_id = self.room.to_string()
  322. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  323. self.store._state_group_members_cache,
  324. group,
  325. state_filter=StateFilter(
  326. types={EventTypes.Member: set()}, include_others=True
  327. ),
  328. )
  329. self.assertEqual(is_all, True)
  330. self.assertDictEqual({}, state_dict)
  331. # test _get_state_for_group_using_cache correctly filters in members
  332. # wildcard types
  333. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  334. self.store._state_group_cache,
  335. group,
  336. state_filter=StateFilter(
  337. types={EventTypes.Member: None}, include_others=True
  338. ),
  339. )
  340. self.assertEqual(is_all, False)
  341. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  342. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  343. self.store._state_group_members_cache,
  344. group,
  345. state_filter=StateFilter(
  346. types={EventTypes.Member: None}, include_others=True
  347. ),
  348. )
  349. self.assertEqual(is_all, True)
  350. self.assertDictEqual(
  351. {
  352. (e3.type, e3.state_key): e3.event_id,
  353. (e5.type, e5.state_key): e5.event_id,
  354. },
  355. state_dict,
  356. )
  357. # test _get_state_for_group_using_cache correctly filters in members
  358. # with specific types
  359. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  360. self.store._state_group_cache,
  361. group,
  362. state_filter=StateFilter(
  363. types={EventTypes.Member: {e5.state_key}}, include_others=True
  364. ),
  365. )
  366. self.assertEqual(is_all, False)
  367. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  368. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  369. self.store._state_group_members_cache,
  370. group,
  371. state_filter=StateFilter(
  372. types={EventTypes.Member: {e5.state_key}}, include_others=True
  373. ),
  374. )
  375. self.assertEqual(is_all, True)
  376. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  377. # test _get_state_for_group_using_cache correctly filters in members
  378. # with specific types
  379. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  380. self.store._state_group_cache,
  381. group,
  382. state_filter=StateFilter(
  383. types={EventTypes.Member: {e5.state_key}}, include_others=False
  384. ),
  385. )
  386. self.assertEqual(is_all, False)
  387. self.assertDictEqual({}, state_dict)
  388. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  389. self.store._state_group_members_cache,
  390. group,
  391. state_filter=StateFilter(
  392. types={EventTypes.Member: {e5.state_key}}, include_others=False
  393. ),
  394. )
  395. self.assertEqual(is_all, True)
  396. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)