|
@@ -15,7 +15,17 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
import json
|
|
|
-from typing import List
|
|
|
+from typing import (
|
|
|
+ TYPE_CHECKING,
|
|
|
+ Awaitable,
|
|
|
+ Container,
|
|
|
+ Iterable,
|
|
|
+ List,
|
|
|
+ Optional,
|
|
|
+ Set,
|
|
|
+ TypeVar,
|
|
|
+ Union,
|
|
|
+)
|
|
|
|
|
|
import jsonschema
|
|
|
from jsonschema import FormatChecker
|
|
@@ -23,7 +33,11 @@ from jsonschema import FormatChecker
|
|
|
from synapse.api.constants import EventContentFields
|
|
|
from synapse.api.errors import SynapseError
|
|
|
from synapse.api.presence import UserPresenceState
|
|
|
-from synapse.types import RoomID, UserID
|
|
|
+from synapse.events import EventBase
|
|
|
+from synapse.types import JsonDict, RoomID, UserID
|
|
|
+
|
|
|
+if TYPE_CHECKING:
|
|
|
+ from synapse.server import HomeServer
|
|
|
|
|
|
FILTER_SCHEMA = {
|
|
|
"additionalProperties": False,
|
|
@@ -120,25 +134,29 @@ USER_FILTER_SCHEMA = {
|
|
|
|
|
|
|
|
|
@FormatChecker.cls_checks("matrix_room_id")
|
|
|
-def matrix_room_id_validator(room_id_str):
|
|
|
+def matrix_room_id_validator(room_id_str: str) -> RoomID:
|
|
|
return RoomID.from_string(room_id_str)
|
|
|
|
|
|
|
|
|
@FormatChecker.cls_checks("matrix_user_id")
|
|
|
-def matrix_user_id_validator(user_id_str):
|
|
|
+def matrix_user_id_validator(user_id_str: str) -> UserID:
|
|
|
return UserID.from_string(user_id_str)
|
|
|
|
|
|
|
|
|
class Filtering:
|
|
|
- def __init__(self, hs):
|
|
|
+ def __init__(self, hs: "HomeServer"):
|
|
|
super().__init__()
|
|
|
self.store = hs.get_datastore()
|
|
|
|
|
|
- async def get_user_filter(self, user_localpart, filter_id):
|
|
|
+ async def get_user_filter(
|
|
|
+ self, user_localpart: str, filter_id: Union[int, str]
|
|
|
+ ) -> "FilterCollection":
|
|
|
result = await self.store.get_user_filter(user_localpart, filter_id)
|
|
|
return FilterCollection(result)
|
|
|
|
|
|
- def add_user_filter(self, user_localpart, user_filter):
|
|
|
+ def add_user_filter(
|
|
|
+ self, user_localpart: str, user_filter: JsonDict
|
|
|
+ ) -> Awaitable[int]:
|
|
|
self.check_valid_filter(user_filter)
|
|
|
return self.store.add_user_filter(user_localpart, user_filter)
|
|
|
|
|
@@ -146,13 +164,13 @@ class Filtering:
|
|
|
# replace_user_filter at some point? There's no REST API specified for
|
|
|
# them however
|
|
|
|
|
|
- def check_valid_filter(self, user_filter_json):
|
|
|
+ def check_valid_filter(self, user_filter_json: JsonDict) -> None:
|
|
|
"""Check if the provided filter is valid.
|
|
|
|
|
|
This inspects all definitions contained within the filter.
|
|
|
|
|
|
Args:
|
|
|
- user_filter_json(dict): The filter
|
|
|
+ user_filter_json: The filter
|
|
|
Raises:
|
|
|
SynapseError: If the filter is not valid.
|
|
|
"""
|
|
@@ -167,8 +185,12 @@ class Filtering:
|
|
|
raise SynapseError(400, str(e))
|
|
|
|
|
|
|
|
|
+# Filters work across events, presence EDUs, and account data.
|
|
|
+FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
|
|
|
+
|
|
|
+
|
|
|
class FilterCollection:
|
|
|
- def __init__(self, filter_json):
|
|
|
+ def __init__(self, filter_json: JsonDict):
|
|
|
self._filter_json = filter_json
|
|
|
|
|
|
room_filter_json = self._filter_json.get("room", {})
|
|
@@ -188,25 +210,25 @@ class FilterCollection:
|
|
|
self.event_fields = filter_json.get("event_fields", [])
|
|
|
self.event_format = filter_json.get("event_format", "client")
|
|
|
|
|
|
- def __repr__(self):
|
|
|
+ def __repr__(self) -> str:
|
|
|
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
|
|
|
|
|
|
- def get_filter_json(self):
|
|
|
+ def get_filter_json(self) -> JsonDict:
|
|
|
return self._filter_json
|
|
|
|
|
|
- def timeline_limit(self):
|
|
|
+ def timeline_limit(self) -> int:
|
|
|
return self._room_timeline_filter.limit()
|
|
|
|
|
|
- def presence_limit(self):
|
|
|
+ def presence_limit(self) -> int:
|
|
|
return self._presence_filter.limit()
|
|
|
|
|
|
- def ephemeral_limit(self):
|
|
|
+ def ephemeral_limit(self) -> int:
|
|
|
return self._room_ephemeral_filter.limit()
|
|
|
|
|
|
- def lazy_load_members(self):
|
|
|
+ def lazy_load_members(self) -> bool:
|
|
|
return self._room_state_filter.lazy_load_members()
|
|
|
|
|
|
- def include_redundant_members(self):
|
|
|
+ def include_redundant_members(self) -> bool:
|
|
|
return self._room_state_filter.include_redundant_members()
|
|
|
|
|
|
def filter_presence(self, events):
|
|
@@ -218,29 +240,31 @@ class FilterCollection:
|
|
|
def filter_room_state(self, events):
|
|
|
return self._room_state_filter.filter(self._room_filter.filter(events))
|
|
|
|
|
|
- def filter_room_timeline(self, events):
|
|
|
+ def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
|
|
return self._room_timeline_filter.filter(self._room_filter.filter(events))
|
|
|
|
|
|
- def filter_room_ephemeral(self, events):
|
|
|
+ def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
|
|
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
|
|
|
|
|
|
- def filter_room_account_data(self, events):
|
|
|
+ def filter_room_account_data(
|
|
|
+ self, events: Iterable[FilterEvent]
|
|
|
+ ) -> List[FilterEvent]:
|
|
|
return self._room_account_data.filter(self._room_filter.filter(events))
|
|
|
|
|
|
- def blocks_all_presence(self):
|
|
|
+ def blocks_all_presence(self) -> bool:
|
|
|
return (
|
|
|
self._presence_filter.filters_all_types()
|
|
|
or self._presence_filter.filters_all_senders()
|
|
|
)
|
|
|
|
|
|
- def blocks_all_room_ephemeral(self):
|
|
|
+ def blocks_all_room_ephemeral(self) -> bool:
|
|
|
return (
|
|
|
self._room_ephemeral_filter.filters_all_types()
|
|
|
or self._room_ephemeral_filter.filters_all_senders()
|
|
|
or self._room_ephemeral_filter.filters_all_rooms()
|
|
|
)
|
|
|
|
|
|
- def blocks_all_room_timeline(self):
|
|
|
+ def blocks_all_room_timeline(self) -> bool:
|
|
|
return (
|
|
|
self._room_timeline_filter.filters_all_types()
|
|
|
or self._room_timeline_filter.filters_all_senders()
|
|
@@ -249,7 +273,7 @@ class FilterCollection:
|
|
|
|
|
|
|
|
|
class Filter:
|
|
|
- def __init__(self, filter_json):
|
|
|
+ def __init__(self, filter_json: JsonDict):
|
|
|
self.filter_json = filter_json
|
|
|
|
|
|
self.types = self.filter_json.get("types", None)
|
|
@@ -266,20 +290,20 @@ class Filter:
|
|
|
self.labels = self.filter_json.get("org.matrix.labels", None)
|
|
|
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
|
|
|
|
|
|
- def filters_all_types(self):
|
|
|
+ def filters_all_types(self) -> bool:
|
|
|
return "*" in self.not_types
|
|
|
|
|
|
- def filters_all_senders(self):
|
|
|
+ def filters_all_senders(self) -> bool:
|
|
|
return "*" in self.not_senders
|
|
|
|
|
|
- def filters_all_rooms(self):
|
|
|
+ def filters_all_rooms(self) -> bool:
|
|
|
return "*" in self.not_rooms
|
|
|
|
|
|
- def check(self, event):
|
|
|
+ def check(self, event: FilterEvent) -> bool:
|
|
|
"""Checks whether the filter matches the given event.
|
|
|
|
|
|
Returns:
|
|
|
- bool: True if the event matches
|
|
|
+ True if the event matches
|
|
|
"""
|
|
|
# We usually get the full "events" as dictionaries coming through,
|
|
|
# except for presence which actually gets passed around as its own
|
|
@@ -305,18 +329,25 @@ class Filter:
|
|
|
room_id = event.get("room_id", None)
|
|
|
ev_type = event.get("type", None)
|
|
|
|
|
|
- content = event.get("content", {})
|
|
|
+ content = event.get("content") or {}
|
|
|
# check if there is a string url field in the content for filtering purposes
|
|
|
contains_url = isinstance(content.get("url"), str)
|
|
|
labels = content.get(EventContentFields.LABELS, [])
|
|
|
|
|
|
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
|
|
|
|
|
|
- def check_fields(self, room_id, sender, event_type, labels, contains_url):
|
|
|
+ def check_fields(
|
|
|
+ self,
|
|
|
+ room_id: Optional[str],
|
|
|
+ sender: Optional[str],
|
|
|
+ event_type: Optional[str],
|
|
|
+ labels: Container[str],
|
|
|
+ contains_url: bool,
|
|
|
+ ) -> bool:
|
|
|
"""Checks whether the filter matches the given event fields.
|
|
|
|
|
|
Returns:
|
|
|
- bool: True if the event fields match
|
|
|
+ True if the event fields match
|
|
|
"""
|
|
|
literal_keys = {
|
|
|
"rooms": lambda v: room_id == v,
|
|
@@ -343,14 +374,14 @@ class Filter:
|
|
|
|
|
|
return True
|
|
|
|
|
|
- def filter_rooms(self, room_ids):
|
|
|
+ def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
|
|
|
"""Apply the 'rooms' filter to a given list of rooms.
|
|
|
|
|
|
Args:
|
|
|
- room_ids (list): A list of room_ids.
|
|
|
+ room_ids: A list of room_ids.
|
|
|
|
|
|
Returns:
|
|
|
- list: A list of room_ids that match the filter
|
|
|
+ A list of room_ids that match the filter
|
|
|
"""
|
|
|
room_ids = set(room_ids)
|
|
|
|
|
@@ -363,23 +394,23 @@ class Filter:
|
|
|
|
|
|
return room_ids
|
|
|
|
|
|
- def filter(self, events):
|
|
|
+ def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
|
|
|
return list(filter(self.check, events))
|
|
|
|
|
|
- def limit(self):
|
|
|
+ def limit(self) -> int:
|
|
|
return self.filter_json.get("limit", 10)
|
|
|
|
|
|
- def lazy_load_members(self):
|
|
|
+ def lazy_load_members(self) -> bool:
|
|
|
return self.filter_json.get("lazy_load_members", False)
|
|
|
|
|
|
- def include_redundant_members(self):
|
|
|
+ def include_redundant_members(self) -> bool:
|
|
|
return self.filter_json.get("include_redundant_members", False)
|
|
|
|
|
|
- def with_room_ids(self, room_ids):
|
|
|
+ def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
|
|
|
"""Returns a new filter with the given room IDs appended.
|
|
|
|
|
|
Args:
|
|
|
- room_ids (iterable[unicode]): The room_ids to add
|
|
|
+ room_ids: The room_ids to add
|
|
|
|
|
|
Returns:
|
|
|
filter: A new filter including the given rooms and the old
|
|
@@ -390,8 +421,8 @@ class Filter:
|
|
|
return newFilter
|
|
|
|
|
|
|
|
|
-def _matches_wildcard(actual_value, filter_value):
|
|
|
- if filter_value.endswith("*"):
|
|
|
+def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
|
|
|
+ if filter_value.endswith("*") and isinstance(actual_value, str):
|
|
|
type_prefix = filter_value[:-1]
|
|
|
return actual_value.startswith(type_prefix)
|
|
|
else:
|