test_push_rule_evaluator.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, Optional, Union
  15. import frozendict
  16. from twisted.test.proto_helpers import MemoryReactor
  17. import synapse.rest.admin
  18. from synapse.api.constants import EventTypes, HistoryVisibility, Membership
  19. from synapse.api.room_versions import RoomVersions
  20. from synapse.appservice import ApplicationService
  21. from synapse.events import FrozenEvent
  22. from synapse.push.bulk_push_rule_evaluator import _flatten_dict
  23. from synapse.push.httppusher import tweaks_for_actions
  24. from synapse.rest import admin
  25. from synapse.rest.client import login, register, room
  26. from synapse.server import HomeServer
  27. from synapse.storage.databases.main.appservice import _make_exclusive_regex
  28. from synapse.synapse_rust.push import PushRuleEvaluator
  29. from synapse.types import JsonDict, UserID
  30. from synapse.util import Clock
  31. from tests import unittest
  32. from tests.test_utils.event_injection import create_event, inject_member_event
  33. class PushRuleEvaluatorTestCase(unittest.TestCase):
  34. def _get_evaluator(
  35. self, content: JsonDict, related_events=None
  36. ) -> PushRuleEvaluator:
  37. event = FrozenEvent(
  38. {
  39. "event_id": "$event_id",
  40. "type": "m.room.history_visibility",
  41. "sender": "@user:test",
  42. "state_key": "",
  43. "room_id": "#room:test",
  44. "content": content,
  45. },
  46. RoomVersions.V1,
  47. )
  48. room_member_count = 0
  49. sender_power_level = 0
  50. power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
  51. return PushRuleEvaluator(
  52. _flatten_dict(event),
  53. room_member_count,
  54. sender_power_level,
  55. power_levels.get("notifications", {}),
  56. {} if related_events is None else related_events,
  57. True,
  58. )
  59. def test_display_name(self) -> None:
  60. """Check for a matching display name in the body of the event."""
  61. evaluator = self._get_evaluator({"body": "foo bar baz"})
  62. condition = {
  63. "kind": "contains_display_name",
  64. }
  65. # Blank names are skipped.
  66. self.assertFalse(evaluator.matches(condition, "@user:test", ""))
  67. # Check a display name that doesn't match.
  68. self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
  69. # Check a display name which matches.
  70. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  71. # A display name that matches, but not a full word does not result in a match.
  72. self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
  73. # A display name should not be interpreted as a regular expression.
  74. self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
  75. # A display name with spaces should work fine.
  76. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
  77. def _assert_matches(
  78. self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
  79. ) -> None:
  80. evaluator = self._get_evaluator(content)
  81. self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
  82. def _assert_not_matches(
  83. self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
  84. ) -> None:
  85. evaluator = self._get_evaluator(content)
  86. self.assertFalse(
  87. evaluator.matches(condition, "@user:test", "display_name"), msg
  88. )
  89. def test_event_match_body(self) -> None:
  90. """Check that event_match conditions on content.body work as expected"""
  91. # if the key is `content.body`, the pattern matches substrings.
  92. # non-wildcards should match
  93. condition = {
  94. "kind": "event_match",
  95. "key": "content.body",
  96. "pattern": "foobaz",
  97. }
  98. self._assert_matches(
  99. condition,
  100. {"body": "aaa FoobaZ zzz"},
  101. "patterns should match and be case-insensitive",
  102. )
  103. self._assert_not_matches(
  104. condition,
  105. {"body": "aa xFoobaZ yy"},
  106. "pattern should only match at word boundaries",
  107. )
  108. self._assert_not_matches(
  109. condition,
  110. {"body": "aa foobazx yy"},
  111. "pattern should only match at word boundaries",
  112. )
  113. # wildcards should match
  114. condition = {
  115. "kind": "event_match",
  116. "key": "content.body",
  117. "pattern": "f?o*baz",
  118. }
  119. self._assert_matches(
  120. condition,
  121. {"body": "aaa FoobarbaZ zzz"},
  122. "* should match string and pattern should be case-insensitive",
  123. )
  124. self._assert_matches(
  125. condition, {"body": "aa foobaz yy"}, "* should match 0 characters"
  126. )
  127. self._assert_not_matches(
  128. condition, {"body": "aa fobbaz yy"}, "? should not match 0 characters"
  129. )
  130. self._assert_not_matches(
  131. condition, {"body": "aa fiiobaz yy"}, "? should not match 2 characters"
  132. )
  133. self._assert_not_matches(
  134. condition,
  135. {"body": "aa xfooxbaz yy"},
  136. "pattern should only match at word boundaries",
  137. )
  138. self._assert_not_matches(
  139. condition,
  140. {"body": "aa fooxbazx yy"},
  141. "pattern should only match at word boundaries",
  142. )
  143. # test backslashes
  144. condition = {
  145. "kind": "event_match",
  146. "key": "content.body",
  147. "pattern": r"f\oobaz",
  148. }
  149. self._assert_matches(
  150. condition,
  151. {"body": r"F\oobaz"},
  152. "backslash should match itself",
  153. )
  154. condition = {
  155. "kind": "event_match",
  156. "key": "content.body",
  157. "pattern": r"f\?obaz",
  158. }
  159. self._assert_matches(
  160. condition,
  161. {"body": r"F\oobaz"},
  162. r"? after \ should match any character",
  163. )
  164. def test_event_match_non_body(self) -> None:
  165. """Check that event_match conditions on other keys work as expected"""
  166. # if the key is anything other than 'content.body', the pattern must match the
  167. # whole value.
  168. # non-wildcards should match
  169. condition = {
  170. "kind": "event_match",
  171. "key": "content.value",
  172. "pattern": "foobaz",
  173. }
  174. self._assert_matches(
  175. condition,
  176. {"value": "FoobaZ"},
  177. "patterns should match and be case-insensitive",
  178. )
  179. self._assert_not_matches(
  180. condition,
  181. {"value": "xFoobaZ"},
  182. "pattern should only match at the start/end of the value",
  183. )
  184. self._assert_not_matches(
  185. condition,
  186. {"value": "FoobaZz"},
  187. "pattern should only match at the start/end of the value",
  188. )
  189. # it should work on frozendicts too
  190. self._assert_matches(
  191. condition,
  192. frozendict.frozendict({"value": "FoobaZ"}),
  193. "patterns should match on frozendicts",
  194. )
  195. # wildcards should match
  196. condition = {
  197. "kind": "event_match",
  198. "key": "content.value",
  199. "pattern": "f?o*baz",
  200. }
  201. self._assert_matches(
  202. condition,
  203. {"value": "FoobarbaZ"},
  204. "* should match string and pattern should be case-insensitive",
  205. )
  206. self._assert_matches(
  207. condition, {"value": "foobaz"}, "* should match 0 characters"
  208. )
  209. self._assert_not_matches(
  210. condition, {"value": "fobbaz"}, "? should not match 0 characters"
  211. )
  212. self._assert_not_matches(
  213. condition, {"value": "fiiobaz"}, "? should not match 2 characters"
  214. )
  215. self._assert_not_matches(
  216. condition,
  217. {"value": "xfooxbaz"},
  218. "pattern should only match at the start/end of the value",
  219. )
  220. self._assert_not_matches(
  221. condition,
  222. {"value": "fooxbazx"},
  223. "pattern should only match at the start/end of the value",
  224. )
  225. self._assert_not_matches(
  226. condition,
  227. {"value": "x\nfooxbaz"},
  228. "pattern should not match after a newline",
  229. )
  230. self._assert_not_matches(
  231. condition,
  232. {"value": "fooxbaz\nx"},
  233. "pattern should not match before a newline",
  234. )
  235. def test_no_body(self) -> None:
  236. """Not having a body shouldn't break the evaluator."""
  237. evaluator = self._get_evaluator({})
  238. condition = {
  239. "kind": "contains_display_name",
  240. }
  241. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  242. def test_invalid_body(self) -> None:
  243. """A non-string body should not break the evaluator."""
  244. condition = {
  245. "kind": "contains_display_name",
  246. }
  247. for body in (1, True, {"foo": "bar"}):
  248. evaluator = self._get_evaluator({"body": body})
  249. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  250. def test_tweaks_for_actions(self) -> None:
  251. """
  252. This tests the behaviour of tweaks_for_actions.
  253. """
  254. actions = [
  255. {"set_tweak": "sound", "value": "default"},
  256. {"set_tweak": "highlight"},
  257. "notify",
  258. ]
  259. self.assertEqual(
  260. tweaks_for_actions(actions),
  261. {"sound": "default", "highlight": True},
  262. )
  263. def test_related_event_match(self):
  264. evaluator = self._get_evaluator(
  265. {
  266. "m.relates_to": {
  267. "event_id": "$parent_event_id",
  268. "key": "😀",
  269. "rel_type": "m.annotation",
  270. "m.in_reply_to": {
  271. "event_id": "$parent_event_id",
  272. },
  273. }
  274. },
  275. {
  276. "m.in_reply_to": {
  277. "event_id": "$parent_event_id",
  278. "type": "m.room.message",
  279. "sender": "@other_user:test",
  280. "room_id": "!room:test",
  281. "content.msgtype": "m.text",
  282. "content.body": "Original message",
  283. },
  284. "m.annotation": {
  285. "event_id": "$parent_event_id",
  286. "type": "m.room.message",
  287. "sender": "@other_user:test",
  288. "room_id": "!room:test",
  289. "content.msgtype": "m.text",
  290. "content.body": "Original message",
  291. },
  292. },
  293. )
  294. self.assertTrue(
  295. evaluator.matches(
  296. {
  297. "kind": "im.nheko.msc3664.related_event_match",
  298. "key": "sender",
  299. "rel_type": "m.in_reply_to",
  300. "pattern": "@other_user:test",
  301. },
  302. "@user:test",
  303. "display_name",
  304. )
  305. )
  306. self.assertFalse(
  307. evaluator.matches(
  308. {
  309. "kind": "im.nheko.msc3664.related_event_match",
  310. "key": "sender",
  311. "rel_type": "m.in_reply_to",
  312. "pattern": "@user:test",
  313. },
  314. "@other_user:test",
  315. "display_name",
  316. )
  317. )
  318. self.assertTrue(
  319. evaluator.matches(
  320. {
  321. "kind": "im.nheko.msc3664.related_event_match",
  322. "key": "sender",
  323. "rel_type": "m.annotation",
  324. "pattern": "@other_user:test",
  325. },
  326. "@other_user:test",
  327. "display_name",
  328. )
  329. )
  330. self.assertFalse(
  331. evaluator.matches(
  332. {
  333. "kind": "im.nheko.msc3664.related_event_match",
  334. "key": "sender",
  335. "rel_type": "m.in_reply_to",
  336. },
  337. "@user:test",
  338. "display_name",
  339. )
  340. )
  341. self.assertTrue(
  342. evaluator.matches(
  343. {
  344. "kind": "im.nheko.msc3664.related_event_match",
  345. "rel_type": "m.in_reply_to",
  346. },
  347. "@user:test",
  348. "display_name",
  349. )
  350. )
  351. self.assertFalse(
  352. evaluator.matches(
  353. {
  354. "kind": "im.nheko.msc3664.related_event_match",
  355. "rel_type": "m.replace",
  356. },
  357. "@other_user:test",
  358. "display_name",
  359. )
  360. )
  361. def test_related_event_match_with_fallback(self):
  362. evaluator = self._get_evaluator(
  363. {
  364. "m.relates_to": {
  365. "event_id": "$parent_event_id",
  366. "key": "😀",
  367. "rel_type": "m.thread",
  368. "is_falling_back": True,
  369. "m.in_reply_to": {
  370. "event_id": "$parent_event_id",
  371. },
  372. }
  373. },
  374. {
  375. "m.in_reply_to": {
  376. "event_id": "$parent_event_id",
  377. "type": "m.room.message",
  378. "sender": "@other_user:test",
  379. "room_id": "!room:test",
  380. "content.msgtype": "m.text",
  381. "content.body": "Original message",
  382. "im.vector.is_falling_back": "",
  383. },
  384. "m.thread": {
  385. "event_id": "$parent_event_id",
  386. "type": "m.room.message",
  387. "sender": "@other_user:test",
  388. "room_id": "!room:test",
  389. "content.msgtype": "m.text",
  390. "content.body": "Original message",
  391. },
  392. },
  393. )
  394. self.assertTrue(
  395. evaluator.matches(
  396. {
  397. "kind": "im.nheko.msc3664.related_event_match",
  398. "key": "sender",
  399. "rel_type": "m.in_reply_to",
  400. "pattern": "@other_user:test",
  401. "include_fallbacks": True,
  402. },
  403. "@user:test",
  404. "display_name",
  405. )
  406. )
  407. self.assertFalse(
  408. evaluator.matches(
  409. {
  410. "kind": "im.nheko.msc3664.related_event_match",
  411. "key": "sender",
  412. "rel_type": "m.in_reply_to",
  413. "pattern": "@other_user:test",
  414. "include_fallbacks": False,
  415. },
  416. "@user:test",
  417. "display_name",
  418. )
  419. )
  420. self.assertFalse(
  421. evaluator.matches(
  422. {
  423. "kind": "im.nheko.msc3664.related_event_match",
  424. "key": "sender",
  425. "rel_type": "m.in_reply_to",
  426. "pattern": "@other_user:test",
  427. },
  428. "@user:test",
  429. "display_name",
  430. )
  431. )
  432. def test_related_event_match_no_related_event(self):
  433. evaluator = self._get_evaluator(
  434. {"msgtype": "m.text", "body": "Message without related event"}
  435. )
  436. self.assertFalse(
  437. evaluator.matches(
  438. {
  439. "kind": "im.nheko.msc3664.related_event_match",
  440. "key": "sender",
  441. "rel_type": "m.in_reply_to",
  442. "pattern": "@other_user:test",
  443. },
  444. "@user:test",
  445. "display_name",
  446. )
  447. )
  448. self.assertFalse(
  449. evaluator.matches(
  450. {
  451. "kind": "im.nheko.msc3664.related_event_match",
  452. "key": "sender",
  453. "rel_type": "m.in_reply_to",
  454. },
  455. "@user:test",
  456. "display_name",
  457. )
  458. )
  459. self.assertFalse(
  460. evaluator.matches(
  461. {
  462. "kind": "im.nheko.msc3664.related_event_match",
  463. "rel_type": "m.in_reply_to",
  464. },
  465. "@user:test",
  466. "display_name",
  467. )
  468. )
  469. class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
  470. """Tests for the bulk push rule evaluator"""
  471. servlets = [
  472. synapse.rest.admin.register_servlets_for_client_rest_resource,
  473. login.register_servlets,
  474. register.register_servlets,
  475. room.register_servlets,
  476. ]
  477. def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
  478. # Define an application service so that we can register appservice users
  479. self._service_token = "some_token"
  480. self._service = ApplicationService(
  481. self._service_token,
  482. "as1",
  483. "@as.sender:test",
  484. namespaces={
  485. "users": [
  486. {"regex": "@_as_.*:test", "exclusive": True},
  487. {"regex": "@as.sender:test", "exclusive": True},
  488. ]
  489. },
  490. msc3202_transaction_extensions=True,
  491. )
  492. self.hs.get_datastores().main.services_cache = [self._service]
  493. self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
  494. [self._service]
  495. )
  496. self._as_user, _ = self.register_appservice_user(
  497. "_as_user", self._service_token
  498. )
  499. self.evaluator = self.hs.get_bulk_push_rule_evaluator()
  500. def test_ignore_appservice_users(self) -> None:
  501. "Test that we don't generate push for appservice users"
  502. user_id = self.register_user("user", "pass")
  503. token = self.login("user", "pass")
  504. room_id = self.helper.create_room_as(user_id, tok=token)
  505. self.get_success(
  506. inject_member_event(self.hs, room_id, self._as_user, Membership.JOIN)
  507. )
  508. event, context = self.get_success(
  509. create_event(
  510. self.hs,
  511. type=EventTypes.Message,
  512. room_id=room_id,
  513. sender=user_id,
  514. content={"body": "test", "msgtype": "m.text"},
  515. )
  516. )
  517. # Assert the returned push rules do not contain the app service user
  518. rules = self.get_success(self.evaluator._get_rules_for_event(event))
  519. self.assertTrue(self._as_user not in rules)
  520. # Assert that no push actions have been added to the staging table (the
  521. # sender should not be pushed for the event)
  522. users_with_push_actions = self.get_success(
  523. self.hs.get_datastores().main.db_pool.simple_select_onecol(
  524. table="event_push_actions_staging",
  525. keyvalues={"event_id": event.event_id},
  526. retcol="user_id",
  527. desc="test_ignore_appservice_users",
  528. )
  529. )
  530. self.assertEqual(len(users_with_push_actions), 0)
  531. class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase):
  532. servlets = [
  533. admin.register_servlets,
  534. login.register_servlets,
  535. room.register_servlets,
  536. ]
  537. def prepare(
  538. self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
  539. ) -> None:
  540. self.main_store = homeserver.get_datastores().main
  541. self.user_id1 = self.register_user("user1", "password")
  542. self.tok1 = self.login(self.user_id1, "password")
  543. self.user_id2 = self.register_user("user2", "password")
  544. self.tok2 = self.login(self.user_id2, "password")
  545. self.room_id = self.helper.create_room_as(tok=self.tok1)
  546. # We want to test history visibility works correctly.
  547. self.helper.send_state(
  548. self.room_id,
  549. EventTypes.RoomHistoryVisibility,
  550. {"history_visibility": HistoryVisibility.JOINED},
  551. tok=self.tok1,
  552. )
  553. def get_notif_count(self, user_id: str) -> int:
  554. return self.get_success(
  555. self.main_store.db_pool.simple_select_one_onecol(
  556. table="event_push_actions",
  557. keyvalues={"user_id": user_id},
  558. retcol="COALESCE(SUM(notif), 0)",
  559. desc="get_staging_notif_count",
  560. )
  561. )
  562. def test_plain_message(self) -> None:
  563. """Test that sending a normal message in a room will trigger a
  564. notification
  565. """
  566. # Have user2 join the room and cle
  567. self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
  568. # They start off with no notifications, but get them when messages are
  569. # sent.
  570. self.assertEqual(self.get_notif_count(self.user_id2), 0)
  571. user1 = UserID.from_string(self.user_id1)
  572. self.create_and_send_event(self.room_id, user1)
  573. self.assertEqual(self.get_notif_count(self.user_id2), 1)
  574. def test_delayed_message(self) -> None:
  575. """Test that a delayed message that was from before a user joined
  576. doesn't cause a notification for the joined user.
  577. """
  578. user1 = UserID.from_string(self.user_id1)
  579. # Send a message before user2 joins
  580. event_id1 = self.create_and_send_event(self.room_id, user1)
  581. # Have user2 join the room
  582. self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
  583. # They start off with no notifications
  584. self.assertEqual(self.get_notif_count(self.user_id2), 0)
  585. # Send another message that references the event before the join to
  586. # simulate a "delayed" event
  587. self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1])
  588. # user2 should not be notified about it, because they can't see it.
  589. self.assertEqual(self.get_notif_count(self.user_id2), 0)