1
0

test_federation.py 29 KB


  1. # Copyright 2019 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import Collection, Optional, cast
  16. from unittest import TestCase
  17. from unittest.mock import AsyncMock, Mock, patch
  18. from twisted.internet.defer import Deferred
  19. from twisted.test.proto_helpers import MemoryReactor
  20. from synapse.api.constants import EventTypes
  21. from synapse.api.errors import (
  22. AuthError,
  23. Codes,
  24. LimitExceededError,
  25. NotFoundError,
  26. SynapseError,
  27. )
  28. from synapse.api.room_versions import RoomVersions
  29. from synapse.events import EventBase, make_event_from_dict
  30. from synapse.federation.federation_base import event_from_pdu_json
  31. from synapse.federation.federation_client import SendJoinResult
  32. from synapse.logging.context import LoggingContext, run_in_background
  33. from synapse.rest import admin
  34. from synapse.rest.client import login, room
  35. from synapse.server import HomeServer
  36. from synapse.storage.databases.main.events_worker import EventCacheEntry
  37. from synapse.util import Clock
  38. from synapse.util.stringutils import random_string
  39. from tests import unittest
  40. from tests.test_utils import event_injection
  41. logger = logging.getLogger(__name__)
  42. def generate_fake_event_id() -> str:
  43. return "$fake_" + random_string(43)
  44. class FederationTestCase(unittest.FederatingHomeserverTestCase):
  45. servlets = [
  46. admin.register_servlets,
  47. login.register_servlets,
  48. room.register_servlets,
  49. ]
  50. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  51. hs = self.setup_test_homeserver()
  52. self.handler = hs.get_federation_handler()
  53. self.store = hs.get_datastores().main
  54. return hs
  55. def test_exchange_revoked_invite(self) -> None:
  56. user_id = self.register_user("kermit", "test")
  57. tok = self.login("kermit", "test")
  58. room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
  59. # Send a 3PID invite event with an empty body so it's considered as a revoked one.
  60. invite_token = "sometoken"
  61. self.helper.send_state(
  62. room_id=room_id,
  63. event_type=EventTypes.ThirdPartyInvite,
  64. state_key=invite_token,
  65. body={},
  66. tok=tok,
  67. )
  68. d = self.handler.on_exchange_third_party_invite_request(
  69. event_dict={
  70. "type": EventTypes.Member,
  71. "room_id": room_id,
  72. "sender": user_id,
  73. "state_key": "@someone:example.org",
  74. "content": {
  75. "membership": "invite",
  76. "third_party_invite": {
  77. "display_name": "alice",
  78. "signed": {
  79. "mxid": "@alice:localhost",
  80. "token": invite_token,
  81. "signatures": {
  82. "magic.forest": {
  83. "ed25519:3": "fQpGIW1Snz+pwLZu6sTy2aHy/DYWWTspTJRPyNp0PKkymfIsNffysMl6ObMMFdIJhk6g6pwlIqZ54rxo8SLmAg"
  84. }
  85. },
  86. },
  87. },
  88. },
  89. },
  90. )
  91. failure = self.get_failure(d, AuthError).value
  92. self.assertEqual(failure.code, 403, failure)
  93. self.assertEqual(failure.errcode, Codes.FORBIDDEN, failure)
  94. self.assertEqual(failure.msg, "You are not invited to this room.")
  95. def test_rejected_message_event_state(self) -> None:
  96. """
  97. Check that we store the state group correctly for rejected non-state events.
  98. Regression test for #6289.
  99. """
  100. OTHER_SERVER = "otherserver"
  101. OTHER_USER = "@otheruser:" + OTHER_SERVER
  102. # create the room
  103. user_id = self.register_user("kermit", "test")
  104. tok = self.login("kermit", "test")
  105. room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
  106. room_version = self.get_success(self.store.get_room_version(room_id))
  107. # pretend that another server has joined
  108. join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
  109. # check the state group
  110. sg = self.get_success(
  111. self.store._get_state_group_for_event(join_event.event_id)
  112. )
  113. # build and send an event which will be rejected
  114. ev = event_from_pdu_json(
  115. {
  116. "type": EventTypes.Message,
  117. "content": {},
  118. "room_id": room_id,
  119. "sender": "@yetanotheruser:" + OTHER_SERVER,
  120. "depth": cast(int, join_event["depth"]) + 1,
  121. "prev_events": [join_event.event_id],
  122. "auth_events": [],
  123. "origin_server_ts": self.clock.time_msec(),
  124. },
  125. room_version,
  126. )
  127. with LoggingContext("send_rejected"):
  128. d = run_in_background(
  129. self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
  130. )
  131. self.get_success(d)
  132. # that should have been rejected
  133. e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
  134. self.assertIsNotNone(e.rejected_reason)
  135. # ... and the state group should be the same as before
  136. sg2 = self.get_success(self.store._get_state_group_for_event(ev.event_id))
  137. self.assertEqual(sg, sg2)
  138. def test_rejected_state_event_state(self) -> None:
  139. """
  140. Check that we store the state group correctly for rejected state events.
  141. Regression test for #6289.
  142. """
  143. OTHER_SERVER = "otherserver"
  144. OTHER_USER = "@otheruser:" + OTHER_SERVER
  145. # create the room
  146. user_id = self.register_user("kermit", "test")
  147. tok = self.login("kermit", "test")
  148. room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
  149. room_version = self.get_success(self.store.get_room_version(room_id))
  150. # pretend that another server has joined
  151. join_event = self._build_and_send_join_event(OTHER_SERVER, OTHER_USER, room_id)
  152. # check the state group
  153. sg = self.get_success(
  154. self.store._get_state_group_for_event(join_event.event_id)
  155. )
  156. # build and send an event which will be rejected
  157. ev = event_from_pdu_json(
  158. {
  159. "type": "org.matrix.test",
  160. "state_key": "test_key",
  161. "content": {},
  162. "room_id": room_id,
  163. "sender": "@yetanotheruser:" + OTHER_SERVER,
  164. "depth": cast(int, join_event["depth"]) + 1,
  165. "prev_events": [join_event.event_id],
  166. "auth_events": [],
  167. "origin_server_ts": self.clock.time_msec(),
  168. },
  169. room_version,
  170. )
  171. with LoggingContext("send_rejected"):
  172. d = run_in_background(
  173. self.hs.get_federation_event_handler().on_receive_pdu, OTHER_SERVER, ev
  174. )
  175. self.get_success(d)
  176. # that should have been rejected
  177. e = self.get_success(self.store.get_event(ev.event_id, allow_rejected=True))
  178. self.assertIsNotNone(e.rejected_reason)
  179. # ... and the state group should be the same as before
  180. sg2 = self.get_success(self.store._get_state_group_for_event(ev.event_id))
  181. self.assertEqual(sg, sg2)
  182. def test_backfill_with_many_backward_extremities(self) -> None:
  183. """
  184. Check that we can backfill with many backward extremities.
  185. The goal is to make sure that when we only use a portion
  186. of backwards extremities(the magic number is more than 5),
  187. no errors are thrown.
  188. Regression test, see #11027
  189. """
  190. # create the room
  191. user_id = self.register_user("kermit", "test")
  192. tok = self.login("kermit", "test")
  193. room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
  194. room_version = self.get_success(self.store.get_room_version(room_id))
  195. # we need a user on the remote server to be a member, so that we can send
  196. # extremity-causing events.
  197. remote_server_user_id = f"@user:{self.OTHER_SERVER_NAME}"
  198. self.get_success(
  199. event_injection.inject_member_event(
  200. self.hs, room_id, remote_server_user_id, "join"
  201. )
  202. )
  203. send_result = self.helper.send(room_id, "first message", tok=tok)
  204. ev1 = self.get_success(
  205. self.store.get_event(send_result["event_id"], allow_none=False)
  206. )
  207. current_state = self.get_success(
  208. self.store.get_events_as_list(
  209. (
  210. self.get_success(self.store.get_partial_current_state_ids(room_id))
  211. ).values()
  212. )
  213. )
  214. # Create "many" backward extremities. The magic number we're trying to
  215. # create more than is 5 which corresponds to the number of backward
  216. # extremities we slice off in `_maybe_backfill_inner`
  217. federation_event_handler = self.hs.get_federation_event_handler()
  218. auth_events = [
  219. ev
  220. for ev in current_state
  221. if (ev.type, ev.state_key)
  222. in {("m.room.create", ""), ("m.room.member", remote_server_user_id)}
  223. ]
  224. for _ in range(0, 8):
  225. event = make_event_from_dict(
  226. self.add_hashes_and_signatures_from_other_server(
  227. {
  228. "origin_server_ts": 1,
  229. "type": "m.room.message",
  230. "content": {
  231. "msgtype": "m.text",
  232. "body": "message connected to fake event",
  233. },
  234. "room_id": room_id,
  235. "sender": remote_server_user_id,
  236. "prev_events": [
  237. ev1.event_id,
  238. # We're creating an backward extremity each time thanks
  239. # to this fake event
  240. generate_fake_event_id(),
  241. ],
  242. "auth_events": [ev.event_id for ev in auth_events],
  243. "depth": ev1.depth + 1,
  244. },
  245. room_version,
  246. ),
  247. room_version,
  248. )
  249. # we poke this directly into _process_received_pdu, to avoid the
  250. # federation handler wanting to backfill the fake event.
  251. state_handler = self.hs.get_state_handler()
  252. context = self.get_success(
  253. state_handler.compute_event_context(
  254. event,
  255. state_ids_before_event={
  256. (e.type, e.state_key): e.event_id for e in current_state
  257. },
  258. partial_state=False,
  259. )
  260. )
  261. self.get_success(
  262. federation_event_handler._process_received_pdu(
  263. self.OTHER_SERVER_NAME,
  264. event,
  265. context,
  266. )
  267. )
  268. # we should now have 8 backwards extremities.
  269. backwards_extremities = self.get_success(
  270. self.store.db_pool.simple_select_list(
  271. "event_backward_extremities",
  272. keyvalues={"room_id": room_id},
  273. retcols=["event_id"],
  274. )
  275. )
  276. self.assertEqual(len(backwards_extremities), 8)
  277. current_depth = 1
  278. limit = 100
  279. with LoggingContext("receive_pdu"):
  280. # Make sure backfill still works
  281. d = run_in_background(
  282. self.hs.get_federation_handler().maybe_backfill,
  283. room_id,
  284. current_depth,
  285. limit,
  286. )
  287. self.get_success(d)
  288. def test_backfill_ignores_known_events(self) -> None:
  289. """
  290. Tests that events that we already know about are ignored when backfilling.
  291. """
  292. # Set up users
  293. user_id = self.register_user("kermit", "test")
  294. tok = self.login("kermit", "test")
  295. other_server = "otherserver"
  296. other_user = "@otheruser:" + other_server
  297. # Create a room to backfill events into
  298. room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
  299. room_version = self.get_success(self.store.get_room_version(room_id))
  300. # Build an event to backfill
  301. event = event_from_pdu_json(
  302. {
  303. "type": EventTypes.Message,
  304. "content": {"body": "hello world", "msgtype": "m.text"},
  305. "room_id": room_id,
  306. "sender": other_user,
  307. "depth": 32,
  308. "prev_events": [],
  309. "auth_events": [],
  310. "origin_server_ts": self.clock.time_msec(),
  311. },
  312. room_version,
  313. )
  314. # Ensure the event is not already in the DB
  315. self.get_failure(
  316. self.store.get_event(event.event_id),
  317. NotFoundError,
  318. )
  319. # Backfill the event and check that it has entered the DB.
  320. # We mock out the FederationClient.backfill method, to pretend that a remote
  321. # server has returned our fake event.
  322. federation_client_backfill_mock = AsyncMock(return_value=[event])
  323. self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[method-assign]
  324. # We also mock the persist method with a side effect of itself. This allows us
  325. # to track when it has been called while preserving its function.
  326. persist_events_and_notify_mock = Mock(
  327. side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
  328. )
  329. self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[method-assign]
  330. persist_events_and_notify_mock
  331. )
  332. # Small side-tangent. We populate the event cache with the event, even though
  333. # it is not yet in the DB. This is an invalid scenario that can currently occur
  334. # due to not properly invalidating the event cache.
  335. # See https://github.com/matrix-org/synapse/issues/13476.
  336. #
  337. # As a result, backfill should not rely on the event cache to check whether
  338. # we already have an event in the DB.
  339. # TODO: Remove this bit when the event cache is properly invalidated.
  340. cache_entry = EventCacheEntry(
  341. event=event,
  342. redacted_event=None,
  343. )
  344. self.store._get_event_cache.set_local((event.event_id,), cache_entry)
  345. # We now call FederationEventHandler.backfill (a separate method) to trigger
  346. # a backfill request. It should receive the fake event.
  347. self.get_success(
  348. self.hs.get_federation_event_handler().backfill(
  349. other_user,
  350. room_id,
  351. limit=10,
  352. extremities=[],
  353. )
  354. )
  355. # Check that our fake event was persisted.
  356. persist_events_and_notify_mock.assert_called_once()
  357. persist_events_and_notify_mock.reset_mock()
  358. # Now we repeat the backfill, having the homeserver receive the fake event
  359. # again.
  360. self.get_success(
  361. self.hs.get_federation_event_handler().backfill(
  362. other_user,
  363. room_id,
  364. limit=10,
  365. extremities=[],
  366. ),
  367. )
  368. # This time, we expect no event persistence to have occurred, as we already
  369. # have this event.
  370. persist_events_and_notify_mock.assert_not_called()
  371. @unittest.override_config(
  372. {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
  373. )
  374. def test_invite_by_user_ratelimit(self) -> None:
  375. """Tests that invites from federation to a particular user are
  376. actually rate-limited.
  377. """
  378. other_server = "otherserver"
  379. other_user = "@otheruser:" + other_server
  380. # create the room
  381. user_id = self.register_user("kermit", "test")
  382. tok = self.login("kermit", "test")
  383. def create_invite() -> EventBase:
  384. room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
  385. room_version = self.get_success(self.store.get_room_version(room_id))
  386. return event_from_pdu_json(
  387. {
  388. "type": EventTypes.Member,
  389. "content": {"membership": "invite"},
  390. "room_id": room_id,
  391. "sender": other_user,
  392. "state_key": "@user:test",
  393. "depth": 32,
  394. "prev_events": [],
  395. "auth_events": [],
  396. "origin_server_ts": self.clock.time_msec(),
  397. },
  398. room_version,
  399. )
  400. for _ in range(3):
  401. event = create_invite()
  402. self.get_success(
  403. self.handler.on_invite_request(
  404. other_server,
  405. event,
  406. event.room_version,
  407. )
  408. )
  409. event = create_invite()
  410. self.get_failure(
  411. self.handler.on_invite_request(
  412. other_server,
  413. event,
  414. event.room_version,
  415. ),
  416. exc=LimitExceededError,
  417. )
  418. def _build_and_send_join_event(
  419. self, other_server: str, other_user: str, room_id: str
  420. ) -> EventBase:
  421. join_event = self.get_success(
  422. self.handler.on_make_join_request(other_server, room_id, other_user)
  423. )
  424. # the auth code requires that a signature exists, but doesn't check that
  425. # signature... go figure.
  426. join_event.signatures[other_server] = {"x": "y"}
  427. with LoggingContext("send_join"):
  428. d = run_in_background(
  429. self.hs.get_federation_event_handler().on_send_membership_event,
  430. other_server,
  431. join_event,
  432. )
  433. self.get_success(d)
  434. # sanity-check: the room should show that the new user is a member
  435. r = self.get_success(self.store.get_partial_current_state_ids(room_id))
  436. self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
  437. return join_event
  438. class EventFromPduTestCase(TestCase):
  439. def test_valid_json(self) -> None:
  440. """Valid JSON should be turned into an event."""
  441. ev = event_from_pdu_json(
  442. {
  443. "type": EventTypes.Message,
  444. "content": {"bool": True, "null": None, "int": 1, "str": "foobar"},
  445. "room_id": "!room:test",
  446. "sender": "@user:test",
  447. "depth": 1,
  448. "prev_events": [],
  449. "auth_events": [],
  450. "origin_server_ts": 1234,
  451. },
  452. RoomVersions.V6,
  453. )
  454. self.assertIsInstance(ev, EventBase)
  455. def test_invalid_numbers(self) -> None:
  456. """Invalid values for an integer should be rejected, all floats should be rejected."""
  457. for value in [
  458. -(2**53),
  459. 2**53,
  460. 1.0,
  461. float("inf"),
  462. float("-inf"),
  463. float("nan"),
  464. ]:
  465. with self.assertRaises(SynapseError):
  466. event_from_pdu_json(
  467. {
  468. "type": EventTypes.Message,
  469. "content": {"foo": value},
  470. "room_id": "!room:test",
  471. "sender": "@user:test",
  472. "depth": 1,
  473. "prev_events": [],
  474. "auth_events": [],
  475. "origin_server_ts": 1234,
  476. },
  477. RoomVersions.V6,
  478. )
  479. def test_invalid_nested(self) -> None:
  480. """List and dictionaries are recursively searched."""
  481. with self.assertRaises(SynapseError):
  482. event_from_pdu_json(
  483. {
  484. "type": EventTypes.Message,
  485. "content": {"foo": [{"bar": 2**56}]},
  486. "room_id": "!room:test",
  487. "sender": "@user:test",
  488. "depth": 1,
  489. "prev_events": [],
  490. "auth_events": [],
  491. "origin_server_ts": 1234,
  492. },
  493. RoomVersions.V6,
  494. )
  495. class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
  496. def test_failed_partial_join_is_clean(self) -> None:
  497. """
  498. Tests that, when failing to partial-join a room, we don't get stuck with
  499. a partial-state flag on a room.
  500. """
  501. fed_handler = self.hs.get_federation_handler()
  502. fed_client = fed_handler.federation_client
  503. room_id = "!room:example.com"
  504. EVENT_CREATE = make_event_from_dict(
  505. {
  506. "room_id": room_id,
  507. "type": "m.room.create",
  508. "sender": "@kristina:example.com",
  509. "state_key": "",
  510. "depth": 0,
  511. "content": {"creator": "@kristina:example.com", "room_version": "10"},
  512. "auth_events": [],
  513. "origin_server_ts": 1,
  514. },
  515. room_version=RoomVersions.V10,
  516. )
  517. EVENT_CREATOR_MEMBERSHIP = make_event_from_dict(
  518. {
  519. "room_id": room_id,
  520. "type": "m.room.member",
  521. "sender": "@kristina:example.com",
  522. "state_key": "@kristina:example.com",
  523. "content": {"membership": "join"},
  524. "depth": 1,
  525. "prev_events": [EVENT_CREATE.event_id],
  526. "auth_events": [EVENT_CREATE.event_id],
  527. "origin_server_ts": 1,
  528. },
  529. room_version=RoomVersions.V10,
  530. )
  531. EVENT_INVITATION_MEMBERSHIP = make_event_from_dict(
  532. {
  533. "room_id": room_id,
  534. "type": "m.room.member",
  535. "sender": "@kristina:example.com",
  536. "state_key": "@alice:test",
  537. "content": {"membership": "invite"},
  538. "depth": 2,
  539. "prev_events": [EVENT_CREATOR_MEMBERSHIP.event_id],
  540. "auth_events": [
  541. EVENT_CREATE.event_id,
  542. EVENT_CREATOR_MEMBERSHIP.event_id,
  543. ],
  544. "origin_server_ts": 1,
  545. },
  546. room_version=RoomVersions.V10,
  547. )
  548. membership_event = make_event_from_dict(
  549. {
  550. "room_id": room_id,
  551. "type": "m.room.member",
  552. "sender": "@alice:test",
  553. "state_key": "@alice:test",
  554. "content": {"membership": "join"},
  555. "prev_events": [EVENT_INVITATION_MEMBERSHIP.event_id],
  556. },
  557. RoomVersions.V10,
  558. )
  559. mock_make_membership_event = AsyncMock(
  560. return_value=(
  561. "example.com",
  562. membership_event,
  563. RoomVersions.V10,
  564. )
  565. )
  566. mock_send_join = AsyncMock(
  567. return_value=SendJoinResult(
  568. membership_event,
  569. "example.com",
  570. state=[
  571. EVENT_CREATE,
  572. EVENT_CREATOR_MEMBERSHIP,
  573. EVENT_INVITATION_MEMBERSHIP,
  574. ],
  575. auth_chain=[
  576. EVENT_CREATE,
  577. EVENT_CREATOR_MEMBERSHIP,
  578. EVENT_INVITATION_MEMBERSHIP,
  579. ],
  580. partial_state=True,
  581. servers_in_room={"example.com"},
  582. )
  583. )
  584. with patch.object(
  585. fed_client, "make_membership_event", mock_make_membership_event
  586. ), patch.object(fed_client, "send_join", mock_send_join):
  587. # Join and check that our join event is rejected
  588. # (The join event is rejected because it doesn't have any signatures)
  589. join_exc = self.get_failure(
  590. fed_handler.do_invite_join(["example.com"], room_id, "@alice:test", {}),
  591. SynapseError,
  592. )
  593. self.assertIn("Join event was rejected", str(join_exc))
  594. store = self.hs.get_datastores().main
  595. # Check that we don't have a left-over partial_state entry.
  596. self.assertFalse(
  597. self.get_success(store.is_partial_state_room(room_id)),
  598. f"Stale partial-stated room flag left over for {room_id} after a"
  599. f" failed do_invite_join!",
  600. )
  601. def test_duplicate_partial_state_room_syncs(self) -> None:
  602. """
  603. Tests that concurrent partial state syncs are not started for the same room.
  604. """
  605. is_partial_state = True
  606. end_sync: "Deferred[None]" = Deferred()
  607. async def is_partial_state_room(room_id: str) -> bool:
  608. return is_partial_state
  609. async def sync_partial_state_room(
  610. initial_destination: Optional[str],
  611. other_destinations: Collection[str],
  612. room_id: str,
  613. ) -> None:
  614. nonlocal end_sync
  615. try:
  616. await end_sync
  617. finally:
  618. end_sync = Deferred()
  619. mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
  620. mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
  621. fed_handler = self.hs.get_federation_handler()
  622. store = self.hs.get_datastores().main
  623. with patch.object(
  624. fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
  625. ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
  626. # Start the partial state sync.
  627. fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
  628. self.assertEqual(mock_sync_partial_state_room.call_count, 1)
  629. # Try to start another partial state sync.
  630. # Nothing should happen.
  631. fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
  632. self.assertEqual(mock_sync_partial_state_room.call_count, 1)
  633. # End the partial state sync
  634. is_partial_state = False
  635. end_sync.callback(None)
  636. # The partial state sync should not be restarted.
  637. self.assertEqual(mock_sync_partial_state_room.call_count, 1)
  638. # The next attempt to start the partial state sync should work.
  639. is_partial_state = True
  640. fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
  641. self.assertEqual(mock_sync_partial_state_room.call_count, 2)
  642. def test_partial_state_room_sync_restart(self) -> None:
  643. """
  644. Tests that partial state syncs are restarted when a second partial state sync
  645. was deduplicated and the first partial state sync fails.
  646. """
  647. is_partial_state = True
  648. end_sync: "Deferred[None]" = Deferred()
  649. async def is_partial_state_room(room_id: str) -> bool:
  650. return is_partial_state
  651. async def sync_partial_state_room(
  652. initial_destination: Optional[str],
  653. other_destinations: Collection[str],
  654. room_id: str,
  655. ) -> None:
  656. nonlocal end_sync
  657. try:
  658. await end_sync
  659. finally:
  660. end_sync = Deferred()
  661. mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
  662. mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
  663. fed_handler = self.hs.get_federation_handler()
  664. store = self.hs.get_datastores().main
  665. with patch.object(
  666. fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
  667. ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
  668. # Start the partial state sync.
  669. fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
  670. self.assertEqual(mock_sync_partial_state_room.call_count, 1)
  671. # Fail the partial state sync.
  672. # The partial state sync should not be restarted.
  673. end_sync.errback(Exception("Failed to request /state_ids"))
  674. self.assertEqual(mock_sync_partial_state_room.call_count, 1)
  675. # Start the partial state sync again.
  676. fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
  677. self.assertEqual(mock_sync_partial_state_room.call_count, 2)
  678. # Deduplicate another partial state sync.
  679. fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
  680. self.assertEqual(mock_sync_partial_state_room.call_count, 2)
  681. # Fail the partial state sync.
  682. # It should restart with the latest parameters.
  683. end_sync.errback(Exception("Failed to request /state_ids"))
  684. self.assertEqual(mock_sync_partial_state_room.call_count, 3)
  685. mock_sync_partial_state_room.assert_called_with(
  686. initial_destination="hs3",
  687. other_destinations={"hs2"},
  688. room_id="room_id",
  689. )