push_rule_evaluator.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. # Copyright 2017 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. import re
  18. from typing import Pattern
  19. from synapse.events import EventBase
  20. from synapse.types import UserID
  21. from synapse.util.caches import register_cache
  22. from synapse.util.caches.lrucache import LruCache
  23. logger = logging.getLogger(__name__)
  24. GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]")
  25. IS_GLOB = re.compile(r"[\?\*\[\]]")
  26. INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
  27. def _room_member_count(ev, condition, room_member_count):
  28. return _test_ineq_condition(condition, room_member_count)
  29. def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
  30. notif_level_key = condition.get("key")
  31. if notif_level_key is None:
  32. return False
  33. notif_levels = power_levels.get("notifications", {})
  34. room_notif_level = notif_levels.get(notif_level_key, 50)
  35. return sender_power_level >= room_notif_level
  36. def _test_ineq_condition(condition, number):
  37. if "is" not in condition:
  38. return False
  39. m = INEQUALITY_EXPR.match(condition["is"])
  40. if not m:
  41. return False
  42. ineq = m.group(1)
  43. rhs = m.group(2)
  44. if not rhs.isdigit():
  45. return False
  46. rhs_int = int(rhs)
  47. if ineq == "" or ineq == "==":
  48. return number == rhs_int
  49. elif ineq == "<":
  50. return number < rhs_int
  51. elif ineq == ">":
  52. return number > rhs_int
  53. elif ineq == ">=":
  54. return number >= rhs_int
  55. elif ineq == "<=":
  56. return number <= rhs_int
  57. else:
  58. return False
  59. def tweaks_for_actions(actions):
  60. tweaks = {}
  61. for a in actions:
  62. if not isinstance(a, dict):
  63. continue
  64. if "set_tweak" in a and "value" in a:
  65. tweaks[a["set_tweak"]] = a["value"]
  66. return tweaks
  67. class PushRuleEvaluatorForEvent(object):
  68. def __init__(
  69. self,
  70. event: EventBase,
  71. room_member_count: int,
  72. sender_power_level: int,
  73. power_levels: dict,
  74. ):
  75. self._event = event
  76. self._room_member_count = room_member_count
  77. self._sender_power_level = sender_power_level
  78. self._power_levels = power_levels
  79. # Maps strings of e.g. 'content.body' -> event["content"]["body"]
  80. self._value_cache = _flatten_dict(event)
  81. def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
  82. if condition["kind"] == "event_match":
  83. return self._event_match(condition, user_id)
  84. elif condition["kind"] == "contains_display_name":
  85. return self._contains_display_name(display_name)
  86. elif condition["kind"] == "room_member_count":
  87. return _room_member_count(self._event, condition, self._room_member_count)
  88. elif condition["kind"] == "sender_notification_permission":
  89. return _sender_notification_permission(
  90. self._event, condition, self._sender_power_level, self._power_levels
  91. )
  92. else:
  93. return True
  94. def _event_match(self, condition: dict, user_id: str) -> bool:
  95. pattern = condition.get("pattern", None)
  96. if not pattern:
  97. pattern_type = condition.get("pattern_type", None)
  98. if pattern_type == "user_id":
  99. pattern = user_id
  100. elif pattern_type == "user_localpart":
  101. pattern = UserID.from_string(user_id).localpart
  102. if not pattern:
  103. logger.warning("event_match condition with no pattern")
  104. return False
  105. # XXX: optimisation: cache our pattern regexps
  106. if condition["key"] == "content.body":
  107. body = self._event.content.get("body", None)
  108. if not body or not isinstance(body, str):
  109. return False
  110. return _glob_matches(pattern, body, word_boundary=True)
  111. else:
  112. haystack = self._get_value(condition["key"])
  113. if haystack is None:
  114. return False
  115. return _glob_matches(pattern, haystack)
  116. def _contains_display_name(self, display_name: str) -> bool:
  117. if not display_name:
  118. return False
  119. body = self._event.content.get("body", None)
  120. if not body or not isinstance(body, str):
  121. return False
  122. # Similar to _glob_matches, but do not treat display_name as a glob.
  123. r = regex_cache.get((display_name, False, True), None)
  124. if not r:
  125. r = re.escape(display_name)
  126. r = _re_word_boundary(r)
  127. r = re.compile(r, flags=re.IGNORECASE)
  128. regex_cache[(display_name, False, True)] = r
  129. return r.search(body)
  130. def _get_value(self, dotted_key: str) -> str:
  131. return self._value_cache.get(dotted_key, None)
  132. # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
  133. regex_cache = LruCache(50000)
  134. register_cache("cache", "regex_push_cache", regex_cache)
  135. def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
  136. """Tests if value matches glob.
  137. Args:
  138. glob
  139. value: String to test against glob.
  140. word_boundary: Whether to match against word boundaries or entire
  141. string. Defaults to False.
  142. """
  143. try:
  144. r = regex_cache.get((glob, True, word_boundary), None)
  145. if not r:
  146. r = _glob_to_re(glob, word_boundary)
  147. regex_cache[(glob, True, word_boundary)] = r
  148. return r.search(value)
  149. except re.error:
  150. logger.warning("Failed to parse glob to regex: %r", glob)
  151. return False
  152. def _glob_to_re(glob: str, word_boundary: bool) -> Pattern:
  153. """Generates regex for a given glob.
  154. Args:
  155. glob
  156. word_boundary: Whether to match against word boundaries or entire string.
  157. """
  158. if IS_GLOB.search(glob):
  159. r = re.escape(glob)
  160. r = r.replace(r"\*", ".*?")
  161. r = r.replace(r"\?", ".")
  162. # handle [abc], [a-z] and [!a-z] style ranges.
  163. r = GLOB_REGEX.sub(
  164. lambda x: (
  165. "[%s%s]" % (x.group(1) and "^" or "", x.group(2).replace(r"\\\-", "-"))
  166. ),
  167. r,
  168. )
  169. if word_boundary:
  170. r = _re_word_boundary(r)
  171. return re.compile(r, flags=re.IGNORECASE)
  172. else:
  173. r = "^" + r + "$"
  174. return re.compile(r, flags=re.IGNORECASE)
  175. elif word_boundary:
  176. r = re.escape(glob)
  177. r = _re_word_boundary(r)
  178. return re.compile(r, flags=re.IGNORECASE)
  179. else:
  180. r = "^" + re.escape(glob) + "$"
  181. return re.compile(r, flags=re.IGNORECASE)
  182. def _re_word_boundary(r: str) -> str:
  183. """
  184. Adds word boundary characters to the start and end of an
  185. expression to require that the match occur as a whole word,
  186. but do so respecting the fact that strings starting or ending
  187. with non-word characters will change word boundaries.
  188. """
  189. # we can't use \b as it chokes on unicode. however \W seems to be okay
  190. # as shorthand for [^0-9A-Za-z_].
  191. return r"(^|\W)%s(\W|$)" % (r,)
  192. def _flatten_dict(d, prefix=[], result=None):
  193. if result is None:
  194. result = {}
  195. for key, value in d.items():
  196. if isinstance(value, str):
  197. result[".".join(prefix + [key])] = value.lower()
  198. elif hasattr(value, "items"):
  199. _flatten_dict(value, prefix=(prefix + [key]), result=result)
  200. return result