test_event_chain.py 25 KB


  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the 'License');
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an 'AS IS' BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List, Set, Tuple
  15. from twisted.trial import unittest
  16. from synapse.api.constants import EventTypes
  17. from synapse.api.room_versions import RoomVersions
  18. from synapse.events import EventBase
  19. from synapse.events.snapshot import EventContext
  20. from synapse.rest import admin
  21. from synapse.rest.client import login, room
  22. from synapse.storage.databases.main.events import _LinkMap
  23. from synapse.types import create_requester
  24. from tests.unittest import HomeserverTestCase
  25. class EventChainStoreTestCase(HomeserverTestCase):
  26. def prepare(self, reactor, clock, hs):
  27. self.store = hs.get_datastores().main
  28. self._next_stream_ordering = 1
  29. def test_simple(self):
  30. """Test that the example in `docs/auth_chain_difference_algorithm.md`
  31. works.
  32. """
  33. event_factory = self.hs.get_event_builder_factory()
  34. bob = "@creator:test"
  35. alice = "@alice:test"
  36. room_id = "!room:test"
  37. # Ensure that we have a rooms entry so that we generate the chain index.
  38. self.get_success(
  39. self.store.store_room(
  40. room_id=room_id,
  41. room_creator_user_id="",
  42. is_public=True,
  43. room_version=RoomVersions.V6,
  44. )
  45. )
  46. create = self.get_success(
  47. event_factory.for_room_version(
  48. RoomVersions.V6,
  49. {
  50. "type": EventTypes.Create,
  51. "state_key": "",
  52. "sender": bob,
  53. "room_id": room_id,
  54. "content": {"tag": "create"},
  55. },
  56. ).build(prev_event_ids=[], auth_event_ids=[])
  57. )
  58. bob_join = self.get_success(
  59. event_factory.for_room_version(
  60. RoomVersions.V6,
  61. {
  62. "type": EventTypes.Member,
  63. "state_key": bob,
  64. "sender": bob,
  65. "room_id": room_id,
  66. "content": {"tag": "bob_join"},
  67. },
  68. ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
  69. )
  70. power = self.get_success(
  71. event_factory.for_room_version(
  72. RoomVersions.V6,
  73. {
  74. "type": EventTypes.PowerLevels,
  75. "state_key": "",
  76. "sender": bob,
  77. "room_id": room_id,
  78. "content": {"tag": "power"},
  79. },
  80. ).build(
  81. prev_event_ids=[],
  82. auth_event_ids=[create.event_id, bob_join.event_id],
  83. )
  84. )
  85. alice_invite = self.get_success(
  86. event_factory.for_room_version(
  87. RoomVersions.V6,
  88. {
  89. "type": EventTypes.Member,
  90. "state_key": alice,
  91. "sender": bob,
  92. "room_id": room_id,
  93. "content": {"tag": "alice_invite"},
  94. },
  95. ).build(
  96. prev_event_ids=[],
  97. auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
  98. )
  99. )
  100. alice_join = self.get_success(
  101. event_factory.for_room_version(
  102. RoomVersions.V6,
  103. {
  104. "type": EventTypes.Member,
  105. "state_key": alice,
  106. "sender": alice,
  107. "room_id": room_id,
  108. "content": {"tag": "alice_join"},
  109. },
  110. ).build(
  111. prev_event_ids=[],
  112. auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
  113. )
  114. )
  115. power_2 = self.get_success(
  116. event_factory.for_room_version(
  117. RoomVersions.V6,
  118. {
  119. "type": EventTypes.PowerLevels,
  120. "state_key": "",
  121. "sender": bob,
  122. "room_id": room_id,
  123. "content": {"tag": "power_2"},
  124. },
  125. ).build(
  126. prev_event_ids=[],
  127. auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
  128. )
  129. )
  130. bob_join_2 = self.get_success(
  131. event_factory.for_room_version(
  132. RoomVersions.V6,
  133. {
  134. "type": EventTypes.Member,
  135. "state_key": bob,
  136. "sender": bob,
  137. "room_id": room_id,
  138. "content": {"tag": "bob_join_2"},
  139. },
  140. ).build(
  141. prev_event_ids=[],
  142. auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
  143. )
  144. )
  145. alice_join2 = self.get_success(
  146. event_factory.for_room_version(
  147. RoomVersions.V6,
  148. {
  149. "type": EventTypes.Member,
  150. "state_key": alice,
  151. "sender": alice,
  152. "room_id": room_id,
  153. "content": {"tag": "alice_join2"},
  154. },
  155. ).build(
  156. prev_event_ids=[],
  157. auth_event_ids=[
  158. create.event_id,
  159. alice_join.event_id,
  160. power_2.event_id,
  161. ],
  162. )
  163. )
  164. events = [
  165. create,
  166. bob_join,
  167. power,
  168. alice_invite,
  169. alice_join,
  170. bob_join_2,
  171. power_2,
  172. alice_join2,
  173. ]
  174. expected_links = [
  175. (bob_join, create),
  176. (power, create),
  177. (power, bob_join),
  178. (alice_invite, create),
  179. (alice_invite, power),
  180. (alice_invite, bob_join),
  181. (bob_join_2, power),
  182. (alice_join2, power_2),
  183. ]
  184. self.persist(events)
  185. chain_map, link_map = self.fetch_chains(events)
  186. # Check that the expected links and only the expected links have been
  187. # added.
  188. self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
  189. for start, end in expected_links:
  190. start_id, start_seq = chain_map[start.event_id]
  191. end_id, end_seq = chain_map[end.event_id]
  192. self.assertIn(
  193. (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
  194. )
  195. # Test that everything can reach the create event, but the create event
  196. # can't reach anything.
  197. for event in events[1:]:
  198. self.assertTrue(
  199. link_map.exists_path_from(
  200. chain_map[event.event_id], chain_map[create.event_id]
  201. ),
  202. )
  203. self.assertFalse(
  204. link_map.exists_path_from(
  205. chain_map[create.event_id],
  206. chain_map[event.event_id],
  207. ),
  208. )
  209. def test_out_of_order_events(self):
  210. """Test that we handle persisting events that we don't have the full
  211. auth chain for yet (which should only happen for out of band memberships).
  212. """
  213. event_factory = self.hs.get_event_builder_factory()
  214. bob = "@creator:test"
  215. alice = "@alice:test"
  216. room_id = "!room:test"
  217. # Ensure that we have a rooms entry so that we generate the chain index.
  218. self.get_success(
  219. self.store.store_room(
  220. room_id=room_id,
  221. room_creator_user_id="",
  222. is_public=True,
  223. room_version=RoomVersions.V6,
  224. )
  225. )
  226. # First persist the base room.
  227. create = self.get_success(
  228. event_factory.for_room_version(
  229. RoomVersions.V6,
  230. {
  231. "type": EventTypes.Create,
  232. "state_key": "",
  233. "sender": bob,
  234. "room_id": room_id,
  235. "content": {"tag": "create"},
  236. },
  237. ).build(prev_event_ids=[], auth_event_ids=[])
  238. )
  239. bob_join = self.get_success(
  240. event_factory.for_room_version(
  241. RoomVersions.V6,
  242. {
  243. "type": EventTypes.Member,
  244. "state_key": bob,
  245. "sender": bob,
  246. "room_id": room_id,
  247. "content": {"tag": "bob_join"},
  248. },
  249. ).build(prev_event_ids=[], auth_event_ids=[create.event_id])
  250. )
  251. power = self.get_success(
  252. event_factory.for_room_version(
  253. RoomVersions.V6,
  254. {
  255. "type": EventTypes.PowerLevels,
  256. "state_key": "",
  257. "sender": bob,
  258. "room_id": room_id,
  259. "content": {"tag": "power"},
  260. },
  261. ).build(
  262. prev_event_ids=[],
  263. auth_event_ids=[create.event_id, bob_join.event_id],
  264. )
  265. )
  266. self.persist([create, bob_join, power])
  267. # Now persist an invite and a couple of memberships out of order.
  268. alice_invite = self.get_success(
  269. event_factory.for_room_version(
  270. RoomVersions.V6,
  271. {
  272. "type": EventTypes.Member,
  273. "state_key": alice,
  274. "sender": bob,
  275. "room_id": room_id,
  276. "content": {"tag": "alice_invite"},
  277. },
  278. ).build(
  279. prev_event_ids=[],
  280. auth_event_ids=[create.event_id, bob_join.event_id, power.event_id],
  281. )
  282. )
  283. alice_join = self.get_success(
  284. event_factory.for_room_version(
  285. RoomVersions.V6,
  286. {
  287. "type": EventTypes.Member,
  288. "state_key": alice,
  289. "sender": alice,
  290. "room_id": room_id,
  291. "content": {"tag": "alice_join"},
  292. },
  293. ).build(
  294. prev_event_ids=[],
  295. auth_event_ids=[create.event_id, alice_invite.event_id, power.event_id],
  296. )
  297. )
  298. alice_join2 = self.get_success(
  299. event_factory.for_room_version(
  300. RoomVersions.V6,
  301. {
  302. "type": EventTypes.Member,
  303. "state_key": alice,
  304. "sender": alice,
  305. "room_id": room_id,
  306. "content": {"tag": "alice_join2"},
  307. },
  308. ).build(
  309. prev_event_ids=[],
  310. auth_event_ids=[create.event_id, alice_join.event_id, power.event_id],
  311. )
  312. )
  313. self.persist([alice_join])
  314. self.persist([alice_join2])
  315. self.persist([alice_invite])
  316. # The end result should be sane.
  317. events = [create, bob_join, power, alice_invite, alice_join]
  318. chain_map, link_map = self.fetch_chains(events)
  319. expected_links = [
  320. (bob_join, create),
  321. (power, create),
  322. (power, bob_join),
  323. (alice_invite, create),
  324. (alice_invite, power),
  325. (alice_invite, bob_join),
  326. ]
  327. # Check that the expected links and only the expected links have been
  328. # added.
  329. self.assertEqual(len(expected_links), len(list(link_map.get_additions())))
  330. for start, end in expected_links:
  331. start_id, start_seq = chain_map[start.event_id]
  332. end_id, end_seq = chain_map[end.event_id]
  333. self.assertIn(
  334. (start_seq, end_seq), list(link_map.get_links_between(start_id, end_id))
  335. )
  336. def persist(
  337. self,
  338. events: List[EventBase],
  339. ):
  340. """Persist the given events and check that the links generated match
  341. those given.
  342. """
  343. persist_events_store = self.hs.get_datastores().persist_events
  344. for e in events:
  345. e.internal_metadata.stream_ordering = self._next_stream_ordering
  346. self._next_stream_ordering += 1
  347. def _persist(txn):
  348. # We need to persist the events to the events and state_events
  349. # tables.
  350. persist_events_store._store_event_txn(
  351. txn,
  352. [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
  353. )
  354. # Actually call the function that calculates the auth chain stuff.
  355. persist_events_store._persist_event_auth_chain_txn(txn, events)
  356. self.get_success(
  357. persist_events_store.db_pool.runInteraction(
  358. "_persist",
  359. _persist,
  360. )
  361. )
  362. def fetch_chains(
  363. self, events: List[EventBase]
  364. ) -> Tuple[Dict[str, Tuple[int, int]], _LinkMap]:
  365. # Fetch the map from event ID -> (chain ID, sequence number)
  366. rows = self.get_success(
  367. self.store.db_pool.simple_select_many_batch(
  368. table="event_auth_chains",
  369. column="event_id",
  370. iterable=[e.event_id for e in events],
  371. retcols=("event_id", "chain_id", "sequence_number"),
  372. keyvalues={},
  373. )
  374. )
  375. chain_map = {
  376. row["event_id"]: (row["chain_id"], row["sequence_number"]) for row in rows
  377. }
  378. # Fetch all the links and pass them to the _LinkMap.
  379. rows = self.get_success(
  380. self.store.db_pool.simple_select_many_batch(
  381. table="event_auth_chain_links",
  382. column="origin_chain_id",
  383. iterable=[chain_id for chain_id, _ in chain_map.values()],
  384. retcols=(
  385. "origin_chain_id",
  386. "origin_sequence_number",
  387. "target_chain_id",
  388. "target_sequence_number",
  389. ),
  390. keyvalues={},
  391. )
  392. )
  393. link_map = _LinkMap()
  394. for row in rows:
  395. added = link_map.add_link(
  396. (row["origin_chain_id"], row["origin_sequence_number"]),
  397. (row["target_chain_id"], row["target_sequence_number"]),
  398. )
  399. # We shouldn't have persisted any redundant links
  400. self.assertTrue(added)
  401. return chain_map, link_map
  402. class LinkMapTestCase(unittest.TestCase):
  403. def test_simple(self):
  404. """Basic tests for the LinkMap."""
  405. link_map = _LinkMap()
  406. link_map.add_link((1, 1), (2, 1), new=False)
  407. self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
  408. self.assertCountEqual(link_map.get_links_from((1, 1)), [(2, 1)])
  409. self.assertCountEqual(link_map.get_additions(), [])
  410. self.assertTrue(link_map.exists_path_from((1, 5), (2, 1)))
  411. self.assertFalse(link_map.exists_path_from((1, 5), (2, 2)))
  412. self.assertTrue(link_map.exists_path_from((1, 5), (1, 1)))
  413. self.assertFalse(link_map.exists_path_from((1, 1), (1, 5)))
  414. # Attempting to add a redundant link is ignored.
  415. self.assertFalse(link_map.add_link((1, 4), (2, 1)))
  416. self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1)])
  417. # Adding new non-redundant links works
  418. self.assertTrue(link_map.add_link((1, 3), (2, 3)))
  419. self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
  420. self.assertTrue(link_map.add_link((2, 5), (1, 3)))
  421. self.assertCountEqual(link_map.get_links_between(2, 1), [(5, 3)])
  422. self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
  423. self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
  424. class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
  425. servlets = [
  426. admin.register_servlets,
  427. room.register_servlets,
  428. login.register_servlets,
  429. ]
  430. def prepare(self, reactor, clock, hs):
  431. self.store = hs.get_datastores().main
  432. self.user_id = self.register_user("foo", "pass")
  433. self.token = self.login("foo", "pass")
  434. self.requester = create_requester(self.user_id)
  435. def _generate_room(self) -> Tuple[str, List[Set[str]]]:
  436. """Insert a room without a chain cover index."""
  437. room_id = self.helper.create_room_as(self.user_id, tok=self.token)
  438. # Mark the room as not having a chain cover index
  439. self.get_success(
  440. self.store.db_pool.simple_update(
  441. table="rooms",
  442. keyvalues={"room_id": room_id},
  443. updatevalues={"has_auth_chain_index": False},
  444. desc="test",
  445. )
  446. )
  447. # Create a fork in the DAG with different events.
  448. event_handler = self.hs.get_event_creation_handler()
  449. latest_event_ids = self.get_success(
  450. self.store.get_prev_events_for_room(room_id)
  451. )
  452. event, context = self.get_success(
  453. event_handler.create_event(
  454. self.requester,
  455. {
  456. "type": "some_state_type",
  457. "state_key": "",
  458. "content": {},
  459. "room_id": room_id,
  460. "sender": self.user_id,
  461. },
  462. prev_event_ids=latest_event_ids,
  463. )
  464. )
  465. self.get_success(
  466. event_handler.handle_new_client_event(
  467. self.requester, events_and_context=[(event, context)]
  468. )
  469. )
  470. state1 = set(self.get_success(context.get_current_state_ids()).values())
  471. event, context = self.get_success(
  472. event_handler.create_event(
  473. self.requester,
  474. {
  475. "type": "some_state_type",
  476. "state_key": "",
  477. "content": {},
  478. "room_id": room_id,
  479. "sender": self.user_id,
  480. },
  481. prev_event_ids=latest_event_ids,
  482. )
  483. )
  484. self.get_success(
  485. event_handler.handle_new_client_event(
  486. self.requester, events_and_context=[(event, context)]
  487. )
  488. )
  489. state2 = set(self.get_success(context.get_current_state_ids()).values())
  490. # Delete the chain cover info.
  491. def _delete_tables(txn):
  492. txn.execute("DELETE FROM event_auth_chains")
  493. txn.execute("DELETE FROM event_auth_chain_links")
  494. self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
  495. return room_id, [state1, state2]
  496. def test_background_update_single_room(self):
  497. """Test that the background update to calculate auth chains for historic
  498. rooms works correctly.
  499. """
  500. # Create a room
  501. room_id, states = self._generate_room()
  502. # Insert and run the background update.
  503. self.get_success(
  504. self.store.db_pool.simple_insert(
  505. "background_updates",
  506. {"update_name": "chain_cover", "progress_json": "{}"},
  507. )
  508. )
  509. # Ugh, have to reset this flag
  510. self.store.db_pool.updates._all_done = False
  511. self.wait_for_background_updates()
  512. # Test that the `has_auth_chain_index` has been set
  513. self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
  514. # Test that calculating the auth chain difference using the newly
  515. # calculated chain cover works.
  516. self.get_success(
  517. self.store.db_pool.runInteraction(
  518. "test",
  519. self.store._get_auth_chain_difference_using_cover_index_txn,
  520. room_id,
  521. states,
  522. )
  523. )
  524. def test_background_update_multiple_rooms(self):
  525. """Test that the background update to calculate auth chains for historic
  526. rooms works correctly.
  527. """
  528. # Create a room
  529. room_id1, states1 = self._generate_room()
  530. room_id2, states2 = self._generate_room()
  531. room_id3, states2 = self._generate_room()
  532. # Insert and run the background update.
  533. self.get_success(
  534. self.store.db_pool.simple_insert(
  535. "background_updates",
  536. {"update_name": "chain_cover", "progress_json": "{}"},
  537. )
  538. )
  539. # Ugh, have to reset this flag
  540. self.store.db_pool.updates._all_done = False
  541. self.wait_for_background_updates()
  542. # Test that the `has_auth_chain_index` has been set
  543. self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
  544. self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
  545. self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
  546. # Test that calculating the auth chain difference using the newly
  547. # calculated chain cover works.
  548. self.get_success(
  549. self.store.db_pool.runInteraction(
  550. "test",
  551. self.store._get_auth_chain_difference_using_cover_index_txn,
  552. room_id1,
  553. states1,
  554. )
  555. )
  556. def test_background_update_single_large_room(self):
  557. """Test that the background update to calculate auth chains for historic
  558. rooms works correctly.
  559. """
  560. # Create a room
  561. room_id, states = self._generate_room()
  562. # Add a bunch of state so that it takes multiple iterations of the
  563. # background update to process the room.
  564. for i in range(0, 150):
  565. self.helper.send_state(
  566. room_id, event_type="m.test", body={"index": i}, tok=self.token
  567. )
  568. # Insert and run the background update.
  569. self.get_success(
  570. self.store.db_pool.simple_insert(
  571. "background_updates",
  572. {"update_name": "chain_cover", "progress_json": "{}"},
  573. )
  574. )
  575. # Ugh, have to reset this flag
  576. self.store.db_pool.updates._all_done = False
  577. iterations = 0
  578. while not self.get_success(
  579. self.store.db_pool.updates.has_completed_background_updates()
  580. ):
  581. iterations += 1
  582. self.get_success(
  583. self.store.db_pool.updates.do_next_background_update(False), by=0.1
  584. )
  585. # Ensure that we did actually take multiple iterations to process the
  586. # room.
  587. self.assertGreater(iterations, 1)
  588. # Test that the `has_auth_chain_index` has been set
  589. self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
  590. # Test that calculating the auth chain difference using the newly
  591. # calculated chain cover works.
  592. self.get_success(
  593. self.store.db_pool.runInteraction(
  594. "test",
  595. self.store._get_auth_chain_difference_using_cover_index_txn,
  596. room_id,
  597. states,
  598. )
  599. )
  600. def test_background_update_multiple_large_room(self):
  601. """Test that the background update to calculate auth chains for historic
  602. rooms works correctly.
  603. """
  604. # Create the rooms
  605. room_id1, _ = self._generate_room()
  606. room_id2, _ = self._generate_room()
  607. # Add a bunch of state so that it takes multiple iterations of the
  608. # background update to process the room.
  609. for i in range(0, 150):
  610. self.helper.send_state(
  611. room_id1, event_type="m.test", body={"index": i}, tok=self.token
  612. )
  613. for i in range(0, 150):
  614. self.helper.send_state(
  615. room_id2, event_type="m.test", body={"index": i}, tok=self.token
  616. )
  617. # Insert and run the background update.
  618. self.get_success(
  619. self.store.db_pool.simple_insert(
  620. "background_updates",
  621. {"update_name": "chain_cover", "progress_json": "{}"},
  622. )
  623. )
  624. # Ugh, have to reset this flag
  625. self.store.db_pool.updates._all_done = False
  626. iterations = 0
  627. while not self.get_success(
  628. self.store.db_pool.updates.has_completed_background_updates()
  629. ):
  630. iterations += 1
  631. self.get_success(
  632. self.store.db_pool.updates.do_next_background_update(False), by=0.1
  633. )
  634. # Ensure that we did actually take multiple iterations to process the
  635. # room.
  636. self.assertGreater(iterations, 1)
  637. # Test that the `has_auth_chain_index` has been set
  638. self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
  639. self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))