push_rule.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from ._base import SQLBaseStore
  16. from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
  17. from synapse.push.baserules import list_with_base_rules
  18. from twisted.internet import defer
  19. import logging
  20. import simplejson as json
  21. logger = logging.getLogger(__name__)
  22. def _load_rules(rawrules, enabled_map):
  23. ruleslist = []
  24. for rawrule in rawrules:
  25. rule = dict(rawrule)
  26. rule["conditions"] = json.loads(rawrule["conditions"])
  27. rule["actions"] = json.loads(rawrule["actions"])
  28. ruleslist.append(rule)
  29. # We're going to be mutating this a lot, so do a deep copy
  30. rules = list(list_with_base_rules(ruleslist))
  31. for i, rule in enumerate(rules):
  32. rule_id = rule['rule_id']
  33. if rule_id in enabled_map:
  34. if rule.get('enabled', True) != bool(enabled_map[rule_id]):
  35. # Rules are cached across users.
  36. rule = dict(rule)
  37. rule['enabled'] = bool(enabled_map[rule_id])
  38. rules[i] = rule
  39. return rules
  40. class PushRuleStore(SQLBaseStore):
  41. @cachedInlineCallbacks()
  42. def get_push_rules_for_user(self, user_id):
  43. rows = yield self._simple_select_list(
  44. table="push_rules",
  45. keyvalues={
  46. "user_name": user_id,
  47. },
  48. retcols=(
  49. "user_name", "rule_id", "priority_class", "priority",
  50. "conditions", "actions",
  51. ),
  52. desc="get_push_rules_enabled_for_user",
  53. )
  54. rows.sort(
  55. key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
  56. )
  57. enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
  58. rules = _load_rules(rows, enabled_map)
  59. defer.returnValue(rules)
  60. @cachedInlineCallbacks()
  61. def get_push_rules_enabled_for_user(self, user_id):
  62. results = yield self._simple_select_list(
  63. table="push_rules_enable",
  64. keyvalues={
  65. 'user_name': user_id
  66. },
  67. retcols=(
  68. "user_name", "rule_id", "enabled",
  69. ),
  70. desc="get_push_rules_enabled_for_user",
  71. )
  72. defer.returnValue({
  73. r['rule_id']: False if r['enabled'] == 0 else True for r in results
  74. })
  75. @cachedList(cached_method_name="get_push_rules_for_user",
  76. list_name="user_ids", num_args=1, inlineCallbacks=True)
  77. def bulk_get_push_rules(self, user_ids):
  78. if not user_ids:
  79. defer.returnValue({})
  80. results = {
  81. user_id: []
  82. for user_id in user_ids
  83. }
  84. rows = yield self._simple_select_many_batch(
  85. table="push_rules",
  86. column="user_name",
  87. iterable=user_ids,
  88. retcols=("*",),
  89. desc="bulk_get_push_rules",
  90. )
  91. rows.sort(
  92. key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
  93. )
  94. for row in rows:
  95. results.setdefault(row['user_name'], []).append(row)
  96. enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
  97. for user_id, rules in results.items():
  98. results[user_id] = _load_rules(
  99. rules, enabled_map_by_user.get(user_id, {})
  100. )
  101. defer.returnValue(results)
  102. def bulk_get_push_rules_for_room(self, event, context):
  103. state_group = context.state_group
  104. if not state_group:
  105. # If state_group is None it means it has yet to be assigned a
  106. # state group, i.e. we need to make sure that calls with a state_group
  107. # of None don't hit previous cached calls with a None state_group.
  108. # To do this we set the state_group to a new object as object() != object()
  109. state_group = object()
  110. return self._bulk_get_push_rules_for_room(
  111. event.room_id, state_group, context.current_state_ids, event=event
  112. )
  113. @cachedInlineCallbacks(num_args=2, cache_context=True)
  114. def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
  115. cache_context, event=None):
  116. # We don't use `state_group`, its there so that we can cache based
  117. # on it. However, its important that its never None, since two current_state's
  118. # with a state_group of None are likely to be different.
  119. # See bulk_get_push_rules_for_room for how we work around this.
  120. assert state_group is not None
  121. # We also will want to generate notifs for other people in the room so
  122. # their unread countss are correct in the event stream, but to avoid
  123. # generating them for bot / AS users etc, we only do so for people who've
  124. # sent a read receipt into the room.
  125. users_in_room = yield self._get_joined_users_from_context(
  126. room_id, state_group, current_state_ids,
  127. on_invalidate=cache_context.invalidate,
  128. event=event,
  129. )
  130. # We ignore app service users for now. This is so that we don't fill
  131. # up the `get_if_users_have_pushers` cache with AS entries that we
  132. # know don't have pushers, nor even read receipts.
  133. local_users_in_room = set(
  134. u for u in users_in_room
  135. if self.hs.is_mine_id(u)
  136. and not self.get_if_app_services_interested_in_user(u)
  137. )
  138. # users in the room who have pushers need to get push rules run because
  139. # that's how their pushers work
  140. if_users_with_pushers = yield self.get_if_users_have_pushers(
  141. local_users_in_room,
  142. on_invalidate=cache_context.invalidate,
  143. )
  144. user_ids = set(
  145. uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
  146. )
  147. users_with_receipts = yield self.get_users_with_read_receipts_in_room(
  148. room_id, on_invalidate=cache_context.invalidate,
  149. )
  150. # any users with pushers must be ours: they have pushers
  151. for uid in users_with_receipts:
  152. if uid in local_users_in_room:
  153. user_ids.add(uid)
  154. rules_by_user = yield self.bulk_get_push_rules(
  155. user_ids, on_invalidate=cache_context.invalidate,
  156. )
  157. rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
  158. defer.returnValue(rules_by_user)
  159. @cachedList(cached_method_name="get_push_rules_enabled_for_user",
  160. list_name="user_ids", num_args=1, inlineCallbacks=True)
  161. def bulk_get_push_rules_enabled(self, user_ids):
  162. if not user_ids:
  163. defer.returnValue({})
  164. results = {
  165. user_id: {}
  166. for user_id in user_ids
  167. }
  168. rows = yield self._simple_select_many_batch(
  169. table="push_rules_enable",
  170. column="user_name",
  171. iterable=user_ids,
  172. retcols=("user_name", "rule_id", "enabled",),
  173. desc="bulk_get_push_rules_enabled",
  174. )
  175. for row in rows:
  176. enabled = bool(row['enabled'])
  177. results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
  178. defer.returnValue(results)
  179. @defer.inlineCallbacks
  180. def add_push_rule(
  181. self, user_id, rule_id, priority_class, conditions, actions,
  182. before=None, after=None
  183. ):
  184. conditions_json = json.dumps(conditions)
  185. actions_json = json.dumps(actions)
  186. with self._push_rules_stream_id_gen.get_next() as ids:
  187. stream_id, event_stream_ordering = ids
  188. if before or after:
  189. yield self.runInteraction(
  190. "_add_push_rule_relative_txn",
  191. self._add_push_rule_relative_txn,
  192. stream_id, event_stream_ordering, user_id, rule_id, priority_class,
  193. conditions_json, actions_json, before, after,
  194. )
  195. else:
  196. yield self.runInteraction(
  197. "_add_push_rule_highest_priority_txn",
  198. self._add_push_rule_highest_priority_txn,
  199. stream_id, event_stream_ordering, user_id, rule_id, priority_class,
  200. conditions_json, actions_json,
  201. )
  202. def _add_push_rule_relative_txn(
  203. self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
  204. conditions_json, actions_json, before, after
  205. ):
  206. # Lock the table since otherwise we'll have annoying races between the
  207. # SELECT here and the UPSERT below.
  208. self.database_engine.lock_table(txn, "push_rules")
  209. relative_to_rule = before or after
  210. res = self._simple_select_one_txn(
  211. txn,
  212. table="push_rules",
  213. keyvalues={
  214. "user_name": user_id,
  215. "rule_id": relative_to_rule,
  216. },
  217. retcols=["priority_class", "priority"],
  218. allow_none=True,
  219. )
  220. if not res:
  221. raise RuleNotFoundException(
  222. "before/after rule not found: %s" % (relative_to_rule,)
  223. )
  224. base_priority_class = res["priority_class"]
  225. base_rule_priority = res["priority"]
  226. if base_priority_class != priority_class:
  227. raise InconsistentRuleException(
  228. "Given priority class does not match class of relative rule"
  229. )
  230. if before:
  231. # Higher priority rules are executed first, So adding a rule before
  232. # a rule means giving it a higher priority than that rule.
  233. new_rule_priority = base_rule_priority + 1
  234. else:
  235. # We increment the priority of the existing rules to make space for
  236. # the new rule. Therefore if we want this rule to appear after
  237. # an existing rule we give it the priority of the existing rule,
  238. # and then increment the priority of the existing rule.
  239. new_rule_priority = base_rule_priority
  240. sql = (
  241. "UPDATE push_rules SET priority = priority + 1"
  242. " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
  243. )
  244. txn.execute(sql, (user_id, priority_class, new_rule_priority))
  245. self._upsert_push_rule_txn(
  246. txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
  247. new_rule_priority, conditions_json, actions_json,
  248. )
  249. def _add_push_rule_highest_priority_txn(
  250. self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
  251. conditions_json, actions_json
  252. ):
  253. # Lock the table since otherwise we'll have annoying races between the
  254. # SELECT here and the UPSERT below.
  255. self.database_engine.lock_table(txn, "push_rules")
  256. # find the highest priority rule in that class
  257. sql = (
  258. "SELECT COUNT(*), MAX(priority) FROM push_rules"
  259. " WHERE user_name = ? and priority_class = ?"
  260. )
  261. txn.execute(sql, (user_id, priority_class))
  262. res = txn.fetchall()
  263. (how_many, highest_prio) = res[0]
  264. new_prio = 0
  265. if how_many > 0:
  266. new_prio = highest_prio + 1
  267. self._upsert_push_rule_txn(
  268. txn,
  269. stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio,
  270. conditions_json, actions_json,
  271. )
  272. def _upsert_push_rule_txn(
  273. self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
  274. priority, conditions_json, actions_json, update_stream=True
  275. ):
  276. """Specialised version of _simple_upsert_txn that picks a push_rule_id
  277. using the _push_rule_id_gen if it needs to insert the rule. It assumes
  278. that the "push_rules" table is locked"""
  279. sql = (
  280. "UPDATE push_rules"
  281. " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
  282. " WHERE user_name = ? AND rule_id = ?"
  283. )
  284. txn.execute(sql, (
  285. priority_class, priority, conditions_json, actions_json,
  286. user_id, rule_id,
  287. ))
  288. if txn.rowcount == 0:
  289. # We didn't update a row with the given rule_id so insert one
  290. push_rule_id = self._push_rule_id_gen.get_next()
  291. self._simple_insert_txn(
  292. txn,
  293. table="push_rules",
  294. values={
  295. "id": push_rule_id,
  296. "user_name": user_id,
  297. "rule_id": rule_id,
  298. "priority_class": priority_class,
  299. "priority": priority,
  300. "conditions": conditions_json,
  301. "actions": actions_json,
  302. },
  303. )
  304. if update_stream:
  305. self._insert_push_rules_update_txn(
  306. txn, stream_id, event_stream_ordering, user_id, rule_id,
  307. op="ADD",
  308. data={
  309. "priority_class": priority_class,
  310. "priority": priority,
  311. "conditions": conditions_json,
  312. "actions": actions_json,
  313. }
  314. )
  315. @defer.inlineCallbacks
  316. def delete_push_rule(self, user_id, rule_id):
  317. """
  318. Delete a push rule. Args specify the row to be deleted and can be
  319. any of the columns in the push_rule table, but below are the
  320. standard ones
  321. Args:
  322. user_id (str): The matrix ID of the push rule owner
  323. rule_id (str): The rule_id of the rule to be deleted
  324. """
  325. def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
  326. self._simple_delete_one_txn(
  327. txn,
  328. "push_rules",
  329. {'user_name': user_id, 'rule_id': rule_id},
  330. )
  331. self._insert_push_rules_update_txn(
  332. txn, stream_id, event_stream_ordering, user_id, rule_id,
  333. op="DELETE"
  334. )
  335. with self._push_rules_stream_id_gen.get_next() as ids:
  336. stream_id, event_stream_ordering = ids
  337. yield self.runInteraction(
  338. "delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering
  339. )
  340. @defer.inlineCallbacks
  341. def set_push_rule_enabled(self, user_id, rule_id, enabled):
  342. with self._push_rules_stream_id_gen.get_next() as ids:
  343. stream_id, event_stream_ordering = ids
  344. yield self.runInteraction(
  345. "_set_push_rule_enabled_txn",
  346. self._set_push_rule_enabled_txn,
  347. stream_id, event_stream_ordering, user_id, rule_id, enabled
  348. )
  349. def _set_push_rule_enabled_txn(
  350. self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
  351. ):
  352. new_id = self._push_rules_enable_id_gen.get_next()
  353. self._simple_upsert_txn(
  354. txn,
  355. "push_rules_enable",
  356. {'user_name': user_id, 'rule_id': rule_id},
  357. {'enabled': 1 if enabled else 0},
  358. {'id': new_id},
  359. )
  360. self._insert_push_rules_update_txn(
  361. txn, stream_id, event_stream_ordering, user_id, rule_id,
  362. op="ENABLE" if enabled else "DISABLE"
  363. )
  364. @defer.inlineCallbacks
  365. def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
  366. actions_json = json.dumps(actions)
  367. def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
  368. if is_default_rule:
  369. # Add a dummy rule to the rules table with the user specified
  370. # actions.
  371. priority_class = -1
  372. priority = 1
  373. self._upsert_push_rule_txn(
  374. txn, stream_id, event_stream_ordering, user_id, rule_id,
  375. priority_class, priority, "[]", actions_json,
  376. update_stream=False
  377. )
  378. else:
  379. self._simple_update_one_txn(
  380. txn,
  381. "push_rules",
  382. {'user_name': user_id, 'rule_id': rule_id},
  383. {'actions': actions_json},
  384. )
  385. self._insert_push_rules_update_txn(
  386. txn, stream_id, event_stream_ordering, user_id, rule_id,
  387. op="ACTIONS", data={"actions": actions_json}
  388. )
  389. with self._push_rules_stream_id_gen.get_next() as ids:
  390. stream_id, event_stream_ordering = ids
  391. yield self.runInteraction(
  392. "set_push_rule_actions", set_push_rule_actions_txn,
  393. stream_id, event_stream_ordering
  394. )
  395. def _insert_push_rules_update_txn(
  396. self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
  397. ):
  398. values = {
  399. "stream_id": stream_id,
  400. "event_stream_ordering": event_stream_ordering,
  401. "user_id": user_id,
  402. "rule_id": rule_id,
  403. "op": op,
  404. }
  405. if data is not None:
  406. values.update(data)
  407. self._simple_insert_txn(txn, "push_rules_stream", values=values)
  408. txn.call_after(
  409. self.get_push_rules_for_user.invalidate, (user_id,)
  410. )
  411. txn.call_after(
  412. self.get_push_rules_enabled_for_user.invalidate, (user_id,)
  413. )
  414. txn.call_after(
  415. self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
  416. )
  417. def get_all_push_rule_updates(self, last_id, current_id, limit):
  418. """Get all the push rules changes that have happend on the server"""
  419. if last_id == current_id:
  420. return defer.succeed([])
  421. def get_all_push_rule_updates_txn(txn):
  422. sql = (
  423. "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
  424. " op, priority_class, priority, conditions, actions"
  425. " FROM push_rules_stream"
  426. " WHERE ? < stream_id AND stream_id <= ?"
  427. " ORDER BY stream_id ASC LIMIT ?"
  428. )
  429. txn.execute(sql, (last_id, current_id, limit))
  430. return txn.fetchall()
  431. return self.runInteraction(
  432. "get_all_push_rule_updates", get_all_push_rule_updates_txn
  433. )
  434. def get_push_rules_stream_token(self):
  435. """Get the position of the push rules stream.
  436. Returns a pair of a stream id for the push_rules stream and the
  437. room stream ordering it corresponds to."""
  438. return self._push_rules_stream_id_gen.get_current_token()
  439. def have_push_rules_changed_for_user(self, user_id, last_id):
  440. if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
  441. return defer.succeed(False)
  442. else:
  443. def have_push_rules_changed_txn(txn):
  444. sql = (
  445. "SELECT COUNT(stream_id) FROM push_rules_stream"
  446. " WHERE user_id = ? AND ? < stream_id"
  447. )
  448. txn.execute(sql, (user_id, last_id))
  449. count, = txn.fetchone()
  450. return bool(count)
  451. return self.runInteraction(
  452. "have_push_rules_changed", have_push_rules_changed_txn
  453. )
  454. class RuleNotFoundException(Exception):
  455. pass
  456. class InconsistentRuleException(Exception):
  457. pass