test_state.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from twisted.internet import defer
  17. from synapse.api.constants import EventTypes, Membership
  18. from synapse.api.room_versions import RoomVersions
  19. from synapse.storage.state import StateFilter
  20. from synapse.types import RoomID, UserID
  21. import tests.unittest
  22. import tests.utils
  23. logger = logging.getLogger(__name__)
  24. class StateStoreTestCase(tests.unittest.TestCase):
  25. @defer.inlineCallbacks
  26. def setUp(self):
  27. hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
  28. self.store = hs.get_datastore()
  29. self.storage = hs.get_storage()
  30. self.state_datastore = self.storage.state.stores.state
  31. self.event_builder_factory = hs.get_event_builder_factory()
  32. self.event_creation_handler = hs.get_event_creation_handler()
  33. self.u_alice = UserID.from_string("@alice:test")
  34. self.u_bob = UserID.from_string("@bob:test")
  35. self.room = RoomID.from_string("!abc123:test")
  36. yield defer.ensureDeferred(
  37. self.store.store_room(
  38. self.room.to_string(),
  39. room_creator_user_id="@creator:text",
  40. is_public=True,
  41. room_version=RoomVersions.V1,
  42. )
  43. )
  44. @defer.inlineCallbacks
  45. def inject_state_event(self, room, sender, typ, state_key, content):
  46. builder = self.event_builder_factory.for_room_version(
  47. RoomVersions.V1,
  48. {
  49. "type": typ,
  50. "sender": sender.to_string(),
  51. "state_key": state_key,
  52. "room_id": room.to_string(),
  53. "content": content,
  54. },
  55. )
  56. event, context = yield defer.ensureDeferred(
  57. self.event_creation_handler.create_new_client_event(builder)
  58. )
  59. yield defer.ensureDeferred(
  60. self.storage.persistence.persist_event(event, context)
  61. )
  62. return event
  63. def assertStateMapEqual(self, s1, s2):
  64. for t in s1:
  65. # just compare event IDs for simplicity
  66. self.assertEqual(s1[t].event_id, s2[t].event_id)
  67. self.assertEqual(len(s1), len(s2))
  68. @defer.inlineCallbacks
  69. def test_get_state_groups_ids(self):
  70. e1 = yield self.inject_state_event(
  71. self.room, self.u_alice, EventTypes.Create, "", {}
  72. )
  73. e2 = yield self.inject_state_event(
  74. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  75. )
  76. state_group_map = yield defer.ensureDeferred(
  77. self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
  78. )
  79. self.assertEqual(len(state_group_map), 1)
  80. state_map = list(state_group_map.values())[0]
  81. self.assertDictEqual(
  82. state_map,
  83. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  84. )
  85. @defer.inlineCallbacks
  86. def test_get_state_groups(self):
  87. e1 = yield self.inject_state_event(
  88. self.room, self.u_alice, EventTypes.Create, "", {}
  89. )
  90. e2 = yield self.inject_state_event(
  91. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  92. )
  93. state_group_map = yield defer.ensureDeferred(
  94. self.storage.state.get_state_groups(self.room, [e2.event_id])
  95. )
  96. self.assertEqual(len(state_group_map), 1)
  97. state_list = list(state_group_map.values())[0]
  98. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  99. @defer.inlineCallbacks
  100. def test_get_state_for_event(self):
  101. # this defaults to a linear DAG as each new injection defaults to whatever
  102. # forward extremities are currently in the DB for this room.
  103. e1 = yield self.inject_state_event(
  104. self.room, self.u_alice, EventTypes.Create, "", {}
  105. )
  106. e2 = yield self.inject_state_event(
  107. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  108. )
  109. e3 = yield self.inject_state_event(
  110. self.room,
  111. self.u_alice,
  112. EventTypes.Member,
  113. self.u_alice.to_string(),
  114. {"membership": Membership.JOIN},
  115. )
  116. e4 = yield self.inject_state_event(
  117. self.room,
  118. self.u_bob,
  119. EventTypes.Member,
  120. self.u_bob.to_string(),
  121. {"membership": Membership.JOIN},
  122. )
  123. e5 = yield self.inject_state_event(
  124. self.room,
  125. self.u_bob,
  126. EventTypes.Member,
  127. self.u_bob.to_string(),
  128. {"membership": Membership.LEAVE},
  129. )
  130. # check we get the full state as of the final event
  131. state = yield defer.ensureDeferred(
  132. self.storage.state.get_state_for_event(e5.event_id)
  133. )
  134. self.assertIsNotNone(e4)
  135. self.assertStateMapEqual(
  136. {
  137. (e1.type, e1.state_key): e1,
  138. (e2.type, e2.state_key): e2,
  139. (e3.type, e3.state_key): e3,
  140. # e4 is overwritten by e5
  141. (e5.type, e5.state_key): e5,
  142. },
  143. state,
  144. )
  145. # check we can filter to the m.room.name event (with a '' state key)
  146. state = yield defer.ensureDeferred(
  147. self.storage.state.get_state_for_event(
  148. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  149. )
  150. )
  151. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  152. # check we can filter to the m.room.name event (with a wildcard None state key)
  153. state = yield defer.ensureDeferred(
  154. self.storage.state.get_state_for_event(
  155. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  156. )
  157. )
  158. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  159. # check we can grab the m.room.member events (with a wildcard None state key)
  160. state = yield defer.ensureDeferred(
  161. self.storage.state.get_state_for_event(
  162. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  163. )
  164. )
  165. self.assertStateMapEqual(
  166. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  167. )
  168. # check we can grab a specific room member without filtering out the
  169. # other event types
  170. state = yield defer.ensureDeferred(
  171. self.storage.state.get_state_for_event(
  172. e5.event_id,
  173. state_filter=StateFilter(
  174. types={EventTypes.Member: {self.u_alice.to_string()}},
  175. include_others=True,
  176. ),
  177. )
  178. )
  179. self.assertStateMapEqual(
  180. {
  181. (e1.type, e1.state_key): e1,
  182. (e2.type, e2.state_key): e2,
  183. (e3.type, e3.state_key): e3,
  184. },
  185. state,
  186. )
  187. # check that we can grab everything except members
  188. state = yield defer.ensureDeferred(
  189. self.storage.state.get_state_for_event(
  190. e5.event_id,
  191. state_filter=StateFilter(
  192. types={EventTypes.Member: set()}, include_others=True
  193. ),
  194. )
  195. )
  196. self.assertStateMapEqual(
  197. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  198. )
  199. #######################################################
  200. # _get_state_for_group_using_cache tests against a full cache
  201. #######################################################
  202. room_id = self.room.to_string()
  203. group_ids = yield defer.ensureDeferred(
  204. self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
  205. )
  206. group = list(group_ids.keys())[0]
  207. # test _get_state_for_group_using_cache correctly filters out members
  208. # with types=[]
  209. (
  210. state_dict,
  211. is_all,
  212. ) = yield self.state_datastore._get_state_for_group_using_cache(
  213. self.state_datastore._state_group_cache,
  214. group,
  215. state_filter=StateFilter(
  216. types={EventTypes.Member: set()}, include_others=True
  217. ),
  218. )
  219. self.assertEqual(is_all, True)
  220. self.assertDictEqual(
  221. {
  222. (e1.type, e1.state_key): e1.event_id,
  223. (e2.type, e2.state_key): e2.event_id,
  224. },
  225. state_dict,
  226. )
  227. (
  228. state_dict,
  229. is_all,
  230. ) = yield self.state_datastore._get_state_for_group_using_cache(
  231. self.state_datastore._state_group_members_cache,
  232. group,
  233. state_filter=StateFilter(
  234. types={EventTypes.Member: set()}, include_others=True
  235. ),
  236. )
  237. self.assertEqual(is_all, True)
  238. self.assertDictEqual({}, state_dict)
  239. # test _get_state_for_group_using_cache correctly filters in members
  240. # with wildcard types
  241. (
  242. state_dict,
  243. is_all,
  244. ) = yield self.state_datastore._get_state_for_group_using_cache(
  245. self.state_datastore._state_group_cache,
  246. group,
  247. state_filter=StateFilter(
  248. types={EventTypes.Member: None}, include_others=True
  249. ),
  250. )
  251. self.assertEqual(is_all, True)
  252. self.assertDictEqual(
  253. {
  254. (e1.type, e1.state_key): e1.event_id,
  255. (e2.type, e2.state_key): e2.event_id,
  256. },
  257. state_dict,
  258. )
  259. (
  260. state_dict,
  261. is_all,
  262. ) = yield self.state_datastore._get_state_for_group_using_cache(
  263. self.state_datastore._state_group_members_cache,
  264. group,
  265. state_filter=StateFilter(
  266. types={EventTypes.Member: None}, include_others=True
  267. ),
  268. )
  269. self.assertEqual(is_all, True)
  270. self.assertDictEqual(
  271. {
  272. (e3.type, e3.state_key): e3.event_id,
  273. # e4 is overwritten by e5
  274. (e5.type, e5.state_key): e5.event_id,
  275. },
  276. state_dict,
  277. )
  278. # test _get_state_for_group_using_cache correctly filters in members
  279. # with specific types
  280. (
  281. state_dict,
  282. is_all,
  283. ) = yield self.state_datastore._get_state_for_group_using_cache(
  284. self.state_datastore._state_group_cache,
  285. group,
  286. state_filter=StateFilter(
  287. types={EventTypes.Member: {e5.state_key}}, include_others=True
  288. ),
  289. )
  290. self.assertEqual(is_all, True)
  291. self.assertDictEqual(
  292. {
  293. (e1.type, e1.state_key): e1.event_id,
  294. (e2.type, e2.state_key): e2.event_id,
  295. },
  296. state_dict,
  297. )
  298. (
  299. state_dict,
  300. is_all,
  301. ) = yield self.state_datastore._get_state_for_group_using_cache(
  302. self.state_datastore._state_group_members_cache,
  303. group,
  304. state_filter=StateFilter(
  305. types={EventTypes.Member: {e5.state_key}}, include_others=True
  306. ),
  307. )
  308. self.assertEqual(is_all, True)
  309. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  310. # test _get_state_for_group_using_cache correctly filters in members
  311. # with specific types
  312. (
  313. state_dict,
  314. is_all,
  315. ) = yield self.state_datastore._get_state_for_group_using_cache(
  316. self.state_datastore._state_group_members_cache,
  317. group,
  318. state_filter=StateFilter(
  319. types={EventTypes.Member: {e5.state_key}}, include_others=False
  320. ),
  321. )
  322. self.assertEqual(is_all, True)
  323. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  324. #######################################################
  325. # deliberately remove e2 (room name) from the _state_group_cache
  326. (
  327. is_all,
  328. known_absent,
  329. state_dict_ids,
  330. ) = self.state_datastore._state_group_cache.get(group)
  331. self.assertEqual(is_all, True)
  332. self.assertEqual(known_absent, set())
  333. self.assertDictEqual(
  334. state_dict_ids,
  335. {
  336. (e1.type, e1.state_key): e1.event_id,
  337. (e2.type, e2.state_key): e2.event_id,
  338. },
  339. )
  340. state_dict_ids.pop((e2.type, e2.state_key))
  341. self.state_datastore._state_group_cache.invalidate(group)
  342. self.state_datastore._state_group_cache.update(
  343. sequence=self.state_datastore._state_group_cache.sequence,
  344. key=group,
  345. value=state_dict_ids,
  346. # list fetched keys so it knows it's partial
  347. fetched_keys=((e1.type, e1.state_key),),
  348. )
  349. (
  350. is_all,
  351. known_absent,
  352. state_dict_ids,
  353. ) = self.state_datastore._state_group_cache.get(group)
  354. self.assertEqual(is_all, False)
  355. self.assertEqual(known_absent, {(e1.type, e1.state_key)})
  356. self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
  357. ############################################
  358. # test that things work with a partial cache
  359. # test _get_state_for_group_using_cache correctly filters out members
  360. # with types=[]
  361. room_id = self.room.to_string()
  362. (
  363. state_dict,
  364. is_all,
  365. ) = yield self.state_datastore._get_state_for_group_using_cache(
  366. self.state_datastore._state_group_cache,
  367. group,
  368. state_filter=StateFilter(
  369. types={EventTypes.Member: set()}, include_others=True
  370. ),
  371. )
  372. self.assertEqual(is_all, False)
  373. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  374. room_id = self.room.to_string()
  375. (
  376. state_dict,
  377. is_all,
  378. ) = yield self.state_datastore._get_state_for_group_using_cache(
  379. self.state_datastore._state_group_members_cache,
  380. group,
  381. state_filter=StateFilter(
  382. types={EventTypes.Member: set()}, include_others=True
  383. ),
  384. )
  385. self.assertEqual(is_all, True)
  386. self.assertDictEqual({}, state_dict)
  387. # test _get_state_for_group_using_cache correctly filters in members
  388. # wildcard types
  389. (
  390. state_dict,
  391. is_all,
  392. ) = yield self.state_datastore._get_state_for_group_using_cache(
  393. self.state_datastore._state_group_cache,
  394. group,
  395. state_filter=StateFilter(
  396. types={EventTypes.Member: None}, include_others=True
  397. ),
  398. )
  399. self.assertEqual(is_all, False)
  400. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  401. (
  402. state_dict,
  403. is_all,
  404. ) = yield self.state_datastore._get_state_for_group_using_cache(
  405. self.state_datastore._state_group_members_cache,
  406. group,
  407. state_filter=StateFilter(
  408. types={EventTypes.Member: None}, include_others=True
  409. ),
  410. )
  411. self.assertEqual(is_all, True)
  412. self.assertDictEqual(
  413. {
  414. (e3.type, e3.state_key): e3.event_id,
  415. (e5.type, e5.state_key): e5.event_id,
  416. },
  417. state_dict,
  418. )
  419. # test _get_state_for_group_using_cache correctly filters in members
  420. # with specific types
  421. (
  422. state_dict,
  423. is_all,
  424. ) = yield self.state_datastore._get_state_for_group_using_cache(
  425. self.state_datastore._state_group_cache,
  426. group,
  427. state_filter=StateFilter(
  428. types={EventTypes.Member: {e5.state_key}}, include_others=True
  429. ),
  430. )
  431. self.assertEqual(is_all, False)
  432. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  433. (
  434. state_dict,
  435. is_all,
  436. ) = yield self.state_datastore._get_state_for_group_using_cache(
  437. self.state_datastore._state_group_members_cache,
  438. group,
  439. state_filter=StateFilter(
  440. types={EventTypes.Member: {e5.state_key}}, include_others=True
  441. ),
  442. )
  443. self.assertEqual(is_all, True)
  444. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  445. # test _get_state_for_group_using_cache correctly filters in members
  446. # with specific types
  447. (
  448. state_dict,
  449. is_all,
  450. ) = yield self.state_datastore._get_state_for_group_using_cache(
  451. self.state_datastore._state_group_cache,
  452. group,
  453. state_filter=StateFilter(
  454. types={EventTypes.Member: {e5.state_key}}, include_others=False
  455. ),
  456. )
  457. self.assertEqual(is_all, False)
  458. self.assertDictEqual({}, state_dict)
  459. (
  460. state_dict,
  461. is_all,
  462. ) = yield self.state_datastore._get_state_for_group_using_cache(
  463. self.state_datastore._state_group_members_cache,
  464. group,
  465. state_filter=StateFilter(
  466. types={EventTypes.Member: {e5.state_key}}, include_others=False
  467. ),
  468. )
  469. self.assertEqual(is_all, True)
  470. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)