push_rule_evaluator.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. # Copyright 2017 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import re
  17. from typing import Any, Dict, List, Mapping, Optional, Pattern, Tuple, Union
  18. from matrix_common.regex import glob_to_regex, to_word_pattern
  19. from synapse.events import EventBase
  20. from synapse.types import UserID
  21. from synapse.util.caches.lrucache import LruCache
  22. logger = logging.getLogger(__name__)
  23. GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]")
  24. IS_GLOB = re.compile(r"[\?\*\[\]]")
  25. INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
  26. def _room_member_count(
  27. ev: EventBase, condition: Dict[str, Any], room_member_count: int
  28. ) -> bool:
  29. return _test_ineq_condition(condition, room_member_count)
  30. def _sender_notification_permission(
  31. ev: EventBase,
  32. condition: Dict[str, Any],
  33. sender_power_level: int,
  34. power_levels: Dict[str, Union[int, Dict[str, int]]],
  35. ) -> bool:
  36. notif_level_key = condition.get("key")
  37. if notif_level_key is None:
  38. return False
  39. notif_levels = power_levels.get("notifications", {})
  40. assert isinstance(notif_levels, dict)
  41. room_notif_level = notif_levels.get(notif_level_key, 50)
  42. return sender_power_level >= room_notif_level
  43. def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
  44. if "is" not in condition:
  45. return False
  46. m = INEQUALITY_EXPR.match(condition["is"])
  47. if not m:
  48. return False
  49. ineq = m.group(1)
  50. rhs = m.group(2)
  51. if not rhs.isdigit():
  52. return False
  53. rhs_int = int(rhs)
  54. if ineq == "" or ineq == "==":
  55. return number == rhs_int
  56. elif ineq == "<":
  57. return number < rhs_int
  58. elif ineq == ">":
  59. return number > rhs_int
  60. elif ineq == ">=":
  61. return number >= rhs_int
  62. elif ineq == "<=":
  63. return number <= rhs_int
  64. else:
  65. return False
  66. def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
  67. """
  68. Converts a list of actions into a `tweaks` dict (which can then be passed to
  69. the push gateway).
  70. This function ignores all actions other than `set_tweak` actions, and treats
  71. absent `value`s as `True`, which agrees with the only spec-defined treatment
  72. of absent `value`s (namely, for `highlight` tweaks).
  73. Args:
  74. actions: list of actions
  75. e.g. [
  76. {"set_tweak": "a", "value": "AAA"},
  77. {"set_tweak": "b", "value": "BBB"},
  78. {"set_tweak": "highlight"},
  79. "notify"
  80. ]
  81. Returns:
  82. dictionary of tweaks for those actions
  83. e.g. {"a": "AAA", "b": "BBB", "highlight": True}
  84. """
  85. tweaks = {}
  86. for a in actions:
  87. if not isinstance(a, dict):
  88. continue
  89. if "set_tweak" in a:
  90. # value is allowed to be absent in which case the value assumed
  91. # should be True.
  92. tweaks[a["set_tweak"]] = a.get("value", True)
  93. return tweaks
  94. class PushRuleEvaluatorForEvent:
  95. def __init__(
  96. self,
  97. event: EventBase,
  98. room_member_count: int,
  99. sender_power_level: int,
  100. power_levels: Dict[str, Union[int, Dict[str, int]]],
  101. ):
  102. self._event = event
  103. self._room_member_count = room_member_count
  104. self._sender_power_level = sender_power_level
  105. self._power_levels = power_levels
  106. # Maps strings of e.g. 'content.body' -> event["content"]["body"]
  107. self._value_cache = _flatten_dict(event)
  108. # Maps cache keys to final values.
  109. self._condition_cache: Dict[str, bool] = {}
  110. def check_conditions(
  111. self, conditions: List[dict], uid: str, display_name: Optional[str]
  112. ) -> bool:
  113. """
  114. Returns true if a user's conditions/user ID/display name match the event.
  115. Args:
  116. conditions: The user's conditions to match.
  117. uid: The user's MXID.
  118. display_name: The display name.
  119. Returns:
  120. True if all conditions match the event, False otherwise.
  121. """
  122. for cond in conditions:
  123. _cache_key = cond.get("_cache_key", None)
  124. if _cache_key:
  125. res = self._condition_cache.get(_cache_key, None)
  126. if res is False:
  127. return False
  128. elif res is True:
  129. continue
  130. res = self.matches(cond, uid, display_name)
  131. if _cache_key:
  132. self._condition_cache[_cache_key] = bool(res)
  133. if not res:
  134. return False
  135. return True
  136. def matches(
  137. self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
  138. ) -> bool:
  139. """
  140. Returns true if a user's condition/user ID/display name match the event.
  141. Args:
  142. condition: The user's condition to match.
  143. uid: The user's MXID.
  144. display_name: The display name, or None if there is not one.
  145. Returns:
  146. True if the condition matches the event, False otherwise.
  147. """
  148. if condition["kind"] == "event_match":
  149. return self._event_match(condition, user_id)
  150. elif condition["kind"] == "contains_display_name":
  151. return self._contains_display_name(display_name)
  152. elif condition["kind"] == "room_member_count":
  153. return _room_member_count(self._event, condition, self._room_member_count)
  154. elif condition["kind"] == "sender_notification_permission":
  155. return _sender_notification_permission(
  156. self._event, condition, self._sender_power_level, self._power_levels
  157. )
  158. else:
  159. return True
  160. def _event_match(self, condition: dict, user_id: str) -> bool:
  161. """
  162. Check an "event_match" push rule condition.
  163. Args:
  164. condition: The "event_match" push rule condition to match.
  165. user_id: The user's MXID.
  166. Returns:
  167. True if the condition matches the event, False otherwise.
  168. """
  169. pattern = condition.get("pattern", None)
  170. if not pattern:
  171. pattern_type = condition.get("pattern_type", None)
  172. if pattern_type == "user_id":
  173. pattern = user_id
  174. elif pattern_type == "user_localpart":
  175. pattern = UserID.from_string(user_id).localpart
  176. if not pattern:
  177. logger.warning("event_match condition with no pattern")
  178. return False
  179. # XXX: optimisation: cache our pattern regexps
  180. if condition["key"] == "content.body":
  181. body = self._event.content.get("body", None)
  182. if not body or not isinstance(body, str):
  183. return False
  184. return _glob_matches(pattern, body, word_boundary=True)
  185. else:
  186. haystack = self._value_cache.get(condition["key"], None)
  187. if haystack is None:
  188. return False
  189. return _glob_matches(pattern, haystack)
  190. def _contains_display_name(self, display_name: Optional[str]) -> bool:
  191. """
  192. Check an "event_match" push rule condition.
  193. Args:
  194. display_name: The display name, or None if there is not one.
  195. Returns:
  196. True if the display name is found in the event body, False otherwise.
  197. """
  198. if not display_name:
  199. return False
  200. body = self._event.content.get("body", None)
  201. if not body or not isinstance(body, str):
  202. return False
  203. # Similar to _glob_matches, but do not treat display_name as a glob.
  204. r = regex_cache.get((display_name, False, True), None)
  205. if not r:
  206. r1 = re.escape(display_name)
  207. r1 = to_word_pattern(r1)
  208. r = re.compile(r1, flags=re.IGNORECASE)
  209. regex_cache[(display_name, False, True)] = r
  210. return bool(r.search(body))
  211. # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
  212. regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
  213. 50000, "regex_push_cache"
  214. )
  215. def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
  216. """Tests if value matches glob.
  217. Args:
  218. glob
  219. value: String to test against glob.
  220. word_boundary: Whether to match against word boundaries or entire
  221. string. Defaults to False.
  222. """
  223. try:
  224. r = regex_cache.get((glob, True, word_boundary), None)
  225. if not r:
  226. r = glob_to_regex(glob, word_boundary=word_boundary)
  227. regex_cache[(glob, True, word_boundary)] = r
  228. return bool(r.search(value))
  229. except re.error:
  230. logger.warning("Failed to parse glob to regex: %r", glob)
  231. return False
  232. def _flatten_dict(
  233. d: Union[EventBase, Mapping[str, Any]],
  234. prefix: Optional[List[str]] = None,
  235. result: Optional[Dict[str, str]] = None,
  236. ) -> Dict[str, str]:
  237. if prefix is None:
  238. prefix = []
  239. if result is None:
  240. result = {}
  241. for key, value in d.items():
  242. if isinstance(value, str):
  243. result[".".join(prefix + [key])] = value.lower()
  244. elif isinstance(value, Mapping):
  245. _flatten_dict(value, prefix=(prefix + [key]), result=result)
  246. return result