Browse Source

Improve get auth chain difference algorithm. (#7095)

It was originally implemented by pulling the full auth chain of all
state sets out of the database and doing set comparison. However, that
can take a lot work if the state and auth chains are large.

Instead, lets try and fetch the auth chains at the same time and
calculate the difference on the fly, allowing us to bail early if all
the auth chains converge. Assuming that the auth chains do converge more
often than not, this should improve performance. Hopefully.
Erik Johnston 4 years ago
parent
commit
4a17a647a9

+ 1 - 0
changelog.d/7095.misc

@@ -0,0 +1 @@
+Attempt to improve performance of state res v2 algorithm.

+ 8 - 20
synapse/state/__init__.py

@@ -662,28 +662,16 @@ class StateResolutionStore(object):
             allow_rejected=allow_rejected,
         )
 
-    def get_auth_chain(self, event_ids: List[str], ignore_events: Set[str]):
-        """Gets the full auth chain for a set of events (including rejected
-        events).
-
-        Includes the given event IDs in the result.
-
-        Note that:
-            1. All events must be state events.
-            2. For v1 rooms this may not have the full auth chain in the
-               presence of rejected events
-
-        Args:
-            event_ids: The event IDs of the events to fetch the auth chain for.
-                Must be state events.
-            ignore_events: Set of events to exclude from the returned auth
-                chain.
+    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+        """Given sets of state events figure out the auth chain difference (as
+        per state res v2 algorithm).
 
+        This equivalent to fetching the full auth chain for each set of state
+        and returning the events that don't appear in each and every auth
+        chain.
 
         Returns:
-            Deferred[list[str]]: List of event IDs of the auth chain.
+            Deferred[Set[str]]: Set of event IDs.
         """
 
-        return self.store.get_auth_chain_ids(
-            event_ids, include_given=True, ignore_events=ignore_events,
-        )
+        return self.store.get_auth_chain_difference(state_sets)

+ 4 - 28
synapse/state/v2.py

@@ -227,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
     Returns:
         Deferred[set[str]]: Set of event IDs
     """
-    common = set(itervalues(state_sets[0])).intersection(
-        *(itervalues(s) for s in state_sets[1:])
-    )
-
-    auth_sets = []
-    for state_set in state_sets:
-        auth_ids = {
-            eid
-            for key, eid in iteritems(state_set)
-            if (
-                key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
-                or key
-                in (
-                    (EventTypes.PowerLevels, ""),
-                    (EventTypes.Create, ""),
-                    (EventTypes.JoinRules, ""),
-                )
-            )
-            and eid not in common
-        }
 
-        auth_chain = yield state_res_store.get_auth_chain(auth_ids, common)
-        auth_ids.update(auth_chain)
-
-        auth_sets.append(auth_ids)
-
-    intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
-    union = set().union(*auth_sets)
+    difference = yield state_res_store.get_auth_chain_difference(
+        [set(state_set.values()) for state_set in state_sets]
+    )
 
-    return union - intersection
+    return difference
 
 
 def _seperate(state_sets):

+ 149 - 1
synapse/storage/data_stores/main/event_federation.py

@@ -14,7 +14,7 @@
 # limitations under the License.
 import itertools
 import logging
-from typing import List, Optional, Set
+from typing import Dict, List, Optional, Set, Tuple
 
 from six.moves.queue import Empty, PriorityQueue
 
@@ -103,6 +103,154 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return list(results)
 
+    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+        """Given sets of state events figure out the auth chain difference (as
+        per state res v2 algorithm).
+
+        This equivalent to fetching the full auth chain for each set of state
+        and returning the events that don't appear in each and every auth
+        chain.
+
+        Returns:
+            Deferred[Set[str]]
+        """
+
+        return self.db.runInteraction(
+            "get_auth_chain_difference",
+            self._get_auth_chain_difference_txn,
+            state_sets,
+        )
+
+    def _get_auth_chain_difference_txn(
+        self, txn, state_sets: List[Set[str]]
+    ) -> Set[str]:
+
+        # Algorithm Description
+        # ~~~~~~~~~~~~~~~~~~~~~
+        #
+        # The idea here is to basically walk the auth graph of each state set in
+        # tandem, keeping track of which auth events are reachable by each state
+        # set. If we reach an auth event we've already visited (via a different
+        # state set) then we mark that auth event and all ancestors as reachable
+        # by the state set. This requires that we keep track of the auth chains
+        # in memory.
+        #
+        # Doing it in a such a way means that we can stop early if all auth
+        # events we're currently walking are reachable by all state sets.
+        #
+        # *Note*: We can't stop walking an event's auth chain if it is reachable
+        # by all state sets. This is because other auth chains we're walking
+        # might be reachable only via the original auth chain. For example,
+        # given the following auth chain:
+        #
+        #       A -> C -> D -> E
+        #           /         /
+        #       B -´---------´
+        #
+        # and state sets {A} and {B} then walking the auth chains of A and B
+        # would immediately show that C is reachable by both. However, if we
+        # stopped at C then we'd only reach E via the auth chain of B and so E
+        # would errornously get included in the returned difference.
+        #
+        # The other thing that we do is limit the number of auth chains we walk
+        # at once, due to practical limits (i.e. we can only query the database
+        # with a limited set of parameters). We pick the auth chains we walk
+        # each iteration based on their depth, in the hope that events with a
+        # lower depth are likely reachable by those with higher depths.
+        #
+        # We could use any ordering that we believe would give a rough
+        # topological ordering, e.g. origin server timestamp. If the ordering
+        # chosen is not topological then the algorithm still produces the right
+        # result, but perhaps a bit more inefficiently. This is why it is safe
+        # to use "depth" here.
+
+        initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+        # Dict from events in auth chains to which sets *cannot* reach them.
+        # I.e. if the set is empty then all sets can reach the event.
+        event_to_missing_sets = {
+            event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
+            for event_id in initial_events
+        }
+
+        # We need to get the depth of the initial events for sorting purposes.
+        sql = """
+            SELECT depth, event_id FROM events
+            WHERE %s
+            ORDER BY depth ASC
+        """
+        clause, args = make_in_list_sql_clause(
+            txn.database_engine, "event_id", initial_events
+        )
+        txn.execute(sql % (clause,), args)
+
+        # The sorted list of events whose auth chains we should walk.
+        search = txn.fetchall()  # type: List[Tuple[int, str]]
+
+        # Map from event to its auth events
+        event_to_auth_events = {}  # type: Dict[str, Set[str]]
+
+        base_sql = """
+            SELECT a.event_id, auth_id, depth
+            FROM event_auth AS a
+            INNER JOIN events AS e ON (e.event_id = a.auth_id)
+            WHERE
+        """
+
+        while search:
+            # Check whether all our current walks are reachable by all state
+            # sets. If so we can bail.
+            if all(not event_to_missing_sets[eid] for _, eid in search):
+                break
+
+            # Fetch the auth events and their depths of the N last events we're
+            # currently walking
+            search, chunk = search[:-100], search[-100:]
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
+            )
+            txn.execute(base_sql + clause, args)
+
+            for event_id, auth_event_id, auth_event_depth in txn:
+                event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
+
+                sets = event_to_missing_sets.get(auth_event_id)
+                if sets is None:
+                    # First time we're seeing this event, so we add it to the
+                    # queue of things to fetch.
+                    search.append((auth_event_depth, auth_event_id))
+
+                    # Assume that this event is unreachable from any of the
+                    # state sets until proven otherwise
+                    sets = event_to_missing_sets[auth_event_id] = set(
+                        range(len(state_sets))
+                    )
+                else:
+                    # We've previously seen this event, so look up its auth
+                    # events and recursively mark all ancestors as reachable
+                    # by the current event's state set.
+                    a_ids = event_to_auth_events.get(auth_event_id)
+                    while a_ids:
+                        new_aids = set()
+                        for a_id in a_ids:
+                            event_to_missing_sets[a_id].intersection_update(
+                                event_to_missing_sets[event_id]
+                            )
+
+                            b = event_to_auth_events.get(a_id)
+                            if b:
+                                new_aids.update(b)
+
+                        a_ids = new_aids
+
+                # Mark that the auth event is reachable by the approriate sets.
+                sets.intersection_update(event_to_missing_sets[event_id])
+
+            search.sort()
+
+        # Return all events where not all sets can reach them.
+        return {eid for eid, n in event_to_missing_sets.items() if n}
+
     def get_oldest_events_in_room(self, room_id):
         return self.db.runInteraction(
             "get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id

+ 8 - 5
tests/state/test_v2.py

@@ -603,7 +603,7 @@ class TestStateResolutionStore(object):
 
         return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
 
-    def get_auth_chain(self, event_ids, ignore_events):
+    def _get_auth_chain(self, event_ids):
         """Gets the full auth chain for a set of events (including rejected
         events).
 
@@ -617,9 +617,6 @@ class TestStateResolutionStore(object):
         Args:
             event_ids (list): The event IDs of the events to fetch the auth
                 chain for. Must be state events.
-            ignore_events: Set of events to exclude from the returned auth
-                chain.
-
         Returns:
             Deferred[list[str]]: List of event IDs of the auth chain.
         """
@@ -629,7 +626,7 @@ class TestStateResolutionStore(object):
         stack = list(event_ids)
         while stack:
             event_id = stack.pop()
-            if event_id in result or event_id in ignore_events:
+            if event_id in result:
                 continue
 
             result.add(event_id)
@@ -639,3 +636,9 @@ class TestStateResolutionStore(object):
                 stack.append(aid)
 
         return list(result)
+
+    def get_auth_chain_difference(self, auth_sets):
+        chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
+
+        common = set(chains[0]).intersection(*chains[1:])
+        return set(chains[0]).union(*chains[1:]) - common

+ 140 - 17
tests/storage/test_event_federation.py

@@ -13,19 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
 import tests.unittest
 import tests.utils
 
 
-class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield tests.utils.setup_test_homeserver(self.addCleanup)
+class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
     def test_get_prev_events_for_room(self):
         room_id = "@ROOM:local"
 
@@ -61,15 +56,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
             )
 
         for i in range(0, 20):
-            yield self.store.db.runInteraction("insert", insert_event, i)
+            self.get_success(self.store.db.runInteraction("insert", insert_event, i))
 
         # this should get the last ten
-        r = yield self.store.get_prev_events_for_room(room_id)
+        r = self.get_success(self.store.get_prev_events_for_room(room_id))
         self.assertEqual(10, len(r))
         for i in range(0, 10):
             self.assertEqual("$event_%i:local" % (19 - i), r[i])
 
-    @defer.inlineCallbacks
     def test_get_rooms_with_many_extremities(self):
         room1 = "#room1"
         room2 = "#room2"
@@ -86,25 +80,154 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
             )
 
         for i in range(0, 20):
-            yield self.store.db.runInteraction("insert", insert_event, i, room1)
-            yield self.store.db.runInteraction("insert", insert_event, i, room2)
-            yield self.store.db.runInteraction("insert", insert_event, i, room3)
+            self.get_success(
+                self.store.db.runInteraction("insert", insert_event, i, room1)
+            )
+            self.get_success(
+                self.store.db.runInteraction("insert", insert_event, i, room2)
+            )
+            self.get_success(
+                self.store.db.runInteraction("insert", insert_event, i, room3)
+            )
 
         # Test simple case
-        r = yield self.store.get_rooms_with_many_extremities(5, 5, [])
+        r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, []))
         self.assertEqual(len(r), 3)
 
         # Does filter work?
 
-        r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1])
+        r = self.get_success(self.store.get_rooms_with_many_extremities(5, 5, [room1]))
         self.assertTrue(room2 in r)
         self.assertTrue(room3 in r)
         self.assertEqual(len(r), 2)
 
-        r = yield self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
+        r = self.get_success(
+            self.store.get_rooms_with_many_extremities(5, 5, [room1, room2])
+        )
         self.assertEqual(r, [room3])
 
         # Does filter and limit work?
 
-        r = yield self.store.get_rooms_with_many_extremities(5, 1, [room1])
+        r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1]))
         self.assertTrue(r == [room2] or r == [room3])
+
+    def test_auth_difference(self):
+        room_id = "@ROOM:local"
+
+        # The silly auth graph we use to test the auth difference algorithm,
+        # where the top are the most recent events.
+        #
+        #   A   B
+        #    \ /
+        #  D  E
+        #  \  |
+        #   ` F   C
+        #     |  /|
+        #     G ´ |
+        #     | \ |
+        #     H   I
+        #     |   |
+        #     K   J
+
+        auth_graph = {
+            "a": ["e"],
+            "b": ["e"],
+            "c": ["g", "i"],
+            "d": ["f"],
+            "e": ["f"],
+            "f": ["g"],
+            "g": ["h", "i"],
+            "h": ["k"],
+            "i": ["j"],
+            "k": [],
+            "j": [],
+        }
+
+        depth_map = {
+            "a": 7,
+            "b": 7,
+            "c": 4,
+            "d": 6,
+            "e": 6,
+            "f": 5,
+            "g": 3,
+            "h": 2,
+            "i": 2,
+            "k": 1,
+            "j": 1,
+        }
+
+        # We rudely fiddle with the appropriate tables directly, as that's much
+        # easier than constructing events properly.
+
+        def insert_event(txn, event_id, stream_ordering):
+
+            depth = depth_map[event_id]
+
+            self.store.db.simple_insert_txn(
+                txn,
+                table="events",
+                values={
+                    "event_id": event_id,
+                    "room_id": room_id,
+                    "depth": depth,
+                    "topological_ordering": depth,
+                    "type": "m.test",
+                    "processed": True,
+                    "outlier": False,
+                    "stream_ordering": stream_ordering,
+                },
+            )
+
+            self.store.db.simple_insert_many_txn(
+                txn,
+                table="event_auth",
+                values=[
+                    {"event_id": event_id, "room_id": room_id, "auth_id": a}
+                    for a in auth_graph[event_id]
+                ],
+            )
+
+        next_stream_ordering = 0
+        for event_id in auth_graph:
+            next_stream_ordering += 1
+            self.get_success(
+                self.store.db.runInteraction(
+                    "insert", insert_event, event_id, next_stream_ordering
+                )
+            )
+
+        # Now actually test that various combinations give the right result:
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "d", "e"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+        )
+        self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
+
+        difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+        self.assertSetEqual(difference, set())