Bläddra i källkod

Add ephemeral messages support (MSC2228) (#6409)

Implement part [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228). The parts that differ are:

* the feature is hidden behind a configuration flag (`enable_ephemeral_messages`)
* self-destruction doesn't happen for state events
* only implement support for the `m.self_destruct_after` field (not the `m.self_destruct` one)
* doesn't send synthetic redactions to clients because for this specific case we consider the clients to be able to destroy an event themselves, instead we just censor it (by pruning its JSON) in the database
Brendan Abolivier 4 år sedan
förälder
incheckning
54dd5dc12b

+ 1 - 0
changelog.d/6409.feature

@@ -0,0 +1 @@
+Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228).

+ 4 - 0
synapse/api/constants.py

@@ -147,3 +147,7 @@ class EventContentFields(object):
 
     # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
     LABELS = "org.matrix.labels"
+
+    # Timestamp to delete the event after
+    # cf https://github.com/matrix-org/matrix-doc/pull/2228
+    SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"

+ 2 - 0
synapse/config/server.py

@@ -490,6 +490,8 @@ class ServerConfig(Config):
             "cleanup_extremities_with_dummy_events", True
         )
 
+        self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)
+
     def has_tls_listener(self) -> bool:
         return any(l["tls"] for l in self.listeners)
 

+ 8 - 0
synapse/handlers/federation.py

@@ -121,6 +121,7 @@ class FederationHandler(BaseHandler):
         self.pusher_pool = hs.get_pusherpool()
         self.spam_checker = hs.get_spam_checker()
         self.event_creation_handler = hs.get_event_creation_handler()
+        self._message_handler = hs.get_message_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
         self.http_client = hs.get_simple_http_client()
@@ -141,6 +142,8 @@ class FederationHandler(BaseHandler):
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
     @defer.inlineCallbacks
     def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
         """ Process a PDU received via a federation /send/ transaction, or
@@ -2715,6 +2718,11 @@ class FederationHandler(BaseHandler):
                 event_and_contexts, backfilled=backfilled
             )
 
+            if self._ephemeral_messages_enabled:
+                for (event, context) in event_and_contexts:
+                    # If there's an expiry timestamp on the event, schedule its expiry.
+                    self._message_handler.maybe_schedule_expiry(event)
+
             if not backfilled:  # Never notify for backfilled events
                 for event, _ in event_and_contexts:
                     yield self._notify_persisted_event(event, max_stream_id)

+ 122 - 1
synapse/handlers/message.py

@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Optional
 
 from six import iteritems, itervalues, string_types
 
@@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json
 
 from twisted.internet import defer
 from twisted.internet.defer import succeed
+from twisted.internet.interfaces import IDelayedCall
 
 from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    Membership,
+    RelationTypes,
+    UserTypes,
+)
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -62,6 +70,17 @@ class MessageHandler(object):
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
         self._event_serializer = hs.get_event_client_serializer()
+        self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+        self._is_worker_app = bool(hs.config.worker_app)
+
+        # The scheduled call to self._expire_event. None if no call is currently
+        # scheduled.
+        self._scheduled_expiry = None  # type: Optional[IDelayedCall]
+
+        if not hs.config.worker_app:
+            run_as_background_process(
+                "_schedule_next_expiry", self._schedule_next_expiry
+            )
 
     @defer.inlineCallbacks
     def get_room_data(
@@ -225,6 +244,100 @@ class MessageHandler(object):
             for user_id, profile in iteritems(users_with_profile)
         }
 
+    def maybe_schedule_expiry(self, event):
+        """Schedule the expiry of an event if there's not already one scheduled,
+        or if the one running is for an event that will expire after the provided
+        timestamp.
+
+        This function needs to invalidate the event cache, which is only possible on
+        the master process, and therefore needs to be run on there.
+
+        Args:
+            event (EventBase): The event to schedule the expiry of.
+        """
+        assert not self._is_worker_app
+
+        expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+        if not isinstance(expiry_ts, int) or event.is_state():
+            return
+
+        # _schedule_expiry_for_event won't actually schedule anything if there's already
+        # a task scheduled for a timestamp that's sooner than the provided one.
+        self._schedule_expiry_for_event(event.event_id, expiry_ts)
+
+    @defer.inlineCallbacks
+    def _schedule_next_expiry(self):
+        """Retrieve the ID and the expiry timestamp of the next event to be expired,
+        and schedule an expiry task for it.
+
+        If there's no event left to expire, set _expiry_scheduled to None so that a
+        future call to save_expiry_ts can schedule a new expiry task.
+        """
+        # Try to get the expiry timestamp of the next event to expire.
+        res = yield self.store.get_next_event_to_expire()
+        if res:
+            event_id, expiry_ts = res
+            self._schedule_expiry_for_event(event_id, expiry_ts)
+
+    def _schedule_expiry_for_event(self, event_id, expiry_ts):
+        """Schedule an expiry task for the provided event if there's not already one
+        scheduled at a timestamp that's sooner than the provided one.
+
+        Args:
+            event_id (str): The ID of the event to expire.
+            expiry_ts (int): The timestamp at which to expire the event.
+        """
+        if self._scheduled_expiry:
+            # If the provided timestamp refers to a time before the scheduled time of the
+            # next expiry task, cancel that task and reschedule it for this timestamp.
+            next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000
+            if expiry_ts < next_scheduled_expiry_ts:
+                self._scheduled_expiry.cancel()
+            else:
+                return
+
+        # Figure out how many seconds we need to wait before expiring the event.
+        now_ms = self.clock.time_msec()
+        delay = (expiry_ts - now_ms) / 1000
+
+        # callLater doesn't support negative delays, so trim the delay to 0 if we're
+        # in that case.
+        if delay < 0:
+            delay = 0
+
+        logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay)
+
+        self._scheduled_expiry = self.clock.call_later(
+            delay,
+            run_as_background_process,
+            "_expire_event",
+            self._expire_event,
+            event_id,
+        )
+
+    @defer.inlineCallbacks
+    def _expire_event(self, event_id):
+        """Retrieve and expire an event that needs to be expired from the database.
+
+        If the event doesn't exist in the database, log it and delete the expiry date
+        from the database (so that we don't try to expire it again).
+        """
+        assert self._ephemeral_events_enabled
+
+        self._scheduled_expiry = None
+
+        logger.info("Expiring event %s", event_id)
+
+        try:
+            # Expire the event if we know about it. This function also deletes the expiry
+            # date from the database in the same database transaction.
+            yield self.store.expire_event(event_id)
+        except Exception as e:
+            logger.error("Could not expire event %s: %r", event_id, e)
+
+        # Schedule the expiry of the next event to expire.
+        yield self._schedule_next_expiry()
+
 
 # The duration (in ms) after which rooms should be removed
 # `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try
@@ -295,6 +408,10 @@ class EventCreationHandler(object):
                 5 * 60 * 1000,
             )
 
+        self._message_handler = hs.get_message_handler()
+
+        self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+
     @defer.inlineCallbacks
     def create_event(
         self,
@@ -877,6 +994,10 @@ class EventCreationHandler(object):
             event, context=context
         )
 
+        if self._ephemeral_events_enabled:
+            # If there's an expiry timestamp on the event, schedule its expiry.
+            self._message_handler.maybe_schedule_expiry(event)
+
         yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
 
         def _notify():

+ 120 - 6
synapse/storage/data_stores/main/events.py

@@ -130,6 +130,8 @@ class EventsStore(
         if self.hs.config.redaction_retention_period is not None:
             hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
 
+        self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
     @defer.inlineCallbacks
     def _read_forward_extremities(self):
         def fetch(txn):
@@ -940,6 +942,12 @@ class EventsStore(
                     txn, event.event_id, labels, event.room_id, event.depth
                 )
 
+            if self._ephemeral_messages_enabled:
+                # If there's an expiry timestamp on the event, store it.
+                expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+                if isinstance(expiry_ts, int) and not event.is_state():
+                    self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
+
         # Insert into the room_memberships table.
         self._store_room_members_txn(
             txn,
@@ -1101,12 +1109,7 @@ class EventsStore(
         def _update_censor_txn(txn):
             for redaction_id, event_id, pruned_json in updates:
                 if pruned_json:
-                    self._simple_update_one_txn(
-                        txn,
-                        table="event_json",
-                        keyvalues={"event_id": event_id},
-                        updatevalues={"json": pruned_json},
-                    )
+                    self._censor_event_txn(txn, event_id, pruned_json)
 
                 self._simple_update_one_txn(
                     txn,
@@ -1117,6 +1120,22 @@ class EventsStore(
 
         yield self.runInteraction("_update_censor_txn", _update_censor_txn)
 
+    def _censor_event_txn(self, txn, event_id, pruned_json):
+        """Censor an event by replacing its JSON in the event_json table with the
+        provided pruned JSON.
+
+        Args:
+            txn (LoggingTransaction): The database transaction.
+            event_id (str): The ID of the event to censor.
+            pruned_json (str): The pruned JSON
+        """
+        self._simple_update_one_txn(
+            txn,
+            table="event_json",
+            keyvalues={"event_id": event_id},
+            updatevalues={"json": pruned_json},
+        )
+
     @defer.inlineCallbacks
     def count_daily_messages(self):
         """
@@ -1957,6 +1976,101 @@ class EventsStore(
             ],
         )
 
+    def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+        """Save the expiry timestamp associated with a given event ID.
+
+        Args:
+            txn (LoggingTransaction): The database transaction to use.
+            event_id (str): The event ID the expiry timestamp is associated with.
+            expiry_ts (int): The timestamp at which to expire (delete) the event.
+        """
+        return self._simple_insert_txn(
+            txn=txn,
+            table="event_expiry",
+            values={"event_id": event_id, "expiry_ts": expiry_ts},
+        )
+
+    @defer.inlineCallbacks
+    def expire_event(self, event_id):
+        """Retrieve and expire an event that has expired, and delete its associated
+        expiry timestamp. If the event can't be retrieved, delete its associated
+        timestamp so we don't try to expire it again in the future.
+
+        Args:
+             event_id (str): The ID of the event to delete.
+        """
+        # Try to retrieve the event's content from the database or the event cache.
+        event = yield self.get_event(event_id)
+
+        def delete_expired_event_txn(txn):
+            # Delete the expiry timestamp associated with this event from the database.
+            self._delete_event_expiry_txn(txn, event_id)
+
+            if not event:
+                # If we can't find the event, log a warning and delete the expiry date
+                # from the database so that we don't try to expire it again in the
+                # future.
+                logger.warning(
+                    "Can't expire event %s because we don't have it.", event_id
+                )
+                return
+
+            # Prune the event's dict then convert it to JSON.
+            pruned_json = encode_json(prune_event_dict(event.get_dict()))
+
+            # Update the event_json table to replace the event's JSON with the pruned
+            # JSON.
+            self._censor_event_txn(txn, event.event_id, pruned_json)
+
+            # We need to invalidate the event cache entry for this event because we
+            # changed its content in the database. We can't call
+            # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+            # right type.
+            txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+            # Send that invalidation to replication so that other workers also invalidate
+            # the event cache.
+            self._send_invalidation_to_replication(
+                txn, "_get_event_cache", (event.event_id,)
+            )
+
+        yield self.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+    def _delete_event_expiry_txn(self, txn, event_id):
+        """Delete the expiry timestamp associated with an event ID without deleting the
+        actual event.
+
+        Args:
+            txn (LoggingTransaction): The transaction to use to perform the deletion.
+            event_id (str): The event ID to delete the associated expiry timestamp of.
+        """
+        return self._simple_delete_txn(
+            txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+        )
+
+    def get_next_event_to_expire(self):
+        """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+        table, or None if there's no more event to expire.
+
+        Returns: Deferred[Optional[Tuple[str, int]]]
+            A tuple containing the event ID as its first element and an expiry timestamp
+            as its second one, if there's at least one row in the event_expiry table.
+            None otherwise.
+        """
+
+        def get_next_event_to_expire_txn(txn):
+            txn.execute(
+                """
+                SELECT event_id, expiry_ts FROM event_expiry
+                ORDER BY expiry_ts ASC LIMIT 1
+                """
+            )
+
+            return txn.fetchone()
+
+        return self.runInteraction(
+            desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+        )
+
 
 AllNewEventsResult = namedtuple(
     "AllNewEventsResult",

+ 21 - 0
synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql

@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_expiry (
+    event_id TEXT PRIMARY KEY,
+    expiry_ts BIGINT NOT NULL
+);
+
+CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);

+ 101 - 0
tests/rest/client/test_ephemeral_message.py

@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.constants import EventContentFields, EventTypes
+from synapse.rest import admin
+from synapse.rest.client.v1 import room
+
+from tests import unittest
+
+
+class EphemeralMessageTestCase(unittest.HomeserverTestCase):
+
+    user_id = "@user:test"
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+
+        config["enable_ephemeral_messages"] = True
+
+        self.hs = self.setup_test_homeserver(config=config)
+        return self.hs
+
+    def prepare(self, reactor, clock, homeserver):
+        self.room_id = self.helper.create_room_as(self.user_id)
+
+    def test_message_expiry_no_delay(self):
+        """Tests that sending a message sent with a m.self_destruct_after field set to the
+        past results in that event being deleted right away.
+        """
+        # Send a message in the room that has expired. From here, the reactor clock is
+        # at 200ms, so 0 is in the past, and even if that wasn't the case and the clock
+        # is at 0ms the code path is the same if the event's expiry timestamp is the
+        # current timestamp.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "hello",
+                EventContentFields.SELF_DESTRUCT_AFTER: 0,
+            },
+        )
+        event_id = res["event_id"]
+
+        # Check that we can't retrieve the content of the event.
+        event_content = self.get_event(self.room_id, event_id)["content"]
+        self.assertFalse(bool(event_content), event_content)
+
+    def test_message_expiry_delay(self):
+        """Tests that sending a message with a m.self_destruct_after field set to the
+        future results in that event not being deleted right away, but advancing the
+        clock to after that expiry timestamp causes the event to be deleted.
+        """
+        # Send a message in the room that'll expire in 1s.
+        res = self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "hello",
+                EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000,
+            },
+        )
+        event_id = res["event_id"]
+
+        # Check that we can retrieve the content of the event before it has expired.
+        event_content = self.get_event(self.room_id, event_id)["content"]
+        self.assertTrue(bool(event_content), event_content)
+
+        # Advance the clock to after the deletion.
+        self.reactor.advance(1)
+
+        # Check that we can't retrieve the content of the event anymore.
+        event_content = self.get_event(self.room_id, event_id)["content"]
+        self.assertFalse(bool(event_content), event_content)
+
+    def get_event(self, room_id, event_id, expected_code=200):
+        url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
+
+        request, channel = self.make_request("GET", url)
+        self.render(request)
+
+        self.assertEqual(channel.code, expected_code, channel.result)
+
+        return channel.json_body