receipts.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from ._base import SQLBaseStore
  16. from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
  17. from synapse.util.caches.stream_change_cache import StreamChangeCache
  18. from twisted.internet import defer
  19. import logging
  20. import ujson as json
  21. logger = logging.getLogger(__name__)
  22. class ReceiptsStore(SQLBaseStore):
  23. def __init__(self, hs):
  24. super(ReceiptsStore, self).__init__(hs)
  25. self._receipts_stream_cache = StreamChangeCache(
  26. "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
  27. )
  28. @cached(num_args=2)
  29. def get_receipts_for_room(self, room_id, receipt_type):
  30. return self._simple_select_list(
  31. table="receipts_linearized",
  32. keyvalues={
  33. "room_id": room_id,
  34. "receipt_type": receipt_type,
  35. },
  36. retcols=("user_id", "event_id"),
  37. desc="get_receipts_for_room",
  38. )
  39. @cached(num_args=3)
  40. def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
  41. return self._simple_select_one_onecol(
  42. table="receipts_linearized",
  43. keyvalues={
  44. "room_id": room_id,
  45. "receipt_type": receipt_type,
  46. "user_id": user_id
  47. },
  48. retcol="event_id",
  49. desc="get_own_receipt_for_user",
  50. allow_none=True,
  51. )
  52. @cachedInlineCallbacks(num_args=2)
  53. def get_receipts_for_user(self, user_id, receipt_type):
  54. rows = yield self._simple_select_list(
  55. table="receipts_linearized",
  56. keyvalues={
  57. "user_id": user_id,
  58. "receipt_type": receipt_type,
  59. },
  60. retcols=("room_id", "event_id"),
  61. desc="get_receipts_for_user",
  62. )
  63. defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
  64. @defer.inlineCallbacks
  65. def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
  66. """Get receipts for multiple rooms for sending to clients.
  67. Args:
  68. room_ids (list): List of room_ids.
  69. to_key (int): Max stream id to fetch receipts upto.
  70. from_key (int): Min stream id to fetch receipts from. None fetches
  71. from the start.
  72. Returns:
  73. list: A list of receipts.
  74. """
  75. room_ids = set(room_ids)
  76. if from_key:
  77. room_ids = yield self._receipts_stream_cache.get_entities_changed(
  78. room_ids, from_key
  79. )
  80. results = yield self._get_linearized_receipts_for_rooms(
  81. room_ids, to_key, from_key=from_key
  82. )
  83. defer.returnValue([ev for res in results.values() for ev in res])
  84. @cachedInlineCallbacks(num_args=3, max_entries=5000)
  85. def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
  86. """Get receipts for a single room for sending to clients.
  87. Args:
  88. room_ids (str): The room id.
  89. to_key (int): Max stream id to fetch receipts upto.
  90. from_key (int): Min stream id to fetch receipts from. None fetches
  91. from the start.
  92. Returns:
  93. list: A list of receipts.
  94. """
  95. def f(txn):
  96. if from_key:
  97. sql = (
  98. "SELECT * FROM receipts_linearized WHERE"
  99. " room_id = ? AND stream_id > ? AND stream_id <= ?"
  100. )
  101. txn.execute(
  102. sql,
  103. (room_id, from_key, to_key)
  104. )
  105. else:
  106. sql = (
  107. "SELECT * FROM receipts_linearized WHERE"
  108. " room_id = ? AND stream_id <= ?"
  109. )
  110. txn.execute(
  111. sql,
  112. (room_id, to_key)
  113. )
  114. rows = self.cursor_to_dict(txn)
  115. return rows
  116. rows = yield self.runInteraction(
  117. "get_linearized_receipts_for_room", f
  118. )
  119. if not rows:
  120. defer.returnValue([])
  121. content = {}
  122. for row in rows:
  123. content.setdefault(
  124. row["event_id"], {}
  125. ).setdefault(
  126. row["receipt_type"], {}
  127. )[row["user_id"]] = json.loads(row["data"])
  128. defer.returnValue([{
  129. "type": "m.receipt",
  130. "room_id": room_id,
  131. "content": content,
  132. }])
  133. @cachedList(cached_method_name="get_linearized_receipts_for_room",
  134. list_name="room_ids", num_args=3, inlineCallbacks=True)
  135. def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
  136. if not room_ids:
  137. defer.returnValue({})
  138. def f(txn):
  139. if from_key:
  140. sql = (
  141. "SELECT * FROM receipts_linearized WHERE"
  142. " room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
  143. ) % (
  144. ",".join(["?"] * len(room_ids))
  145. )
  146. args = list(room_ids)
  147. args.extend([from_key, to_key])
  148. txn.execute(sql, args)
  149. else:
  150. sql = (
  151. "SELECT * FROM receipts_linearized WHERE"
  152. " room_id IN (%s) AND stream_id <= ?"
  153. ) % (
  154. ",".join(["?"] * len(room_ids))
  155. )
  156. args = list(room_ids)
  157. args.append(to_key)
  158. txn.execute(sql, args)
  159. return self.cursor_to_dict(txn)
  160. txn_results = yield self.runInteraction(
  161. "_get_linearized_receipts_for_rooms", f
  162. )
  163. results = {}
  164. for row in txn_results:
  165. # We want a single event per room, since we want to batch the
  166. # receipts by room, event and type.
  167. room_event = results.setdefault(row["room_id"], {
  168. "type": "m.receipt",
  169. "room_id": row["room_id"],
  170. "content": {},
  171. })
  172. # The content is of the form:
  173. # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
  174. event_entry = room_event["content"].setdefault(row["event_id"], {})
  175. receipt_type = event_entry.setdefault(row["receipt_type"], {})
  176. receipt_type[row["user_id"]] = json.loads(row["data"])
  177. results = {
  178. room_id: [results[room_id]] if room_id in results else []
  179. for room_id in room_ids
  180. }
  181. defer.returnValue(results)
  182. def get_max_receipt_stream_id(self):
  183. return self._receipts_id_gen.get_current_token()
  184. def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
  185. user_id, event_id, data, stream_id):
  186. txn.call_after(
  187. self.get_receipts_for_room.invalidate, (room_id, receipt_type)
  188. )
  189. txn.call_after(
  190. self.get_receipts_for_user.invalidate, (user_id, receipt_type)
  191. )
  192. # FIXME: This shouldn't invalidate the whole cache
  193. txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
  194. txn.call_after(
  195. self._receipts_stream_cache.entity_has_changed,
  196. room_id, stream_id
  197. )
  198. txn.call_after(
  199. self.get_last_receipt_event_id_for_user.invalidate,
  200. (user_id, room_id, receipt_type)
  201. )
  202. # We don't want to clobber receipts for more recent events, so we
  203. # have to compare orderings of existing receipts
  204. sql = (
  205. "SELECT topological_ordering, stream_ordering, event_id FROM events"
  206. " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
  207. " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
  208. )
  209. txn.execute(sql, (room_id, receipt_type, user_id))
  210. results = txn.fetchall()
  211. if results:
  212. res = self._simple_select_one_txn(
  213. txn,
  214. table="events",
  215. retcols=["topological_ordering", "stream_ordering"],
  216. keyvalues={"event_id": event_id},
  217. )
  218. topological_ordering = int(res["topological_ordering"])
  219. stream_ordering = int(res["stream_ordering"])
  220. for to, so, _ in results:
  221. if int(to) > topological_ordering:
  222. return False
  223. elif int(to) == topological_ordering and int(so) >= stream_ordering:
  224. return False
  225. self._simple_delete_txn(
  226. txn,
  227. table="receipts_linearized",
  228. keyvalues={
  229. "room_id": room_id,
  230. "receipt_type": receipt_type,
  231. "user_id": user_id,
  232. }
  233. )
  234. self._simple_insert_txn(
  235. txn,
  236. table="receipts_linearized",
  237. values={
  238. "stream_id": stream_id,
  239. "room_id": room_id,
  240. "receipt_type": receipt_type,
  241. "user_id": user_id,
  242. "event_id": event_id,
  243. "data": json.dumps(data),
  244. }
  245. )
  246. return True
  247. @defer.inlineCallbacks
  248. def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
  249. """Insert a receipt, either from local client or remote server.
  250. Automatically does conversion between linearized and graph
  251. representations.
  252. """
  253. if not event_ids:
  254. return
  255. if len(event_ids) == 1:
  256. linearized_event_id = event_ids[0]
  257. else:
  258. # we need to points in graph -> linearized form.
  259. # TODO: Make this better.
  260. def graph_to_linear(txn):
  261. query = (
  262. "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
  263. " SELECT max(stream_ordering) WHERE event_id IN (%s)"
  264. ")"
  265. ) % (",".join(["?"] * len(event_ids)))
  266. txn.execute(query, [room_id] + event_ids)
  267. rows = txn.fetchall()
  268. if rows:
  269. return rows[0][0]
  270. else:
  271. raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
  272. linearized_event_id = yield self.runInteraction(
  273. "insert_receipt_conv", graph_to_linear
  274. )
  275. stream_id_manager = self._receipts_id_gen.get_next()
  276. with stream_id_manager as stream_id:
  277. have_persisted = yield self.runInteraction(
  278. "insert_linearized_receipt",
  279. self.insert_linearized_receipt_txn,
  280. room_id, receipt_type, user_id, linearized_event_id,
  281. data,
  282. stream_id=stream_id,
  283. )
  284. if not have_persisted:
  285. defer.returnValue(None)
  286. yield self.insert_graph_receipt(
  287. room_id, receipt_type, user_id, event_ids, data
  288. )
  289. max_persisted_id = self._stream_id_gen.get_current_token()
  290. defer.returnValue((stream_id, max_persisted_id))
  291. def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
  292. data):
  293. return self.runInteraction(
  294. "insert_graph_receipt",
  295. self.insert_graph_receipt_txn,
  296. room_id, receipt_type, user_id, event_ids, data
  297. )
  298. def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
  299. user_id, event_ids, data):
  300. txn.call_after(
  301. self.get_receipts_for_room.invalidate, (room_id, receipt_type)
  302. )
  303. txn.call_after(
  304. self.get_receipts_for_user.invalidate, (user_id, receipt_type)
  305. )
  306. # FIXME: This shouldn't invalidate the whole cache
  307. txn.call_after(self.get_linearized_receipts_for_room.invalidate_all)
  308. self._simple_delete_txn(
  309. txn,
  310. table="receipts_graph",
  311. keyvalues={
  312. "room_id": room_id,
  313. "receipt_type": receipt_type,
  314. "user_id": user_id,
  315. }
  316. )
  317. self._simple_insert_txn(
  318. txn,
  319. table="receipts_graph",
  320. values={
  321. "room_id": room_id,
  322. "receipt_type": receipt_type,
  323. "user_id": user_id,
  324. "event_ids": json.dumps(event_ids),
  325. "data": json.dumps(data),
  326. }
  327. )
  328. def get_all_updated_receipts(self, last_id, current_id, limit):
  329. def get_all_updated_receipts_txn(txn):
  330. sql = (
  331. "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
  332. " FROM receipts_linearized"
  333. " WHERE ? < stream_id AND stream_id <= ?"
  334. " ORDER BY stream_id ASC"
  335. " LIMIT ?"
  336. )
  337. txn.execute(sql, (last_id, current_id, limit))
  338. return txn.fetchall()
  339. return self.runInteraction(
  340. "get_all_updated_receipts", get_all_updated_receipts_txn
  341. )