test_state.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from twisted.internet import defer
  17. from synapse.api.constants import EventTypes, Membership
  18. from synapse.api.room_versions import RoomVersions
  19. from synapse.storage.state import StateFilter
  20. from synapse.types import RoomID, UserID
  21. import tests.unittest
  22. import tests.utils
  23. logger = logging.getLogger(__name__)
  24. class StateStoreTestCase(tests.unittest.TestCase):
  25. @defer.inlineCallbacks
  26. def setUp(self):
  27. hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
  28. self.store = hs.get_datastore()
  29. self.storage = hs.get_storage()
  30. self.state_datastore = self.storage.state.stores.state
  31. self.event_builder_factory = hs.get_event_builder_factory()
  32. self.event_creation_handler = hs.get_event_creation_handler()
  33. self.u_alice = UserID.from_string("@alice:test")
  34. self.u_bob = UserID.from_string("@bob:test")
  35. self.room = RoomID.from_string("!abc123:test")
  36. yield self.store.store_room(
  37. self.room.to_string(),
  38. room_creator_user_id="@creator:text",
  39. is_public=True,
  40. room_version=RoomVersions.V1,
  41. )
  42. @defer.inlineCallbacks
  43. def inject_state_event(self, room, sender, typ, state_key, content):
  44. builder = self.event_builder_factory.for_room_version(
  45. RoomVersions.V1,
  46. {
  47. "type": typ,
  48. "sender": sender.to_string(),
  49. "state_key": state_key,
  50. "room_id": room.to_string(),
  51. "content": content,
  52. },
  53. )
  54. event, context = yield self.event_creation_handler.create_new_client_event(
  55. builder
  56. )
  57. yield self.storage.persistence.persist_event(event, context)
  58. return event
  59. def assertStateMapEqual(self, s1, s2):
  60. for t in s1:
  61. # just compare event IDs for simplicity
  62. self.assertEqual(s1[t].event_id, s2[t].event_id)
  63. self.assertEqual(len(s1), len(s2))
  64. @defer.inlineCallbacks
  65. def test_get_state_groups_ids(self):
  66. e1 = yield self.inject_state_event(
  67. self.room, self.u_alice, EventTypes.Create, "", {}
  68. )
  69. e2 = yield self.inject_state_event(
  70. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  71. )
  72. state_group_map = yield self.storage.state.get_state_groups_ids(
  73. self.room, [e2.event_id]
  74. )
  75. self.assertEqual(len(state_group_map), 1)
  76. state_map = list(state_group_map.values())[0]
  77. self.assertDictEqual(
  78. state_map,
  79. {(EventTypes.Create, ""): e1.event_id, (EventTypes.Name, ""): e2.event_id},
  80. )
  81. @defer.inlineCallbacks
  82. def test_get_state_groups(self):
  83. e1 = yield self.inject_state_event(
  84. self.room, self.u_alice, EventTypes.Create, "", {}
  85. )
  86. e2 = yield self.inject_state_event(
  87. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  88. )
  89. state_group_map = yield self.storage.state.get_state_groups(
  90. self.room, [e2.event_id]
  91. )
  92. self.assertEqual(len(state_group_map), 1)
  93. state_list = list(state_group_map.values())[0]
  94. self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
  95. @defer.inlineCallbacks
  96. def test_get_state_for_event(self):
  97. # this defaults to a linear DAG as each new injection defaults to whatever
  98. # forward extremities are currently in the DB for this room.
  99. e1 = yield self.inject_state_event(
  100. self.room, self.u_alice, EventTypes.Create, "", {}
  101. )
  102. e2 = yield self.inject_state_event(
  103. self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
  104. )
  105. e3 = yield self.inject_state_event(
  106. self.room,
  107. self.u_alice,
  108. EventTypes.Member,
  109. self.u_alice.to_string(),
  110. {"membership": Membership.JOIN},
  111. )
  112. e4 = yield self.inject_state_event(
  113. self.room,
  114. self.u_bob,
  115. EventTypes.Member,
  116. self.u_bob.to_string(),
  117. {"membership": Membership.JOIN},
  118. )
  119. e5 = yield self.inject_state_event(
  120. self.room,
  121. self.u_bob,
  122. EventTypes.Member,
  123. self.u_bob.to_string(),
  124. {"membership": Membership.LEAVE},
  125. )
  126. # check we get the full state as of the final event
  127. state = yield self.storage.state.get_state_for_event(e5.event_id)
  128. self.assertIsNotNone(e4)
  129. self.assertStateMapEqual(
  130. {
  131. (e1.type, e1.state_key): e1,
  132. (e2.type, e2.state_key): e2,
  133. (e3.type, e3.state_key): e3,
  134. # e4 is overwritten by e5
  135. (e5.type, e5.state_key): e5,
  136. },
  137. state,
  138. )
  139. # check we can filter to the m.room.name event (with a '' state key)
  140. state = yield self.storage.state.get_state_for_event(
  141. e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
  142. )
  143. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  144. # check we can filter to the m.room.name event (with a wildcard None state key)
  145. state = yield self.storage.state.get_state_for_event(
  146. e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
  147. )
  148. self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
  149. # check we can grab the m.room.member events (with a wildcard None state key)
  150. state = yield self.storage.state.get_state_for_event(
  151. e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
  152. )
  153. self.assertStateMapEqual(
  154. {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
  155. )
  156. # check we can grab a specific room member without filtering out the
  157. # other event types
  158. state = yield self.storage.state.get_state_for_event(
  159. e5.event_id,
  160. state_filter=StateFilter(
  161. types={EventTypes.Member: {self.u_alice.to_string()}},
  162. include_others=True,
  163. ),
  164. )
  165. self.assertStateMapEqual(
  166. {
  167. (e1.type, e1.state_key): e1,
  168. (e2.type, e2.state_key): e2,
  169. (e3.type, e3.state_key): e3,
  170. },
  171. state,
  172. )
  173. # check that we can grab everything except members
  174. state = yield self.storage.state.get_state_for_event(
  175. e5.event_id,
  176. state_filter=StateFilter(
  177. types={EventTypes.Member: set()}, include_others=True
  178. ),
  179. )
  180. self.assertStateMapEqual(
  181. {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
  182. )
  183. #######################################################
  184. # _get_state_for_group_using_cache tests against a full cache
  185. #######################################################
  186. room_id = self.room.to_string()
  187. group_ids = yield self.storage.state.get_state_groups_ids(
  188. room_id, [e5.event_id]
  189. )
  190. group = list(group_ids.keys())[0]
  191. # test _get_state_for_group_using_cache correctly filters out members
  192. # with types=[]
  193. (
  194. state_dict,
  195. is_all,
  196. ) = yield self.state_datastore._get_state_for_group_using_cache(
  197. self.state_datastore._state_group_cache,
  198. group,
  199. state_filter=StateFilter(
  200. types={EventTypes.Member: set()}, include_others=True
  201. ),
  202. )
  203. self.assertEqual(is_all, True)
  204. self.assertDictEqual(
  205. {
  206. (e1.type, e1.state_key): e1.event_id,
  207. (e2.type, e2.state_key): e2.event_id,
  208. },
  209. state_dict,
  210. )
  211. (
  212. state_dict,
  213. is_all,
  214. ) = yield self.state_datastore._get_state_for_group_using_cache(
  215. self.state_datastore._state_group_members_cache,
  216. group,
  217. state_filter=StateFilter(
  218. types={EventTypes.Member: set()}, include_others=True
  219. ),
  220. )
  221. self.assertEqual(is_all, True)
  222. self.assertDictEqual({}, state_dict)
  223. # test _get_state_for_group_using_cache correctly filters in members
  224. # with wildcard types
  225. (
  226. state_dict,
  227. is_all,
  228. ) = yield self.state_datastore._get_state_for_group_using_cache(
  229. self.state_datastore._state_group_cache,
  230. group,
  231. state_filter=StateFilter(
  232. types={EventTypes.Member: None}, include_others=True
  233. ),
  234. )
  235. self.assertEqual(is_all, True)
  236. self.assertDictEqual(
  237. {
  238. (e1.type, e1.state_key): e1.event_id,
  239. (e2.type, e2.state_key): e2.event_id,
  240. },
  241. state_dict,
  242. )
  243. (
  244. state_dict,
  245. is_all,
  246. ) = yield self.state_datastore._get_state_for_group_using_cache(
  247. self.state_datastore._state_group_members_cache,
  248. group,
  249. state_filter=StateFilter(
  250. types={EventTypes.Member: None}, include_others=True
  251. ),
  252. )
  253. self.assertEqual(is_all, True)
  254. self.assertDictEqual(
  255. {
  256. (e3.type, e3.state_key): e3.event_id,
  257. # e4 is overwritten by e5
  258. (e5.type, e5.state_key): e5.event_id,
  259. },
  260. state_dict,
  261. )
  262. # test _get_state_for_group_using_cache correctly filters in members
  263. # with specific types
  264. (
  265. state_dict,
  266. is_all,
  267. ) = yield self.state_datastore._get_state_for_group_using_cache(
  268. self.state_datastore._state_group_cache,
  269. group,
  270. state_filter=StateFilter(
  271. types={EventTypes.Member: {e5.state_key}}, include_others=True
  272. ),
  273. )
  274. self.assertEqual(is_all, True)
  275. self.assertDictEqual(
  276. {
  277. (e1.type, e1.state_key): e1.event_id,
  278. (e2.type, e2.state_key): e2.event_id,
  279. },
  280. state_dict,
  281. )
  282. (
  283. state_dict,
  284. is_all,
  285. ) = yield self.state_datastore._get_state_for_group_using_cache(
  286. self.state_datastore._state_group_members_cache,
  287. group,
  288. state_filter=StateFilter(
  289. types={EventTypes.Member: {e5.state_key}}, include_others=True
  290. ),
  291. )
  292. self.assertEqual(is_all, True)
  293. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  294. # test _get_state_for_group_using_cache correctly filters in members
  295. # with specific types
  296. (
  297. state_dict,
  298. is_all,
  299. ) = yield self.state_datastore._get_state_for_group_using_cache(
  300. self.state_datastore._state_group_members_cache,
  301. group,
  302. state_filter=StateFilter(
  303. types={EventTypes.Member: {e5.state_key}}, include_others=False
  304. ),
  305. )
  306. self.assertEqual(is_all, True)
  307. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  308. #######################################################
  309. # deliberately remove e2 (room name) from the _state_group_cache
  310. (
  311. is_all,
  312. known_absent,
  313. state_dict_ids,
  314. ) = self.state_datastore._state_group_cache.get(group)
  315. self.assertEqual(is_all, True)
  316. self.assertEqual(known_absent, set())
  317. self.assertDictEqual(
  318. state_dict_ids,
  319. {
  320. (e1.type, e1.state_key): e1.event_id,
  321. (e2.type, e2.state_key): e2.event_id,
  322. },
  323. )
  324. state_dict_ids.pop((e2.type, e2.state_key))
  325. self.state_datastore._state_group_cache.invalidate(group)
  326. self.state_datastore._state_group_cache.update(
  327. sequence=self.state_datastore._state_group_cache.sequence,
  328. key=group,
  329. value=state_dict_ids,
  330. # list fetched keys so it knows it's partial
  331. fetched_keys=((e1.type, e1.state_key),),
  332. )
  333. (
  334. is_all,
  335. known_absent,
  336. state_dict_ids,
  337. ) = self.state_datastore._state_group_cache.get(group)
  338. self.assertEqual(is_all, False)
  339. self.assertEqual(known_absent, {(e1.type, e1.state_key)})
  340. self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
  341. ############################################
  342. # test that things work with a partial cache
  343. # test _get_state_for_group_using_cache correctly filters out members
  344. # with types=[]
  345. room_id = self.room.to_string()
  346. (
  347. state_dict,
  348. is_all,
  349. ) = yield self.state_datastore._get_state_for_group_using_cache(
  350. self.state_datastore._state_group_cache,
  351. group,
  352. state_filter=StateFilter(
  353. types={EventTypes.Member: set()}, include_others=True
  354. ),
  355. )
  356. self.assertEqual(is_all, False)
  357. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  358. room_id = self.room.to_string()
  359. (
  360. state_dict,
  361. is_all,
  362. ) = yield self.state_datastore._get_state_for_group_using_cache(
  363. self.state_datastore._state_group_members_cache,
  364. group,
  365. state_filter=StateFilter(
  366. types={EventTypes.Member: set()}, include_others=True
  367. ),
  368. )
  369. self.assertEqual(is_all, True)
  370. self.assertDictEqual({}, state_dict)
  371. # test _get_state_for_group_using_cache correctly filters in members
  372. # wildcard types
  373. (
  374. state_dict,
  375. is_all,
  376. ) = yield self.state_datastore._get_state_for_group_using_cache(
  377. self.state_datastore._state_group_cache,
  378. group,
  379. state_filter=StateFilter(
  380. types={EventTypes.Member: None}, include_others=True
  381. ),
  382. )
  383. self.assertEqual(is_all, False)
  384. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  385. (
  386. state_dict,
  387. is_all,
  388. ) = yield self.state_datastore._get_state_for_group_using_cache(
  389. self.state_datastore._state_group_members_cache,
  390. group,
  391. state_filter=StateFilter(
  392. types={EventTypes.Member: None}, include_others=True
  393. ),
  394. )
  395. self.assertEqual(is_all, True)
  396. self.assertDictEqual(
  397. {
  398. (e3.type, e3.state_key): e3.event_id,
  399. (e5.type, e5.state_key): e5.event_id,
  400. },
  401. state_dict,
  402. )
  403. # test _get_state_for_group_using_cache correctly filters in members
  404. # with specific types
  405. (
  406. state_dict,
  407. is_all,
  408. ) = yield self.state_datastore._get_state_for_group_using_cache(
  409. self.state_datastore._state_group_cache,
  410. group,
  411. state_filter=StateFilter(
  412. types={EventTypes.Member: {e5.state_key}}, include_others=True
  413. ),
  414. )
  415. self.assertEqual(is_all, False)
  416. self.assertDictEqual({(e1.type, e1.state_key): e1.event_id}, state_dict)
  417. (
  418. state_dict,
  419. is_all,
  420. ) = yield self.state_datastore._get_state_for_group_using_cache(
  421. self.state_datastore._state_group_members_cache,
  422. group,
  423. state_filter=StateFilter(
  424. types={EventTypes.Member: {e5.state_key}}, include_others=True
  425. ),
  426. )
  427. self.assertEqual(is_all, True)
  428. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
  429. # test _get_state_for_group_using_cache correctly filters in members
  430. # with specific types
  431. (
  432. state_dict,
  433. is_all,
  434. ) = yield self.state_datastore._get_state_for_group_using_cache(
  435. self.state_datastore._state_group_cache,
  436. group,
  437. state_filter=StateFilter(
  438. types={EventTypes.Member: {e5.state_key}}, include_others=False
  439. ),
  440. )
  441. self.assertEqual(is_all, False)
  442. self.assertDictEqual({}, state_dict)
  443. (
  444. state_dict,
  445. is_all,
  446. ) = yield self.state_datastore._get_state_for_group_using_cache(
  447. self.state_datastore._state_group_members_cache,
  448. group,
  449. state_filter=StateFilter(
  450. types={EventTypes.Member: {e5.state_key}}, include_others=False
  451. ),
  452. )
  453. self.assertEqual(is_all, True)
  454. self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)