test_state.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from twisted.internet import defer
  17. from synapse.api.constants import EventTypes, Membership
  18. from synapse.api.room_versions import RoomVersions
  19. from synapse.storage.state import StateFilter
  20. from synapse.types import RoomID, UserID
  21. import tests.unittest
  22. import tests.utils
  23. logger = logging.getLogger(__name__)
  24. class StateStoreTestCase(tests.unittest.TestCase):
  25. @defer.inlineCallbacks
  26. def setUp(self):
  27. hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
  28. self.store = hs.get_datastore()
  29. self.storage = hs.get_storage()
  30. self.state_datastore = self.store
  31. self.event_builder_factory = hs.get_event_builder_factory()
  32. self.event_creation_handler = hs.get_event_creation_handler()
  33. self.u_alice = UserID.from_string("@alice:test")
  34. self.u_bob = UserID.from_string("@bob:test")
  35. self.room = RoomID.from_string("!abc123:test")
  36. yield self.store.store_room(
  37. self.room.to_string(), room_creator_user_id="@creator:text", is_public=True
  38. )
  39. @defer.inlineCallbacks
  40. def inject_state_event(self, room, sender, typ, state_key, content):
  41. builder = self.event_builder_factory.for_room_version(
  42. RoomVersions.V1,
  43. {
  44. "type": typ,
  45. "sender": sender.to_string(),
  46. "state_key": state_key,
  47. "room_id": room.to_string(),
  48. "content": content,
  49. },
  50. )
  51. event, context = yield self.event_creation_handler.create_new_client_event(
  52. builder
  53. )
  54. yield self.storage.persistence.persist_event(event, context)
  55. return event
  56. def assertStateMapEqual(self, s1, s2):
  57. for t in s1:
  58. # just compare event IDs for simplicity
  59. self.assertEqual(s1[t].event_id, s2[t].event_id)
  60. self.assertEqual(len(s1), len(s2))
  61. @defer.inlineCallbacks
  62. def test_get_state_groups_ids(self):
  63. e1 = yield self.inject_state_event(
  64. self.room, self.u_alice, EventTypes.Create, "", {}
  65. )
  66. e2 = yield self.inject_state_event(
  67. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  68. )
  69. state_group_map = yield self.storage.state.get_state_groups_ids(
  70. self.room, [e2.event_id]
  71. )
  72. self.assertEqual(len(state_group_map), 1)
  73. state_map = list(state_group_map.values())[0]
  74. self.assertDictEqual(
  75. state_map,
  76. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  77. )
  78. @defer.inlineCallbacks
  79. def test_get_state_groups(self):
  80. e1 = yield self.inject_state_event(
  81. self.room, self.u_alice, EventTypes.Create, "", {}
  82. )
  83. e2 = yield self.inject_state_event(
  84. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  85. )
  86. state_group_map = yield self.storage.state.get_state_groups(
  87. self.room, [e2.event_id]
  88. )
  89. self.assertEqual(len(state_group_map), 1)
  90. state_list = list(state_group_map.values())[0]
  91. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  92. @defer.inlineCallbacks
  93. def test_get_state_for_event(self):
  94. # this defaults to a linear DAG as each new injection defaults to whatever
  95. # forward extremities are currently in the DB for this room.
  96. e1 = yield self.inject_state_event(
  97. self.room, self.u_alice, EventTypes.Create, "", {}
  98. )
  99. e2 = yield self.inject_state_event(
  100. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  101. )
  102. e3 = yield self.inject_state_event(
  103. self.room,
  104. self.u_alice,
  105. EventTypes.Member,
  106. self.u_alice.to_string(),
  107. {"membership": Membership.JOIN},
  108. )
  109. e4 = yield self.inject_state_event(
  110. self.room,
  111. self.u_bob,
  112. EventTypes.Member,
  113. self.u_bob.to_string(),
  114. {"membership": Membership.JOIN},
  115. )
  116. e5 = yield self.inject_state_event(
  117. self.room,
  118. self.u_bob,
  119. EventTypes.Member,
  120. self.u_bob.to_string(),
  121. {"membership": Membership.LEAVE},
  122. )
  123. # check we get the full state as of the final event
  124. state = yield self.storage.state.get_state_for_event(e5.event_id)
  125. self.assertIsNotNone(e4)
  126. self.assertStateMapEqual(
  127. {
  128. (e1.type, e1.state_key): e1,
  129. (e2.type, e2.state_key): e2,
  130. (e3.type, e3.state_key): e3,
  131. # e4 is overwritten by e5
  132. (e5.type, e5.state_key): e5,
  133. },
  134. state,
  135. )
  136. # check we can filter to the m.room.name event (with a '' state key)
  137. state = yield self.storage.state.get_state_for_event(
  138. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  139. )
  140. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  141. # check we can filter to the m.room.name event (with a wildcard None state key)
  142. state = yield self.storage.state.get_state_for_event(
  143. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  144. )
  145. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  146. # check we can grab the m.room.member events (with a wildcard None state key)
  147. state = yield self.storage.state.get_state_for_event(
  148. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  149. )
  150. self.assertStateMapEqual(
  151. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  152. )
  153. # check we can grab a specific room member without filtering out the
  154. # other event types
  155. state = yield self.storage.state.get_state_for_event(
  156. e5.event_id,
  157. state_filter=StateFilter(
  158. types={EventTypes.Member: {self.u_alice.to_string()}},
  159. include_others=True,
  160. ),
  161. )
  162. self.assertStateMapEqual(
  163. {
  164. (e1.type, e1.state_key): e1,
  165. (e2.type, e2.state_key): e2,
  166. (e3.type, e3.state_key): e3,
  167. },
  168. state,
  169. )
  170. # check that we can grab everything except members
  171. state = yield self.storage.state.get_state_for_event(
  172. e5.event_id,
  173. state_filter=StateFilter(
  174. types={EventTypes.Member: set()}, include_others=True
  175. ),
  176. )
  177. self.assertStateMapEqual(
  178. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  179. )
  180. #######################################################
  181. # _get_state_for_group_using_cache tests against a full cache
  182. #######################################################
  183. room_id = self.room.to_string()
  184. group_ids = yield self.storage.state.get_state_groups_ids(
  185. room_id, [e5.event_id]
  186. )
  187. group = list(group_ids.keys())[0]
  188. # test _get_state_for_group_using_cache correctly filters out members
  189. # with types=[]
  190. (
  191. state_dict,
  192. is_all,
  193. ) = yield self.state_datastore._get_state_for_group_using_cache(
  194. self.state_datastore._state_group_cache,
  195. group,
  196. state_filter=StateFilter(
  197. types={EventTypes.Member: set()}, include_others=True
  198. ),
  199. )
  200. self.assertEqual(is_all, True)
  201. self.assertDictEqual(
  202. {
  203. (e1.type, e1.state_key): e1.event_id,
  204. (e2.type, e2.state_key): e2.event_id,
  205. },
  206. state_dict,
  207. )
  208. (
  209. state_dict,
  210. is_all,
  211. ) = yield self.state_datastore._get_state_for_group_using_cache(
  212. self.state_datastore._state_group_members_cache,
  213. group,
  214. state_filter=StateFilter(
  215. types={EventTypes.Member: set()}, include_others=True
  216. ),
  217. )
  218. self.assertEqual(is_all, True)
  219. self.assertDictEqual({}, state_dict)
  220. # test _get_state_for_group_using_cache correctly filters in members
  221. # with wildcard types
  222. (
  223. state_dict,
  224. is_all,
  225. ) = yield self.state_datastore._get_state_for_group_using_cache(
  226. self.state_datastore._state_group_cache,
  227. group,
  228. state_filter=StateFilter(
  229. types={EventTypes.Member: None}, include_others=True
  230. ),
  231. )
  232. self.assertEqual(is_all, True)
  233. self.assertDictEqual(
  234. {
  235. (e1.type, e1.state_key): e1.event_id,
  236. (e2.type, e2.state_key): e2.event_id,
  237. },
  238. state_dict,
  239. )
  240. (
  241. state_dict,
  242. is_all,
  243. ) = yield self.state_datastore._get_state_for_group_using_cache(
  244. self.state_datastore._state_group_members_cache,
  245. group,
  246. state_filter=StateFilter(
  247. types={EventTypes.Member: None}, include_others=True
  248. ),
  249. )
  250. self.assertEqual(is_all, True)
  251. self.assertDictEqual(
  252. {
  253. (e3.type, e3.state_key): e3.event_id,
  254. # e4 is overwritten by e5
  255. (e5.type, e5.state_key): e5.event_id,
  256. },
  257. state_dict,
  258. )
  259. # test _get_state_for_group_using_cache correctly filters in members
  260. # with specific types
  261. (
  262. state_dict,
  263. is_all,
  264. ) = yield self.state_datastore._get_state_for_group_using_cache(
  265. self.state_datastore._state_group_cache,
  266. group,
  267. state_filter=StateFilter(
  268. types={EventTypes.Member: {e5.state_key}}, include_others=True
  269. ),
  270. )
  271. self.assertEqual(is_all, True)
  272. self.assertDictEqual(
  273. {
  274. (e1.type, e1.state_key): e1.event_id,
  275. (e2.type, e2.state_key): e2.event_id,
  276. },
  277. state_dict,
  278. )
  279. (
  280. state_dict,
  281. is_all,
  282. ) = yield self.state_datastore._get_state_for_group_using_cache(
  283. self.state_datastore._state_group_members_cache,
  284. group,
  285. state_filter=StateFilter(
  286. types={EventTypes.Member: {e5.state_key}}, include_others=True
  287. ),
  288. )
  289. self.assertEqual(is_all, True)
  290. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  291. # test _get_state_for_group_using_cache correctly filters in members
  292. # with specific types
  293. (
  294. state_dict,
  295. is_all,
  296. ) = yield self.state_datastore._get_state_for_group_using_cache(
  297. self.state_datastore._state_group_members_cache,
  298. group,
  299. state_filter=StateFilter(
  300. types={EventTypes.Member: {e5.state_key}}, include_others=False
  301. ),
  302. )
  303. self.assertEqual(is_all, True)
  304. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  305. #######################################################
  306. # deliberately remove e2 (room name) from the _state_group_cache
  307. (
  308. is_all,
  309. known_absent,
  310. state_dict_ids,
  311. ) = self.state_datastore._state_group_cache.get(group)
  312. self.assertEqual(is_all, True)
  313. self.assertEqual(known_absent, set())
  314. self.assertDictEqual(
  315. state_dict_ids,
  316. {
  317. (e1.type, e1.state_key): e1.event_id,
  318. (e2.type, e2.state_key): e2.event_id,
  319. },
  320. )
  321. state_dict_ids.pop((e2.type, e2.state_key))
  322. self.state_datastore._state_group_cache.invalidate(group)
  323. self.state_datastore._state_group_cache.update(
  324. sequence=self.state_datastore._state_group_cache.sequence,
  325. key=group,
  326. value=state_dict_ids,
  327. # list fetched keys so it knows it's partial
  328. fetched_keys=((e1.type, e1.state_key),),
  329. )
  330. (
  331. is_all,
  332. known_absent,
  333. state_dict_ids,
  334. ) = self.state_datastore._state_group_cache.get(group)
  335. self.assertEqual(is_all, False)
  336. self.assertEqual(known_absent, set([(e1.type, e1.state_key)]))
  337. self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
  338. ############################################
  339. # test that things work with a partial cache
  340. # test _get_state_for_group_using_cache correctly filters out members
  341. # with types=[]
  342. room_id = self.room.to_string()
  343. (
  344. state_dict,
  345. is_all,
  346. ) = yield self.state_datastore._get_state_for_group_using_cache(
  347. self.state_datastore._state_group_cache,
  348. group,
  349. state_filter=StateFilter(
  350. types={EventTypes.Member: set()}, include_others=True
  351. ),
  352. )
  353. self.assertEqual(is_all, False)
  354. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  355. room_id = self.room.to_string()
  356. (
  357. state_dict,
  358. is_all,
  359. ) = yield self.state_datastore._get_state_for_group_using_cache(
  360. self.state_datastore._state_group_members_cache,
  361. group,
  362. state_filter=StateFilter(
  363. types={EventTypes.Member: set()}, include_others=True
  364. ),
  365. )
  366. self.assertEqual(is_all, True)
  367. self.assertDictEqual({}, state_dict)
  368. # test _get_state_for_group_using_cache correctly filters in members
  369. # wildcard types
  370. (
  371. state_dict,
  372. is_all,
  373. ) = yield self.state_datastore._get_state_for_group_using_cache(
  374. self.state_datastore._state_group_cache,
  375. group,
  376. state_filter=StateFilter(
  377. types={EventTypes.Member: None}, include_others=True
  378. ),
  379. )
  380. self.assertEqual(is_all, False)
  381. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  382. (
  383. state_dict,
  384. is_all,
  385. ) = yield self.state_datastore._get_state_for_group_using_cache(
  386. self.state_datastore._state_group_members_cache,
  387. group,
  388. state_filter=StateFilter(
  389. types={EventTypes.Member: None}, include_others=True
  390. ),
  391. )
  392. self.assertEqual(is_all, True)
  393. self.assertDictEqual(
  394. {
  395. (e3.type, e3.state_key): e3.event_id,
  396. (e5.type, e5.state_key): e5.event_id,
  397. },
  398. state_dict,
  399. )
  400. # test _get_state_for_group_using_cache correctly filters in members
  401. # with specific types
  402. (
  403. state_dict,
  404. is_all,
  405. ) = yield self.state_datastore._get_state_for_group_using_cache(
  406. self.state_datastore._state_group_cache,
  407. group,
  408. state_filter=StateFilter(
  409. types={EventTypes.Member: {e5.state_key}}, include_others=True
  410. ),
  411. )
  412. self.assertEqual(is_all, False)
  413. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  414. (
  415. state_dict,
  416. is_all,
  417. ) = yield self.state_datastore._get_state_for_group_using_cache(
  418. self.state_datastore._state_group_members_cache,
  419. group,
  420. state_filter=StateFilter(
  421. types={EventTypes.Member: {e5.state_key}}, include_others=True
  422. ),
  423. )
  424. self.assertEqual(is_all, True)
  425. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  426. # test _get_state_for_group_using_cache correctly filters in members
  427. # with specific types
  428. (
  429. state_dict,
  430. is_all,
  431. ) = yield self.state_datastore._get_state_for_group_using_cache(
  432. self.state_datastore._state_group_cache,
  433. group,
  434. state_filter=StateFilter(
  435. types={EventTypes.Member: {e5.state_key}}, include_others=False
  436. ),
  437. )
  438. self.assertEqual(is_all, False)
  439. self.assertDictEqual({}, state_dict)
  440. (
  441. state_dict,
  442. is_all,
  443. ) = yield self.state_datastore._get_state_for_group_using_cache(
  444. self.state_datastore._state_group_members_cache,
  445. group,
  446. state_filter=StateFilter(
  447. types={EventTypes.Member: {e5.state_key}}, include_others=False
  448. ),
  449. )
  450. self.assertEqual(is_all, True)
  451. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)