Sfoglia il codice sorgente

Use a stream id generator for backfilled ids

Mark Haines 8 anni fa
parent
commit
e36bfbab38

+ 7 - 13
synapse/storage/__init__.py

@@ -88,15 +88,6 @@ class DataStore(RoomMemberStore, RoomStore,
         self.hs = hs
         self.database_engine = hs.database_engine
 
-        cur = db_conn.cursor()
-        try:
-            cur.execute("SELECT MIN(stream_ordering) FROM events",)
-            rows = cur.fetchall()
-            self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
-            self.min_stream_token = min(self.min_stream_token, -1)
-        finally:
-            cur.close()
-
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen",
             keylen=4,
@@ -105,6 +96,9 @@ class DataStore(RoomMemberStore, RoomStore,
         self._stream_id_gen = StreamIdGenerator(
             db_conn, "events", "stream_ordering"
         )
+        self._backfill_id_gen = StreamIdGenerator(
+            db_conn, "events", "stream_ordering", direction=-1
+        )
         self._receipts_id_gen = StreamIdGenerator(
             db_conn, "receipts_linearized", "stream_id"
         )
@@ -129,7 +123,7 @@ class DataStore(RoomMemberStore, RoomStore,
             extra_tables=[("deleted_pushers", "stream_id")],
         )
 
-        events_max = self._stream_id_gen.get_max_token()
+        events_max = self._stream_id_gen.get_current_token()
         event_cache_prefill, min_event_val = self._get_cache_dict(
             db_conn, "events",
             entity_column="room_id",
@@ -145,7 +139,7 @@ class DataStore(RoomMemberStore, RoomStore,
             "MembershipStreamChangeCache", events_max,
         )
 
-        account_max = self._account_data_id_gen.get_max_token()
+        account_max = self._account_data_id_gen.get_current_token()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max,
         )
@@ -156,7 +150,7 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "presence_stream",
             entity_column="user_id",
             stream_column="stream_id",
-            max_value=self._presence_id_gen.get_max_token(),
+            max_value=self._presence_id_gen.get_current_token(),
         )
         self.presence_stream_cache = StreamChangeCache(
             "PresenceStreamChangeCache", min_presence_val,
@@ -167,7 +161,7 @@ class DataStore(RoomMemberStore, RoomStore,
             db_conn, "push_rules_stream",
             entity_column="user_id",
             stream_column="stream_id",
-            max_value=self._push_rules_stream_id_gen.get_max_token()[0],
+            max_value=self._push_rules_stream_id_gen.get_current_token()[0],
         )
 
         self.push_rules_stream_cache = StreamChangeCache(

+ 2 - 2
synapse/storage/account_data.py

@@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore):
                 "add_room_account_data", add_account_data_txn, next_id
             )
 
-        result = self._account_data_id_gen.get_max_token()
+        result = self._account_data_id_gen.get_current_token()
         defer.returnValue(result)
 
     @defer.inlineCallbacks
@@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore):
                 "add_user_account_data", add_account_data_txn, next_id
             )
 
-        result = self._account_data_id_gen.get_max_token()
+        result = self._account_data_id_gen.get_current_token()
         defer.returnValue(result)
 
     def _update_max_stream_id(self, txn, next_id):

+ 5 - 14
synapse/storage/events.py

@@ -24,7 +24,6 @@ from synapse.util.logutils import log_function
 from synapse.api.constants import EventTypes
 
 from canonicaljson import encode_canonical_json
-from contextlib import contextmanager
 from collections import namedtuple
 
 import logging
@@ -66,14 +65,9 @@ class EventsStore(SQLBaseStore):
             return
 
         if backfilled:
-            start = self.min_stream_token - 1
-            self.min_stream_token -= len(events_and_contexts) + 1
-            stream_orderings = range(start, self.min_stream_token, -1)
-
-            @contextmanager
-            def stream_ordering_manager():
-                yield stream_orderings
-            stream_ordering_manager = stream_ordering_manager()
+            stream_ordering_manager = self._backfill_id_gen.get_next_mult(
+                len(events_and_contexts)
+            )
         else:
             stream_ordering_manager = self._stream_id_gen.get_next_mult(
                 len(events_and_contexts)
@@ -130,7 +124,7 @@ class EventsStore(SQLBaseStore):
         except _RollbackButIsFineException:
             pass
 
-        max_persisted_id = yield self._stream_id_gen.get_max_token()
+        max_persisted_id = yield self._stream_id_gen.get_current_token()
         defer.returnValue((stream_ordering, max_persisted_id))
 
     @defer.inlineCallbacks
@@ -1117,10 +1111,7 @@ class EventsStore(SQLBaseStore):
 
     def get_current_backfill_token(self):
         """The current minimum token that backfilled events have reached"""
-
-        # TODO: Fix race with the persit_event txn by using one of the
-        # stream id managers
-        return -self.min_stream_token
+        return -self._backfill_id_gen.get_current_token()
 
     def get_all_new_events(self, last_backfill_id, last_forward_id,
                            current_backfill_id, current_forward_id, limit):

+ 4 - 2
synapse/storage/presence.py

@@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore):
                 self._update_presence_txn, stream_orderings, presence_states,
             )
 
-        defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token()))
+        defer.returnValue((
+            stream_orderings[-1], self._presence_id_gen.get_current_token()
+        ))
 
     def _update_presence_txn(self, txn, stream_orderings, presence_states):
         for stream_id, state in zip(stream_orderings, presence_states):
@@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore):
         defer.returnValue([UserPresenceState(**row) for row in rows])
 
     def get_current_presence_token(self):
-        return self._presence_id_gen.get_max_token()
+        return self._presence_id_gen.get_current_token()
 
     def allow_presence_visible(self, observed_localpart, observer_userid):
         return self._simple_insert(

+ 1 - 1
synapse/storage/push_rule.py

@@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore):
         """Get the position of the push rules stream.
         Returns a pair of a stream id for the push_rules stream and the
         room stream ordering it corresponds to."""
-        return self._push_rules_stream_id_gen.get_max_token()
+        return self._push_rules_stream_id_gen.get_current_token()
 
     def have_push_rules_changed_for_user(self, user_id, last_id):
         if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):

+ 1 - 1
synapse/storage/pusher.py

@@ -78,7 +78,7 @@ class PusherStore(SQLBaseStore):
         defer.returnValue(rows)
 
     def get_pushers_stream_token(self):
-        return self._pushers_id_gen.get_max_token()
+        return self._pushers_id_gen.get_current_token()
 
     def get_all_updated_pushers(self, last_id, current_id, limit):
         def get_all_updated_pushers_txn(txn):

+ 3 - 3
synapse/storage/receipts.py

@@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
         super(ReceiptsStore, self).__init__(hs)
 
         self._receipts_stream_cache = StreamChangeCache(
-            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token()
+            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
         )
 
     @cached(num_args=2)
@@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore):
         defer.returnValue(results)
 
     def get_max_receipt_stream_id(self):
-        return self._receipts_id_gen.get_max_token()
+        return self._receipts_id_gen.get_current_token()
 
     def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
                                       user_id, event_id, data, stream_id):
@@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore):
             room_id, receipt_type, user_id, event_ids, data
         )
 
-        max_persisted_id = self._stream_id_gen.get_max_token()
+        max_persisted_id = self._stream_id_gen.get_current_token()
 
         defer.returnValue((stream_id, max_persisted_id))
 

+ 1 - 1
synapse/storage/state.py

@@ -458,4 +458,4 @@ class StateStore(SQLBaseStore):
         )
 
     def get_state_stream_token(self):
-        return self._state_groups_id_gen.get_max_token()
+        return self._state_groups_id_gen.get_current_token()

+ 1 - 1
synapse/storage/stream.py

@@ -539,7 +539,7 @@ class StreamStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_room_events_max_id(self, direction='f'):
-        token = yield self._stream_id_gen.get_max_token()
+        token = yield self._stream_id_gen.get_current_token()
         if direction != 'b':
             defer.returnValue("s%d" % (token,))
         else:

+ 3 - 3
synapse/storage/tags.py

@@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
         Returns:
             A deferred int.
         """
-        return self._account_data_id_gen.get_max_token()
+        return self._account_data_id_gen.get_current_token()
 
     @cached()
     def get_tags_for_user(self, user_id):
@@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore):
 
         self.get_tags_for_user.invalidate((user_id,))
 
-        result = self._account_data_id_gen.get_max_token()
+        result = self._account_data_id_gen.get_current_token()
         defer.returnValue(result)
 
     @defer.inlineCallbacks
@@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore):
 
         self.get_tags_for_user.invalidate((user_id,))
 
-        result = self._account_data_id_gen.get_max_token()
+        result = self._account_data_id_gen.get_current_token()
         defer.returnValue(result)
 
     def _update_revision_txn(self, txn, user_id, room_id, next_id):

+ 41 - 20
synapse/storage/util/id_generators.py

@@ -21,7 +21,7 @@ import threading
 class IdGenerator(object):
     def __init__(self, db_conn, table, column):
         self._lock = threading.Lock()
-        self._next_id = _load_max_id(db_conn, table, column)
+        self._next_id = _load_current_id(db_conn, table, column)
 
     def get_next(self):
         with self._lock:
@@ -29,12 +29,16 @@ class IdGenerator(object):
             return self._next_id
 
 
-def _load_max_id(db_conn, table, column):
+def _load_current_id(db_conn, table, column, direction=1):
     cur = db_conn.cursor()
-    cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+    if direction == 1:
+        cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
+    else:
+        cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
     val, = cur.fetchone()
     cur.close()
-    return int(val) if val else 1
+    current_id = int(val) if val else direction
+    return (max if direction == 1 else min)(current_id, direction)
 
 
 class StreamIdGenerator(object):
@@ -45,17 +49,30 @@ class StreamIdGenerator(object):
     all ids less than or equal to it have completed. This handles the fact that
     persistence of events can complete out of order.
 
+    :param connection db_conn:  A database connection to use to fetch the
+        initial value of the generator from.
+    :param str table: A database table to read the initial value of the id
+        generator from.
+    :param str column: The column of the database table to read the initial
+        value from the id generator from.
+    :param list extra_tables: List of pairs of database tables and columns to
+        use to source the initial value of the generator from. The value with
+        the largest magnitude is used.
+    :param int direction: which direction the stream ids grow in. +1 to grow
+        upwards, -1 to grow downwards.
+
     Usage:
         with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
-    def __init__(self, db_conn, table, column, extra_tables=[]):
+    def __init__(self, db_conn, table, column, extra_tables=[], direction=1):
         self._lock = threading.Lock()
-        self._current_max = _load_max_id(db_conn, table, column)
+        self._direction = direction
+        self._current = _load_current_id(db_conn, table, column, direction)
         for table, column in extra_tables:
-            self._current_max = max(
-                self._current_max,
-                _load_max_id(db_conn, table, column)
+            self._current = (max if direction > 0 else min)(
+                self._current,
+                _load_current_id(db_conn, table, column, direction)
             )
         self._unfinished_ids = deque()
 
@@ -66,8 +83,8 @@ class StreamIdGenerator(object):
                 # ... persist event ...
         """
         with self._lock:
-            self._current_max += 1
-            next_id = self._current_max
+            self._current += self._direction
+            next_id = self._current
 
             self._unfinished_ids.append(next_id)
 
@@ -88,8 +105,12 @@ class StreamIdGenerator(object):
                 # ... persist events ...
         """
         with self._lock:
-            next_ids = range(self._current_max + 1, self._current_max + n + 1)
-            self._current_max += n
+            next_ids = range(
+                self._current + self._direction,
+                self._current + self._direction * (n + 1),
+                self._direction
+            )
+            self._current += n
 
             for next_id in next_ids:
                 self._unfinished_ids.append(next_id)
@@ -105,15 +126,15 @@ class StreamIdGenerator(object):
 
         return manager()
 
-    def get_max_token(self):
+    def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
         """
         with self._lock:
             if self._unfinished_ids:
-                return self._unfinished_ids[0] - 1
+                return self._unfinished_ids[0] - self._direction
 
-            return self._current_max
+            return self._current
 
 
 class ChainedIdGenerator(object):
@@ -125,7 +146,7 @@ class ChainedIdGenerator(object):
     def __init__(self, chained_generator, db_conn, table, column):
         self.chained_generator = chained_generator
         self._lock = threading.Lock()
-        self._current_max = _load_max_id(db_conn, table, column)
+        self._current_max = _load_current_id(db_conn, table, column)
         self._unfinished_ids = deque()
 
     def get_next(self):
@@ -137,7 +158,7 @@ class ChainedIdGenerator(object):
         with self._lock:
             self._current_max += 1
             next_id = self._current_max
-            chained_id = self.chained_generator.get_max_token()
+            chained_id = self.chained_generator.get_current_token()
 
             self._unfinished_ids.append((next_id, chained_id))
 
@@ -151,7 +172,7 @@ class ChainedIdGenerator(object):
 
         return manager()
 
-    def get_max_token(self):
+    def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
         """
@@ -160,4 +181,4 @@ class ChainedIdGenerator(object):
                 stream_id, chained_id = self._unfinished_ids[0]
                 return (stream_id - 1, chained_id)
 
-            return (self._current_max, self.chained_generator.get_max_token())
+            return (self._current_max, self.chained_generator.get_current_token())