test_push_rule_evaluator.py 22 KB


  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Dict, List, Optional, Union, cast
  15. import frozendict
  16. from twisted.test.proto_helpers import MemoryReactor
  17. import synapse.rest.admin
  18. from synapse.api.constants import EventTypes, HistoryVisibility, Membership
  19. from synapse.api.room_versions import RoomVersions
  20. from synapse.appservice import ApplicationService
  21. from synapse.events import FrozenEvent
  22. from synapse.push.bulk_push_rule_evaluator import _flatten_dict
  23. from synapse.push.httppusher import tweaks_for_actions
  24. from synapse.rest import admin
  25. from synapse.rest.client import login, register, room
  26. from synapse.server import HomeServer
  27. from synapse.storage.databases.main.appservice import _make_exclusive_regex
  28. from synapse.synapse_rust.push import PushRuleEvaluator
  29. from synapse.types import JsonDict, JsonMapping, UserID
  30. from synapse.util import Clock
  31. from tests import unittest
  32. from tests.test_utils.event_injection import create_event, inject_member_event
  33. class PushRuleEvaluatorTestCase(unittest.TestCase):
  34. def _get_evaluator(
  35. self, content: JsonMapping, related_events: Optional[JsonDict] = None
  36. ) -> PushRuleEvaluator:
  37. event = FrozenEvent(
  38. {
  39. "event_id": "$event_id",
  40. "type": "m.room.history_visibility",
  41. "sender": "@user:test",
  42. "state_key": "",
  43. "room_id": "#room:test",
  44. "content": content,
  45. },
  46. RoomVersions.V1,
  47. )
  48. room_member_count = 0
  49. sender_power_level = 0
  50. power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
  51. return PushRuleEvaluator(
  52. _flatten_dict(event),
  53. room_member_count,
  54. sender_power_level,
  55. cast(Dict[str, int], power_levels.get("notifications", {})),
  56. {} if related_events is None else related_events,
  57. True,
  58. event.room_version.msc3931_push_features,
  59. True,
  60. )
  61. def test_display_name(self) -> None:
  62. """Check for a matching display name in the body of the event."""
  63. evaluator = self._get_evaluator({"body": "foo bar baz"})
  64. condition = {"kind": "contains_display_name"}
  65. # Blank names are skipped.
  66. self.assertFalse(evaluator.matches(condition, "@user:test", ""))
  67. # Check a display name that doesn't match.
  68. self.assertFalse(evaluator.matches(condition, "@user:test", "not found"))
  69. # Check a display name which matches.
  70. self.assertTrue(evaluator.matches(condition, "@user:test", "foo"))
  71. # A display name that matches, but not a full word does not result in a match.
  72. self.assertFalse(evaluator.matches(condition, "@user:test", "ba"))
  73. # A display name should not be interpreted as a regular expression.
  74. self.assertFalse(evaluator.matches(condition, "@user:test", "ba[rz]"))
  75. # A display name with spaces should work fine.
  76. self.assertTrue(evaluator.matches(condition, "@user:test", "foo bar"))
  77. def _assert_matches(
  78. self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
  79. ) -> None:
  80. evaluator = self._get_evaluator(content)
  81. self.assertTrue(evaluator.matches(condition, "@user:test", "display_name"), msg)
  82. def _assert_not_matches(
  83. self, condition: JsonDict, content: JsonDict, msg: Optional[str] = None
  84. ) -> None:
  85. evaluator = self._get_evaluator(content)
  86. self.assertFalse(
  87. evaluator.matches(condition, "@user:test", "display_name"), msg
  88. )
  89. def test_event_match_body(self) -> None:
  90. """Check that event_match conditions on content.body work as expected"""
  91. # if the key is `content.body`, the pattern matches substrings.
  92. # non-wildcards should match
  93. condition = {
  94. "kind": "event_match",
  95. "key": "content.body",
  96. "pattern": "foobaz",
  97. }
  98. self._assert_matches(
  99. condition,
  100. {"body": "aaa FoobaZ zzz"},
  101. "patterns should match and be case-insensitive",
  102. )
  103. self._assert_not_matches(
  104. condition,
  105. {"body": "aa xFoobaZ yy"},
  106. "pattern should only match at word boundaries",
  107. )
  108. self._assert_not_matches(
  109. condition,
  110. {"body": "aa foobazx yy"},
  111. "pattern should only match at word boundaries",
  112. )
  113. # wildcards should match
  114. condition = {
  115. "kind": "event_match",
  116. "key": "content.body",
  117. "pattern": "f?o*baz",
  118. }
  119. self._assert_matches(
  120. condition,
  121. {"body": "aaa FoobarbaZ zzz"},
  122. "* should match string and pattern should be case-insensitive",
  123. )
  124. self._assert_matches(
  125. condition, {"body": "aa foobaz yy"}, "* should match 0 characters"
  126. )
  127. self._assert_not_matches(
  128. condition, {"body": "aa fobbaz yy"}, "? should not match 0 characters"
  129. )
  130. self._assert_not_matches(
  131. condition, {"body": "aa fiiobaz yy"}, "? should not match 2 characters"
  132. )
  133. self._assert_not_matches(
  134. condition,
  135. {"body": "aa xfooxbaz yy"},
  136. "pattern should only match at word boundaries",
  137. )
  138. self._assert_not_matches(
  139. condition,
  140. {"body": "aa fooxbazx yy"},
  141. "pattern should only match at word boundaries",
  142. )
  143. # test backslashes
  144. condition = {
  145. "kind": "event_match",
  146. "key": "content.body",
  147. "pattern": r"f\oobaz",
  148. }
  149. self._assert_matches(
  150. condition,
  151. {"body": r"F\oobaz"},
  152. "backslash should match itself",
  153. )
  154. condition = {
  155. "kind": "event_match",
  156. "key": "content.body",
  157. "pattern": r"f\?obaz",
  158. }
  159. self._assert_matches(
  160. condition,
  161. {"body": r"F\oobaz"},
  162. r"? after \ should match any character",
  163. )
  164. def test_event_match_non_body(self) -> None:
  165. """Check that event_match conditions on other keys work as expected"""
  166. # if the key is anything other than 'content.body', the pattern must match the
  167. # whole value.
  168. # non-wildcards should match
  169. condition = {
  170. "kind": "event_match",
  171. "key": "content.value",
  172. "pattern": "foobaz",
  173. }
  174. self._assert_matches(
  175. condition,
  176. {"value": "FoobaZ"},
  177. "patterns should match and be case-insensitive",
  178. )
  179. self._assert_not_matches(
  180. condition,
  181. {"value": "xFoobaZ"},
  182. "pattern should only match at the start/end of the value",
  183. )
  184. self._assert_not_matches(
  185. condition,
  186. {"value": "FoobaZz"},
  187. "pattern should only match at the start/end of the value",
  188. )
  189. # it should work on frozendicts too
  190. self._assert_matches(
  191. condition,
  192. frozendict.frozendict({"value": "FoobaZ"}),
  193. "patterns should match on frozendicts",
  194. )
  195. # wildcards should match
  196. condition = {
  197. "kind": "event_match",
  198. "key": "content.value",
  199. "pattern": "f?o*baz",
  200. }
  201. self._assert_matches(
  202. condition,
  203. {"value": "FoobarbaZ"},
  204. "* should match string and pattern should be case-insensitive",
  205. )
  206. self._assert_matches(
  207. condition, {"value": "foobaz"}, "* should match 0 characters"
  208. )
  209. self._assert_not_matches(
  210. condition, {"value": "fobbaz"}, "? should not match 0 characters"
  211. )
  212. self._assert_not_matches(
  213. condition, {"value": "fiiobaz"}, "? should not match 2 characters"
  214. )
  215. self._assert_not_matches(
  216. condition,
  217. {"value": "xfooxbaz"},
  218. "pattern should only match at the start/end of the value",
  219. )
  220. self._assert_not_matches(
  221. condition,
  222. {"value": "fooxbazx"},
  223. "pattern should only match at the start/end of the value",
  224. )
  225. self._assert_not_matches(
  226. condition,
  227. {"value": "x\nfooxbaz"},
  228. "pattern should not match after a newline",
  229. )
  230. self._assert_not_matches(
  231. condition,
  232. {"value": "fooxbaz\nx"},
  233. "pattern should not match before a newline",
  234. )
  235. def test_no_body(self) -> None:
  236. """Not having a body shouldn't break the evaluator."""
  237. evaluator = self._get_evaluator({})
  238. condition = {
  239. "kind": "contains_display_name",
  240. }
  241. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  242. def test_invalid_body(self) -> None:
  243. """A non-string body should not break the evaluator."""
  244. condition = {
  245. "kind": "contains_display_name",
  246. }
  247. for body in (1, True, {"foo": "bar"}):
  248. evaluator = self._get_evaluator({"body": body})
  249. self.assertFalse(evaluator.matches(condition, "@user:test", "foo"))
  250. def test_tweaks_for_actions(self) -> None:
  251. """
  252. This tests the behaviour of tweaks_for_actions.
  253. """
  254. actions: List[Union[Dict[str, str], str]] = [
  255. {"set_tweak": "sound", "value": "default"},
  256. {"set_tweak": "highlight"},
  257. "notify",
  258. ]
  259. self.assertEqual(
  260. tweaks_for_actions(actions),
  261. {"sound": "default", "highlight": True},
  262. )
  263. def test_related_event_match(self) -> None:
  264. evaluator = self._get_evaluator(
  265. {
  266. "m.relates_to": {
  267. "event_id": "$parent_event_id",
  268. "key": "😀",
  269. "rel_type": "m.annotation",
  270. "m.in_reply_to": {
  271. "event_id": "$parent_event_id",
  272. },
  273. }
  274. },
  275. {
  276. "m.in_reply_to": {
  277. "event_id": "$parent_event_id",
  278. "type": "m.room.message",
  279. "sender": "@other_user:test",
  280. "room_id": "!room:test",
  281. "content.msgtype": "m.text",
  282. "content.body": "Original message",
  283. },
  284. "m.annotation": {
  285. "event_id": "$parent_event_id",
  286. "type": "m.room.message",
  287. "sender": "@other_user:test",
  288. "room_id": "!room:test",
  289. "content.msgtype": "m.text",
  290. "content.body": "Original message",
  291. },
  292. },
  293. )
  294. self.assertTrue(
  295. evaluator.matches(
  296. {
  297. "kind": "im.nheko.msc3664.related_event_match",
  298. "key": "sender",
  299. "rel_type": "m.in_reply_to",
  300. "pattern": "@other_user:test",
  301. },
  302. "@user:test",
  303. "display_name",
  304. )
  305. )
  306. self.assertFalse(
  307. evaluator.matches(
  308. {
  309. "kind": "im.nheko.msc3664.related_event_match",
  310. "key": "sender",
  311. "rel_type": "m.in_reply_to",
  312. "pattern": "@user:test",
  313. },
  314. "@other_user:test",
  315. "display_name",
  316. )
  317. )
  318. self.assertTrue(
  319. evaluator.matches(
  320. {
  321. "kind": "im.nheko.msc3664.related_event_match",
  322. "key": "sender",
  323. "rel_type": "m.annotation",
  324. "pattern": "@other_user:test",
  325. },
  326. "@other_user:test",
  327. "display_name",
  328. )
  329. )
  330. self.assertFalse(
  331. evaluator.matches(
  332. {
  333. "kind": "im.nheko.msc3664.related_event_match",
  334. "key": "sender",
  335. "rel_type": "m.in_reply_to",
  336. },
  337. "@user:test",
  338. "display_name",
  339. )
  340. )
  341. self.assertTrue(
  342. evaluator.matches(
  343. {
  344. "kind": "im.nheko.msc3664.related_event_match",
  345. "rel_type": "m.in_reply_to",
  346. },
  347. "@user:test",
  348. "display_name",
  349. )
  350. )
  351. self.assertFalse(
  352. evaluator.matches(
  353. {
  354. "kind": "im.nheko.msc3664.related_event_match",
  355. "rel_type": "m.replace",
  356. },
  357. "@other_user:test",
  358. "display_name",
  359. )
  360. )
  361. def test_related_event_match_with_fallback(self) -> None:
  362. evaluator = self._get_evaluator(
  363. {
  364. "m.relates_to": {
  365. "event_id": "$parent_event_id",
  366. "key": "😀",
  367. "rel_type": "m.thread",
  368. "is_falling_back": True,
  369. "m.in_reply_to": {
  370. "event_id": "$parent_event_id",
  371. },
  372. }
  373. },
  374. {
  375. "m.in_reply_to": {
  376. "event_id": "$parent_event_id",
  377. "type": "m.room.message",
  378. "sender": "@other_user:test",
  379. "room_id": "!room:test",
  380. "content.msgtype": "m.text",
  381. "content.body": "Original message",
  382. "im.vector.is_falling_back": "",
  383. },
  384. "m.thread": {
  385. "event_id": "$parent_event_id",
  386. "type": "m.room.message",
  387. "sender": "@other_user:test",
  388. "room_id": "!room:test",
  389. "content.msgtype": "m.text",
  390. "content.body": "Original message",
  391. },
  392. },
  393. )
  394. self.assertTrue(
  395. evaluator.matches(
  396. {
  397. "kind": "im.nheko.msc3664.related_event_match",
  398. "key": "sender",
  399. "rel_type": "m.in_reply_to",
  400. "pattern": "@other_user:test",
  401. "include_fallbacks": True,
  402. },
  403. "@user:test",
  404. "display_name",
  405. )
  406. )
  407. self.assertFalse(
  408. evaluator.matches(
  409. {
  410. "kind": "im.nheko.msc3664.related_event_match",
  411. "key": "sender",
  412. "rel_type": "m.in_reply_to",
  413. "pattern": "@other_user:test",
  414. "include_fallbacks": False,
  415. },
  416. "@user:test",
  417. "display_name",
  418. )
  419. )
  420. self.assertFalse(
  421. evaluator.matches(
  422. {
  423. "kind": "im.nheko.msc3664.related_event_match",
  424. "key": "sender",
  425. "rel_type": "m.in_reply_to",
  426. "pattern": "@other_user:test",
  427. },
  428. "@user:test",
  429. "display_name",
  430. )
  431. )
  432. def test_related_event_match_no_related_event(self) -> None:
  433. evaluator = self._get_evaluator(
  434. {"msgtype": "m.text", "body": "Message without related event"}
  435. )
  436. self.assertFalse(
  437. evaluator.matches(
  438. {
  439. "kind": "im.nheko.msc3664.related_event_match",
  440. "key": "sender",
  441. "rel_type": "m.in_reply_to",
  442. "pattern": "@other_user:test",
  443. },
  444. "@user:test",
  445. "display_name",
  446. )
  447. )
  448. self.assertFalse(
  449. evaluator.matches(
  450. {
  451. "kind": "im.nheko.msc3664.related_event_match",
  452. "key": "sender",
  453. "rel_type": "m.in_reply_to",
  454. },
  455. "@user:test",
  456. "display_name",
  457. )
  458. )
  459. self.assertFalse(
  460. evaluator.matches(
  461. {
  462. "kind": "im.nheko.msc3664.related_event_match",
  463. "rel_type": "m.in_reply_to",
  464. },
  465. "@user:test",
  466. "display_name",
  467. )
  468. )
  469. class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
  470. """Tests for the bulk push rule evaluator"""
  471. servlets = [
  472. synapse.rest.admin.register_servlets_for_client_rest_resource,
  473. login.register_servlets,
  474. register.register_servlets,
  475. room.register_servlets,
  476. ]
  477. def prepare(
  478. self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
  479. ) -> None:
  480. # Define an application service so that we can register appservice users
  481. self._service_token = "some_token"
  482. self._service = ApplicationService(
  483. self._service_token,
  484. "as1",
  485. "@as.sender:test",
  486. namespaces={
  487. "users": [
  488. {"regex": "@_as_.*:test", "exclusive": True},
  489. {"regex": "@as.sender:test", "exclusive": True},
  490. ]
  491. },
  492. msc3202_transaction_extensions=True,
  493. )
  494. self.hs.get_datastores().main.services_cache = [self._service]
  495. self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
  496. [self._service]
  497. )
  498. self._as_user, _ = self.register_appservice_user(
  499. "_as_user", self._service_token
  500. )
  501. self.evaluator = self.hs.get_bulk_push_rule_evaluator()
  502. def test_ignore_appservice_users(self) -> None:
  503. "Test that we don't generate push for appservice users"
  504. user_id = self.register_user("user", "pass")
  505. token = self.login("user", "pass")
  506. room_id = self.helper.create_room_as(user_id, tok=token)
  507. self.get_success(
  508. inject_member_event(self.hs, room_id, self._as_user, Membership.JOIN)
  509. )
  510. event, context = self.get_success(
  511. create_event(
  512. self.hs,
  513. type=EventTypes.Message,
  514. room_id=room_id,
  515. sender=user_id,
  516. content={"body": "test", "msgtype": "m.text"},
  517. )
  518. )
  519. # Assert the returned push rules do not contain the app service user
  520. rules = self.get_success(self.evaluator._get_rules_for_event(event))
  521. self.assertTrue(self._as_user not in rules)
  522. # Assert that no push actions have been added to the staging table (the
  523. # sender should not be pushed for the event)
  524. users_with_push_actions = self.get_success(
  525. self.hs.get_datastores().main.db_pool.simple_select_onecol(
  526. table="event_push_actions_staging",
  527. keyvalues={"event_id": event.event_id},
  528. retcol="user_id",
  529. desc="test_ignore_appservice_users",
  530. )
  531. )
  532. self.assertEqual(len(users_with_push_actions), 0)
  533. class BulkPushRuleEvaluatorTestCase(unittest.HomeserverTestCase):
  534. servlets = [
  535. admin.register_servlets,
  536. login.register_servlets,
  537. room.register_servlets,
  538. ]
  539. def prepare(
  540. self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
  541. ) -> None:
  542. self.main_store = homeserver.get_datastores().main
  543. self.user_id1 = self.register_user("user1", "password")
  544. self.tok1 = self.login(self.user_id1, "password")
  545. self.user_id2 = self.register_user("user2", "password")
  546. self.tok2 = self.login(self.user_id2, "password")
  547. self.room_id = self.helper.create_room_as(tok=self.tok1)
  548. # We want to test history visibility works correctly.
  549. self.helper.send_state(
  550. self.room_id,
  551. EventTypes.RoomHistoryVisibility,
  552. {"history_visibility": HistoryVisibility.JOINED},
  553. tok=self.tok1,
  554. )
  555. def get_notif_count(self, user_id: str) -> int:
  556. return self.get_success(
  557. self.main_store.db_pool.simple_select_one_onecol(
  558. table="event_push_actions",
  559. keyvalues={"user_id": user_id},
  560. retcol="COALESCE(SUM(notif), 0)",
  561. desc="get_staging_notif_count",
  562. )
  563. )
  564. def test_plain_message(self) -> None:
  565. """Test that sending a normal message in a room will trigger a
  566. notification
  567. """
  568. # Have user2 join the room and cle
  569. self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
  570. # They start off with no notifications, but get them when messages are
  571. # sent.
  572. self.assertEqual(self.get_notif_count(self.user_id2), 0)
  573. user1 = UserID.from_string(self.user_id1)
  574. self.create_and_send_event(self.room_id, user1)
  575. self.assertEqual(self.get_notif_count(self.user_id2), 1)
  576. def test_delayed_message(self) -> None:
  577. """Test that a delayed message that was from before a user joined
  578. doesn't cause a notification for the joined user.
  579. """
  580. user1 = UserID.from_string(self.user_id1)
  581. # Send a message before user2 joins
  582. event_id1 = self.create_and_send_event(self.room_id, user1)
  583. # Have user2 join the room
  584. self.helper.join(self.room_id, self.user_id2, tok=self.tok2)
  585. # They start off with no notifications
  586. self.assertEqual(self.get_notif_count(self.user_id2), 0)
  587. # Send another message that references the event before the join to
  588. # simulate a "delayed" event
  589. self.create_and_send_event(self.room_id, user1, prev_event_ids=[event_id1])
  590. # user2 should not be notified about it, because they can't see it.
  591. self.assertEqual(self.get_notif_count(self.user_id2), 0)