deviceinbox.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from canonicaljson import json
  17. from twisted.internet import defer
  18. from synapse.logging.opentracing import log_kv, set_tag, trace
  19. from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
  20. from synapse.storage.database import Database
  21. from synapse.util.caches.expiringcache import ExpiringCache
  22. logger = logging.getLogger(__name__)
  23. class DeviceInboxWorkerStore(SQLBaseStore):
  24. def get_to_device_stream_token(self):
  25. return self._device_inbox_id_gen.get_current_token()
  26. def get_new_messages_for_device(
  27. self, user_id, device_id, last_stream_id, current_stream_id, limit=100
  28. ):
  29. """
  30. Args:
  31. user_id(str): The recipient user_id.
  32. device_id(str): The recipient device_id.
  33. current_stream_id(int): The current position of the to device
  34. message stream.
  35. Returns:
  36. Deferred ([dict], int): List of messages for the device and where
  37. in the stream the messages got to.
  38. """
  39. has_changed = self._device_inbox_stream_cache.has_entity_changed(
  40. user_id, last_stream_id
  41. )
  42. if not has_changed:
  43. return defer.succeed(([], current_stream_id))
  44. def get_new_messages_for_device_txn(txn):
  45. sql = (
  46. "SELECT stream_id, message_json FROM device_inbox"
  47. " WHERE user_id = ? AND device_id = ?"
  48. " AND ? < stream_id AND stream_id <= ?"
  49. " ORDER BY stream_id ASC"
  50. " LIMIT ?"
  51. )
  52. txn.execute(
  53. sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
  54. )
  55. messages = []
  56. for row in txn:
  57. stream_pos = row[0]
  58. messages.append(json.loads(row[1]))
  59. if len(messages) < limit:
  60. stream_pos = current_stream_id
  61. return messages, stream_pos
  62. return self.db.runInteraction(
  63. "get_new_messages_for_device", get_new_messages_for_device_txn
  64. )
  65. @trace
  66. @defer.inlineCallbacks
  67. def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
  68. """
  69. Args:
  70. user_id(str): The recipient user_id.
  71. device_id(str): The recipient device_id.
  72. up_to_stream_id(int): Where to delete messages up to.
  73. Returns:
  74. A deferred that resolves to the number of messages deleted.
  75. """
  76. # If we have cached the last stream id we've deleted up to, we can
  77. # check if there is likely to be anything that needs deleting
  78. last_deleted_stream_id = self._last_device_delete_cache.get(
  79. (user_id, device_id), None
  80. )
  81. set_tag("last_deleted_stream_id", last_deleted_stream_id)
  82. if last_deleted_stream_id:
  83. has_changed = self._device_inbox_stream_cache.has_entity_changed(
  84. user_id, last_deleted_stream_id
  85. )
  86. if not has_changed:
  87. log_kv({"message": "No changes in cache since last check"})
  88. return 0
  89. def delete_messages_for_device_txn(txn):
  90. sql = (
  91. "DELETE FROM device_inbox"
  92. " WHERE user_id = ? AND device_id = ?"
  93. " AND stream_id <= ?"
  94. )
  95. txn.execute(sql, (user_id, device_id, up_to_stream_id))
  96. return txn.rowcount
  97. count = yield self.db.runInteraction(
  98. "delete_messages_for_device", delete_messages_for_device_txn
  99. )
  100. log_kv(
  101. {"message": "deleted {} messages for device".format(count), "count": count}
  102. )
  103. # Update the cache, ensuring that we only ever increase the value
  104. last_deleted_stream_id = self._last_device_delete_cache.get(
  105. (user_id, device_id), 0
  106. )
  107. self._last_device_delete_cache[(user_id, device_id)] = max(
  108. last_deleted_stream_id, up_to_stream_id
  109. )
  110. return count
  111. @trace
  112. def get_new_device_msgs_for_remote(
  113. self, destination, last_stream_id, current_stream_id, limit
  114. ):
  115. """
  116. Args:
  117. destination(str): The name of the remote server.
  118. last_stream_id(int|long): The last position of the device message stream
  119. that the server sent up to.
  120. current_stream_id(int|long): The current position of the device
  121. message stream.
  122. Returns:
  123. Deferred ([dict], int|long): List of messages for the device and where
  124. in the stream the messages got to.
  125. """
  126. set_tag("destination", destination)
  127. set_tag("last_stream_id", last_stream_id)
  128. set_tag("current_stream_id", current_stream_id)
  129. set_tag("limit", limit)
  130. has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
  131. destination, last_stream_id
  132. )
  133. if not has_changed or last_stream_id == current_stream_id:
  134. log_kv({"message": "No new messages in stream"})
  135. return defer.succeed(([], current_stream_id))
  136. if limit <= 0:
  137. # This can happen if we run out of room for EDUs in the transaction.
  138. return defer.succeed(([], last_stream_id))
  139. @trace
  140. def get_new_messages_for_remote_destination_txn(txn):
  141. sql = (
  142. "SELECT stream_id, messages_json FROM device_federation_outbox"
  143. " WHERE destination = ?"
  144. " AND ? < stream_id AND stream_id <= ?"
  145. " ORDER BY stream_id ASC"
  146. " LIMIT ?"
  147. )
  148. txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
  149. messages = []
  150. for row in txn:
  151. stream_pos = row[0]
  152. messages.append(json.loads(row[1]))
  153. if len(messages) < limit:
  154. log_kv({"message": "Set stream position to current position"})
  155. stream_pos = current_stream_id
  156. return messages, stream_pos
  157. return self.db.runInteraction(
  158. "get_new_device_msgs_for_remote",
  159. get_new_messages_for_remote_destination_txn,
  160. )
  161. @trace
  162. def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
  163. """Used to delete messages when the remote destination acknowledges
  164. their receipt.
  165. Args:
  166. destination(str): The destination server_name
  167. up_to_stream_id(int): Where to delete messages up to.
  168. Returns:
  169. A deferred that resolves when the messages have been deleted.
  170. """
  171. def delete_messages_for_remote_destination_txn(txn):
  172. sql = (
  173. "DELETE FROM device_federation_outbox"
  174. " WHERE destination = ?"
  175. " AND stream_id <= ?"
  176. )
  177. txn.execute(sql, (destination, up_to_stream_id))
  178. return self.db.runInteraction(
  179. "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
  180. )
  181. class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
  182. DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
  183. def __init__(self, database: Database, db_conn, hs):
  184. super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
  185. self.db.updates.register_background_index_update(
  186. "device_inbox_stream_index",
  187. index_name="device_inbox_stream_id_user_id",
  188. table="device_inbox",
  189. columns=["stream_id", "user_id"],
  190. )
  191. self.db.updates.register_background_update_handler(
  192. self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
  193. )
  194. @defer.inlineCallbacks
  195. def _background_drop_index_device_inbox(self, progress, batch_size):
  196. def reindex_txn(conn):
  197. txn = conn.cursor()
  198. txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
  199. txn.close()
  200. yield self.db.runWithConnection(reindex_txn)
  201. yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
  202. return 1
  203. class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
  204. DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
  205. def __init__(self, database: Database, db_conn, hs):
  206. super(DeviceInboxStore, self).__init__(database, db_conn, hs)
  207. # Map of (user_id, device_id) to the last stream_id that has been
  208. # deleted up to. This is so that we can no op deletions.
  209. self._last_device_delete_cache = ExpiringCache(
  210. cache_name="last_device_delete_cache",
  211. clock=self._clock,
  212. max_len=10000,
  213. expiry_ms=30 * 60 * 1000,
  214. )
  215. @trace
  216. @defer.inlineCallbacks
  217. def add_messages_to_device_inbox(
  218. self, local_messages_by_user_then_device, remote_messages_by_destination
  219. ):
  220. """Used to send messages from this server.
  221. Args:
  222. sender_user_id(str): The ID of the user sending these messages.
  223. local_messages_by_user_and_device(dict):
  224. Dictionary of user_id to device_id to message.
  225. remote_messages_by_destination(dict):
  226. Dictionary of destination server_name to the EDU JSON to send.
  227. Returns:
  228. A deferred stream_id that resolves when the messages have been
  229. inserted.
  230. """
  231. def add_messages_txn(txn, now_ms, stream_id):
  232. # Add the local messages directly to the local inbox.
  233. self._add_messages_to_local_device_inbox_txn(
  234. txn, stream_id, local_messages_by_user_then_device
  235. )
  236. # Add the remote messages to the federation outbox.
  237. # We'll send them to a remote server when we next send a
  238. # federation transaction to that destination.
  239. sql = (
  240. "INSERT INTO device_federation_outbox"
  241. " (destination, stream_id, queued_ts, messages_json)"
  242. " VALUES (?,?,?,?)"
  243. )
  244. rows = []
  245. for destination, edu in remote_messages_by_destination.items():
  246. edu_json = json.dumps(edu)
  247. rows.append((destination, stream_id, now_ms, edu_json))
  248. txn.executemany(sql, rows)
  249. with self._device_inbox_id_gen.get_next() as stream_id:
  250. now_ms = self.clock.time_msec()
  251. yield self.db.runInteraction(
  252. "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
  253. )
  254. for user_id in local_messages_by_user_then_device.keys():
  255. self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
  256. for destination in remote_messages_by_destination.keys():
  257. self._device_federation_outbox_stream_cache.entity_has_changed(
  258. destination, stream_id
  259. )
  260. return self._device_inbox_id_gen.get_current_token()
  261. @defer.inlineCallbacks
  262. def add_messages_from_remote_to_device_inbox(
  263. self, origin, message_id, local_messages_by_user_then_device
  264. ):
  265. def add_messages_txn(txn, now_ms, stream_id):
  266. # Check if we've already inserted a matching message_id for that
  267. # origin. This can happen if the origin doesn't receive our
  268. # acknowledgement from the first time we received the message.
  269. already_inserted = self.db.simple_select_one_txn(
  270. txn,
  271. table="device_federation_inbox",
  272. keyvalues={"origin": origin, "message_id": message_id},
  273. retcols=("message_id",),
  274. allow_none=True,
  275. )
  276. if already_inserted is not None:
  277. return
  278. # Add an entry for this message_id so that we know we've processed
  279. # it.
  280. self.db.simple_insert_txn(
  281. txn,
  282. table="device_federation_inbox",
  283. values={
  284. "origin": origin,
  285. "message_id": message_id,
  286. "received_ts": now_ms,
  287. },
  288. )
  289. # Add the messages to the approriate local device inboxes so that
  290. # they'll be sent to the devices when they next sync.
  291. self._add_messages_to_local_device_inbox_txn(
  292. txn, stream_id, local_messages_by_user_then_device
  293. )
  294. with self._device_inbox_id_gen.get_next() as stream_id:
  295. now_ms = self.clock.time_msec()
  296. yield self.db.runInteraction(
  297. "add_messages_from_remote_to_device_inbox",
  298. add_messages_txn,
  299. now_ms,
  300. stream_id,
  301. )
  302. for user_id in local_messages_by_user_then_device.keys():
  303. self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
  304. return stream_id
  305. def _add_messages_to_local_device_inbox_txn(
  306. self, txn, stream_id, messages_by_user_then_device
  307. ):
  308. sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
  309. txn.execute(sql, (stream_id, stream_id))
  310. local_by_user_then_device = {}
  311. for user_id, messages_by_device in messages_by_user_then_device.items():
  312. messages_json_for_user = {}
  313. devices = list(messages_by_device.keys())
  314. if len(devices) == 1 and devices[0] == "*":
  315. # Handle wildcard device_ids.
  316. sql = "SELECT device_id FROM devices WHERE user_id = ?"
  317. txn.execute(sql, (user_id,))
  318. message_json = json.dumps(messages_by_device["*"])
  319. for row in txn:
  320. # Add the message for all devices for this user on this
  321. # server.
  322. device = row[0]
  323. messages_json_for_user[device] = message_json
  324. else:
  325. if not devices:
  326. continue
  327. clause, args = make_in_list_sql_clause(
  328. txn.database_engine, "device_id", devices
  329. )
  330. sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
  331. # TODO: Maybe this needs to be done in batches if there are
  332. # too many local devices for a given user.
  333. txn.execute(sql, [user_id] + list(args))
  334. for row in txn:
  335. # Only insert into the local inbox if the device exists on
  336. # this server
  337. device = row[0]
  338. message_json = json.dumps(messages_by_device[device])
  339. messages_json_for_user[device] = message_json
  340. if messages_json_for_user:
  341. local_by_user_then_device[user_id] = messages_json_for_user
  342. if not local_by_user_then_device:
  343. return
  344. sql = (
  345. "INSERT INTO device_inbox"
  346. " (user_id, device_id, stream_id, message_json)"
  347. " VALUES (?,?,?,?)"
  348. )
  349. rows = []
  350. for user_id, messages_by_device in local_by_user_then_device.items():
  351. for device_id, message_json in messages_by_device.items():
  352. rows.append((user_id, device_id, stream_id, message_json))
  353. txn.executemany(sql, rows)
  354. def get_all_new_device_messages(self, last_pos, current_pos, limit):
  355. """
  356. Args:
  357. last_pos(int):
  358. current_pos(int):
  359. limit(int):
  360. Returns:
  361. A deferred list of rows from the device inbox
  362. """
  363. if last_pos == current_pos:
  364. return defer.succeed([])
  365. def get_all_new_device_messages_txn(txn):
  366. # We limit like this as we might have multiple rows per stream_id, and
  367. # we want to make sure we always get all entries for any stream_id
  368. # we return.
  369. upper_pos = min(current_pos, last_pos + limit)
  370. sql = (
  371. "SELECT max(stream_id), user_id"
  372. " FROM device_inbox"
  373. " WHERE ? < stream_id AND stream_id <= ?"
  374. " GROUP BY user_id"
  375. )
  376. txn.execute(sql, (last_pos, upper_pos))
  377. rows = txn.fetchall()
  378. sql = (
  379. "SELECT max(stream_id), destination"
  380. " FROM device_federation_outbox"
  381. " WHERE ? < stream_id AND stream_id <= ?"
  382. " GROUP BY destination"
  383. )
  384. txn.execute(sql, (last_pos, upper_pos))
  385. rows.extend(txn)
  386. # Order by ascending stream ordering
  387. rows.sort()
  388. return rows
  389. return self.db.runInteraction(
  390. "get_all_new_device_messages", get_all_new_device_messages_txn
  391. )