test_state.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from twisted.internet import defer
  17. from synapse.api.constants import EventTypes, Membership
  18. from synapse.storage.state import StateFilter
  19. from synapse.types import RoomID, UserID
  20. import tests.unittest
  21. import tests.utils
  22. logger = logging.getLogger(__name__)
  23. class StateStoreTestCase(tests.unittest.TestCase):
  24. def __init__(self, *args, **kwargs):
  25. super(StateStoreTestCase, self).__init__(*args, **kwargs)
  26. self.store = None # type: synapse.storage.DataStore
  27. @defer.inlineCallbacks
  28. def setUp(self):
  29. hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
  30. self.store = hs.get_datastore()
  31. self.event_builder_factory = hs.get_event_builder_factory()
  32. self.event_creation_handler = hs.get_event_creation_handler()
  33. self.u_alice = UserID.from_string("@alice:test")
  34. self.u_bob = UserID.from_string("@bob:test")
  35. self.room = RoomID.from_string("!abc123:test")
  36. yield self.store.store_room(
  37. self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
  38. )
  39. @defer.inlineCallbacks
  40. def inject_state_event(self, room, sender, typ, state_key, content):
  41. builder = self.event_builder_factory.new(
  42. {
  43. "type": typ,
  44. "sender": sender.to_string(),
  45. "state_key": state_key,
  46. "room_id": room.to_string(),
  47. "content": content,
  48. }
  49. )
  50. event, context = yield self.event_creation_handler.create_new_client_event(
  51. builder
  52. )
  53. yield self.store.persist_event(event, context)
  54. defer.returnValue(event)
  55. def assertStateMapEqual(self, s1, s2):
  56. for t in s1:
  57. # just compare event IDs for simplicity
  58. self.assertEqual(s1[t].event_id, s2[t].event_id)
  59. self.assertEqual(len(s1), len(s2))
  60. @defer.inlineCallbacks
  61. def test_get_state_groups_ids(self):
  62. e1 = yield self.inject_state_event(
  63. self.room, self.u_alice, EventTypes.Create, '', {}
  64. )
  65. e2 = yield self.inject_state_event(
  66. self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
  67. )
  68. state_group_map = yield self.store.get_state_groups_ids(self.room, [e2.event_id])
  69. self.assertEqual(len(state_group_map), 1)
  70. state_map = list(state_group_map.values())[0]
  71. self.assertDictEqual(
  72. state_map,
  73. {
  74. (EventTypes.Create, ''): e1.event_id,
  75. (EventTypes.Name, ''): e2.event_id,
  76. },
  77. )
  78. @defer.inlineCallbacks
  79. def test_get_state_groups(self):
  80. e1 = yield self.inject_state_event(
  81. self.room, self.u_alice, EventTypes.Create, '', {}
  82. )
  83. e2 = yield self.inject_state_event(
  84. self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
  85. )
  86. state_group_map = yield self.store.get_state_groups(
  87. self.room, [e2.event_id])
  88. self.assertEqual(len(state_group_map), 1)
  89. state_list = list(state_group_map.values())[0]
  90. self.assertEqual(
  91. {ev.event_id for ev in state_list},
  92. {e1.event_id, e2.event_id},
  93. )
  94. @defer.inlineCallbacks
  95. def test_get_state_for_event(self):
  96. # this defaults to a linear DAG as each new injection defaults to whatever
  97. # forward extremities are currently in the DB for this room.
  98. e1 = yield self.inject_state_event(
  99. self.room, self.u_alice, EventTypes.Create, '', {}
  100. )
  101. e2 = yield self.inject_state_event(
  102. self.room, self.u_alice, EventTypes.Name, '', {"name": "test room"}
  103. )
  104. e3 = yield self.inject_state_event(
  105. self.room,
  106. self.u_alice,
  107. EventTypes.Member,
  108. self.u_alice.to_string(),
  109. {"membership": Membership.JOIN},
  110. )
  111. e4 = yield self.inject_state_event(
  112. self.room,
  113. self.u_bob,
  114. EventTypes.Member,
  115. self.u_bob.to_string(),
  116. {"membership": Membership.JOIN},
  117. )
  118. e5 = yield self.inject_state_event(
  119. self.room,
  120. self.u_bob,
  121. EventTypes.Member,
  122. self.u_bob.to_string(),
  123. {"membership": Membership.LEAVE},
  124. )
  125. # check we get the full state as of the final event
  126. state = yield self.store.get_state_for_event(
  127. e5.event_id,
  128. )
  129. self.assertIsNotNone(e4)
  130. self.assertStateMapEqual(
  131. {
  132. (e1.type, e1.state_key): e1,
  133. (e2.type, e2.state_key): e2,
  134. (e3.type, e3.state_key): e3,
  135. # e4 is overwritten by e5
  136. (e5.type, e5.state_key): e5,
  137. },
  138. state,
  139. )
  140. # check we can filter to the m.room.name event (with a '' state key)
  141. state = yield self.store.get_state_for_event(
  142. e5.event_id, StateFilter.from_types([(EventTypes.Name, '')])
  143. )
  144. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  145. # check we can filter to the m.room.name event (with a wildcard None state key)
  146. state = yield self.store.get_state_for_event(
  147. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  148. )
  149. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  150. # check we can grab the m.room.member events (with a wildcard None state key)
  151. state = yield self.store.get_state_for_event(
  152. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  153. )
  154. self.assertStateMapEqual(
  155. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  156. )
  157. # check we can grab a specific room member without filtering out the
  158. # other event types
  159. state = yield self.store.get_state_for_event(
  160. e5.event_id,
  161. state_filter=StateFilter(
  162. types={EventTypes.Member: {self.u_alice.to_string()}},
  163. include_others=True,
  164. )
  165. )
  166. self.assertStateMapEqual(
  167. {
  168. (e1.type, e1.state_key): e1,
  169. (e2.type, e2.state_key): e2,
  170. (e3.type, e3.state_key): e3,
  171. },
  172. state,
  173. )
  174. # check that we can grab everything except members
  175. state = yield self.store.get_state_for_event(
  176. e5.event_id, state_filter=StateFilter(
  177. types={EventTypes.Member: set()},
  178. include_others=True,
  179. ),
  180. )
  181. self.assertStateMapEqual(
  182. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  183. )
  184. #######################################################
  185. # _get_state_for_group_using_cache tests against a full cache
  186. #######################################################
  187. room_id = self.room.to_string()
  188. group_ids = yield self.store.get_state_groups_ids(room_id, [e5.event_id])
  189. group = list(group_ids.keys())[0]
  190. # test _get_state_for_group_using_cache correctly filters out members
  191. # with types=[]
  192. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  193. self.store._state_group_cache, group,
  194. state_filter=StateFilter(
  195. types={EventTypes.Member: set()},
  196. include_others=True,
  197. ),
  198. )
  199. self.assertEqual(is_all, True)
  200. self.assertDictEqual(
  201. {
  202. (e1.type, e1.state_key): e1.event_id,
  203. (e2.type, e2.state_key): e2.event_id,
  204. },
  205. state_dict,
  206. )
  207. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  208. self.store._state_group_members_cache,
  209. group,
  210. state_filter=StateFilter(
  211. types={EventTypes.Member: set()},
  212. include_others=True,
  213. ),
  214. )
  215. self.assertEqual(is_all, True)
  216. self.assertDictEqual({}, state_dict)
  217. # test _get_state_for_group_using_cache correctly filters in members
  218. # with wildcard types
  219. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  220. self.store._state_group_cache,
  221. group,
  222. state_filter=StateFilter(
  223. types={EventTypes.Member: None},
  224. include_others=True,
  225. ),
  226. )
  227. self.assertEqual(is_all, True)
  228. self.assertDictEqual(
  229. {
  230. (e1.type, e1.state_key): e1.event_id,
  231. (e2.type, e2.state_key): e2.event_id,
  232. },
  233. state_dict,
  234. )
  235. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  236. self.store._state_group_members_cache,
  237. group,
  238. state_filter=StateFilter(
  239. types={EventTypes.Member: None},
  240. include_others=True,
  241. ),
  242. )
  243. self.assertEqual(is_all, True)
  244. self.assertDictEqual(
  245. {
  246. (e3.type, e3.state_key): e3.event_id,
  247. # e4 is overwritten by e5
  248. (e5.type, e5.state_key): e5.event_id,
  249. },
  250. state_dict,
  251. )
  252. # test _get_state_for_group_using_cache correctly filters in members
  253. # with specific types
  254. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  255. self.store._state_group_cache,
  256. group,
  257. state_filter=StateFilter(
  258. types={EventTypes.Member: {e5.state_key}},
  259. include_others=True,
  260. ),
  261. )
  262. self.assertEqual(is_all, True)
  263. self.assertDictEqual(
  264. {
  265. (e1.type, e1.state_key): e1.event_id,
  266. (e2.type, e2.state_key): e2.event_id,
  267. },
  268. state_dict,
  269. )
  270. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  271. self.store._state_group_members_cache,
  272. group,
  273. state_filter=StateFilter(
  274. types={EventTypes.Member: {e5.state_key}},
  275. include_others=True,
  276. ),
  277. )
  278. self.assertEqual(is_all, True)
  279. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  280. # test _get_state_for_group_using_cache correctly filters in members
  281. # with specific types
  282. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  283. self.store._state_group_members_cache,
  284. group,
  285. state_filter=StateFilter(
  286. types={EventTypes.Member: {e5.state_key}},
  287. include_others=False,
  288. ),
  289. )
  290. self.assertEqual(is_all, True)
  291. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  292. #######################################################
  293. # deliberately remove e2 (room name) from the _state_group_cache
  294. (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
  295. group
  296. )
  297. self.assertEqual(is_all, True)
  298. self.assertEqual(known_absent, set())
  299. self.assertDictEqual(
  300. state_dict_ids,
  301. {
  302. (e1.type, e1.state_key): e1.event_id,
  303. (e2.type, e2.state_key): e2.event_id,
  304. },
  305. )
  306. state_dict_ids.pop((e2.type, e2.state_key))
  307. self.store._state_group_cache.invalidate(group)
  308. self.store._state_group_cache.update(
  309. sequence=self.store._state_group_cache.sequence,
  310. key=group,
  311. value=state_dict_ids,
  312. # list fetched keys so it knows it's partial
  313. fetched_keys=((e1.type, e1.state_key),),
  314. )
  315. (is_all, known_absent, state_dict_ids) = self.store._state_group_cache.get(
  316. group
  317. )
  318. self.assertEqual(is_all, False)
  319. self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
  320. self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
  321. ############################################
  322. # test that things work with a partial cache
  323. # test _get_state_for_group_using_cache correctly filters out members
  324. # with types=[]
  325. room_id = self.room.to_string()
  326. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  327. self.store._state_group_cache, group,
  328. state_filter=StateFilter(
  329. types={EventTypes.Member: set()},
  330. include_others=True,
  331. ),
  332. )
  333. self.assertEqual(is_all, False)
  334. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  335. room_id = self.room.to_string()
  336. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  337. self.store._state_group_members_cache,
  338. group,
  339. state_filter=StateFilter(
  340. types={EventTypes.Member: set()},
  341. include_others=True,
  342. ),
  343. )
  344. self.assertEqual(is_all, True)
  345. self.assertDictEqual({}, state_dict)
  346. # test _get_state_for_group_using_cache correctly filters in members
  347. # wildcard types
  348. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  349. self.store._state_group_cache,
  350. group,
  351. state_filter=StateFilter(
  352. types={EventTypes.Member: None},
  353. include_others=True,
  354. ),
  355. )
  356. self.assertEqual(is_all, False)
  357. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  358. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  359. self.store._state_group_members_cache,
  360. group,
  361. state_filter=StateFilter(
  362. types={EventTypes.Member: None},
  363. include_others=True,
  364. ),
  365. )
  366. self.assertEqual(is_all, True)
  367. self.assertDictEqual(
  368. {
  369. (e3.type, e3.state_key): e3.event_id,
  370. (e5.type, e5.state_key): e5.event_id,
  371. },
  372. state_dict,
  373. )
  374. # test _get_state_for_group_using_cache correctly filters in members
  375. # with specific types
  376. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  377. self.store._state_group_cache,
  378. group,
  379. state_filter=StateFilter(
  380. types={EventTypes.Member: {e5.state_key}},
  381. include_others=True,
  382. ),
  383. )
  384. self.assertEqual(is_all, False)
  385. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  386. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  387. self.store._state_group_members_cache,
  388. group,
  389. state_filter=StateFilter(
  390. types={EventTypes.Member: {e5.state_key}},
  391. include_others=True,
  392. ),
  393. )
  394. self.assertEqual(is_all, True)
  395. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  396. # test _get_state_for_group_using_cache correctly filters in members
  397. # with specific types
  398. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  399. self.store._state_group_cache,
  400. group,
  401. state_filter=StateFilter(
  402. types={EventTypes.Member: {e5.state_key}},
  403. include_others=False,
  404. ),
  405. )
  406. self.assertEqual(is_all, False)
  407. self.assertDictEqual({}, state_dict)
  408. (state_dict, is_all) = yield self.store._get_state_for_group_using_cache(
  409. self.store._state_group_members_cache,
  410. group,
  411. state_filter=StateFilter(
  412. types={EventTypes.Member: {e5.state_key}},
  413. include_others=False,
  414. ),
  415. )
  416. self.assertEqual(is_all, True)
  417. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)