filtering.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from synapse.api.errors import SynapseError
  16. from synapse.types import UserID, RoomID
  17. import ujson as json
  18. class Filtering(object):
  19. def __init__(self, hs):
  20. super(Filtering, self).__init__()
  21. self.store = hs.get_datastore()
  22. def get_user_filter(self, user_localpart, filter_id):
  23. result = self.store.get_user_filter(user_localpart, filter_id)
  24. result.addCallback(FilterCollection)
  25. return result
  26. def add_user_filter(self, user_localpart, user_filter):
  27. self.check_valid_filter(user_filter)
  28. return self.store.add_user_filter(user_localpart, user_filter)
  29. # TODO(paul): surely we should probably add a delete_user_filter or
  30. # replace_user_filter at some point? There's no REST API specified for
  31. # them however
  32. def check_valid_filter(self, user_filter_json):
  33. """Check if the provided filter is valid.
  34. This inspects all definitions contained within the filter.
  35. Args:
  36. user_filter_json(dict): The filter
  37. Raises:
  38. SynapseError: If the filter is not valid.
  39. """
  40. # NB: Filters are the complete json blobs. "Definitions" are an
  41. # individual top-level key e.g. public_user_data. Filters are made of
  42. # many definitions.
  43. top_level_definitions = [
  44. "presence", "account_data"
  45. ]
  46. room_level_definitions = [
  47. "state", "timeline", "ephemeral", "account_data"
  48. ]
  49. for key in top_level_definitions:
  50. if key in user_filter_json:
  51. self._check_definition(user_filter_json[key])
  52. if "room" in user_filter_json:
  53. self._check_definition_room_lists(user_filter_json["room"])
  54. for key in room_level_definitions:
  55. if key in user_filter_json["room"]:
  56. self._check_definition(user_filter_json["room"][key])
  57. def _check_definition_room_lists(self, definition):
  58. """Check that "rooms" and "not_rooms" are lists of room ids if they
  59. are present
  60. Args:
  61. definition(dict): The filter definition
  62. Raises:
  63. SynapseError: If there was a problem with this definition.
  64. """
  65. # check rooms are valid room IDs
  66. room_id_keys = ["rooms", "not_rooms"]
  67. for key in room_id_keys:
  68. if key in definition:
  69. if type(definition[key]) != list:
  70. raise SynapseError(400, "Expected %s to be a list." % key)
  71. for room_id in definition[key]:
  72. RoomID.from_string(room_id)
  73. def _check_definition(self, definition):
  74. """Check if the provided definition is valid.
  75. This inspects not only the types but also the values to make sure they
  76. make sense.
  77. Args:
  78. definition(dict): The filter definition
  79. Raises:
  80. SynapseError: If there was a problem with this definition.
  81. """
  82. # NB: Filters are the complete json blobs. "Definitions" are an
  83. # individual top-level key e.g. public_user_data. Filters are made of
  84. # many definitions.
  85. if type(definition) != dict:
  86. raise SynapseError(
  87. 400, "Expected JSON object, not %s" % (definition,)
  88. )
  89. self._check_definition_room_lists(definition)
  90. # check senders are valid user IDs
  91. user_id_keys = ["senders", "not_senders"]
  92. for key in user_id_keys:
  93. if key in definition:
  94. if type(definition[key]) != list:
  95. raise SynapseError(400, "Expected %s to be a list." % key)
  96. for user_id in definition[key]:
  97. UserID.from_string(user_id)
  98. # TODO: We don't limit event type values but we probably should...
  99. # check types are valid event types
  100. event_keys = ["types", "not_types"]
  101. for key in event_keys:
  102. if key in definition:
  103. if type(definition[key]) != list:
  104. raise SynapseError(400, "Expected %s to be a list." % key)
  105. for event_type in definition[key]:
  106. if not isinstance(event_type, basestring):
  107. raise SynapseError(400, "Event type should be a string")
  108. class FilterCollection(object):
  109. def __init__(self, filter_json):
  110. self._filter_json = filter_json
  111. room_filter_json = self._filter_json.get("room", {})
  112. self._room_filter = Filter({
  113. k: v for k, v in room_filter_json.items()
  114. if k in ("rooms", "not_rooms")
  115. })
  116. self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
  117. self._room_state_filter = Filter(room_filter_json.get("state", {}))
  118. self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
  119. self._room_account_data = Filter(room_filter_json.get("account_data", {}))
  120. self._presence_filter = Filter(filter_json.get("presence", {}))
  121. self._account_data = Filter(filter_json.get("account_data", {}))
  122. self.include_leave = filter_json.get("room", {}).get(
  123. "include_leave", False
  124. )
  125. def __repr__(self):
  126. return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
  127. def get_filter_json(self):
  128. return self._filter_json
  129. def timeline_limit(self):
  130. return self._room_timeline_filter.limit()
  131. def presence_limit(self):
  132. return self._presence_filter.limit()
  133. def ephemeral_limit(self):
  134. return self._room_ephemeral_filter.limit()
  135. def filter_presence(self, events):
  136. return self._presence_filter.filter(events)
  137. def filter_account_data(self, events):
  138. return self._account_data.filter(events)
  139. def filter_room_state(self, events):
  140. return self._room_state_filter.filter(self._room_filter.filter(events))
  141. def filter_room_timeline(self, events):
  142. return self._room_timeline_filter.filter(self._room_filter.filter(events))
  143. def filter_room_ephemeral(self, events):
  144. return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
  145. def filter_room_account_data(self, events):
  146. return self._room_account_data.filter(self._room_filter.filter(events))
  147. class Filter(object):
  148. def __init__(self, filter_json):
  149. self.filter_json = filter_json
  150. def check(self, event):
  151. """Checks whether the filter matches the given event.
  152. Returns:
  153. bool: True if the event matches
  154. """
  155. sender = event.get("sender", None)
  156. if not sender:
  157. # Presence events have their 'sender' in content.user_id
  158. content = event.get("content")
  159. # account_data has been allowed to have non-dict content, so check type first
  160. if isinstance(content, dict):
  161. sender = content.get("user_id")
  162. return self.check_fields(
  163. event.get("room_id", None),
  164. sender,
  165. event.get("type", None),
  166. )
  167. def check_fields(self, room_id, sender, event_type):
  168. """Checks whether the filter matches the given event fields.
  169. Returns:
  170. bool: True if the event fields match
  171. """
  172. literal_keys = {
  173. "rooms": lambda v: room_id == v,
  174. "senders": lambda v: sender == v,
  175. "types": lambda v: _matches_wildcard(event_type, v)
  176. }
  177. for name, match_func in literal_keys.items():
  178. not_name = "not_%s" % (name,)
  179. disallowed_values = self.filter_json.get(not_name, [])
  180. if any(map(match_func, disallowed_values)):
  181. return False
  182. allowed_values = self.filter_json.get(name, None)
  183. if allowed_values is not None:
  184. if not any(map(match_func, allowed_values)):
  185. return False
  186. return True
  187. def filter_rooms(self, room_ids):
  188. """Apply the 'rooms' filter to a given list of rooms.
  189. Args:
  190. room_ids (list): A list of room_ids.
  191. Returns:
  192. list: A list of room_ids that match the filter
  193. """
  194. room_ids = set(room_ids)
  195. disallowed_rooms = set(self.filter_json.get("not_rooms", []))
  196. room_ids -= disallowed_rooms
  197. allowed_rooms = self.filter_json.get("rooms", None)
  198. if allowed_rooms is not None:
  199. room_ids &= set(allowed_rooms)
  200. return room_ids
  201. def filter(self, events):
  202. return filter(self.check, events)
  203. def limit(self):
  204. return self.filter_json.get("limit", 10)
  205. def _matches_wildcard(actual_value, filter_value):
  206. if filter_value.endswith("*"):
  207. type_prefix = filter_value[:-1]
  208. return actual_value.startswith(type_prefix)
  209. else:
  210. return actual_value == filter_value
  211. DEFAULT_FILTER_COLLECTION = FilterCollection({})