Browse Source

Add an approximate difference method to StateFilters (#10825)

reivilibre 2 years ago
parent
commit
406f7bfa17
3 changed files with 683 additions and 3 deletions
  1. 1 0
      changelog.d/10825.misc
  2. 171 1
      synapse/storage/state.py
  3. 511 2
      tests/storage/test_state.py

+ 1 - 0
changelog.d/10825.misc

@@ -0,0 +1 @@
+Add an 'approximate difference' method to `StateFilter`.

+ 171 - 1
synapse/storage/state.py

@@ -15,9 +15,11 @@ import logging
 from typing import (
     TYPE_CHECKING,
     Awaitable,
+    Collection,
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -29,7 +31,7 @@ from frozendict import frozendict
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import MutableStateMap, StateKey, StateMap
 
 if TYPE_CHECKING:
     from typing import FrozenSet  # noqa: used within quoted type hint; flake8 sad
@@ -134,6 +136,23 @@ class StateFilter:
             include_others=True,
         )
 
+    @staticmethod
+    def freeze(types: Mapping[str, Optional[Collection[str]]], include_others: bool):
+        """
+        Returns a (frozen) StateFilter with the same contents as the parameters
+        specified here, which can be made of mutable types.
+        """
+        types_with_frozen_values: Dict[str, Optional[FrozenSet[str]]] = {}
+        for state_types, state_keys in types.items():
+            if state_keys is not None:
+                types_with_frozen_values[state_types] = frozenset(state_keys)
+            else:
+                types_with_frozen_values[state_types] = None
+
+        return StateFilter(
+            frozendict(types_with_frozen_values), include_others=include_others
+        )
+
     def return_expanded(self) -> "StateFilter":
         """Creates a new StateFilter where type wild cards have been removed
         (except for memberships). The returned filter is a superset of the
@@ -356,6 +375,157 @@ class StateFilter:
 
         return member_filter, non_member_filter
 
+    def _decompose_into_four_parts(
+        self,
+    ) -> Tuple[Tuple[bool, Set[str]], Tuple[Set[str], Set[StateKey]]]:
+        """
+        Decomposes this state filter into 4 constituent parts, which can be
+        thought of as this:
+            all? - minus_wildcards + plus_wildcards + plus_state_keys
+
+        where
+        * all represents ALL state
+        * minus_wildcards represents entire state types to remove
+        * plus_wildcards represents entire state types to add
+        * plus_state_keys represents individual state keys to add
+
+        See `recompose_from_four_parts` for the other direction of this
+        correspondence.
+        """
+        is_all = self.include_others
+        excluded_types: Set[str] = {t for t in self.types if is_all}
+        wildcard_types: Set[str] = {t for t, s in self.types.items() if s is None}
+        concrete_keys: Set[StateKey] = set(self.concrete_types())
+
+        return (is_all, excluded_types), (wildcard_types, concrete_keys)
+
+    @staticmethod
+    def _recompose_from_four_parts(
+        all_part: bool,
+        minus_wildcards: Set[str],
+        plus_wildcards: Set[str],
+        plus_state_keys: Set[StateKey],
+    ) -> "StateFilter":
+        """
+        Recomposes a state filter from 4 parts.
+
+        See `decompose_into_four_parts` (the other direction of this
+        correspondence) for descriptions on each of the parts.
+        """
+
+        # {state type -> set of state keys OR None for wildcard}
+        # (The same structure as that of a StateFilter.)
+        new_types: Dict[str, Optional[Set[str]]] = {}
+
+        # if we start with all, insert the excluded statetypes as empty sets
+        # to prevent them from being included
+        if all_part:
+            new_types.update({state_type: set() for state_type in minus_wildcards})
+
+        # insert the plus wildcards
+        new_types.update({state_type: None for state_type in plus_wildcards})
+
+        # insert the specific state keys
+        for state_type, state_key in plus_state_keys:
+            if state_type in new_types:
+                entry = new_types[state_type]
+                if entry is not None:
+                    entry.add(state_key)
+            elif not all_part:
+                # don't insert if the entire type is already included by
+                # include_others as this would actually shrink the state allowed
+                # by this filter.
+                new_types[state_type] = {state_key}
+
+        return StateFilter.freeze(new_types, include_others=all_part)
+
+    def approx_difference(self, other: "StateFilter") -> "StateFilter":
+        """
+        Returns a state filter which represents `self - other`.
+
+        This is useful for determining what state remains to be pulled out of the
+        database if we want the state included by `self` but already have the state
+        included by `other`.
+
+        The returned state filter
+        - MUST include all state events that are included by this filter (`self`)
+          unless they are included by `other`;
+        - MUST NOT include state events not included by this filter (`self`); and
+        - MAY be an over-approximation: the returned state filter
+          MAY additionally include some state events from `other`.
+
+        This implementation attempts to return the narrowest such state filter.
+        In the case that `self` contains wildcards for state types where
+        `other` contains specific state keys, an approximation must be made:
+        the returned state filter keeps the wildcard, as state filters are not
+        able to express 'all state keys except some given examples'.
+        e.g.
+            StateFilter(m.room.member -> None (wildcard))
+                minus
+            StateFilter(m.room.member -> {'@wombat:example.org'})
+                is approximated as
+            StateFilter(m.room.member -> None (wildcard))
+        """
+
+        # We first transform self and other into an alternative representation:
+        #   - whether or not they include all events to begin with ('all')
+        #   - if so, which event types are excluded? ('excludes')
+        #   - which entire event types to include ('wildcards')
+        #   - which concrete state keys to include ('concrete state keys')
+        (self_all, self_excludes), (
+            self_wildcards,
+            self_concrete_keys,
+        ) = self._decompose_into_four_parts()
+        (other_all, other_excludes), (
+            other_wildcards,
+            other_concrete_keys,
+        ) = other._decompose_into_four_parts()
+
+        # Start with an estimate of the difference based on self
+        new_all = self_all
+        # Wildcards from the other can be added to the exclusion filter
+        new_excludes = self_excludes | other_wildcards
+        # We remove wildcards that appeared as wildcards in the other
+        new_wildcards = self_wildcards - other_wildcards
+        # We filter out the concrete state keys that appear in the other
+        # as wildcards or concrete state keys.
+        new_concrete_keys = {
+            (state_type, state_key)
+            for (state_type, state_key) in self_concrete_keys
+            if state_type not in other_wildcards
+        } - other_concrete_keys
+
+        if other_all:
+            if self_all:
+                # If self starts with all, then we add as wildcards any
+                # types which appear in the other's exclusion filter (but
+                # aren't in the self exclusion filter). This is as the other
+                # filter will return everything BUT the types in its exclusion, so
+                # we need to add those excluded types that also match the self
+                # filter as wildcard types in the new filter.
+                new_wildcards |= other_excludes.difference(self_excludes)
+
+            # If other is an `include_others` then the difference isn't.
+            new_all = False
+            # (We have no need for excludes when we don't start with all, as there
+            #  is nothing to exclude.)
+            new_excludes = set()
+
+            # We also filter out all state types that aren't in the exclusion
+            # list of the other.
+            new_wildcards &= other_excludes
+            new_concrete_keys = {
+                (state_type, state_key)
+                for (state_type, state_key) in new_concrete_keys
+                if state_type in other_excludes
+            }
+
+        # Transform our newly-constructed state filter from the alternative
+        # representation back into the normal StateFilter representation.
+        return StateFilter._recompose_from_four_parts(
+            new_all, new_excludes, new_wildcards, new_concrete_keys
+        )
+
 
 class StateGroupStorage:
     """High level interface to fetching state for event."""

+ 511 - 2
tests/storage/test_state.py

@@ -21,7 +21,7 @@ from synapse.api.room_versions import RoomVersions
 from synapse.storage.state import StateFilter
 from synapse.types import RoomID, UserID
 
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, TestCase
 
 logger = logging.getLogger(__name__)
 
@@ -105,7 +105,6 @@ class StateStoreTestCase(HomeserverTestCase):
         self.assertEqual({ev.event_id for ev in state_list}, {e1.event_id, e2.event_id})
 
     def test_get_state_for_event(self):
-
         # this defaults to a linear DAG as each new injection defaults to whatever
         # forward extremities are currently in the DB for this room.
         e1 = self.inject_state_event(self.room, self.u_alice, EventTypes.Create, "", {})
@@ -483,3 +482,513 @@ class StateStoreTestCase(HomeserverTestCase):
 
         self.assertEqual(is_all, True)
         self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
+
+
+class StateFilterDifferenceTestCase(TestCase):
+    def assert_difference(
+        self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
+    ):
+        self.assertEqual(
+            minuend.approx_difference(subtrahend),
+            expected,
+            f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
+        )
+
+    def test_state_filter_difference_no_include_other_minus_no_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), both a and b do not have the
+        include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+            StateFilter.freeze({EventTypes.Create: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:spqr"}},
+                include_others=False,
+            ),
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.CanonicalAlias: {""}},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_include_other_minus_no_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), only a has the include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Create: None,
+                    EventTypes.Member: set(),
+                    EventTypes.CanonicalAlias: set(),
+                },
+                include_others=True,
+            ),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        # This also shows that the resultant state filter is normalised.
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.Create: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter(types=frozendict(), include_others=True),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.CanonicalAlias: {""},
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+        )
+
+    def test_state_filter_difference_include_other_minus_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), both a and b have the include_others
+        flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=True,
+            ),
+            StateFilter(types=frozendict(), include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=False,
+            ),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                    EventTypes.Create: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                    EventTypes.Create: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                    EventTypes.Create: {""},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_no_include_other_minus_include_other(self):
+        """
+        Tests the StateFilter.approx_difference method
+        where, in a.approx_difference(b), only b has the include_others flag set.
+        """
+        # (wildcard on state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.Create: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+                include_others=True,
+            ),
+            StateFilter(types=frozendict(), include_others=False),
+        )
+
+        # (wildcard on state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+            StateFilter.freeze(
+                {EventTypes.Member: {"@wombat:spqr"}},
+                include_others=True,
+            ),
+            StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+        )
+
+        # (wildcard on state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (wildcard on state keys):
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {EventTypes.Member: None},
+                include_others=True,
+            ),
+            StateFilter(
+                types=frozendict(),
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (specific state keys)
+        # This one is an over-approximation because we can't represent
+        # 'all state keys except a few named examples'
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr"},
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+        # (specific state keys) - (no state keys)
+        self.assert_difference(
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                    EventTypes.CanonicalAlias: {""},
+                },
+                include_others=False,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: set(),
+                },
+                include_others=True,
+            ),
+            StateFilter.freeze(
+                {
+                    EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+                },
+                include_others=False,
+            ),
+        )
+
+    def test_state_filter_difference_simple_cases(self):
+        """
+        Tests some very simple cases of the StateFilter approx_difference,
+        that are not explicitly tested by the more in-depth tests.
+        """
+
+        self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
+
+        self.assert_difference(
+            StateFilter.all(),
+            StateFilter.none(),
+            StateFilter.all(),
+        )