test_state.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from twisted.internet import defer
  17. from synapse.api.constants import EventTypes, Membership, RoomVersions
  18. from synapse.storage.state import StateFilter
  19. from synapse.types import RoomID, UserID
  20. import tests.unittest
  21. import tests.utils
  22. logger = logging.getLogger(__name__)
  23. class StateStoreTestCase(tests.unittest.TestCase):
  24. @defer.inlineCallbacks
  25. def setUp(self):
  26. hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
  27. self.store = hs.get_datastore()
  28. self.event_builder_factory = hs.get_event_builder_factory()
  29. self.event_creation_handler = hs.get_event_creation_handler()
  30. self.u_alice = UserID.from_string("@alice:test")
  31. self.u_bob = UserID.from_string("@bob:test")
  32. self.room = RoomID.from_string("!abc123:test")
  33. yield self.store.store_room(
  34. self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
  35. )
  36. @defer.inlineCallbacks
  37. def inject_state_event(self, room, sender, typ, state_key, content):
  38. builder = self.event_builder_factory.new(
  39. RoomVersions.V1,
  40. {
  41. "type": typ,
  42. "sender": sender.to_string(),
  43. "state_key": state_key,
  44. "room_id": room.to_string(),
  45. "content": content,
  46. }
  47. )
  48. event, context = yield self.event_creation_handler.create_new_client_event(
  49. builder
  50. )
  51. yield self.store.persist_event(event, context)
  52. defer.returnValue(event)
  53. def assertStateMapEqual(self, s1, s2):
  54. for t in s1:
  55. # just compare event IDs for simplicity
  56. self.assertEqual(s1[t].event_id, s2[t].event_id)
  57. self.assertEqual(len(s1), len(s2))
  58. @defer.inlineCallbacks
  59. def test_get_state_groups_ids(self):
  60. e1 = yield self.inject_state_event(
  61. self.room, self.u_alice, EventTypes.Create, '', {}
  62. )
  63. e2 = yield self.inject_state_event(
  64. self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
  65. )
  66. state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
  67. self.assertEqual(len(state_group_map), 1)
  68. state_map = list(state_group_map.values())[0]
  69. self.assertDictEqual(
  70. state_map,
  71. {
  72. (EventTypes.Create, ''): e1.event_id,
  73. (EventTypes.Name, ''): e2.event_id,
  74. },
  75. )
  76. @defer.inlineCallbacks
  77. def test_get_state_groups(self):
  78. e1 = yield self.inject_state_event(
  79. self.room, self.u_alice, EventTypes.Create, '', {}
  80. )
  81. e2 = yield self.inject_state_event(
  82. self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
  83. )
  84. state_group_map = yield self.store.get_state_groups(
  85. self.room, [e2.event_id])
  86. self.assertEqual(len(state_group_map), 1)
  87. state_list = list(state_group_map.values())[0]
  88. self.assertEqual(
  89. {ev.event_id for ev in state_list},
  90. {e1.event_id, e2.event_id},
  91. )
  92. @defer.inlineCallbacks
  93. def test_get_state_for_event(self):
  94. # this defaults to a linear DAG as each new injection defaults to whatever
  95. # forward extremities are currently in the DB for this room.
  96. e1 = yield self.inject_state_event(
  97. self.room, self.u_alice, EventTypes.Create, '', {}
  98. )
  99. e2 = yield self.inject_state_event(
  100. self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
  101. )
  102. e3 = yield self.inject_state_event(
  103. self.room,
  104. self.u_alice,
  105. EventTypes.Member,
  106. self.u_alice.to_string(),
  107. {"membership": Membership.JOIN},
  108. )
  109. e4 = yield self.inject_state_event(
  110. self.room,
  111. self.u_bob,
  112. EventTypes.Member,
  113. self.u_bob.to_string(),
  114. {"membership": Membership.JOIN},
  115. )
  116. e5 = yield self.inject_state_event(
  117. self.room,
  118. self.u_bob,
  119. EventTypes.Member,
  120. self.u_bob.to_string(),
  121. {"membership": Membership.LEAVE},
  122. )
  123. # check we get the full state as of the final event
  124. state = yield self.store.get_state_for_event(
  125. e5.event_id,
  126. )
  127. self.assertIsNotNone(e4)
  128. self.assertStateMapEqual(
  129. {
  130. (e1.type, e1.state_key): e1,
  131. (e2.type, e2.state_key): e2,
  132. (e3.type, e3.state_key): e3,
  133. # e4 is overwritten by e5
  134. (e5.type, e5.state_key): e5,
  135. },
  136. state,
  137. )
  138. # check we can filter to the m.room.name event (with a '' state key)
  139. state = yield self.store.get_state_for_event(
  140. e5.event_id, StateFilter.from_types([(EventTypes.Name, '')])
  141. )
  142. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  143. # check we can filter to the m.room.name event (with a wildcard None state key)
  144. state = yield self.store.get_state_for_event(
  145. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  146. )
  147. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  148. # check we can grab the m.room.member events (with a wildcard None state key)
  149. state = yield self.store.get_state_for_event(
  150. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  151. )
  152. self.assertStateMapEqual(
  153. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  154. )
  155. # check we can grab a specific room member without filtering out the
  156. # other event types
  157. state = yield self.store.get_state_for_event(
  158. e5.event_id,
  159. state_filter=StateFilter(
  160. types={EventTypes.Member: {self.u_alice.to_string()}},
  161. include_others=True,
  162. )
  163. )
  164. self.assertStateMapEqual(
  165. {
  166. (e1.type, e1.state_key): e1,
  167. (e2.type, e2.state_key): e2,
  168. (e3.type, e3.state_key): e3,
  169. },
  170. state,
  171. )
  172. # check that we can grab everything except members
  173. state = yield self.store.get_state_for_event(
  174. e5.event_id, state_filter=StateFilter(
  175. types={EventTypes.Member: set()},
  176. include_others=True,
  177. ),
  178. )
  179. self.assertStateMapEqual(
  180. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  181. )
  182. #######################################################
  183. # _get_state_for_group_using_cache tests against a full cache
  184. #######################################################
  185. room_id = self.room.to_string()
  186. group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
  187. group = list(group_ids.keys())[0]
  188. # test _get_state_for_group_using_cache correctly filters out members
  189. # with types=[]
  190. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  191. self.store._state_group_cache, group,
  192. state_filter=StateFilter(
  193. types={EventTypes.Member: set()},
  194. include_others=True,
  195. ),
  196. )
  197. self.assertEqual(is_all, True)
  198. self.assertDictEqual(
  199. {
  200. (e1.type, e1.state_key): e1.event_id,
  201. (e2.type, e2.state_key): e2.event_id,
  202. },
  203. state_dict,
  204. )
  205. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  206. self.store._state_group_members_cache,
  207. group,
  208. state_filter=StateFilter(
  209. types={EventTypes.Member: set()},
  210. include_others=True,
  211. ),
  212. )
  213. self.assertEqual(is_all, True)
  214. self.assertDictEqual({}, state_dict)
  215. # test _get_state_for_group_using_cache correctly filters in members
  216. # with wildcard types
  217. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  218. self.store._state_group_cache,
  219. group,
  220. state_filter=StateFilter(
  221. types={EventTypes.Member: None},
  222. include_others=True,
  223. ),
  224. )
  225. self.assertEqual(is_all, True)
  226. self.assertDictEqual(
  227. {
  228. (e1.type, e1.state_key): e1.event_id,
  229. (e2.type, e2.state_key): e2.event_id,
  230. },
  231. state_dict,
  232. )
  233. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  234. self.store._state_group_members_cache,
  235. group,
  236. state_filter=StateFilter(
  237. types={EventTypes.Member: None},
  238. include_others=True,
  239. ),
  240. )
  241. self.assertEqual(is_all, True)
  242. self.assertDictEqual(
  243. {
  244. (e3.type, e3.state_key): e3.event_id,
  245. # e4 is overwritten by e5
  246. (e5.type, e5.state_key): e5.event_id,
  247. },
  248. state_dict,
  249. )
  250. # test _get_state_for_group_using_cache correctly filters in members
  251. # with specific types
  252. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  253. self.store._state_group_cache,
  254. group,
  255. state_filter=StateFilter(
  256. types={EventTypes.Member: {e5.state_key}},
  257. include_others=True,
  258. ),
  259. )
  260. self.assertEqual(is_all, True)
  261. self.assertDictEqual(
  262. {
  263. (e1.type, e1.state_key): e1.event_id,
  264. (e2.type, e2.state_key): e2.event_id,
  265. },
  266. state_dict,
  267. )
  268. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  269. self.store._state_group_members_cache,
  270. group,
  271. state_filter=StateFilter(
  272. types={EventTypes.Member: {e5.state_key}},
  273. include_others=True,
  274. ),
  275. )
  276. self.assertEqual(is_all, True)
  277. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  278. # test _get_state_for_group_using_cache correctly filters in members
  279. # with specific types
  280. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  281. self.store._state_group_members_cache,
  282. group,
  283. state_filter=StateFilter(
  284. types={EventTypes.Member: {e5.state_key}},
  285. include_others=False,
  286. ),
  287. )
  288. self.assertEqual(is_all, True)
  289. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  290. #######################################################
  291. # deliberately remove e2 (room name) from the _state_group_cache
  292. (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
  293. group
  294. )
  295. self.assertEqual(is_all, True)
  296. self.assertEqual(known_absent, set())
  297. self.assertDictEqual(
  298. state_dict_ids,
  299. {
  300. (e1.type, e1.state_key): e1.event_id,
  301. (e2.type, e2.state_key): e2.event_id,
  302. },
  303. )
  304. state_dict_ids.pop((e2.type, e2.state_key))
  305. self.store._state_group_cache.invalidate(group)
  306. self.store._state_group_cache.update(
  307. sequence=self.store._state_group_cache.sequence,
  308. key=group,
  309. value=state_dict_ids,
  310. # list fetched keys so it knows it's partial
  311. fetched_keys=((e1.type, e1.state_key),),
  312. )
  313. (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
  314. group
  315. )
  316. self.assertEqual(is_all, False)
  317. self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
  318. self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
  319. ############################################
  320. # test that things work with a partial cache
  321. # test _get_state_for_group_using_cache correctly filters out members
  322. # with types=[]
  323. room_id = self.room.to_string()
  324. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  325. self.store._state_group_cache, group,
  326. state_filter=StateFilter(
  327. types={EventTypes.Member: set()},
  328. include_others=True,
  329. ),
  330. )
  331. self.assertEqual(is_all, False)
  332. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  333. room_id = self.room.to_string()
  334. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  335. self.store._state_group_members_cache,
  336. group,
  337. state_filter=StateFilter(
  338. types={EventTypes.Member: set()},
  339. include_others=True,
  340. ),
  341. )
  342. self.assertEqual(is_all, True)
  343. self.assertDictEqual({}, state_dict)
  344. # test _get_state_for_group_using_cache correctly filters in members
  345. # wildcard types
  346. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  347. self.store._state_group_cache,
  348. group,
  349. state_filter=StateFilter(
  350. types={EventTypes.Member: None},
  351. include_others=True,
  352. ),
  353. )
  354. self.assertEqual(is_all, False)
  355. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  356. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  357. self.store._state_group_members_cache,
  358. group,
  359. state_filter=StateFilter(
  360. types={EventTypes.Member: None},
  361. include_others=True,
  362. ),
  363. )
  364. self.assertEqual(is_all, True)
  365. self.assertDictEqual(
  366. {
  367. (e3.type, e3.state_key): e3.event_id,
  368. (e5.type, e5.state_key): e5.event_id,
  369. },
  370. state_dict,
  371. )
  372. # test _get_state_for_group_using_cache correctly filters in members
  373. # with specific types
  374. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  375. self.store._state_group_cache,
  376. group,
  377. state_filter=StateFilter(
  378. types={EventTypes.Member: {e5.state_key}},
  379. include_others=True,
  380. ),
  381. )
  382. self.assertEqual(is_all, False)
  383. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  384. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  385. self.store._state_group_members_cache,
  386. group,
  387. state_filter=StateFilter(
  388. types={EventTypes.Member: {e5.state_key}},
  389. include_others=True,
  390. ),
  391. )
  392. self.assertEqual(is_all, True)
  393. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  394. # test _get_state_for_group_using_cache correctly filters in members
  395. # with specific types
  396. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  397. self.store._state_group_cache,
  398. group,
  399. state_filter=StateFilter(
  400. types={EventTypes.Member: {e5.state_key}},
  401. include_others=False,
  402. ),
  403. )
  404. self.assertEqual(is_all, False)
  405. self.assertDictEqual({}, state_dict)
  406. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  407. self.store._state_group_members_cache,
  408. group,
  409. state_filter=StateFilter(
  410. types={EventTypes.Member: {e5.state_key}},
  411. include_others=False,
  412. ),
  413. )
  414. self.assertEqual(is_all, True)
  415. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)