test_push_rule_evaluator.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, Optional, Union
  15. import frozendict
  16. from twisted.test.proto_helpers import MemoryReactor
  17. import synapse.rest.admin
  18. from synapse.api.constants import EventTypes, HistoryVisibility, Membership
  19. from synapse.api.room_versions import RoomVersions
  20. from synapse.appservice import ApplicationService
  21. from synapse.events import FrozenEvent
  22. from synapse.push.bulk_push_rule_evaluator import _flatten_dict
  23. from synapse.push.httppusher import tweaks_for_actions
  24. from synapse.rest import admin
  25. from synapse.rest.client import login, register, room
  26. from synapse.server import HomeServer
  27. from synapse.storage.databases.main.appservice import _make_exclusive_regex
  28. from synapse.synapse_rust.push import PushRuleEvaluator
  29. from synapse.types import JsonDict, UserID
  30. from synapse.util import Clock
  31. from tests import unittest
  32. from tests.test_utils.event_injection import create_event, inject_member_event
  33. class PushRuleEvaluatorTestCase(unittest.TestCase):
  34. def _get_evaluator(
  35. self, content: JsonDict, related_events=None
  36. ) -> PushRuleEvaluator:
  37. event = FrozenEvent(
  38. {
  39. "event_id": "$event_id",
  40. "type": "m.room.history_visibility",
  41. "sender": "@user:test",
  42. "state_key": "",
  43. "room_id": "#room:test",
  44. "content": content,
  45. },
  46. RoomVersions.V1,
  47. )
  48. room_member_count = 0
  49. sender_power_level = 0
  50. power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
  51. return PushRuleEvaluator(
  52. _flatten_dict(event),
  53. room_member_count,
  54. sender_power_level,
  55. power_levels.get("notifications", {}),
  56. {} if related_events is None else related_events,
  57. True,
  58. event.room_version.msc3931_push_features,
  59. True,
  60. )
  61. def test_display_name(self) -> None:
  62. """Check for a matching display name in the body of the event."""
  63. evaluator = self._get_evaluator({"body": "foo bar baz"})
  64. condition = {
  65. "kind": "contains_display_name",
  66. }
  67. # Blank names are skipped.
  68. self.assertFalse(evaluator.matches(condition, "@user:test", ""))
  69. # Check a display name that doesn't match.
  70. self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
  71. # Check a display name which matches.
  72. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  73. # A display name that matches, but not a full word does not result in a match.
  74. self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
  75. # A display name should not be interpreted as a regular expression.
  76. self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
  77. # A display name with spaces should work fine.
  78. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
  79. def _assert_matches(
  80. self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
  81. ) -> None:
  82. evaluator = self._get_evaluator(content)
  83. self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
  84. def _assert_not_matches(
  85. self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
  86. ) -> None:
  87. evaluator = self._get_evaluator(content)
  88. self.assertFalse(
  89. evaluator.matches(condition, "@user:test", "display_name"), msg
  90. )
  91. def test_event_match_body(self) -> None:
  92. """Check that event_match conditions on content.body work as expected"""
  93. # if the key is `content.body`, the pattern matches substrings.
  94. # non-wildcards should match
  95. condition = {
  96. "kind": "event_match",
  97. "key": "content.body",
  98. "pattern": "foobaz",
  99. }
  100. self._assert_matches(
  101. condition,
  102. {"body": "aaa FoobaZ zzz"},
  103. "patterns should match and be case-insensitive",
  104. )
  105. self._assert_not_matches(
  106. condition,
  107. {"body": "aa xFoobaZ yy"},
  108. "pattern should only match at word boundaries",
  109. )
  110. self._assert_not_matches(
  111. condition,
  112. {"body": "aa foobazx yy"},
  113. "pattern should only match at word boundaries",
  114. )
  115. # wildcards should match
  116. condition = {
  117. "kind": "event_match",
  118. "key": "content.body",
  119. "pattern": "f?o*baz",
  120. }
  121. self._assert_matches(
  122. condition,
  123. {"body": "aaa FoobarbaZ zzz"},
  124. "* should match string and pattern should be case-insensitive",
  125. )
  126. self._assert_matches(
  127. condition, {"body": "aa foobaz yy"}, "* should match 0 characters"
  128. )
  129. self._assert_not_matches(
  130. condition, {"body": "aa fobbaz yy"}, "? should not match 0 characters"
  131. )
  132. self._assert_not_matches(
  133. condition, {"body": "aa fiiobaz yy"}, "? should not match 2 characters"
  134. )
  135. self._assert_not_matches(
  136. condition,
  137. {"body": "aa xfooxbaz yy"},
  138. "pattern should only match at word boundaries",
  139. )
  140. self._assert_not_matches(
  141. condition,
  142. {"body": "aa fooxbazx yy"},
  143. "pattern should only match at word boundaries",
  144. )
  145. # test backslashes
  146. condition = {
  147. "kind": "event_match",
  148. "key": "content.body",
  149. "pattern": r"f\oobaz",
  150. }
  151. self._assert_matches(
  152. condition,
  153. {"body": r"F\oobaz"},
  154. "backslash should match itself",
  155. )
  156. condition = {
  157. "kind": "event_match",
  158. "key": "content.body",
  159. "pattern": r"f\?obaz",
  160. }
  161. self._assert_matches(
  162. condition,
  163. {"body": r"F\oobaz"},
  164. r"? after \ should match any character",
  165. )
  166. def test_event_match_non_body(self) -> None:
  167. """Check that event_match conditions on other keys work as expected"""
  168. # if the key is anything other than 'content.body', the pattern must match the
  169. # whole value.
  170. # non-wildcards should match
  171. condition = {
  172. "kind": "event_match",
  173. "key": "content.value",
  174. "pattern": "foobaz",
  175. }
  176. self._assert_matches(
  177. condition,
  178. {"value": "FoobaZ"},
  179. "patterns should match and be case-insensitive",
  180. )
  181. self._assert_not_matches(
  182. condition,
  183. {"value": "xFoobaZ"},
  184. "pattern should only match at the start/end of the value",
  185. )
  186. self._assert_not_matches(
  187. condition,
  188. {"value": "FoobaZz"},
  189. "pattern should only match at the start/end of the value",
  190. )
  191. # it should work on frozendicts too
  192. self._assert_matches(
  193. condition,
  194. frozendict.frozendict({"value": "FoobaZ"}),
  195. "patterns should match on frozendicts",
  196. )
  197. # wildcards should match
  198. condition = {
  199. "kind": "event_match",
  200. "key": "content.value",
  201. "pattern": "f?o*baz",
  202. }
  203. self._assert_matches(
  204. condition,
  205. {"value": "FoobarbaZ"},
  206. "* should match string and pattern should be case-insensitive",
  207. )
  208. self._assert_matches(
  209. condition, {"value": "foobaz"}, "* should match 0 characters"
  210. )
  211. self._assert_not_matches(
  212. condition, {"value": "fobbaz"}, "? should not match 0 characters"
  213. )
  214. self._assert_not_matches(
  215. condition, {"value": "fiiobaz"}, "? should not match 2 characters"
  216. )
  217. self._assert_not_matches(
  218. condition,
  219. {"value": "xfooxbaz"},
  220. "pattern should only match at the start/end of the value",
  221. )
  222. self._assert_not_matches(
  223. condition,
  224. {"value": "fooxbazx"},
  225. "pattern should only match at the start/end of the value",
  226. )
  227. self._assert_not_matches(
  228. condition,
  229. {"value": "x\nfooxbaz"},
  230. "pattern should not match after a newline",
  231. )
  232. self._assert_not_matches(
  233. condition,
  234. {"value": "fooxbaz\nx"},
  235. "pattern should not match before a newline",
  236. )
  237. def test_no_body(self) -> None:
  238. """Not having a body shouldn't break the evaluator."""
  239. evaluator = self._get_evaluator({})
  240. condition = {
  241. "kind": "contains_display_name",
  242. }
  243. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  244. def test_invalid_body(self) -> None:
  245. """A non-string body should not break the evaluator."""
  246. condition = {
  247. "kind": "contains_display_name",
  248. }
  249. for body in (1, True, {"foo": "bar"}):
  250. evaluator = self._get_evaluator({"body": body})
  251. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  252. def test_tweaks_for_actions(self) -> None:
  253. """
  254. This tests the behaviour of tweaks_for_actions.
  255. """
  256. actions = [
  257. {"set_tweak": "sound", "value": "default"},
  258. {"set_tweak": "highlight"},
  259. "notify",
  260. ]
  261. self.assertEqual(
  262. tweaks_for_actions(actions),
  263. {"sound": "default", "highlight": True},
  264. )
  265. def test_related_event_match(self):
  266. evaluator = self._get_evaluator(
  267. {
  268. "m.relates_to": {
  269. "event_id": "$parent_event_id",
  270. "key": "😀",
  271. "rel_type": "m.annotation",
  272. "m.in_reply_to": {
  273. "event_id": "$parent_event_id",
  274. },
  275. }
  276. },
  277. {
  278. "m.in_reply_to": {
  279. "event_id": "$parent_event_id",
  280. "type": "m.room.message",
  281. "sender": "@other_user:test",
  282. "room_id": "!room:test",
  283. "content.msgtype": "m.text",
  284. "content.body": "Original message",
  285. },
  286. "m.annotation": {
  287. "event_id": "$parent_event_id",
  288. "type": "m.room.message",
  289. "sender": "@other_user:test",
  290. "room_id": "!room:test",
  291. "content.msgtype": "m.text",
  292. "content.body": "Original message",
  293. },
  294. },
  295. )
  296. self.assertTrue(
  297. evaluator.matches(
  298. {
  299. "kind": "im.nheko.msc3664.related_event_match",
  300. "key": "sender",
  301. "rel_type": "m.in_reply_to",
  302. "pattern": "@other_user:test",
  303. },
  304. "@user:test",
  305. "display_name",
  306. )
  307. )
  308. self.assertFalse(
  309. evaluator.matches(
  310. {
  311. "kind": "im.nheko.msc3664.related_event_match",
  312. "key": "sender",
  313. "rel_type": "m.in_reply_to",
  314. "pattern": "@user:test",
  315. },
  316. "@other_user:test",
  317. "display_name",
  318. )
  319. )
  320. self.assertTrue(
  321. evaluator.matches(
  322. {
  323. "kind": "im.nheko.msc3664.related_event_match",
  324. "key": "sender",
  325. "rel_type": "m.annotation",
  326. "pattern": "@other_user:test",
  327. },
  328. "@other_user:test",
  329. "display_name",
  330. )
  331. )
  332. self.assertFalse(
  333. evaluator.matches(
  334. {
  335. "kind": "im.nheko.msc3664.related_event_match",
  336. "key": "sender",
  337. "rel_type": "m.in_reply_to",
  338. },
  339. "@user:test",
  340. "display_name",
  341. )
  342. )
  343. self.assertTrue(
  344. evaluator.matches(
  345. {
  346. "kind": "im.nheko.msc3664.related_event_match",
  347. "rel_type": "m.in_reply_to",
  348. },
  349. "@user:test",
  350. "display_name",
  351. )
  352. )
  353. self.assertFalse(
  354. evaluator.matches(
  355. {
  356. "kind": "im.nheko.msc3664.related_event_match",
  357. "rel_type": "m.replace",
  358. },
  359. "@other_user:test",
  360. "display_name",
  361. )
  362. )
  363. def test_related_event_match_with_fallback(self):
  364. evaluator = self._get_evaluator(
  365. {
  366. "m.relates_to": {
  367. "event_id": "$parent_event_id",
  368. "key": "😀",
  369. "rel_type": "m.thread",
  370. "is_falling_back": True,
  371. "m.in_reply_to": {
  372. "event_id": "$parent_event_id",
  373. },
  374. }
  375. },
  376. {
  377. "m.in_reply_to": {
  378. "event_id": "$parent_event_id",
  379. "type": "m.room.message",
  380. "sender": "@other_user:test",
  381. "room_id": "!room:test",
  382. "content.msgtype": "m.text",
  383. "content.body": "Original message",
  384. "im.vector.is_falling_back": "",
  385. },
  386. "m.thread": {
  387. "event_id": "$parent_event_id",
  388. "type": "m.room.message",
  389. "sender": "@other_user:test",
  390. "room_id": "!room:test",
  391. "content.msgtype": "m.text",
  392. "content.body": "Original message",
  393. },
  394. },
  395. )
  396. self.assertTrue(
  397. evaluator.matches(
  398. {
  399. "kind": "im.nheko.msc3664.related_event_match",
  400. "key": "sender",
  401. "rel_type": "m.in_reply_to",
  402. "pattern": "@other_user:test",
  403. "include_fallbacks": True,
  404. },
  405. "@user:test",
  406. "display_name",
  407. )
  408. )
  409. self.assertFalse(
  410. evaluator.matches(
  411. {
  412. "kind": "im.nheko.msc3664.related_event_match",
  413. "key": "sender",
  414. "rel_type": "m.in_reply_to",
  415. "pattern": "@other_user:test",
  416. "include_fallbacks": False,
  417. },
  418. "@user:test",
  419. "display_name",
  420. )
  421. )
  422. self.assertFalse(
  423. evaluator.matches(
  424. {
  425. "kind": "im.nheko.msc3664.related_event_match",
  426. "key": "sender",
  427. "rel_type": "m.in_reply_to",
  428. "pattern": "@other_user:test",
  429. },
  430. "@user:test",
  431. "display_name",
  432. )
  433. )
  434. def test_related_event_match_no_related_event(self):
  435. evaluator = self._get_evaluator(
  436. {"msgtype": "m.text", "body": "Message without related event"}
  437. )
  438. self.assertFalse(
  439. evaluator.matches(
  440. {
  441. "kind": "im.nheko.msc3664.related_event_match",
  442. "key": "sender",
  443. "rel_type": "m.in_reply_to",
  444. "pattern": "@other_user:test",
  445. },
  446. "@user:test",
  447. "display_name",
  448. )
  449. )
  450. self.assertFalse(
  451. evaluator.matches(
  452. {
  453. "kind": "im.nheko.msc3664.related_event_match",
  454. "key": "sender",
  455. "rel_type": "m.in_reply_to",
  456. },
  457. "@user:test",
  458. "display_name",
  459. )
  460. )
  461. self.assertFalse(
  462. evaluator.matches(
  463. {
  464. "kind": "im.nheko.msc3664.related_event_match",
  465. "rel_type": "m.in_reply_to",
  466. },
  467. "@user:test",
  468. "display_name",
  469. )
  470. )
  471. class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
  472. """Tests for the bulk push rule evaluator"""
  473. servlets = [
  474. synapse.rest.admin.register_servlets_for_client_rest_resource,
  475. login.register_servlets,
  476. register.register_servlets,
  477. room.register_servlets,
  478. ]
  479. def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
  480. # Define an application service so that we can register appservice users
  481. self._service_token = "some_token"
  482. self._service = ApplicationService(
  483. self._service_token,
  484. "as1",
  485. "@as.sender:test",
  486. namespaces={
  487. "users": [
  488. {"regex": "@_as_.*:test", "exclusive": True},
  489. {"regex": "@as.sender:test", "exclusive": True},
  490. ]
  491. },
  492. msc3202_transaction_extensions=True,
  493. )
  494. self.hs.get_datastores().main.services_cache = [self._service]
  495. self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
  496. [self._service]
  497. )
  498. self._as_user, _ = self.register_appservice_user(
  499. "_as_user", self._service_token
  500. )
  501. self.evaluator = self.hs.get_bulk_push_rule_evaluator()
  502. def test_ignore_appservice_users(self) -> None:
  503. "Test that we don't generate push for appservice users"
  504. user_id = self.register_user("user", "pass")
  505. token = self.login("user", "pass")
  506. room_id = self.helper.create_room_as(user_id, tok=token)
  507. self.get_success(
  508. inject_member_event(self.hs, room_id, self._as_user, Membership.JOIN)
  509. )
  510. event, context = self.get_success(
  511. create_event(
  512. self.hs,
  513. type=EventTypes.Message,
  514. room_id=room_id,
  515. sender=user_id,
  516. content={"body": "test", "msgtype": "m.text"},
  517. )
  518. )
  519. # Assert the returned push rules do not contain the app service user
  520. rules = self.get_success(self.evaluator._get_rules_for_event(event))
  521. self.assertTrue(self._as_user not in rules)
  522. # Assert that no push actions have been added to the staging table (the
  523. # sender should not be pushed for the event)
  524. users_with_push_actions = self.get_success(
  525. self.hs.get_datastores().main.db_pool.simple_select_onecol(
  526. table="event_push_actions_staging",
  527. keyvalues={"event_id": event.event_id},
  528. retcol="user_id",
  529. desc="test_ignore_appservice_users",
  530. )
  531. )
  532. self.assertEqual(len(users_with_push_actions), 0)
  533. class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase):
  534. servlets = [
  535. admin.register_servlets,
  536. login.register_servlets,
  537. room.register_servlets,
  538. ]
  539. def prepare(
  540. self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
  541. ) -> None:
  542. self.main_store = homeserver.get_datastores().main
  543. self.user_id1 = self.register_user("user1", "password")
  544. self.tok1 = self.login(self.user_id1, "password")
  545. self.user_id2 = self.register_user("user2", "password")
  546. self.tok2 = self.login(self.user_id2, "password")
  547. self.room_id = self.helper.create_room_as(tok=self.tok1)
  548. # We want to test history visibility works correctly.
  549. self.helper.send_state(
  550. self.room_id,
  551. EventTypes.RoomHistoryVisibility,
  552. {"history_visibility": HistoryVisibility.JOINED},
  553. tok=self.tok1,
  554. )
  555. def get_notif_count(self, user_id: str) -> int:
  556. return self.get_success(
  557. self.main_store.db_pool.simple_select_one_onecol(
  558. table="event_push_actions",
  559. keyvalues={"user_id": user_id},
  560. retcol="COALESCE(SUM(notif), 0)",
  561. desc="get_staging_notif_count",
  562. )
  563. )
  564. def test_plain_message(self) -> None:
  565. """Test that sending a normal message in a room will trigger a
  566. notification
  567. """
  568. # Have user2 join the room and cle
  569. self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
  570. # They start off with no notifications, but get them when messages are
  571. # sent.
  572. self.assertEqual(self.get_notif_count(self.user_id2), 0)
  573. user1 = UserID.from_string(self.user_id1)
  574. self.create_and_send_event(self.room_id, user1)
  575. self.assertEqual(self.get_notif_count(self.user_id2), 1)
  576. def test_delayed_message(self) -> None:
  577. """Test that a delayed message that was from before a user joined
  578. doesn't cause a notification for the joined user.
  579. """
  580. user1 = UserID.from_string(self.user_id1)
  581. # Send a message before user2 joins
  582. event_id1 = self.create_and_send_event(self.room_id, user1)
  583. # Have user2 join the room
  584. self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
  585. # They start off with no notifications
  586. self.assertEqual(self.get_notif_count(self.user_id2), 0)
  587. # Send another message that references the event before the join to
  588. # simulate a "delayed" event
  589. self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1])
  590. # user2 should not be notified about it, because they can't see it.
  591. self.assertEqual(self.get_notif_count(self.user_id2), 0)