Browse Source

Implement changes to MSC2285 (hidden read receipts) (#12168)

* Changes hidden read receipts to be a separate receipt type
  (instead of a field on `m.read`).
* Updates the `/receipts` endpoint to accept `m.fully_read`.
Šimon Brandner 2 years ago
parent
commit
116a4c8340

+ 1 - 0
changelog.d/12168.feature

@@ -0,0 +1 @@
+Implement [changes](https://github.com/matrix-org/matrix-spec-proposals/pull/2285/commits/4a77139249c2e830aec3c7d6bd5501a514d1cc27) to [MSC2285 (hidden read receipts)](https://github.com/matrix-org/matrix-spec-proposals/pull/2285). Contributed by @SimonBrandner.

+ 2 - 4
synapse/api/constants.py

@@ -255,7 +255,5 @@ class GuestAccess:
 
 class ReceiptTypes:
     READ: Final = "m.read"
-
-
-class ReadReceiptEventFields:
-    MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
+    READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
+    FULLY_READ: Final = "m.fully_read"

+ 27 - 38
synapse/handlers/receipts.py

@@ -14,7 +14,7 @@
 import logging
 from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
+from synapse.api.constants import ReceiptTypes
 from synapse.appservice import ApplicationService
 from synapse.streams import EventSource
 from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
@@ -112,7 +112,7 @@ class ReceiptsHandler:
             )
 
             if not res:
-                # res will be None if this read receipt is 'old'
+                # res will be None if this receipt is 'old'
                 continue
 
             stream_id, max_persisted_id = res
@@ -138,7 +138,7 @@ class ReceiptsHandler:
         return True
 
     async def received_client_receipt(
-        self, room_id: str, receipt_type: str, user_id: str, event_id: str, hidden: bool
+        self, room_id: str, receipt_type: str, user_id: str, event_id: str
     ) -> None:
         """Called when a client tells us a local user has read up to the given
         event_id in the room.
@@ -148,16 +148,14 @@ class ReceiptsHandler:
             receipt_type=receipt_type,
             user_id=user_id,
             event_ids=[event_id],
-            data={"ts": int(self.clock.time_msec()), "hidden": hidden},
+            data={"ts": int(self.clock.time_msec())},
         )
 
         is_new = await self._handle_new_receipts([receipt])
         if not is_new:
             return
 
-        if self.federation_sender and not (
-            self.hs.config.experimental.msc2285_enabled and hidden
-        ):
+        if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE:
             await self.federation_sender.send_read_receipt(receipt)
 
 
@@ -168,6 +166,13 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
     @staticmethod
     def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:
+        """
+        This method takes in what is returned by
+        get_linearized_receipts_for_rooms() and goes through read receipts
+        filtering out m.read.private receipts if they were not sent by the
+        current user.
+        """
+
         visible_events = []
 
         # filter out hidden receipts the user shouldn't see
@@ -176,37 +181,21 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
             new_event = event.copy()
             new_event["content"] = {}
 
-            for event_id in content.keys():
-                event_content = content.get(event_id, {})
-                m_read = event_content.get(ReceiptTypes.READ, {})
-
-                # If m_read is missing copy over the original event_content as there is nothing to process here
-                if not m_read:
-                    new_event["content"][event_id] = event_content.copy()
-                    continue
-
-                new_users = {}
-                for rr_user_id, user_rr in m_read.items():
-                    try:
-                        hidden = user_rr.get("hidden")
-                    except AttributeError:
-                        # Due to https://github.com/matrix-org/synapse/issues/10376
-                        # there are cases where user_rr is a string, in those cases
-                        # we just ignore the read receipt
-                        continue
-
-                    if hidden is not True or rr_user_id == user_id:
-                        new_users[rr_user_id] = user_rr.copy()
-                        # If hidden has a value replace hidden with the correct prefixed key
-                        if hidden is not None:
-                            new_users[rr_user_id].pop("hidden")
-                            new_users[rr_user_id][
-                                ReadReceiptEventFields.MSC2285_HIDDEN
-                            ] = hidden
-
-                # Set new users unless empty
-                if len(new_users.keys()) > 0:
-                    new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
+            for event_id, event_content in content.items():
+                receipt_event = {}
+                for receipt_type, receipt_content in event_content.items():
+                    if receipt_type == ReceiptTypes.READ_PRIVATE:
+                        user_rr = receipt_content.get(user_id, None)
+                        if user_rr:
+                            receipt_event[ReceiptTypes.READ_PRIVATE] = {
+                                user_id: user_rr.copy()
+                            }
+                    else:
+                        receipt_event[receipt_type] = receipt_content.copy()
+
+                # Only include the receipt event if it is non-empty.
+                if receipt_event:
+                    new_event["content"][event_id] = receipt_event
 
             # Append new_event to visible_events unless empty
             if len(new_event["content"].keys()) > 0:

+ 1 - 1
synapse/handlers/sync.py

@@ -1045,7 +1045,7 @@ class SyncHandler:
             last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
                 user_id=sync_config.user.to_string(),
                 room_id=room_id,
-                receipt_type=ReceiptTypes.READ,
+                receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
             )
 
             return await self.store.get_unread_event_push_actions_by_room_for_user(

+ 3 - 1
synapse/push/push_tools.py

@@ -24,7 +24,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
     invites = await store.get_invited_rooms_for_local_user(user_id)
     joins = await store.get_rooms_for_user(user_id)
 
-    my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)
+    my_receipts_by_room = await store.get_receipts_for_user(
+        user_id, (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
+    )
 
     badge = len(invites)
 

+ 1 - 1
synapse/rest/client/notifications.py

@@ -58,7 +58,7 @@ class NotificationsServlet(RestServlet):
         )
 
         receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
-            user_id, ReceiptTypes.READ
+            user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
         )
 
         notif_event_ids = [pa.event_id for pa in push_actions]

+ 22 - 10
synapse/rest/client/read_marker.py

@@ -15,8 +15,8 @@
 import logging
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.constants import ReceiptTypes
+from synapse.api.errors import SynapseError
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.http.site import SynapseRequest
@@ -36,6 +36,7 @@ class ReadMarkerRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.auth = hs.get_auth()
+        self.config = hs.config
         self.receipts_handler = hs.get_receipts_handler()
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
@@ -48,27 +49,38 @@ class ReadMarkerRestServlet(RestServlet):
         await self.presence_handler.bump_presence_active_time(requester.user)
 
         body = parse_json_object_from_request(request)
-        read_event_id = body.get(ReceiptTypes.READ, None)
-        hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
 
-        if not isinstance(hidden, bool):
+        valid_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ}
+        if self.config.experimental.msc2285_enabled:
+            valid_receipt_types.add(ReceiptTypes.READ_PRIVATE)
+
+        if set(body.keys()) > valid_receipt_types:
             raise SynapseError(
                 400,
-                "Param %s must be a boolean, if given"
-                % ReadReceiptEventFields.MSC2285_HIDDEN,
-                Codes.BAD_JSON,
+                "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'"
+                if self.config.experimental.msc2285_enabled
+                else "Receipt type must be 'm.read' or 'm.fully_read'",
             )
 
+        read_event_id = body.get(ReceiptTypes.READ, None)
         if read_event_id:
             await self.receipts_handler.received_client_receipt(
                 room_id,
                 ReceiptTypes.READ,
                 user_id=requester.user.to_string(),
                 event_id=read_event_id,
-                hidden=hidden,
             )
 
-        read_marker_event_id = body.get("m.fully_read", None)
+        read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
+        if read_private_event_id and self.config.experimental.msc2285_enabled:
+            await self.receipts_handler.received_client_receipt(
+                room_id,
+                ReceiptTypes.READ_PRIVATE,
+                user_id=requester.user.to_string(),
+                event_id=read_private_event_id,
+            )
+
+        read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)
         if read_marker_event_id:
             await self.read_marker_handler.received_client_read_marker(
                 room_id,

+ 31 - 20
synapse/rest/client/receipts.py

@@ -16,8 +16,8 @@ import logging
 import re
 from typing import TYPE_CHECKING, Tuple
 
-from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.constants import ReceiptTypes
+from synapse.api.errors import SynapseError
 from synapse.http import get_request_user_agent
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -46,6 +46,7 @@ class ReceiptRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_receipts_handler()
+        self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
     async def on_POST(
@@ -53,7 +54,19 @@ class ReceiptRestServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        if receipt_type != ReceiptTypes.READ:
+        if self.hs.config.experimental.msc2285_enabled and receipt_type not in [
+            ReceiptTypes.READ,
+            ReceiptTypes.READ_PRIVATE,
+            ReceiptTypes.FULLY_READ,
+        ]:
+            raise SynapseError(
+                400,
+                "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'",
+            )
+        elif (
+            not self.hs.config.experimental.msc2285_enabled
+            and receipt_type != ReceiptTypes.READ
+        ):
             raise SynapseError(400, "Receipt type must be 'm.read'")
 
         # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
@@ -62,26 +75,24 @@ class ReceiptRestServlet(RestServlet):
         if "Android" in user_agent:
             if pattern.match(user_agent) or "Riot" in user_agent:
                 allow_empty_body = True
-        body = parse_json_object_from_request(request, allow_empty_body)
-        hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
-
-        if not isinstance(hidden, bool):
-            raise SynapseError(
-                400,
-                "Param %s must be a boolean, if given"
-                % ReadReceiptEventFields.MSC2285_HIDDEN,
-                Codes.BAD_JSON,
-            )
+        # This call makes sure possible empty body is handled correctly
+        parse_json_object_from_request(request, allow_empty_body)
 
         await self.presence_handler.bump_presence_active_time(requester.user)
 
-        await self.receipts_handler.received_client_receipt(
-            room_id,
-            receipt_type,
-            user_id=requester.user.to_string(),
-            event_id=event_id,
-            hidden=hidden,
-        )
+        if receipt_type == ReceiptTypes.FULLY_READ:
+            await self.read_marker_handler.received_client_read_marker(
+                room_id,
+                user_id=requester.user.to_string(),
+                event_id=event_id,
+            )
+        else:
+            await self.receipts_handler.received_client_receipt(
+                room_id,
+                receipt_type,
+                user_id=requester.user.to_string(),
+                event_id=event_id,
+            )
 
         return 200, {}
 

+ 110 - 32
synapse/storage/databases/main/receipts.py

@@ -144,43 +144,77 @@ class ReceiptsWorkerStore(SQLBaseStore):
             desc="get_receipts_for_room",
         )
 
-    @cached()
     async def get_last_receipt_event_id_for_user(
-        self, user_id: str, room_id: str, receipt_type: str
+        self, user_id: str, room_id: str, receipt_types: Iterable[str]
     ) -> Optional[str]:
         """
-        Fetch the event ID for the latest receipt in a room with the given receipt type.
+        Fetch the event ID for the latest receipt in a room with one of the given receipt types.
 
         Args:
             user_id: The user to fetch receipts for.
             room_id: The room ID to fetch the receipt for.
-            receipt_type: The receipt type to fetch.
+            receipt_type: The receipt types to fetch. Earlier receipt types
+                are given priority if multiple receipts point to the same event.
 
         Returns:
-            The event ID of the latest receipt, if one exists; otherwise `None`.
+            The latest receipt, if one exists.
         """
-        return await self.db_pool.simple_select_one_onecol(
-            table="receipts_linearized",
-            keyvalues={
-                "room_id": room_id,
-                "receipt_type": receipt_type,
-                "user_id": user_id,
-            },
-            retcol="event_id",
-            desc="get_own_receipt_for_user",
-            allow_none=True,
-        )
+        latest_event_id: Optional[str] = None
+        latest_stream_ordering = 0
+        for receipt_type in receipt_types:
+            result = await self._get_last_receipt_event_id_for_user(
+                user_id, room_id, receipt_type
+            )
+            if result is None:
+                continue
+            event_id, stream_ordering = result
+
+            if latest_event_id is None or latest_stream_ordering < stream_ordering:
+                latest_event_id = event_id
+                latest_stream_ordering = stream_ordering
+
+        return latest_event_id
 
     @cached()
+    async def _get_last_receipt_event_id_for_user(
+        self, user_id: str, room_id: str, receipt_type: str
+    ) -> Optional[Tuple[str, int]]:
+        """
+        Fetch the event ID and stream ordering for the latest receipt.
+
+        Args:
+            user_id: The user to fetch receipts for.
+            room_id: The room ID to fetch the receipt for.
+            receipt_type: The receipt type to fetch.
+
+        Returns:
+            The event ID and stream ordering of the latest receipt, if one exists;
+            otherwise `None`.
+        """
+        sql = """
+            SELECT event_id, stream_ordering
+            FROM receipts_linearized
+            INNER JOIN events USING (room_id, event_id)
+            WHERE user_id = ?
+            AND room_id = ?
+            AND receipt_type = ?
+        """
+
+        def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]:
+            txn.execute(sql, (user_id, room_id, receipt_type))
+            return cast(Optional[Tuple[str, int]], txn.fetchone())
+
+        return await self.db_pool.runInteraction("get_own_receipt_for_user", f)
+
     async def get_receipts_for_user(
-        self, user_id: str, receipt_type: str
+        self, user_id: str, receipt_types: Iterable[str]
     ) -> Dict[str, str]:
         """
         Fetch the event IDs for the latest receipts sent by the given user.
 
         Args:
             user_id: The user to fetch receipts for.
-            receipt_type: The receipt type to fetch.
+            receipt_types: The receipt types to check.
 
         Returns:
             A map of room ID to the event ID of the latest receipt for that room.
@@ -188,16 +222,48 @@ class ReceiptsWorkerStore(SQLBaseStore):
             If the user has not sent a receipt to a room then it will not appear
             in the returned dictionary.
         """
-        rows = await self.db_pool.simple_select_list(
-            table="receipts_linearized",
-            keyvalues={"user_id": user_id, "receipt_type": receipt_type},
-            retcols=("room_id", "event_id"),
-            desc="get_receipts_for_user",
+        results = await self.get_receipts_for_user_with_orderings(
+            user_id, receipt_types
         )
 
-        return {row["room_id"]: row["event_id"] for row in rows}
+        # Reduce the result to room ID -> event ID.
+        return {
+            room_id: room_result["event_id"] for room_id, room_result in results.items()
+        }
 
     async def get_receipts_for_user_with_orderings(
+        self, user_id: str, receipt_types: Iterable[str]
+    ) -> JsonDict:
+        """
+        Fetch receipts for all rooms that the given user is joined to.
+
+        Args:
+            user_id: The user to fetch receipts for.
+            receipt_types: The receipt types to fetch. Earlier receipt types
+                are given priority if multiple receipts point to the same event.
+
+        Returns:
+            A map of room ID to the latest receipt (for the given types).
+        """
+        results: JsonDict = {}
+        for receipt_type in receipt_types:
+            partial_result = await self._get_receipts_for_user_with_orderings(
+                user_id, receipt_type
+            )
+            for room_id, room_result in partial_result.items():
+                # If the room has not yet been seen, or the receipt is newer,
+                # use it.
+                if (
+                    room_id not in results
+                    or results[room_id]["stream_ordering"]
+                    < room_result["stream_ordering"]
+                ):
+                    results[room_id] = room_result
+
+        return results
+
+    @cached()
+    async def _get_receipts_for_user_with_orderings(
         self, user_id: str, receipt_type: str
     ) -> JsonDict:
         """
@@ -220,8 +286,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 " WHERE rl.room_id = e.room_id"
                 " AND rl.event_id = e.event_id"
                 " AND user_id = ?"
+                " AND receipt_type = ?"
             )
-            txn.execute(sql, (user_id,))
+            txn.execute(sql, (user_id, receipt_type))
             return cast(List[Tuple[str, str, int, int]], txn.fetchall())
 
         rows = await self.db_pool.runInteraction(
@@ -552,9 +619,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
     def invalidate_caches_for_receipt(
         self, room_id: str, receipt_type: str, user_id: str
     ) -> None:
-        self.get_receipts_for_user.invalidate((user_id, receipt_type))
+        self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
         self._get_linearized_receipts_for_room.invalidate((room_id,))
-        self.get_last_receipt_event_id_for_user.invalidate(
+        self._get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
@@ -590,8 +657,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
         """Inserts a receipt into the database if it's newer than the current one.
 
         Returns:
-            None if the RR is older than the current RR
-            otherwise, the rx timestamp of the event that the RR corresponds to
+            None if the receipt is older than the current receipt
+            otherwise, the rx timestamp of the event that the receipt corresponds to
                 (or 0 if the event is unknown)
         """
         assert self._can_write_to_receipts
@@ -612,7 +679,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         if stream_ordering is not None:
             sql = (
                 "SELECT stream_ordering, event_id FROM events"
-                " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+                " INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
                 " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
             )
             txn.execute(sql, (room_id, receipt_type, user_id))
@@ -653,7 +720,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
             lock=False,
         )
 
-        if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
+        if (
+            receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
+            and stream_ordering is not None
+        ):
             self._remove_old_push_actions_before_txn(  # type: ignore[attr-defined]
                 txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
             )
@@ -672,6 +742,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         Automatically does conversion between linearized and graph
         representations.
+
+        Returns:
+            The new receipts stream ID and token, if the receipt is newer than
+            what was previously persisted. None, otherwise.
         """
         assert self._can_write_to_receipts
 
@@ -719,6 +793,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
                 stream_id=stream_id,
             )
 
+        # If the receipt was older than the currently persisted one, nothing to do.
         if event_ts is None:
             return None
 
@@ -774,7 +849,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
             receipt_type,
             user_id,
         )
-        txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
+        txn.call_after(
+            self._get_receipts_for_user_with_orderings.invalidate,
+            (user_id, receipt_type),
+        )
         # FIXME: This shouldn't invalidate the whole cache
         txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
 

+ 76 - 53
tests/handlers/test_receipts.py

@@ -15,7 +15,7 @@
 
 from typing import List
 
-from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
+from synapse.api.constants import ReceiptTypes
 from synapse.types import JsonDict
 
 from tests import unittest
@@ -25,20 +25,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.event_source = hs.get_event_sources().sources.receipt
 
-    # In the first param of _test_filters_hidden we use "hidden" instead of
-    # ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking
-    # the data from the database which doesn't use the prefix
-
     def test_filters_out_hidden_receipt(self):
         self._test_filters_hidden(
             [
                 {
                     "content": {
                         "$1435641916114394fHBLK:matrix.org": {
-                            ReceiptTypes.READ: {
+                            ReceiptTypes.READ_PRIVATE: {
                                 "@rikj:jki.re": {
                                     "ts": 1436451550453,
-                                    "hidden": True,
                                 }
                             }
                         }
@@ -50,58 +45,23 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
             [],
         )
 
-    def test_does_not_filter_out_our_hidden_receipt(self):
-        self._test_filters_hidden(
-            [
-                {
-                    "content": {
-                        "$1435641916hfgh4394fHBLK:matrix.org": {
-                            ReceiptTypes.READ: {
-                                "@me:server.org": {
-                                    "ts": 1436451550453,
-                                    "hidden": True,
-                                },
-                            }
-                        }
-                    },
-                    "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
-                    "type": "m.receipt",
-                }
-            ],
-            [
-                {
-                    "content": {
-                        "$1435641916hfgh4394fHBLK:matrix.org": {
-                            ReceiptTypes.READ: {
-                                "@me:server.org": {
-                                    "ts": 1436451550453,
-                                    ReadReceiptEventFields.MSC2285_HIDDEN: True,
-                                },
-                            }
-                        }
-                    },
-                    "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
-                    "type": "m.receipt",
-                }
-            ],
-        )
-
     def test_filters_out_hidden_receipt_and_ignores_rest(self):
         self._test_filters_hidden(
             [
                 {
                     "content": {
                         "$1dgdgrd5641916114394fHBLK:matrix.org": {
-                            ReceiptTypes.READ: {
+                            ReceiptTypes.READ_PRIVATE: {
                                 "@rikj:jki.re": {
                                     "ts": 1436451550453,
-                                    "hidden": True,
                                 },
+                            },
+                            ReceiptTypes.READ: {
                                 "@user:jki.re": {
                                     "ts": 1436451550453,
                                 },
-                            }
-                        }
+                            },
+                        },
                     },
                     "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
                     "type": "m.receipt",
@@ -130,10 +90,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
                 {
                     "content": {
                         "$14356419edgd14394fHBLK:matrix.org": {
-                            ReceiptTypes.READ: {
+                            ReceiptTypes.READ_PRIVATE: {
                                 "@rikj:jki.re": {
                                     "ts": 1436451550453,
-                                    "hidden": True,
                                 },
                             }
                         },
@@ -223,7 +182,6 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
             [
                 {
                     "content": {
-                        "$143564gdfg6114394fHBLK:matrix.org": {},
                         "$1435641916114394fHBLK:matrix.org": {
                             ReceiptTypes.READ: {
                                 "@user:jki.re": {
@@ -244,10 +202,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
                 {
                     "content": {
                         "$14356419edgd14394fHBLK:matrix.org": {
-                            ReceiptTypes.READ: {
+                            ReceiptTypes.READ_PRIVATE: {
                                 "@rikj:jki.re": {
                                     "ts": 1436451550453,
-                                    "hidden": True,
                                 },
                             }
                         },
@@ -306,7 +263,73 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
                     "type": "m.receipt",
                 },
             ],
-            [],
+            [
+                {
+                    "content": {
+                        "$14356419edgd14394fHBLK:matrix.org": {
+                            ReceiptTypes.READ: {
+                                "@rikj:jki.re": "string",
+                            }
+                        },
+                    },
+                    "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
+                    "type": "m.receipt",
+                },
+            ],
+        )
+
+    def test_leaves_our_hidden_and_their_public(self):
+        self._test_filters_hidden(
+            [
+                {
+                    "content": {
+                        "$1dgdgrd5641916114394fHBLK:matrix.org": {
+                            ReceiptTypes.READ_PRIVATE: {
+                                "@me:server.org": {
+                                    "ts": 1436451550453,
+                                },
+                            },
+                            ReceiptTypes.READ: {
+                                "@rikj:jki.re": {
+                                    "ts": 1436451550453,
+                                },
+                            },
+                            "a.receipt.type": {
+                                "@rikj:jki.re": {
+                                    "ts": 1436451550453,
+                                },
+                            },
+                        },
+                    },
+                    "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
+                    "type": "m.receipt",
+                }
+            ],
+            [
+                {
+                    "content": {
+                        "$1dgdgrd5641916114394fHBLK:matrix.org": {
+                            ReceiptTypes.READ_PRIVATE: {
+                                "@me:server.org": {
+                                    "ts": 1436451550453,
+                                },
+                            },
+                            ReceiptTypes.READ: {
+                                "@rikj:jki.re": {
+                                    "ts": 1436451550453,
+                                },
+                            },
+                            "a.receipt.type": {
+                                "@rikj:jki.re": {
+                                    "ts": 1436451550453,
+                                },
+                            },
+                        }
+                    },
+                    "room_id": "!jEsUZKDJdhlrceRyVU:example.org",
+                    "type": "m.receipt",
+                }
+            ],
         )
 
     def _test_filters_hidden(

+ 229 - 9
tests/replication/slave/storage/test_receipts.py

@@ -14,26 +14,246 @@
 
 from synapse.api.constants import ReceiptTypes
 from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.types import UserID, create_requester
+
+from tests.test_utils.event_injection import create_event
 
 from ._base import BaseSlavedStoreTestCase
 
-USER_ID = "@feeling:blue"
-ROOM_ID = "!room:blue"
-EVENT_ID = "$event:blue"
+OTHER_USER_ID = "@other:test"
+OUR_USER_ID = "@our:test"
 
 
 class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
 
     STORE_TYPE = SlavedReceiptsStore
 
-    def test_receipt(self):
-        self.check("get_receipts_for_user", [USER_ID, ReceiptTypes.READ], {})
+    def prepare(self, reactor, clock, homeserver):
+        super().prepare(reactor, clock, homeserver)
+        self.room_creator = homeserver.get_room_creation_handler()
+        self.persist_event_storage = self.hs.get_storage().persistence
+
+        # Create a test user
+        self.ourUser = UserID.from_string(OUR_USER_ID)
+        self.ourRequester = create_requester(self.ourUser)
+
+        # Create a second test user
+        self.otherUser = UserID.from_string(OTHER_USER_ID)
+        self.otherRequester = create_requester(self.otherUser)
+
+        # Create a test room
+        info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
+        self.room_id1 = info["room_id"]
+
+        # Create a second test room
+        info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
+        self.room_id2 = info["room_id"]
+
+        # Join the second user to the first room
+        memberEvent, memberEventContext = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id1,
+                type="m.room.member",
+                sender=self.otherRequester.user.to_string(),
+                state_key=self.otherRequester.user.to_string(),
+                content={"membership": "join"},
+            )
+        )
+        self.get_success(
+            self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+        )
+
+        # Join the second user to the second room
+        memberEvent, memberEventContext = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id2,
+                type="m.room.member",
+                sender=self.otherRequester.user.to_string(),
+                state_key=self.otherRequester.user.to_string(),
+                content={"membership": "join"},
+            )
+        )
+        self.get_success(
+            self.persist_event_storage.persist_event(memberEvent, memberEventContext)
+        )
+
+    def test_return_empty_with_no_data(self):
+        res = self.get_success(
+            self.master_store.get_receipts_for_user(
+                OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+            )
+        )
+        self.assertEqual(res, {})
+
+        res = self.get_success(
+            self.master_store.get_receipts_for_user_with_orderings(
+                OUR_USER_ID,
+                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+            )
+        )
+        self.assertEqual(res, {})
+
+        res = self.get_success(
+            self.master_store.get_last_receipt_event_id_for_user(
+                OUR_USER_ID,
+                self.room_id1,
+                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+            )
+        )
+        self.assertEqual(res, None)
+
+    def test_get_receipts_for_user(self):
+        # Send some events into the first room
+        event1_1_id = self.create_and_send_event(
+            self.room_id1, UserID.from_string(OTHER_USER_ID)
+        )
+        event1_2_id = self.create_and_send_event(
+            self.room_id1, UserID.from_string(OTHER_USER_ID)
+        )
+
+        # Send public read receipt for the first event
+        self.get_success(
+            self.master_store.insert_receipt(
+                self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+            )
+        )
+        # Send private read receipt for the second event
         self.get_success(
             self.master_store.insert_receipt(
-                ROOM_ID, ReceiptTypes.READ, USER_ID, [EVENT_ID], {}
+                self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+            )
+        )
+
+        # Test we get the latest event when we want both private and public receipts
+        res = self.get_success(
+            self.master_store.get_receipts_for_user(
+                OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
             )
         )
-        self.replicate()
-        self.check(
-            "get_receipts_for_user", [USER_ID, ReceiptTypes.READ], {ROOM_ID: EVENT_ID}
+        self.assertEqual(res, {self.room_id1: event1_2_id})
+
+        # Test we get the older event when we want only public receipt
+        res = self.get_success(
+            self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
+        )
+        self.assertEqual(res, {self.room_id1: event1_1_id})
+
+        # Test we get the latest event when we want only the public receipt
+        res = self.get_success(
+            self.master_store.get_receipts_for_user(
+                OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]
+            )
+        )
+        self.assertEqual(res, {self.room_id1: event1_2_id})
+
+        # Test receipt updating
+        self.get_success(
+            self.master_store.insert_receipt(
+                self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
+            )
+        )
+        res = self.get_success(
+            self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
+        )
+        self.assertEqual(res, {self.room_id1: event1_2_id})
+
+        # Send some events into the second room
+        event2_1_id = self.create_and_send_event(
+            self.room_id2, UserID.from_string(OTHER_USER_ID)
+        )
+
+        # Test new room is reflected in what the method returns
+        self.get_success(
+            self.master_store.insert_receipt(
+                self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+            )
+        )
+        res = self.get_success(
+            self.master_store.get_receipts_for_user(
+                OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+            )
+        )
+        self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
+
+    def test_get_last_receipt_event_id_for_user(self):
+        # Send some events into the first room
+        event1_1_id = self.create_and_send_event(
+            self.room_id1, UserID.from_string(OTHER_USER_ID)
+        )
+        event1_2_id = self.create_and_send_event(
+            self.room_id1, UserID.from_string(OTHER_USER_ID)
+        )
+
+        # Send public read receipt for the first event
+        self.get_success(
+            self.master_store.insert_receipt(
+                self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
+            )
+        )
+        # Send private read receipt for the second event
+        self.get_success(
+            self.master_store.insert_receipt(
+                self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
+            )
+        )
+
+        # Test we get the latest event when we want both private and public receipts
+        res = self.get_success(
+            self.master_store.get_last_receipt_event_id_for_user(
+                OUR_USER_ID,
+                self.room_id1,
+                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+            )
+        )
+        self.assertEqual(res, event1_2_id)
+
+        # Test we get the older event when we want only public receipt
+        res = self.get_success(
+            self.master_store.get_last_receipt_event_id_for_user(
+                OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
+            )
+        )
+        self.assertEqual(res, event1_1_id)
+
+        # Test we get the latest event when we want only the private receipt
+        res = self.get_success(
+            self.master_store.get_last_receipt_event_id_for_user(
+                OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
+            )
+        )
+        self.assertEqual(res, event1_2_id)
+
+        # Test receipt updating
+        self.get_success(
+            self.master_store.insert_receipt(
+                self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
+            )
+        )
+        res = self.get_success(
+            self.master_store.get_last_receipt_event_id_for_user(
+                OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
+            )
+        )
+        self.assertEqual(res, event1_2_id)
+
+        # Send some events into the second room
+        event2_1_id = self.create_and_send_event(
+            self.room_id2, UserID.from_string(OTHER_USER_ID)
+        )
+
+        # Test new room is reflected in what the method returns
+        self.get_success(
+            self.master_store.insert_receipt(
+                self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
+            )
+        )
+        res = self.get_success(
+            self.master_store.get_last_receipt_event_id_for_user(
+                OUR_USER_ID,
+                self.room_id2,
+                [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
+            )
         )
+        self.assertEqual(res, event2_1_id)

+ 144 - 17
tests/rest/client/test_sync.py

@@ -23,7 +23,6 @@ import synapse.rest.admin
 from synapse.api.constants import (
     EventContentFields,
     EventTypes,
-    ReadReceiptEventFields,
     ReceiptTypes,
     RelationTypes,
 )
@@ -347,7 +346,7 @@ class SyncKnockTestCase(
         # Knock on a room
         channel = self.make_request(
             "POST",
-            "/_matrix/client/r0/knock/%s" % (self.room_id,),
+            f"/_matrix/client/r0/knock/{self.room_id}",
             b"{}",
             self.knocker_tok,
         )
@@ -412,18 +411,79 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         # Send a message as the first user
         res = self.helper.send(self.room_id, body="hello", tok=self.tok)
 
-        # Send a read receipt to tell the server the first user's message was read
-        body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8")
+        # Send a private read receipt to tell the server the first user's message was read
         channel = self.make_request(
             "POST",
-            "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]),
-            body,
+            f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+            {},
             access_token=self.tok2,
         )
         self.assertEqual(channel.code, 200)
 
-        # Test that the first user can't see the other user's hidden read receipt
-        self.assertEqual(self._get_read_receipt(), None)
+        # Test that the first user can't see the other user's private read receipt
+        self.assertIsNone(self._get_read_receipt())
+
+    @override_config({"experimental_features": {"msc2285_enabled": True}})
+    def test_public_receipt_can_override_private(self) -> None:
+        """
+        Sending a public read receipt to the same event which has a private read
+        receipt should cause that receipt to become public.
+        """
+        # Send a message as the first user
+        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+        # Send a private read receipt
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertIsNone(self._get_read_receipt())
+
+        # Send a public read receipt
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Test that we did override the private read receipt
+        self.assertNotEqual(self._get_read_receipt(), None)
+
+    @override_config({"experimental_features": {"msc2285_enabled": True}})
+    def test_private_receipt_cannot_override_public(self) -> None:
+        """
+        Sending a private read receipt to the same event which has a public read
+        receipt should cause no change.
+        """
+        # Send a message as the first user
+        res = self.helper.send(self.room_id, body="hello", tok=self.tok)
+
+        # Send a public read receipt
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertNotEqual(self._get_read_receipt(), None)
+
+        # Send a private read receipt
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+            {},
+            access_token=self.tok2,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # Test that we didn't override the public read receipt
+        self.assertIsNone(self._get_read_receipt())
 
     @parameterized.expand(
         [
@@ -455,7 +515,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         # Send a read receipt for this message with an empty body
         channel = self.make_request(
             "POST",
-            "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]),
+            f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}",
             access_token=self.tok2,
             custom_headers=[("User-Agent", user_agent)],
         )
@@ -479,6 +539,9 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         # Store the next batch for the next request.
         self.next_batch = channel.json_body["next_batch"]
 
+        if channel.json_body.get("rooms", None) is None:
+            return None
+
         # Return the read receipt
         ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][
             "ephemeral"
@@ -499,7 +562,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
 
     def default_config(self) -> JsonDict:
         config = super().default_config()
-        config["experimental_features"] = {"msc2654_enabled": True}
+        config["experimental_features"] = {
+            "msc2654_enabled": True,
+            "msc2285_enabled": True,
+        }
         return config
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -564,7 +630,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8")
         channel = self.make_request(
             "POST",
-            "/rooms/%s/read_markers" % self.room_id,
+            f"/rooms/{self.room_id}/read_markers",
             body,
             access_token=self.tok,
         )
@@ -578,11 +644,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         self._check_unread_count(1)
 
         # Send a read receipt to tell the server we've read the latest event.
-        body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8")
         channel = self.make_request(
             "POST",
-            "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]),
-            body,
+            f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+            {},
             access_token=self.tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -644,13 +709,73 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         self._check_unread_count(4)
 
         # Check that tombstone events changes increase the unread counter.
-        self.helper.send_state(
+        res1 = self.helper.send_state(
             self.room_id,
             EventTypes.Tombstone,
             {"replacement_room": "!someroom:test"},
             tok=self.tok2,
         )
         self._check_unread_count(5)
+        res2 = self.helper.send(self.room_id, "hello", tok=self.tok2)
+
+        # Make sure both m.read and org.matrix.msc2285.read.private advance
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}",
+            {},
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self._check_unread_count(1)
+
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}",
+            {},
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self._check_unread_count(0)
+
+    # We test for both receipt types that influence notification counts
+    @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE])
+    def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None:
+        # Join the new user
+        self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
+
+        # Send messages
+        res1 = self.helper.send(self.room_id, "hello", tok=self.tok2)
+        res2 = self.helper.send(self.room_id, "hello", tok=self.tok2)
+
+        # Read last event
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
+            {},
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self._check_unread_count(0)
+
+        # Make sure neither m.read nor org.matrix.msc2285.read.private make the
+        # read receipt go up to an older event
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}",
+            {},
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self._check_unread_count(0)
+
+        channel = self.make_request(
+            "POST",
+            f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}",
+            {},
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        self._check_unread_count(0)
 
     def _check_unread_count(self, expected_count: int) -> None:
         """Syncs and compares the unread count with the expected value."""
@@ -663,9 +788,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(channel.code, 200, channel.json_body)
 
-        room_entry = channel.json_body["rooms"]["join"][self.room_id]
+        room_entry = (
+            channel.json_body.get("rooms", {}).get("join", {}).get(self.room_id, {})
+        )
         self.assertEqual(
-            room_entry["org.matrix.msc2654.unread_count"],
+            room_entry.get("org.matrix.msc2654.unread_count", 0),
             expected_count,
             room_entry,
         )