receipts.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import (
  17. TYPE_CHECKING,
  18. Any,
  19. Collection,
  20. Dict,
  21. Iterable,
  22. List,
  23. Optional,
  24. Tuple,
  25. cast,
  26. )
  27. from synapse.api.constants import EduTypes, ReceiptTypes
  28. from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
  29. from synapse.replication.tcp.streams import ReceiptsStream
  30. from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
  31. from synapse.storage.database import (
  32. DatabasePool,
  33. LoggingDatabaseConnection,
  34. LoggingTransaction,
  35. )
  36. from synapse.storage.engines.postgres import PostgresEngine
  37. from synapse.storage.util.id_generators import (
  38. AbstractStreamIdTracker,
  39. MultiWriterIdGenerator,
  40. StreamIdGenerator,
  41. )
  42. from synapse.types import JsonDict
  43. from synapse.util import json_encoder
  44. from synapse.util.caches.descriptors import cached, cachedList
  45. from synapse.util.caches.stream_change_cache import StreamChangeCache
  46. if TYPE_CHECKING:
  47. from synapse.server import HomeServer
  48. logger = logging.getLogger(__name__)
  49. class ReceiptsWorkerStore(SQLBaseStore):
  50. def __init__(
  51. self,
  52. database: DatabasePool,
  53. db_conn: LoggingDatabaseConnection,
  54. hs: "HomeServer",
  55. ):
  56. self._instance_name = hs.get_instance_name()
  57. self._receipts_id_gen: AbstractStreamIdTracker
  58. if isinstance(database.engine, PostgresEngine):
  59. self._can_write_to_receipts = (
  60. self._instance_name in hs.config.worker.writers.receipts
  61. )
  62. self._receipts_id_gen = MultiWriterIdGenerator(
  63. db_conn=db_conn,
  64. db=database,
  65. stream_name="receipts",
  66. instance_name=self._instance_name,
  67. tables=[("receipts_linearized", "instance_name", "stream_id")],
  68. sequence_name="receipts_sequence",
  69. writers=hs.config.worker.writers.receipts,
  70. )
  71. else:
  72. self._can_write_to_receipts = True
  73. # We shouldn't be running in worker mode with SQLite, but its useful
  74. # to support it for unit tests.
  75. #
  76. # If this process is the writer than we need to use
  77. # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
  78. # updated over replication. (Multiple writers are not supported for
  79. # SQLite).
  80. if hs.get_instance_name() in hs.config.worker.writers.receipts:
  81. self._receipts_id_gen = StreamIdGenerator(
  82. db_conn, "receipts_linearized", "stream_id"
  83. )
  84. else:
  85. self._receipts_id_gen = SlavedIdTracker(
  86. db_conn, "receipts_linearized", "stream_id"
  87. )
  88. super().__init__(database, db_conn, hs)
  89. max_receipts_stream_id = self.get_max_receipt_stream_id()
  90. receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
  91. db_conn,
  92. "receipts_linearized",
  93. entity_column="room_id",
  94. stream_column="stream_id",
  95. max_value=max_receipts_stream_id,
  96. limit=10000,
  97. )
  98. self._receipts_stream_cache = StreamChangeCache(
  99. "ReceiptsRoomChangeCache",
  100. min_receipts_stream_id,
  101. prefilled_cache=receipts_stream_prefill,
  102. )
  103. def get_max_receipt_stream_id(self) -> int:
  104. """Get the current max stream ID for receipts stream"""
  105. return self._receipts_id_gen.get_current_token()
  106. async def get_last_receipt_event_id_for_user(
  107. self, user_id: str, room_id: str, receipt_types: Iterable[str]
  108. ) -> Optional[str]:
  109. """
  110. Fetch the event ID for the latest receipt in a room with one of the given receipt types.
  111. Args:
  112. user_id: The user to fetch receipts for.
  113. room_id: The room ID to fetch the receipt for.
  114. receipt_type: The receipt types to fetch. Earlier receipt types
  115. are given priority if multiple receipts point to the same event.
  116. Returns:
  117. The latest receipt, if one exists.
  118. """
  119. latest_event_id: Optional[str] = None
  120. latest_stream_ordering = 0
  121. for receipt_type in receipt_types:
  122. result = await self._get_last_receipt_event_id_for_user(
  123. user_id, room_id, receipt_type
  124. )
  125. if result is None:
  126. continue
  127. event_id, stream_ordering = result
  128. if latest_event_id is None or latest_stream_ordering < stream_ordering:
  129. latest_event_id = event_id
  130. latest_stream_ordering = stream_ordering
  131. return latest_event_id
  132. @cached()
  133. async def _get_last_receipt_event_id_for_user(
  134. self, user_id: str, room_id: str, receipt_type: str
  135. ) -> Optional[Tuple[str, int]]:
  136. """
  137. Fetch the event ID and stream ordering for the latest receipt.
  138. Args:
  139. user_id: The user to fetch receipts for.
  140. room_id: The room ID to fetch the receipt for.
  141. receipt_type: The receipt type to fetch.
  142. Returns:
  143. The event ID and stream ordering of the latest receipt, if one exists;
  144. otherwise `None`.
  145. """
  146. sql = """
  147. SELECT event_id, stream_ordering
  148. FROM receipts_linearized
  149. INNER JOIN events USING (room_id, event_id)
  150. WHERE user_id = ?
  151. AND room_id = ?
  152. AND receipt_type = ?
  153. """
  154. def f(txn: LoggingTransaction) -> Optional[Tuple[str, int]]:
  155. txn.execute(sql, (user_id, room_id, receipt_type))
  156. return cast(Optional[Tuple[str, int]], txn.fetchone())
  157. return await self.db_pool.runInteraction("get_own_receipt_for_user", f)
  158. async def get_receipts_for_user(
  159. self, user_id: str, receipt_types: Iterable[str]
  160. ) -> Dict[str, str]:
  161. """
  162. Fetch the event IDs for the latest receipts sent by the given user.
  163. Args:
  164. user_id: The user to fetch receipts for.
  165. receipt_types: The receipt types to check.
  166. Returns:
  167. A map of room ID to the event ID of the latest receipt for that room.
  168. If the user has not sent a receipt to a room then it will not appear
  169. in the returned dictionary.
  170. """
  171. results = await self.get_receipts_for_user_with_orderings(
  172. user_id, receipt_types
  173. )
  174. # Reduce the result to room ID -> event ID.
  175. return {
  176. room_id: room_result["event_id"] for room_id, room_result in results.items()
  177. }
  178. async def get_receipts_for_user_with_orderings(
  179. self, user_id: str, receipt_types: Iterable[str]
  180. ) -> JsonDict:
  181. """
  182. Fetch receipts for all rooms that the given user is joined to.
  183. Args:
  184. user_id: The user to fetch receipts for.
  185. receipt_types: The receipt types to fetch. Earlier receipt types
  186. are given priority if multiple receipts point to the same event.
  187. Returns:
  188. A map of room ID to the latest receipt (for the given types).
  189. """
  190. results: JsonDict = {}
  191. for receipt_type in receipt_types:
  192. partial_result = await self._get_receipts_for_user_with_orderings(
  193. user_id, receipt_type
  194. )
  195. for room_id, room_result in partial_result.items():
  196. # If the room has not yet been seen, or the receipt is newer,
  197. # use it.
  198. if (
  199. room_id not in results
  200. or results[room_id]["stream_ordering"]
  201. < room_result["stream_ordering"]
  202. ):
  203. results[room_id] = room_result
  204. return results
  205. @cached()
  206. async def _get_receipts_for_user_with_orderings(
  207. self, user_id: str, receipt_type: str
  208. ) -> JsonDict:
  209. """
  210. Fetch receipts for all rooms that the given user is joined to.
  211. Args:
  212. user_id: The user to fetch receipts for.
  213. receipt_type: The receipt type to fetch.
  214. Returns:
  215. A map of room ID to the latest receipt information.
  216. """
  217. def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
  218. sql = (
  219. "SELECT rl.room_id, rl.event_id,"
  220. " e.topological_ordering, e.stream_ordering"
  221. " FROM receipts_linearized AS rl"
  222. " INNER JOIN events AS e USING (room_id, event_id)"
  223. " WHERE rl.room_id = e.room_id"
  224. " AND rl.event_id = e.event_id"
  225. " AND user_id = ?"
  226. " AND receipt_type = ?"
  227. )
  228. txn.execute(sql, (user_id, receipt_type))
  229. return cast(List[Tuple[str, str, int, int]], txn.fetchall())
  230. rows = await self.db_pool.runInteraction(
  231. "get_receipts_for_user_with_orderings", f
  232. )
  233. return {
  234. row[0]: {
  235. "event_id": row[1],
  236. "topological_ordering": row[2],
  237. "stream_ordering": row[3],
  238. }
  239. for row in rows
  240. }
  241. async def get_linearized_receipts_for_rooms(
  242. self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
  243. ) -> List[dict]:
  244. """Get receipts for multiple rooms for sending to clients.
  245. Args:
  246. room_id: The room IDs to fetch receipts of.
  247. to_key: Max stream id to fetch receipts up to.
  248. from_key: Min stream id to fetch receipts from. None fetches
  249. from the start.
  250. Returns:
  251. A list of receipts.
  252. """
  253. room_ids = set(room_ids)
  254. if from_key is not None:
  255. # Only ask the database about rooms where there have been new
  256. # receipts added since `from_key`
  257. room_ids = self._receipts_stream_cache.get_entities_changed(
  258. room_ids, from_key
  259. )
  260. results = await self._get_linearized_receipts_for_rooms(
  261. room_ids, to_key, from_key=from_key
  262. )
  263. return [ev for res in results.values() for ev in res]
  264. async def get_linearized_receipts_for_room(
  265. self, room_id: str, to_key: int, from_key: Optional[int] = None
  266. ) -> List[dict]:
  267. """Get receipts for a single room for sending to clients.
  268. Args:
  269. room_ids: The room id.
  270. to_key: Max stream id to fetch receipts up to.
  271. from_key: Min stream id to fetch receipts from. None fetches
  272. from the start.
  273. Returns:
  274. A list of receipts.
  275. """
  276. if from_key is not None:
  277. # Check the cache first to see if any new receipts have been added
  278. # since`from_key`. If not we can no-op.
  279. if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
  280. return []
  281. return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
  282. @cached(tree=True)
  283. async def _get_linearized_receipts_for_room(
  284. self, room_id: str, to_key: int, from_key: Optional[int] = None
  285. ) -> List[JsonDict]:
  286. """See get_linearized_receipts_for_room"""
  287. def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
  288. if from_key:
  289. sql = (
  290. "SELECT * FROM receipts_linearized WHERE"
  291. " room_id = ? AND stream_id > ? AND stream_id <= ?"
  292. )
  293. txn.execute(sql, (room_id, from_key, to_key))
  294. else:
  295. sql = (
  296. "SELECT * FROM receipts_linearized WHERE"
  297. " room_id = ? AND stream_id <= ?"
  298. )
  299. txn.execute(sql, (room_id, to_key))
  300. rows = self.db_pool.cursor_to_dict(txn)
  301. return rows
  302. rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
  303. if not rows:
  304. return []
  305. content: JsonDict = {}
  306. for row in rows:
  307. content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
  308. row["user_id"]
  309. ] = db_to_json(row["data"])
  310. return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
  311. @cachedList(
  312. cached_method_name="_get_linearized_receipts_for_room",
  313. list_name="room_ids",
  314. num_args=3,
  315. )
  316. async def _get_linearized_receipts_for_rooms(
  317. self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
  318. ) -> Dict[str, List[JsonDict]]:
  319. if not room_ids:
  320. return {}
  321. def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
  322. if from_key:
  323. sql = """
  324. SELECT * FROM receipts_linearized WHERE
  325. stream_id > ? AND stream_id <= ? AND
  326. """
  327. clause, args = make_in_list_sql_clause(
  328. self.database_engine, "room_id", room_ids
  329. )
  330. txn.execute(sql + clause, [from_key, to_key] + list(args))
  331. else:
  332. sql = """
  333. SELECT * FROM receipts_linearized WHERE
  334. stream_id <= ? AND
  335. """
  336. clause, args = make_in_list_sql_clause(
  337. self.database_engine, "room_id", room_ids
  338. )
  339. txn.execute(sql + clause, [to_key] + list(args))
  340. return self.db_pool.cursor_to_dict(txn)
  341. txn_results = await self.db_pool.runInteraction(
  342. "_get_linearized_receipts_for_rooms", f
  343. )
  344. results: JsonDict = {}
  345. for row in txn_results:
  346. # We want a single event per room, since we want to batch the
  347. # receipts by room, event and type.
  348. room_event = results.setdefault(
  349. row["room_id"],
  350. {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
  351. )
  352. # The content is of the form:
  353. # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
  354. event_entry = room_event["content"].setdefault(row["event_id"], {})
  355. receipt_type = event_entry.setdefault(row["receipt_type"], {})
  356. receipt_type[row["user_id"]] = db_to_json(row["data"])
  357. results = {
  358. room_id: [results[room_id]] if room_id in results else []
  359. for room_id in room_ids
  360. }
  361. return results
  362. @cached(
  363. num_args=2,
  364. )
  365. async def get_linearized_receipts_for_all_rooms(
  366. self, to_key: int, from_key: Optional[int] = None
  367. ) -> Dict[str, JsonDict]:
  368. """Get receipts for all rooms between two stream_ids, up
  369. to a limit of the latest 100 read receipts.
  370. Args:
  371. to_key: Max stream id to fetch receipts up to.
  372. from_key: Min stream id to fetch receipts from. None fetches
  373. from the start.
  374. Returns:
  375. A dictionary of roomids to a list of receipts.
  376. """
  377. def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
  378. if from_key:
  379. sql = """
  380. SELECT * FROM receipts_linearized WHERE
  381. stream_id > ? AND stream_id <= ?
  382. ORDER BY stream_id DESC
  383. LIMIT 100
  384. """
  385. txn.execute(sql, [from_key, to_key])
  386. else:
  387. sql = """
  388. SELECT * FROM receipts_linearized WHERE
  389. stream_id <= ?
  390. ORDER BY stream_id DESC
  391. LIMIT 100
  392. """
  393. txn.execute(sql, [to_key])
  394. return self.db_pool.cursor_to_dict(txn)
  395. txn_results = await self.db_pool.runInteraction(
  396. "get_linearized_receipts_for_all_rooms", f
  397. )
  398. results: JsonDict = {}
  399. for row in txn_results:
  400. # We want a single event per room, since we want to batch the
  401. # receipts by room, event and type.
  402. room_event = results.setdefault(
  403. row["room_id"],
  404. {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
  405. )
  406. # The content is of the form:
  407. # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
  408. event_entry = room_event["content"].setdefault(row["event_id"], {})
  409. receipt_type = event_entry.setdefault(row["receipt_type"], {})
  410. receipt_type[row["user_id"]] = db_to_json(row["data"])
  411. return results
  412. async def get_users_sent_receipts_between(
  413. self, last_id: int, current_id: int
  414. ) -> List[str]:
  415. """Get all users who sent receipts between `last_id` exclusive and
  416. `current_id` inclusive.
  417. Returns:
  418. The list of users.
  419. """
  420. if last_id == current_id:
  421. return []
  422. def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
  423. sql = """
  424. SELECT DISTINCT user_id FROM receipts_linearized
  425. WHERE ? < stream_id AND stream_id <= ?
  426. """
  427. txn.execute(sql, (last_id, current_id))
  428. return [r[0] for r in txn]
  429. return await self.db_pool.runInteraction(
  430. "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
  431. )
  432. async def get_all_updated_receipts(
  433. self, instance_name: str, last_id: int, current_id: int, limit: int
  434. ) -> Tuple[List[Tuple[int, list]], int, bool]:
  435. """Get updates for receipts replication stream.
  436. Args:
  437. instance_name: The writer we want to fetch updates from. Unused
  438. here since there is only ever one writer.
  439. last_id: The token to fetch updates from. Exclusive.
  440. current_id: The token to fetch updates up to. Inclusive.
  441. limit: The requested limit for the number of rows to return. The
  442. function may return more or fewer rows.
  443. Returns:
  444. A tuple consisting of: the updates, a token to use to fetch
  445. subsequent updates, and whether we returned fewer rows than exists
  446. between the requested tokens due to the limit.
  447. The token returned can be used in a subsequent call to this
  448. function to get further updatees.
  449. The updates are a list of 2-tuples of stream ID and the row data
  450. """
  451. if last_id == current_id:
  452. return [], current_id, False
  453. def get_all_updated_receipts_txn(
  454. txn: LoggingTransaction,
  455. ) -> Tuple[List[Tuple[int, list]], int, bool]:
  456. sql = """
  457. SELECT stream_id, room_id, receipt_type, user_id, event_id, data
  458. FROM receipts_linearized
  459. WHERE ? < stream_id AND stream_id <= ?
  460. ORDER BY stream_id ASC
  461. LIMIT ?
  462. """
  463. txn.execute(sql, (last_id, current_id, limit))
  464. updates = cast(
  465. List[Tuple[int, list]],
  466. [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
  467. )
  468. limited = False
  469. upper_bound = current_id
  470. if len(updates) == limit:
  471. limited = True
  472. upper_bound = updates[-1][0]
  473. return updates, upper_bound, limited
  474. return await self.db_pool.runInteraction(
  475. "get_all_updated_receipts", get_all_updated_receipts_txn
  476. )
  477. def invalidate_caches_for_receipt(
  478. self, room_id: str, receipt_type: str, user_id: str
  479. ) -> None:
  480. self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
  481. self._get_linearized_receipts_for_room.invalidate((room_id,))
  482. self._get_last_receipt_event_id_for_user.invalidate(
  483. (user_id, room_id, receipt_type)
  484. )
  485. def process_replication_rows(
  486. self,
  487. stream_name: str,
  488. instance_name: str,
  489. token: int,
  490. rows: Iterable[Any],
  491. ) -> None:
  492. if stream_name == ReceiptsStream.NAME:
  493. self._receipts_id_gen.advance(instance_name, token)
  494. for row in rows:
  495. self.invalidate_caches_for_receipt(
  496. row.room_id, row.receipt_type, row.user_id
  497. )
  498. self._receipts_stream_cache.entity_has_changed(row.room_id, token)
  499. return super().process_replication_rows(stream_name, instance_name, token, rows)
  500. def _insert_linearized_receipt_txn(
  501. self,
  502. txn: LoggingTransaction,
  503. room_id: str,
  504. receipt_type: str,
  505. user_id: str,
  506. event_id: str,
  507. data: JsonDict,
  508. stream_id: int,
  509. ) -> Optional[int]:
  510. """Inserts a receipt into the database if it's newer than the current one.
  511. Returns:
  512. None if the receipt is older than the current receipt
  513. otherwise, the rx timestamp of the event that the receipt corresponds to
  514. (or 0 if the event is unknown)
  515. """
  516. assert self._can_write_to_receipts
  517. res = self.db_pool.simple_select_one_txn(
  518. txn,
  519. table="events",
  520. retcols=["stream_ordering", "received_ts"],
  521. keyvalues={"event_id": event_id},
  522. allow_none=True,
  523. )
  524. stream_ordering = int(res["stream_ordering"]) if res else None
  525. rx_ts = res["received_ts"] if res else 0
  526. # We don't want to clobber receipts for more recent events, so we
  527. # have to compare orderings of existing receipts
  528. if stream_ordering is not None:
  529. sql = (
  530. "SELECT stream_ordering, event_id FROM events"
  531. " INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
  532. " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
  533. )
  534. txn.execute(sql, (room_id, receipt_type, user_id))
  535. for so, eid in txn:
  536. if int(so) >= stream_ordering:
  537. logger.debug(
  538. "Ignoring new receipt for %s in favour of existing "
  539. "one for later event %s",
  540. event_id,
  541. eid,
  542. )
  543. return None
  544. txn.call_after(
  545. self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
  546. )
  547. txn.call_after(
  548. self._receipts_stream_cache.entity_has_changed, room_id, stream_id
  549. )
  550. self.db_pool.simple_upsert_txn(
  551. txn,
  552. table="receipts_linearized",
  553. keyvalues={
  554. "room_id": room_id,
  555. "receipt_type": receipt_type,
  556. "user_id": user_id,
  557. },
  558. values={
  559. "stream_id": stream_id,
  560. "event_id": event_id,
  561. "data": json_encoder.encode(data),
  562. },
  563. # receipts_linearized has a unique constraint on
  564. # (user_id, room_id, receipt_type), so no need to lock
  565. lock=False,
  566. )
  567. # When updating a local users read receipt, remove any push actions
  568. # which resulted from the receipt's event and all earlier events.
  569. if (
  570. self.hs.is_mine_id(user_id)
  571. and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
  572. and stream_ordering is not None
  573. ):
  574. self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
  575. txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
  576. )
  577. return rx_ts
  578. def _graph_to_linear(
  579. self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
  580. ) -> str:
  581. """
  582. Generate a linearized event from a list of events (i.e. a list of forward
  583. extremities in the room).
  584. This should allow for calculation of the correct read receipt even if
  585. servers have different event ordering.
  586. Args:
  587. txn: The transaction
  588. room_id: The room ID the events are in.
  589. event_ids: The list of event IDs to linearize.
  590. Returns:
  591. The linearized event ID.
  592. """
  593. # TODO: Make this better.
  594. clause, args = make_in_list_sql_clause(
  595. self.database_engine, "event_id", event_ids
  596. )
  597. sql = """
  598. SELECT event_id WHERE room_id = ? AND stream_ordering IN (
  599. SELECT max(stream_ordering) WHERE %s
  600. )
  601. """ % (
  602. clause,
  603. )
  604. txn.execute(sql, [room_id] + list(args))
  605. rows = txn.fetchall()
  606. if rows:
  607. return rows[0][0]
  608. else:
  609. raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
  610. async def insert_receipt(
  611. self,
  612. room_id: str,
  613. receipt_type: str,
  614. user_id: str,
  615. event_ids: List[str],
  616. data: dict,
  617. ) -> Optional[Tuple[int, int]]:
  618. """Insert a receipt, either from local client or remote server.
  619. Automatically does conversion between linearized and graph
  620. representations.
  621. Returns:
  622. The new receipts stream ID and token, if the receipt is newer than
  623. what was previously persisted. None, otherwise.
  624. """
  625. assert self._can_write_to_receipts
  626. if not event_ids:
  627. return None
  628. if len(event_ids) == 1:
  629. linearized_event_id = event_ids[0]
  630. else:
  631. # we need to points in graph -> linearized form.
  632. linearized_event_id = await self.db_pool.runInteraction(
  633. "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
  634. )
  635. async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
  636. event_ts = await self.db_pool.runInteraction(
  637. "insert_linearized_receipt",
  638. self._insert_linearized_receipt_txn,
  639. room_id,
  640. receipt_type,
  641. user_id,
  642. linearized_event_id,
  643. data,
  644. stream_id=stream_id,
  645. )
  646. # If the receipt was older than the currently persisted one, nothing to do.
  647. if event_ts is None:
  648. return None
  649. now = self._clock.time_msec()
  650. logger.debug(
  651. "RR for event %s in %s (%i ms old)",
  652. linearized_event_id,
  653. room_id,
  654. now - event_ts,
  655. )
  656. await self.db_pool.runInteraction(
  657. "insert_graph_receipt",
  658. self._insert_graph_receipt_txn,
  659. room_id,
  660. receipt_type,
  661. user_id,
  662. event_ids,
  663. data,
  664. )
  665. max_persisted_id = self._receipts_id_gen.get_current_token()
  666. return stream_id, max_persisted_id
  667. def _insert_graph_receipt_txn(
  668. self,
  669. txn: LoggingTransaction,
  670. room_id: str,
  671. receipt_type: str,
  672. user_id: str,
  673. event_ids: List[str],
  674. data: JsonDict,
  675. ) -> None:
  676. assert self._can_write_to_receipts
  677. txn.call_after(
  678. self._get_receipts_for_user_with_orderings.invalidate,
  679. (user_id, receipt_type),
  680. )
  681. # FIXME: This shouldn't invalidate the whole cache
  682. txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
  683. self.db_pool.simple_delete_txn(
  684. txn,
  685. table="receipts_graph",
  686. keyvalues={
  687. "room_id": room_id,
  688. "receipt_type": receipt_type,
  689. "user_id": user_id,
  690. },
  691. )
  692. self.db_pool.simple_insert_txn(
  693. txn,
  694. table="receipts_graph",
  695. values={
  696. "room_id": room_id,
  697. "receipt_type": receipt_type,
  698. "user_id": user_id,
  699. "event_ids": json_encoder.encode(event_ids),
  700. "data": json_encoder.encode(data),
  701. },
  702. )
  703. class ReceiptsStore(ReceiptsWorkerStore):
  704. pass