test_events.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # Copyright 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import Any, Callable, Iterable, List, Optional, Tuple
  16. from canonicaljson import encode_canonical_json
  17. from parameterized import parameterized
  18. from twisted.test.proto_helpers import MemoryReactor
  19. from synapse.api.constants import ReceiptTypes
  20. from synapse.api.room_versions import RoomVersions
  21. from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
  22. from synapse.events.snapshot import EventContext
  23. from synapse.handlers.room import RoomEventSource
  24. from synapse.server import HomeServer
  25. from synapse.storage.databases.main.event_push_actions import (
  26. NotifCounts,
  27. RoomNotifCounts,
  28. )
  29. from synapse.storage.databases.main.events_worker import EventsWorkerStore
  30. from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
  31. from synapse.types import PersistedEventPosition
  32. from synapse.util import Clock
  33. from tests.server import FakeTransport
  34. from ._base import BaseWorkerStoreTestCase
  35. USER_ID = "@feeling:test"
  36. USER_ID_2 = "@bright:test"
  37. OUTLIER = {"outlier": True}
  38. ROOM_ID = "!room:test"
  39. logger = logging.getLogger(__name__)
  40. def dict_equals(self: EventBase, other: EventBase) -> bool:
  41. me = encode_canonical_json(self.get_pdu_json())
  42. them = encode_canonical_json(other.get_pdu_json())
  43. return me == them
  44. def patch__eq__(cls: object) -> Callable[[], None]:
  45. eq = getattr(cls, "__eq__", None)
  46. cls.__eq__ = dict_equals # type: ignore[assignment]
  47. def unpatch() -> None:
  48. if eq is not None:
  49. cls.__eq__ = eq # type: ignore[assignment]
  50. return unpatch
  51. class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase):
  52. STORE_TYPE = EventsWorkerStore
  53. def setUp(self) -> None:
  54. # Patch up the equality operator for events so that we can check
  55. # whether lists of events match using assertEqual
  56. self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)]
  57. super().setUp()
  58. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  59. super().prepare(reactor, clock, hs)
  60. self.get_success(
  61. self.master_store.store_room(
  62. ROOM_ID,
  63. USER_ID,
  64. is_public=False,
  65. room_version=RoomVersions.V1,
  66. )
  67. )
  68. def tearDown(self) -> None:
  69. [unpatch() for unpatch in self.unpatches]
  70. def test_get_latest_event_ids_in_room(self) -> None:
  71. create = self.persist(type="m.room.create", key="", creator=USER_ID)
  72. self.replicate()
  73. self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
  74. join = self.persist(
  75. type="m.room.member",
  76. key=USER_ID,
  77. membership="join",
  78. prev_events=[(create.event_id, {})],
  79. )
  80. self.replicate()
  81. self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
  82. def test_redactions(self) -> None:
  83. self.persist(type="m.room.create", key="", creator=USER_ID)
  84. self.persist(type="m.room.member", key=USER_ID, membership="join")
  85. msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
  86. self.replicate()
  87. self.check("get_event", [msg.event_id], msg)
  88. redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
  89. self.replicate()
  90. msg_dict = msg.get_dict()
  91. msg_dict["content"] = {}
  92. msg_dict["unsigned"]["redacted_by"] = redaction.event_id
  93. msg_dict["unsigned"]["redacted_because"] = redaction
  94. redacted = make_event_from_dict(
  95. msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
  96. )
  97. self.check("get_event", [msg.event_id], redacted)
  98. def test_backfilled_redactions(self) -> None:
  99. self.persist(type="m.room.create", key="", creator=USER_ID)
  100. self.persist(type="m.room.member", key=USER_ID, membership="join")
  101. msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
  102. self.replicate()
  103. self.check("get_event", [msg.event_id], msg)
  104. redaction = self.persist(
  105. type="m.room.redaction", redacts=msg.event_id, backfill=True
  106. )
  107. self.replicate()
  108. msg_dict = msg.get_dict()
  109. msg_dict["content"] = {}
  110. msg_dict["unsigned"]["redacted_by"] = redaction.event_id
  111. msg_dict["unsigned"]["redacted_because"] = redaction
  112. redacted = make_event_from_dict(
  113. msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict()
  114. )
  115. self.check("get_event", [msg.event_id], redacted)
  116. def test_invites(self) -> None:
  117. self.persist(type="m.room.create", key="", creator=USER_ID)
  118. self.check("get_invited_rooms_for_local_user", [USER_ID_2], [])
  119. event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
  120. assert event.internal_metadata.stream_ordering is not None
  121. self.replicate()
  122. self.check(
  123. "get_invited_rooms_for_local_user",
  124. [USER_ID_2],
  125. [
  126. RoomsForUser(
  127. ROOM_ID,
  128. USER_ID,
  129. "invite",
  130. event.event_id,
  131. event.internal_metadata.stream_ordering,
  132. RoomVersions.V1.identifier,
  133. )
  134. ],
  135. )
  136. @parameterized.expand([(True,), (False,)])
  137. def test_push_actions_for_user(self, send_receipt: bool) -> None:
  138. self.persist(type="m.room.create", key="", creator=USER_ID)
  139. self.persist(type="m.room.member", key=USER_ID, membership="join")
  140. self.persist(
  141. type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join"
  142. )
  143. event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
  144. self.replicate()
  145. if send_receipt:
  146. self.get_success(
  147. self.master_store.insert_receipt(
  148. ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {}
  149. )
  150. )
  151. self.check(
  152. "get_unread_event_push_actions_by_room_for_user",
  153. [ROOM_ID, USER_ID_2],
  154. RoomNotifCounts(
  155. NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}
  156. ),
  157. )
  158. self.persist(
  159. type="m.room.message",
  160. msgtype="m.text",
  161. body="world",
  162. push_actions=[(USER_ID_2, ["notify"])],
  163. )
  164. self.replicate()
  165. self.check(
  166. "get_unread_event_push_actions_by_room_for_user",
  167. [ROOM_ID, USER_ID_2],
  168. RoomNotifCounts(
  169. NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}
  170. ),
  171. )
  172. self.persist(
  173. type="m.room.message",
  174. msgtype="m.text",
  175. body="world",
  176. push_actions=[
  177. (USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
  178. ],
  179. )
  180. self.replicate()
  181. self.check(
  182. "get_unread_event_push_actions_by_room_for_user",
  183. [ROOM_ID, USER_ID_2],
  184. RoomNotifCounts(
  185. NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}
  186. ),
  187. )
  188. def test_get_rooms_for_user_with_stream_ordering(self) -> None:
  189. """Check that the cache on get_rooms_for_user_with_stream_ordering is invalidated
  190. by rows in the events stream
  191. """
  192. self.persist(type="m.room.create", key="", creator=USER_ID)
  193. self.persist(type="m.room.member", key=USER_ID, membership="join")
  194. self.replicate()
  195. self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
  196. j2 = self.persist(
  197. type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
  198. )
  199. assert j2.internal_metadata.stream_ordering is not None
  200. self.replicate()
  201. expected_pos = PersistedEventPosition(
  202. "master", j2.internal_metadata.stream_ordering
  203. )
  204. self.check(
  205. "get_rooms_for_user_with_stream_ordering",
  206. (USER_ID_2,),
  207. {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
  208. )
  209. def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(
  210. self,
  211. ) -> None:
  212. """Check that current_state invalidation happens correctly with multiple events
  213. in the persistence batch.
  214. This test attempts to reproduce a race condition between the event persistence
  215. loop and a worker-based Sync handler.
  216. The problem occurred when the master persisted several events in one batch. It
  217. only updates the current_state at the end of each batch, so the obvious thing
  218. to do is then to issue a current_state_delta stream update corresponding to the
  219. last stream_id in the batch.
  220. However, that raises the possibility that a worker will see the replication
  221. notification for a join event before the current_state caches are invalidated.
  222. The test involves:
  223. * creating a join and a message event for a user, and persisting them in the
  224. same batch
  225. * controlling the replication stream so that updates are sent gradually
  226. * between each bunch of replication updates, check that we see a consistent
  227. snapshot of the state.
  228. """
  229. self.persist(type="m.room.create", key="", creator=USER_ID)
  230. self.persist(type="m.room.member", key=USER_ID, membership="join")
  231. self.replicate()
  232. self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
  233. # limit the replication rate
  234. repl_transport = self._server_transport
  235. assert isinstance(repl_transport, FakeTransport)
  236. repl_transport.autoflush = False
  237. # build the join and message events and persist them in the same batch.
  238. logger.info("----- build test events ------")
  239. j2, j2ctx = self.build_event(
  240. type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join"
  241. )
  242. msg, msgctx = self.build_event()
  243. self.get_success(self.persistance.persist_events([(j2, j2ctx), (msg, msgctx)]))
  244. self.replicate()
  245. assert j2.internal_metadata.stream_ordering is not None
  246. event_source = RoomEventSource(self.hs)
  247. event_source.store = self.worker_store
  248. current_token = event_source.get_current_key()
  249. # gradually stream out the replication
  250. while repl_transport.buffer:
  251. logger.info("------ flush ------")
  252. repl_transport.flush(30)
  253. self.pump(0)
  254. prev_token = current_token
  255. current_token = event_source.get_current_key()
  256. # attempt to replicate the behaviour of the sync handler.
  257. #
  258. # First, we get a list of the rooms we are joined to
  259. joined_rooms = self.get_success(
  260. self.worker_store.get_rooms_for_user_with_stream_ordering(USER_ID_2)
  261. )
  262. # Then, we get a list of the events since the last sync
  263. membership_changes = self.get_success(
  264. self.worker_store.get_membership_changes_for_user(
  265. USER_ID_2, prev_token, current_token
  266. )
  267. )
  268. logger.info(
  269. "%s->%s: joined_rooms=%r membership_changes=%r",
  270. prev_token,
  271. current_token,
  272. joined_rooms,
  273. membership_changes,
  274. )
  275. # the membership change is only any use to us if the room is in the
  276. # joined_rooms list.
  277. if membership_changes:
  278. expected_pos = PersistedEventPosition(
  279. "master", j2.internal_metadata.stream_ordering
  280. )
  281. self.assertEqual(
  282. joined_rooms,
  283. {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
  284. )
  285. event_id = 0
  286. def persist(self, backfill: bool = False, **kwargs: Any) -> EventBase:
  287. """
  288. Returns:
  289. The event that was persisted.
  290. """
  291. event, context = self.build_event(**kwargs)
  292. if backfill:
  293. self.get_success(
  294. self.persistance.persist_events([(event, context)], backfilled=True)
  295. )
  296. else:
  297. self.get_success(self.persistance.persist_event(event, context))
  298. return event
  299. def build_event(
  300. self,
  301. sender: str = USER_ID,
  302. room_id: str = ROOM_ID,
  303. type: str = "m.room.message",
  304. key: Optional[str] = None,
  305. internal: Optional[dict] = None,
  306. depth: Optional[int] = None,
  307. prev_events: Optional[List[Tuple[str, dict]]] = None,
  308. auth_events: Optional[List[str]] = None,
  309. prev_state: Optional[List[str]] = None,
  310. redacts: Optional[str] = None,
  311. push_actions: Iterable = frozenset(),
  312. **content: object,
  313. ) -> Tuple[EventBase, EventContext]:
  314. prev_events = prev_events or []
  315. auth_events = auth_events or []
  316. prev_state = prev_state or []
  317. if depth is None:
  318. depth = self.event_id
  319. if not prev_events:
  320. latest_event_ids = self.get_success(
  321. self.master_store.get_latest_event_ids_in_room(room_id)
  322. )
  323. prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
  324. event_dict = {
  325. "sender": sender,
  326. "type": type,
  327. "content": content,
  328. "event_id": "$%d:blue" % (self.event_id,),
  329. "room_id": room_id,
  330. "depth": depth,
  331. "origin_server_ts": self.event_id,
  332. "prev_events": prev_events,
  333. "auth_events": auth_events,
  334. }
  335. if key is not None:
  336. event_dict["state_key"] = key
  337. event_dict["prev_state"] = prev_state
  338. if redacts is not None:
  339. event_dict["redacts"] = redacts
  340. event = make_event_from_dict(event_dict, internal_metadata_dict=internal or {})
  341. self.event_id += 1
  342. state_handler = self.hs.get_state_handler()
  343. context = self.get_success(state_handler.compute_event_context(event))
  344. self.get_success(
  345. self.master_store.add_push_actions_to_staging(
  346. event.event_id,
  347. dict(push_actions),
  348. False,
  349. "main",
  350. )
  351. )
  352. return event, context