123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420 |
- # Copyright 2016 OpenMarket Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- from typing import Any, Callable, Iterable, List, Optional, Tuple
- from canonicaljson import encode_canonical_json
- from parameterized import parameterized
- from twisted.test.proto_helpers import MemoryReactor
- from synapse.api.constants import ReceiptTypes
- from synapse.api.room_versions import RoomVersions
- from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
- from synapse.events.snapshot import EventContext
- from synapse.handlers.room import RoomEventSource
- from synapse.server import HomeServer
- from synapse.storage.databases.main.event_push_actions import (
- NotifCounts,
- RoomNotifCounts,
- )
- from synapse.storage.databases.main.events_worker import EventsWorkerStore
- from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
- from synapse.types import PersistedEventPosition
- from synapse.util import Clock
- from tests.server import FakeTransport
- from ._base import BaseWorkerStoreTestCase
- USER_ID = "@feeling:test"
- USER_ID_2 = "@bright:test"
- OUTLIER = {"outlier": True}
- ROOM_ID = "!room:test"
- logger = logging.getLogger(__name__)
- def dict_equals(self: EventBase, other: EventBase) -> bool:
- me = encode_canonical_json(self.get_pdu_json())
- them = encode_canonical_json(other.get_pdu_json())
- return me == them
- def patch__eq__(cls: object) -> Callable[[], None]:
- eq = getattr(cls, "__eq__", None)
- cls.__eq__ = dict_equals # type: ignore[assignment]
- def unpatch() -> None:
- if eq is not None:
- cls.__eq__ = eq # type: ignore[assignment]
- return unpatch
- class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
- STORE_TYPE = EventsWorkerStore
- def setUp(self) -> None:
- # Patch up the equality operator for events so that we can check
- # whether lists of events match using assertEqual
- self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
- super().setUp()
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- super().prepare(reactor, clock, hs)
- self.get_success(
- self.master_store.store_room(
- ROOM_ID,
- USER_ID,
- is_public=False,
- room_version=RoomVersions.V1,
- )
- )
- def tearDown(self) -> None:
- [unpatch() for unpatch in self.unpatches]
- def test_get_latest_event_ids_in_room(self) -> None:
- create = self.persist(type="m.room.create", key="", creator=USER_ID)
- self.replicate()
- self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
- join = self.persist(
- type="m.room.member",
- key=USER_ID,
- membership="join",
- prev_events=[(create.event_id, {})],
- )
- self.replicate()
- self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
- def test_redactions(self) -> None:
- self.persist(type="m.room.create", key="", creator=USER_ID)
- self.persist(type="m.room.member", key=USER_ID, membership="join")
- msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
- self.replicate()
- self.check("get_event", [msg.event_id], msg)
- redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
- self.replicate()
- msg_dict = msg.get_dict()
- msg_dict["content"] = {}
- msg_dict["unsigned"]["redacted_by"] = redaction.event_id
- msg_dict["unsigned"]["redacted_because"] = redaction
- redacted = make_event_from_dict(
- msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
- )
- self.check("get_event", [msg.event_id], redacted)
- def test_backfilled_redactions(self) -> None:
- self.persist(type="m.room.create", key="", creator=USER_ID)
- self.persist(type="m.room.member", key=USER_ID, membership="join")
- msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
- self.replicate()
- self.check("get_event", [msg.event_id], msg)
- redaction = self.persist(
- type="m.room.redaction", redacts=msg.event_id, backfill=True
- )
- self.replicate()
- msg_dict = msg.get_dict()
- msg_dict["content"] = {}
- msg_dict["unsigned"]["redacted_by"] = redaction.event_id
- msg_dict["unsigned"]["redacted_because"] = redaction
- redacted = make_event_from_dict(
- msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
- )
- self.check("get_event", [msg.event_id], redacted)
- def test_invites(self) -> None:
- self.persist(type="m.room.create", key="", creator=USER_ID)
- self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
- event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
- assert event.internal_metadata.stream_ordering is not None
- self.replicate()
- self.check(
- "get_invited_rooms_for_local_user",
- [USER_ID_2],
- [
- RoomsForUser(
- ROOM_ID,
- USER_ID,
- "invite",
- event.event_id,
- event.internal_metadata.stream_ordering,
- RoomVersions.V1.identifier,
- )
- ],
- )
- @parameterized.expand([(True,), (False,)])
- def test_push_actions_for_user(self, send_receipt: bool) -> None:
- self.persist(type="m.room.create", key="", creator=USER_ID)
- self.persist(type="m.room.member", key=USER_ID, membership="join")
- self.persist(
- type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join"
- )
- event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
- self.replicate()
- if send_receipt:
- self.get_success(
- self.master_store.insert_receipt(
- ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {}
- )
- )
- self.check(
- "get_unread_event_push_actions_by_room_for_user",
- [ROOM_ID, USER_ID_2],
- RoomNotifCounts(
- NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}
- ),
- )
- self.persist(
- type="m.room.message",
- msgtype="m.text",
- body="world",
- push_actions=[(USER_ID_2, ["notify"])],
- )
- self.replicate()
- self.check(
- "get_unread_event_push_actions_by_room_for_user",
- [ROOM_ID, USER_ID_2],
- RoomNotifCounts(
- NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}
- ),
- )
- self.persist(
- type="m.room.message",
- msgtype="m.text",
- body="world",
- push_actions=[
- (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
- ],
- )
- self.replicate()
- self.check(
- "get_unread_event_push_actions_by_room_for_user",
- [ROOM_ID, USER_ID_2],
- RoomNotifCounts(
- NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}
- ),
- )
- def test_get_rooms_for_user_with_stream_ordering(self) -> None:
- """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
- by rows in the events stream
- """
- self.persist(type="m.room.create", key="", creator=USER_ID)
- self.persist(type="m.room.member", key=USER_ID, membership="join")
- self.replicate()
- self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
- j2 = self.persist(
- type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
- )
- assert j2.internal_metadata.stream_ordering is not None
- self.replicate()
- expected_pos = PersistedEventPosition(
- "master", j2.internal_metadata.stream_ordering
- )
- self.check(
- "get_rooms_for_user_with_stream_ordering",
- (USER_ID_2,),
- {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
- )
- def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
- self,
- ) -> None:
- """Check that current_state invalidation happens correctly with multiple events
- in the persistence batch.
- This test attempts to reproduce a race condition between the event persistence
- loop and a worker-based Sync handler.
- The problem occurred when the master persisted several events in one batch. It
- only updates the current_state at the end of each batch, so the obvious thing
- to do is then to issue a current_state_delta stream update corresponding to the
- last stream_id in the batch.
- However, that raises the possibility that a worker will see the replication
- notification for a join event before the current_state caches are invalidated.
- The test involves:
- * creating a join and a message event for a user, and persisting them in the
- same batch
- * controlling the replication stream so that updates are sent gradually
- * between each bunch of replication updates, check that we see a consistent
- snapshot of the state.
- """
- self.persist(type="m.room.create", key="", creator=USER_ID)
- self.persist(type="m.room.member", key=USER_ID, membership="join")
- self.replicate()
- self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
- # limit the replication rate
- repl_transport = self._server_transport
- assert isinstance(repl_transport, FakeTransport)
- repl_transport.autoflush = False
- # build the join and message events and persist them in the same batch.
- logger.info("----- build test events ------")
- j2, j2ctx = self.build_event(
- type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
- )
- msg, msgctx = self.build_event()
- self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
- self.replicate()
- assert j2.internal_metadata.stream_ordering is not None
- event_source = RoomEventSource(self.hs)
- event_source.store = self.worker_store
- current_token = event_source.get_current_key()
- # gradually stream out the replication
- while repl_transport.buffer:
- logger.info("------ flush ------")
- repl_transport.flush(30)
- self.pump(0)
- prev_token = current_token
- current_token = event_source.get_current_key()
- # attempt to replicate the behaviour of the sync handler.
- #
- # First, we get a list of the rooms we are joined to
- joined_rooms = self.get_success(
- self.worker_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
- )
- # Then, we get a list of the events since the last sync
- membership_changes = self.get_success(
- self.worker_store.get_membership_changes_for_user(
- USER_ID_2, prev_token, current_token
- )
- )
- logger.info(
- "%s->%s: joined_rooms=%r membership_changes=%r",
- prev_token,
- current_token,
- joined_rooms,
- membership_changes,
- )
- # the membership change is only any use to us if the room is in the
- # joined_rooms list.
- if membership_changes:
- expected_pos = PersistedEventPosition(
- "master", j2.internal_metadata.stream_ordering
- )
- self.assertEqual(
- joined_rooms,
- {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
- )
- event_id = 0
- def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
- """
- Returns:
- The event that was persisted.
- """
- event, context = self.build_event(**kwargs)
- if backfill:
- self.get_success(
- self.persistance.persist_events([(event, context)], backfilled=True)
- )
- else:
- self.get_success(self.persistance.persist_event(event, context))
- return event
- def build_event(
- self,
- sender: str = USER_ID,
- room_id: str = ROOM_ID,
- type: str = "m.room.message",
- key: Optional[str] = None,
- internal: Optional[dict] = None,
- depth: Optional[int] = None,
- prev_events: Optional[List[Tuple[str, dict]]] = None,
- auth_events: Optional[List[str]] = None,
- prev_state: Optional[List[str]] = None,
- redacts: Optional[str] = None,
- push_actions: Iterable = frozenset(),
- **content: object,
- ) -> Tuple[EventBase, EventContext]:
- prev_events = prev_events or []
- auth_events = auth_events or []
- prev_state = prev_state or []
- if depth is None:
- depth = self.event_id
- if not prev_events:
- latest_event_ids = self.get_success(
- self.master_store.get_latest_event_ids_in_room(room_id)
- )
- prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
- event_dict = {
- "sender": sender,
- "type": type,
- "content": content,
- "event_id": "$%d:blue" % (self.event_id,),
- "room_id": room_id,
- "depth": depth,
- "origin_server_ts": self.event_id,
- "prev_events": prev_events,
- "auth_events": auth_events,
- }
- if key is not None:
- event_dict["state_key"] = key
- event_dict["prev_state"] = prev_state
- if redacts is not None:
- event_dict["redacts"] = redacts
- event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {})
- self.event_id += 1
- state_handler = self.hs.get_state_handler()
- context = self.get_success(state_handler.compute_event_context(event))
- self.get_success(
- self.master_store.add_push_actions_to_staging(
- event.event_id,
- dict(push_actions),
- False,
- "main",
- )
- )
- return event, context
|