Browse Source

Use vector clocks for room stream tokens. (#8439)

Currently when using multiple event persisters we (in the worst case) don't tell clients about events until all event persisters have persisted new events after the original event. This is a suboptimal, especially if one of the event persisters goes down.

To handle this, we encode the position of each event persister in the room tokens so that we can send events to clients immediately. To reduce the size of the token we do two things:

1. We create a unique immutable persistent mapping between instance names and a generated small integer ID, which we can encode in the tokens instead of the instance name; and
2. We encode the "persisted upto position" of the room token and then only explicitly include instances that have positions strictly greater than that.

The new tokens look something like: `m3478~1.3488~2.3489`, where the first number is the min position, and the subsequent `-` separated pairs are the instance ID to positions map. (We use `.` and `~` as separators as they're URL safe and not already used by `StreamToken`).
Erik Johnston 3 years ago
parent
commit
52a50e8686

+ 1 - 0
changelog.d/8439.misc

@@ -0,0 +1 @@
+Allow events to be sent to clients sooner when using sharded event persisters.

+ 25 - 0
synapse/storage/databases/main/schema/delta/58/19instance_map.sql.postgres

@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- A unique and immutable mapping between instance name and an integer ID. This
+-- lets us refer to instances via a small ID in e.g. stream tokens, without
+-- having to encode the full name.
+CREATE TABLE IF NOT EXISTS instance_map (
+    instance_id SERIAL PRIMARY KEY,
+    instance_name TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS instance_map_idx ON instance_map(instance_name);

+ 243 - 37
synapse/storage/databases/main/stream.py

@@ -53,7 +53,9 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
+from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 if TYPE_CHECKING:
@@ -208,6 +210,55 @@ def _make_generic_sql_bound(
     )
 
 
+def _filter_results(
+    lower_token: Optional[RoomStreamToken],
+    upper_token: Optional[RoomStreamToken],
+    instance_name: str,
+    topological_ordering: int,
+    stream_ordering: int,
+) -> bool:
+    """Returns True if the event persisted by the given instance at the given
+    topological/stream_ordering falls between the two tokens (taking a None
+    token to mean unbounded).
+
+    Used to filter results from fetching events in the DB against the given
+    tokens. This is necessary to handle the case where the tokens include
+    position maps, which we handle by fetching more than necessary from the DB
+    and then filtering (rather than attempting to construct a complicated SQL
+    query).
+    """
+
+    event_historical_tuple = (
+        topological_ordering,
+        stream_ordering,
+    )
+
+    if lower_token:
+        if lower_token.topological is not None:
+            # If these are historical tokens we compare the `(topological, stream)`
+            # tuples.
+            if event_historical_tuple <= lower_token.as_historical_tuple():
+                return False
+
+        else:
+            # If these are live tokens we compare the stream ordering against the
+            # writers stream position.
+            if stream_ordering <= lower_token.get_stream_pos_for_instance(
+                instance_name
+            ):
+                return False
+
+    if upper_token:
+        if upper_token.topological is not None:
+            if upper_token.as_historical_tuple() < event_historical_tuple:
+                return False
+        else:
+            if upper_token.get_stream_pos_for_instance(instance_name) < stream_ordering:
+                return False
+
+    return True
+
+
 def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
     # NB: This may create SQL clauses that don't optimise well (and we don't
     # have indices on all possible clauses). E.g. it may create
@@ -305,7 +356,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     def get_room_max_token(self) -> RoomStreamToken:
-        return RoomStreamToken(None, self.get_room_max_stream_ordering())
+        """Get a `RoomStreamToken` that marks the current maximum persisted
+        position of the events stream. Useful to get a token that represents
+        "now".
+
+        The token returned is a "live" token that may have an instance_map
+        component.
+        """
+
+        min_pos = self._stream_id_gen.get_current_token()
+
+        positions = {}
+        if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
+            # The `min_pos` is the minimum position that we know all instances
+            # have finished persisting to, so we only care about instances whose
+            # positions are ahead of that. (Instance positions can be behind the
+            # min position as there are times we can work out that the minimum
+            # position is ahead of the naive minimum across all current
+            # positions. See MultiWriterIdGenerator for details)
+            positions = {
+                i: p
+                for i, p in self._stream_id_gen.get_positions().items()
+                if p > min_pos
+            }
+
+        return RoomStreamToken(None, min_pos, positions)
 
     async def get_room_events_stream_for_rooms(
         self,
@@ -404,25 +479,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         if from_key == to_key:
             return [], from_key
 
-        from_id = from_key.stream
-        to_id = to_key.stream
-
-        has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
+        has_changed = self._events_stream_cache.has_entity_changed(
+            room_id, from_key.stream
+        )
 
         if not has_changed:
             return [], from_key
 
         def f(txn):
-            sql = (
-                "SELECT event_id, stream_ordering FROM events WHERE"
-                " room_id = ?"
-                " AND not outlier"
-                " AND stream_ordering > ? AND stream_ordering <= ?"
-                " ORDER BY stream_ordering %s LIMIT ?"
-            ) % (order,)
-            txn.execute(sql, (room_id, from_id, to_id, limit))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT event_id, instance_name, topological_ordering, stream_ordering
+                FROM events
+                WHERE
+                    room_id = ?
+                    AND not outlier
+                    AND stream_ordering > ? AND stream_ordering <= ?
+                ORDER BY stream_ordering %s LIMIT ?
+            """ % (
+                order,
+            )
+            txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ][:limit]
             return rows
 
         rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -431,7 +524,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             [r.event_id for r in rows], get_prev_content=True
         )
 
-        self._set_before_and_after(ret, rows, topo_order=from_id is None)
+        self._set_before_and_after(ret, rows, topo_order=False)
 
         if order.lower() == "desc":
             ret.reverse()
@@ -448,31 +541,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
     async def get_membership_changes_for_user(
         self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
     ) -> List[EventBase]:
-        from_id = from_key.stream
-        to_id = to_key.stream
-
         if from_key == to_key:
             return []
 
-        if from_id:
+        if from_key:
             has_changed = self._membership_stream_cache.has_entity_changed(
-                user_id, int(from_id)
+                user_id, int(from_key.stream)
             )
             if not has_changed:
                 return []
 
         def f(txn):
-            sql = (
-                "SELECT m.event_id, stream_ordering FROM events AS e,"
-                " room_memberships AS m"
-                " WHERE e.event_id = m.event_id"
-                " AND m.user_id = ?"
-                " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
-                " ORDER BY e.stream_ordering ASC"
-            )
-            txn.execute(sql, (user_id, from_id, to_id))
-
-            rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
+            # To handle tokens with a non-empty instance_map we fetch more
+            # results than necessary and then filter down
+            min_from_id = from_key.stream
+            max_to_id = to_key.get_max_stream_pos()
+
+            sql = """
+                SELECT m.event_id, instance_name, topological_ordering, stream_ordering
+                FROM events AS e, room_memberships AS m
+                WHERE e.event_id = m.event_id
+                    AND m.user_id = ?
+                    AND e.stream_ordering > ? AND e.stream_ordering <= ?
+                ORDER BY e.stream_ordering ASC
+            """
+            txn.execute(sql, (user_id, min_from_id, max_to_id,))
+
+            rows = [
+                _EventDictReturn(event_id, None, stream_ordering)
+                for event_id, instance_name, topological_ordering, stream_ordering in txn
+                if _filter_results(
+                    from_key,
+                    to_key,
+                    instance_name,
+                    topological_ordering,
+                    stream_ordering,
+                )
+            ]
 
             return rows
 
@@ -966,11 +1071,46 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
         else:
             order = "ASC"
 
+        # The bounds for the stream tokens are complicated by the fact
+        # that we need to handle the instance_map part of the tokens. We do this
+        # by fetching all events between the min stream token and the maximum
+        # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+        # then filtering the results.
+        if from_token.topological is not None:
+            from_bound = (
+                from_token.as_historical_tuple()
+            )  # type: Tuple[Optional[int], int]
+        elif direction == "b":
+            from_bound = (
+                None,
+                from_token.get_max_stream_pos(),
+            )
+        else:
+            from_bound = (
+                None,
+                from_token.stream,
+            )
+
+        to_bound = None  # type: Optional[Tuple[Optional[int], int]]
+        if to_token:
+            if to_token.topological is not None:
+                to_bound = to_token.as_historical_tuple()
+            elif direction == "b":
+                to_bound = (
+                    None,
+                    to_token.stream,
+                )
+            else:
+                to_bound = (
+                    None,
+                    to_token.get_max_stream_pos(),
+                )
+
         bounds = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token.as_tuple(),
-            to_token=to_token.as_tuple() if to_token else None,
+            from_token=from_bound,
+            to_token=to_bound,
             engine=self.database_engine,
         )
 
@@ -980,7 +1120,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
             bounds += " AND " + filter_clause
             args.extend(filter_args)
 
-        args.append(int(limit))
+        # We fetch more events as we'll filter the result set
+        args.append(int(limit) * 2)
 
         select_keywords = "SELECT"
         join_clause = ""
@@ -1002,7 +1143,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
                 select_keywords += "DISTINCT"
 
         sql = """
-            %(select_keywords)s event_id, topological_ordering, stream_ordering
+            %(select_keywords)s
+                event_id, instance_name,
+                topological_ordering, stream_ordering
             FROM events
             %(join_clause)s
             WHERE outlier = ? AND room_id = ? AND %(bounds)s
@@ -1017,7 +1160,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         txn.execute(sql, args)
 
-        rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+        # Filter the result set.
+        rows = [
+            _EventDictReturn(event_id, topological_ordering, stream_ordering)
+            for event_id, instance_name, topological_ordering, stream_ordering in txn
+            if _filter_results(
+                lower_token=to_token if direction == "b" else from_token,
+                upper_token=from_token if direction == "b" else to_token,
+                instance_name=instance_name,
+                topological_ordering=topological_ordering,
+                stream_ordering=stream_ordering,
+            )
+        ][:limit]
 
         if rows:
             topo = rows[-1].topological_ordering
@@ -1082,6 +1236,58 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
 
         return (events, token)
 
+    @cached()
+    async def get_id_for_instance(self, instance_name: str) -> int:
+        """Get a unique, immutable ID that corresponds to the given Synapse worker instance.
+        """
+
+        def _get_id_for_instance_txn(txn):
+            instance_id = self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+                allow_none=True,
+            )
+            if instance_id is not None:
+                return instance_id
+
+            # If we don't have an entry upsert one.
+            #
+            # We could do this before the first check, and rely on the cache for
+            # efficiency, but each UPSERT causes the next ID to increment which
+            # can quickly bloat the size of the generated IDs for new instances.
+            self.db_pool.simple_upsert_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                values={},
+            )
+
+            return self.db_pool.simple_select_one_onecol_txn(
+                txn,
+                table="instance_map",
+                keyvalues={"instance_name": instance_name},
+                retcol="instance_id",
+            )
+
+        return await self.db_pool.runInteraction(
+            "get_id_for_instance", _get_id_for_instance_txn
+        )
+
+    @cached()
+    async def get_name_from_instance_id(self, instance_id: int) -> str:
+        """Get the instance name from an ID previously returned by
+        `get_id_for_instance`.
+        """
+
+        return await self.db_pool.simple_select_one_onecol(
+            table="instance_map",
+            keyvalues={"instance_id": instance_id},
+            retcol="instance_name",
+            desc="get_name_from_instance_id",
+        )
+
 
 class StreamStore(StreamWorkerStore):
     def get_room_max_stream_ordering(self) -> int:

+ 111 - 5
synapse/types.py

@@ -22,6 +22,7 @@ from typing import (
     TYPE_CHECKING,
     Any,
     Dict,
+    Iterable,
     Mapping,
     MutableMapping,
     Optional,
@@ -43,7 +44,7 @@ if TYPE_CHECKING:
 if sys.version_info[:3] >= (3, 6, 0):
     from typing import Collection
 else:
-    from typing import Container, Iterable, Sized
+    from typing import Container, Sized
 
     T_co = TypeVar("T_co", covariant=True)
 
@@ -375,7 +376,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
     return username.decode("ascii")
 
 
-@attr.s(frozen=True, slots=True)
+@attr.s(frozen=True, slots=True, cmp=False)
 class RoomStreamToken:
     """Tokens are positions between events. The token "s1" comes after event 1.
 
@@ -397,6 +398,31 @@ class RoomStreamToken:
     event it comes after. Historic tokens start with a "t" followed by the
     "topological_ordering" id of the event it comes after, followed by "-",
     followed by the "stream_ordering" id of the event it comes after.
+
+    There is also a third mode for live tokens where the token starts with "m",
+    which is sometimes used when using sharded event persisters. In this case
+    the events stream is considered to be a set of streams (one for each writer)
+    and the token encodes the vector clock of positions of each writer in their
+    respective streams.
+
+    The format of the token in such case is an initial integer min position,
+    followed by the mapping of instance ID to position separated by '.' and '~':
+
+        m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}. ...
+
+    The `min_pos` corresponds to the minimum position all writers have persisted
+    up to, and then only writers that are ahead of that position need to be
+    encoded. An example token is:
+
+        m56~2.58~3.59
+
+    Which corresponds to a set of three (or more writers) where instances 2 and
+    3 (these are instance IDs that can be looked up in the DB to fetch the more
+    commonly used instance names) are at positions 58 and 59 respectively, and
+    all other instances are at position 56.
+
+    Note: The `RoomStreamToken` cannot have both a topological part and an
+    instance map.
     """
 
     topological = attr.ib(
@@ -405,6 +431,25 @@ class RoomStreamToken:
     )
     stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
 
+    instance_map = attr.ib(
+        type=Dict[str, int],
+        factory=dict,
+        validator=attr.validators.deep_mapping(
+            key_validator=attr.validators.instance_of(str),
+            value_validator=attr.validators.instance_of(int),
+            mapping_validator=attr.validators.instance_of(dict),
+        ),
+    )
+
+    def __attrs_post_init__(self):
+        """Validates that both `topological` and `instance_map` aren't set.
+        """
+
+        if self.instance_map and self.topological:
+            raise ValueError(
+                "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
+            )
+
     @classmethod
     async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
         try:
@@ -413,6 +458,20 @@ class RoomStreamToken:
             if string[0] == "t":
                 parts = string[1:].split("-", 1)
                 return cls(topological=int(parts[0]), stream=int(parts[1]))
+            if string[0] == "m":
+                parts = string[1:].split("~")
+                stream = int(parts[0])
+
+                instance_map = {}
+                for part in parts[1:]:
+                    key, value = part.split(".")
+                    instance_id = int(key)
+                    pos = int(value)
+
+                    instance_name = await store.get_name_from_instance_id(instance_id)
+                    instance_map[instance_name] = pos
+
+                return cls(topological=None, stream=stream, instance_map=instance_map,)
         except Exception:
             pass
         raise SynapseError(400, "Invalid token %r" % (string,))
@@ -436,14 +495,61 @@ class RoomStreamToken:
 
         max_stream = max(self.stream, other.stream)
 
-        return RoomStreamToken(None, max_stream)
+        instance_map = {
+            instance: max(
+                self.instance_map.get(instance, self.stream),
+                other.instance_map.get(instance, other.stream),
+            )
+            for instance in set(self.instance_map).union(other.instance_map)
+        }
+
+        return RoomStreamToken(None, max_stream, instance_map)
+
+    def as_historical_tuple(self) -> Tuple[int, int]:
+        """Returns a tuple of `(topological, stream)` for historical tokens.
+
+        Raises if not an historical token (i.e. doesn't have a topological part).
+        """
+        if self.topological is None:
+            raise Exception(
+                "Cannot call `RoomStreamToken.as_historical_tuple` on live token"
+            )
 
-    def as_tuple(self) -> Tuple[Optional[int], int]:
         return (self.topological, self.stream)
 
+    def get_stream_pos_for_instance(self, instance_name: str) -> int:
+        """Get the stream position that the given writer was at at this token.
+
+        This only makes sense for "live" tokens that may have a vector clock
+        component, and so asserts that this is a "live" token.
+        """
+        assert self.topological is None
+
+        # If we don't have an entry for the instance we can assume that it was
+        # at `self.stream`.
+        return self.instance_map.get(instance_name, self.stream)
+
+    def get_max_stream_pos(self) -> int:
+        """Get the maximum stream position referenced in this token.
+
+        The corresponding "min" position is, by definition just `self.stream`.
+
+        This is used to handle tokens that have non-empty `instance_map`, and so
+        reference stream positions after the `self.stream` position.
+        """
+        return max(self.instance_map.values(), default=self.stream)
+
     async def to_string(self, store: "DataStore") -> str:
         if self.topological is not None:
             return "t%d-%d" % (self.topological, self.stream)
+        elif self.instance_map:
+            entries = []
+            for name, pos in self.instance_map.items():
+                instance_id = await store.get_id_for_instance(name)
+                entries.append("{}.{}".format(instance_id, pos))
+
+            encoded_map = "~".join(entries)
+            return "m{}~{}".format(self.stream, encoded_map)
         else:
             return "s%d" % (self.stream,)
 
@@ -535,7 +641,7 @@ class PersistedEventPosition:
     stream = attr.ib(type=int)
 
     def persisted_after(self, token: RoomStreamToken) -> bool:
-        return token.stream < self.stream
+        return token.get_stream_pos_for_instance(self.instance_name) < self.stream
 
     def to_room_stream_token(self) -> RoomStreamToken:
         """Converts the position to a room stream token such that events