Browse Source

Make StateFilter frozen so we can hash it (#10816)

Also enables Mypy for related tests.
reivilibre 2 years ago
parent
commit
8eb7cb2e0d
4 changed files with 63 additions and 30 deletions
  1. 1 0
      changelog.d/10816.misc
  2. 1 0
      mypy.ini
  3. 32 13
      synapse/storage/state.py
  4. 29 17
      tests/storage/test_state.py

+ 1 - 0
changelog.d/10816.misc

@@ -0,0 +1 @@
+Make `StateFilter` frozen so it is hashable.

+ 1 - 0
mypy.ini

@@ -86,6 +86,7 @@ files =
   tests/handlers/test_sync.py,
   tests/rest/client/test_login.py,
   tests/rest/client/test_auth.py,
+  tests/storage/test_state.py,
   tests/util/test_itertools.py,
   tests/util/test_stream_change_cache.py
 

+ 32 - 13
synapse/storage/state.py

@@ -25,12 +25,15 @@ from typing import (
 )
 
 import attr
+from frozendict import frozendict
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
 
 if TYPE_CHECKING:
+    from typing import FrozenSet  # noqa: used within quoted type hint; flake8 sad
+
     from synapse.server import HomeServer
     from synapse.storage.databases import Databases
 
@@ -40,7 +43,7 @@ logger = logging.getLogger(__name__)
 T = TypeVar("T")
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, frozen=True)
 class StateFilter:
     """A filter used when querying for state.
 
@@ -53,14 +56,19 @@ class StateFilter:
             appear in `types`.
     """
 
-    types = attr.ib(type=Dict[str, Optional[Set[str]]])
+    types = attr.ib(type="frozendict[str, Optional[FrozenSet[str]]]")
     include_others = attr.ib(default=False, type=bool)
 
     def __attrs_post_init__(self):
         # If `include_others` is set we canonicalise the filter by removing
         # wildcards from the types dictionary
         if self.include_others:
-            self.types = {k: v for k, v in self.types.items() if v is not None}
+            # this is needed to work around the fact that StateFilter is frozen
+            object.__setattr__(
+                self,
+                "types",
+                frozendict({k: v for k, v in self.types.items() if v is not None}),
+            )
 
     @staticmethod
     def all() -> "StateFilter":
@@ -69,7 +77,7 @@ class StateFilter:
         Returns:
             The new state filter.
         """
-        return StateFilter(types={}, include_others=True)
+        return StateFilter(types=frozendict(), include_others=True)
 
     @staticmethod
     def none() -> "StateFilter":
@@ -78,7 +86,7 @@ class StateFilter:
         Returns:
             The new state filter.
         """
-        return StateFilter(types={}, include_others=False)
+        return StateFilter(types=frozendict(), include_others=False)
 
     @staticmethod
     def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
@@ -103,7 +111,12 @@ class StateFilter:
 
             type_dict.setdefault(typ, set()).add(s)  # type: ignore
 
-        return StateFilter(types=type_dict)
+        return StateFilter(
+            types=frozendict(
+                (k, frozenset(v) if v is not None else None)
+                for k, v in type_dict.items()
+            )
+        )
 
     @staticmethod
     def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
@@ -116,7 +129,10 @@ class StateFilter:
         Returns:
             The new state filter
         """
-        return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
+        return StateFilter(
+            types=frozendict({EventTypes.Member: frozenset(members)}),
+            include_others=True,
+        )
 
     def return_expanded(self) -> "StateFilter":
         """Creates a new StateFilter where type wild cards have been removed
@@ -173,7 +189,7 @@ class StateFilter:
             # We want to return all non-members, but only particular
             # memberships
             return StateFilter(
-                types={EventTypes.Member: self.types[EventTypes.Member]},
+                types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
                 include_others=True,
             )
 
@@ -245,14 +261,15 @@ class StateFilter:
 
         return len(self.concrete_types())
 
-    def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
-        """Returns the state filtered with by this StateFilter
+    def filter_state(self, state_dict: StateMap[T]) -> MutableStateMap[T]:
+        """Returns the state filtered with by this StateFilter.
 
         Args:
             state: The state map to filter
 
         Returns:
-            The filtered state map
+            The filtered state map.
+            This is a copy, so it's safe to mutate.
         """
         if self.is_full():
             return dict(state_dict)
@@ -324,14 +341,16 @@ class StateFilter:
             if state_keys is None:
                 member_filter = StateFilter.all()
             else:
-                member_filter = StateFilter({EventTypes.Member: state_keys})
+                member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
         elif self.include_others:
             member_filter = StateFilter.all()
         else:
             member_filter = StateFilter.none()
 
         non_member_filter = StateFilter(
-            types={k: v for k, v in self.types.items() if k != EventTypes.Member},
+            types=frozendict(
+                {k: v for k, v in self.types.items() if k != EventTypes.Member}
+            ),
             include_others=self.include_others,
         )
 

+ 29 - 17
tests/storage/test_state.py

@@ -14,6 +14,8 @@
 
 import logging
 
+from frozendict import frozendict
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.storage.state import StateFilter
@@ -183,7 +185,9 @@ class StateStoreTestCase(HomeserverTestCase):
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
-                    types={EventTypes.Member: {self.u_alice.to_string()}},
+                    types=frozendict(
+                        {EventTypes.Member: frozenset({self.u_alice.to_string()})}
+                    ),
                     include_others=True,
                 ),
             )
@@ -203,7 +207,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.storage.state.get_state_for_event(
                 e5.event_id,
                 state_filter=StateFilter(
-                    types={EventTypes.Member: set()}, include_others=True
+                    types=frozendict({EventTypes.Member: frozenset()}),
+                    include_others=True,
                 ),
             )
         )
@@ -228,7 +233,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset()}), include_others=True
             ),
         )
 
@@ -245,7 +250,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset()}), include_others=True
             ),
         )
 
@@ -258,7 +263,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -275,7 +280,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -295,7 +300,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=True,
             ),
         )
 
@@ -312,7 +318,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=True,
             ),
         )
 
@@ -325,7 +332,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=False
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=False,
             ),
         )
 
@@ -375,7 +383,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset()}), include_others=True
             ),
         )
 
@@ -387,7 +395,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: set()}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset()}), include_others=True
             ),
         )
 
@@ -400,7 +408,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -411,7 +419,7 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: None}, include_others=True
+                types=frozendict({EventTypes.Member: None}), include_others=True
             ),
         )
 
@@ -430,7 +438,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=True,
             ),
         )
 
@@ -441,7 +450,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=True
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=True,
             ),
         )
 
@@ -454,7 +464,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=False
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=False,
             ),
         )
 
@@ -465,7 +476,8 @@ class StateStoreTestCase(HomeserverTestCase):
             self.state_datastore._state_group_members_cache,
             group,
             state_filter=StateFilter(
-                types={EventTypes.Member: {e5.state_key}}, include_others=False
+                types=frozendict({EventTypes.Member: frozenset({e5.state_key})}),
+                include_others=False,
             ),
         )