deviceinbox.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. # Copyright 2016 OpenMarket Ltd
  2. # Copyright 2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import (
  17. TYPE_CHECKING,
  18. Collection,
  19. Dict,
  20. Iterable,
  21. List,
  22. Optional,
  23. Set,
  24. Tuple,
  25. cast,
  26. )
  27. from synapse.api.constants import EventContentFields
  28. from synapse.logging import issue9533_logger
  29. from synapse.logging.opentracing import (
  30. SynapseTags,
  31. log_kv,
  32. set_tag,
  33. start_active_span,
  34. trace,
  35. )
  36. from synapse.replication.tcp.streams import ToDeviceStream
  37. from synapse.storage._base import SQLBaseStore, db_to_json
  38. from synapse.storage.database import (
  39. DatabasePool,
  40. LoggingDatabaseConnection,
  41. LoggingTransaction,
  42. make_in_list_sql_clause,
  43. )
  44. from synapse.storage.engines import PostgresEngine
  45. from synapse.storage.util.id_generators import (
  46. AbstractStreamIdGenerator,
  47. MultiWriterIdGenerator,
  48. StreamIdGenerator,
  49. )
  50. from synapse.types import JsonDict
  51. from synapse.util import json_encoder
  52. from synapse.util.caches.expiringcache import ExpiringCache
  53. from synapse.util.caches.stream_change_cache import StreamChangeCache
  54. if TYPE_CHECKING:
  55. from synapse.server import HomeServer
  56. logger = logging.getLogger(__name__)
  57. class DeviceInboxWorkerStore(SQLBaseStore):
  58. def __init__(
  59. self,
  60. database: DatabasePool,
  61. db_conn: LoggingDatabaseConnection,
  62. hs: "HomeServer",
  63. ):
  64. super().__init__(database, db_conn, hs)
  65. self._instance_name = hs.get_instance_name()
  66. # Map of (user_id, device_id) to the last stream_id that has been
  67. # deleted up to. This is so that we can no op deletions.
  68. self._last_device_delete_cache: ExpiringCache[
  69. Tuple[str, Optional[str]], int
  70. ] = ExpiringCache(
  71. cache_name="last_device_delete_cache",
  72. clock=self._clock,
  73. max_len=10000,
  74. expiry_ms=30 * 60 * 1000,
  75. )
  76. if isinstance(database.engine, PostgresEngine):
  77. self._can_write_to_device = (
  78. self._instance_name in hs.config.worker.writers.to_device
  79. )
  80. self._device_inbox_id_gen: AbstractStreamIdGenerator = (
  81. MultiWriterIdGenerator(
  82. db_conn=db_conn,
  83. db=database,
  84. notifier=hs.get_replication_notifier(),
  85. stream_name="to_device",
  86. instance_name=self._instance_name,
  87. tables=[("device_inbox", "instance_name", "stream_id")],
  88. sequence_name="device_inbox_sequence",
  89. writers=hs.config.worker.writers.to_device,
  90. )
  91. )
  92. else:
  93. self._can_write_to_device = True
  94. self._device_inbox_id_gen = StreamIdGenerator(
  95. db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
  96. )
  97. max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
  98. device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
  99. db_conn,
  100. "device_inbox",
  101. entity_column="user_id",
  102. stream_column="stream_id",
  103. max_value=max_device_inbox_id,
  104. limit=1000,
  105. )
  106. self._device_inbox_stream_cache = StreamChangeCache(
  107. "DeviceInboxStreamChangeCache",
  108. min_device_inbox_id,
  109. prefilled_cache=device_inbox_prefill,
  110. )
  111. # The federation outbox and the local device inbox uses the same
  112. # stream_id generator.
  113. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
  114. db_conn,
  115. "device_federation_outbox",
  116. entity_column="destination",
  117. stream_column="stream_id",
  118. max_value=max_device_inbox_id,
  119. limit=1000,
  120. )
  121. self._device_federation_outbox_stream_cache = StreamChangeCache(
  122. "DeviceFederationOutboxStreamChangeCache",
  123. min_device_outbox_id,
  124. prefilled_cache=device_outbox_prefill,
  125. )
  126. def process_replication_rows(
  127. self,
  128. stream_name: str,
  129. instance_name: str,
  130. token: int,
  131. rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
  132. ) -> None:
  133. if stream_name == ToDeviceStream.NAME:
  134. # If replication is happening than postgres must be being used.
  135. assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
  136. self._device_inbox_id_gen.advance(instance_name, token)
  137. for row in rows:
  138. if row.entity.startswith("@"):
  139. self._device_inbox_stream_cache.entity_has_changed(
  140. row.entity, token
  141. )
  142. else:
  143. self._device_federation_outbox_stream_cache.entity_has_changed(
  144. row.entity, token
  145. )
  146. return super().process_replication_rows(stream_name, instance_name, token, rows)
  147. def process_replication_position(
  148. self, stream_name: str, instance_name: str, token: int
  149. ) -> None:
  150. if stream_name == ToDeviceStream.NAME:
  151. self._device_inbox_id_gen.advance(instance_name, token)
  152. super().process_replication_position(stream_name, instance_name, token)
  153. def get_to_device_stream_token(self) -> int:
  154. return self._device_inbox_id_gen.get_current_token()
  155. async def get_messages_for_user_devices(
  156. self,
  157. user_ids: Collection[str],
  158. from_stream_id: int,
  159. to_stream_id: int,
  160. ) -> Dict[Tuple[str, str], List[JsonDict]]:
  161. """
  162. Retrieve to-device messages for a given set of users.
  163. Only to-device messages with stream ids between the given boundaries
  164. (from < X <= to) are returned.
  165. Args:
  166. user_ids: The users to retrieve to-device messages for.
  167. from_stream_id: The lower boundary of stream id to filter with (exclusive).
  168. to_stream_id: The upper boundary of stream id to filter with (inclusive).
  169. Returns:
  170. A dictionary of (user id, device id) -> list of to-device messages.
  171. """
  172. # We expect the stream ID returned by _get_device_messages to always
  173. # be to_stream_id. So, no need to return it from this function.
  174. (
  175. user_id_device_id_to_messages,
  176. last_processed_stream_id,
  177. ) = await self._get_device_messages(
  178. user_ids=user_ids,
  179. from_stream_id=from_stream_id,
  180. to_stream_id=to_stream_id,
  181. )
  182. assert (
  183. last_processed_stream_id == to_stream_id
  184. ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
  185. return user_id_device_id_to_messages
  186. async def get_messages_for_device(
  187. self,
  188. user_id: str,
  189. device_id: str,
  190. from_stream_id: int,
  191. to_stream_id: int,
  192. limit: int = 100,
  193. ) -> Tuple[List[JsonDict], int]:
  194. """
  195. Retrieve to-device messages for a single user device.
  196. Only to-device messages with stream ids between the given boundaries
  197. (from < X <= to) are returned.
  198. Args:
  199. user_id: The ID of the user to retrieve messages for.
  200. device_id: The ID of the device to retrieve to-device messages for.
  201. from_stream_id: The lower boundary of stream id to filter with (exclusive).
  202. to_stream_id: The upper boundary of stream id to filter with (inclusive).
  203. limit: A limit on the number of to-device messages returned.
  204. Returns:
  205. A tuple containing:
  206. * A list of to-device messages within the given stream id range intended for
  207. the given user / device combo.
  208. * The last-processed stream ID. Subsequent calls of this function with the
  209. same device should pass this value as 'from_stream_id'.
  210. """
  211. (
  212. user_id_device_id_to_messages,
  213. last_processed_stream_id,
  214. ) = await self._get_device_messages(
  215. user_ids=[user_id],
  216. device_id=device_id,
  217. from_stream_id=from_stream_id,
  218. to_stream_id=to_stream_id,
  219. limit=limit,
  220. )
  221. if not user_id_device_id_to_messages:
  222. # There were no messages!
  223. return [], to_stream_id
  224. # Extract the messages, no need to return the user and device ID again
  225. to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
  226. return to_device_messages, last_processed_stream_id
  227. async def _get_device_messages(
  228. self,
  229. user_ids: Collection[str],
  230. from_stream_id: int,
  231. to_stream_id: int,
  232. device_id: Optional[str] = None,
  233. limit: Optional[int] = None,
  234. ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
  235. """
  236. Retrieve pending to-device messages for a collection of user devices.
  237. Only to-device messages with stream ids between the given boundaries
  238. (from < X <= to) are returned.
  239. Note that a stream ID can be shared by multiple copies of the same message with
  240. different recipient devices. Stream IDs are only unique in the context of a single
  241. user ID / device ID pair. Thus, applying a limit (of messages to return) when working
  242. with a sliding window of stream IDs is only possible when querying messages of a
  243. single user device.
  244. Finally, note that device IDs are not unique across users.
  245. Args:
  246. user_ids: The user IDs to filter device messages by.
  247. from_stream_id: The lower boundary of stream id to filter with (exclusive).
  248. to_stream_id: The upper boundary of stream id to filter with (inclusive).
  249. device_id: A device ID to query to-device messages for. If not provided, to-device
  250. messages from all device IDs for the given user IDs will be queried. May not be
  251. provided if `user_ids` contains more than one entry.
  252. limit: The maximum number of to-device messages to return. Can only be used when
  253. passing a single user ID / device ID tuple.
  254. Returns:
  255. A tuple containing:
  256. * A dict of (user_id, device_id) -> list of to-device messages
  257. * The last-processed stream ID. If this is less than `to_stream_id`, then
  258. there may be more messages to retrieve. If `limit` is not set, then this
  259. is always equal to 'to_stream_id'.
  260. """
  261. if not user_ids:
  262. logger.warning("No users provided upon querying for device IDs")
  263. return {}, to_stream_id
  264. # Prevent a query for one user's device also retrieving another user's device with
  265. # the same device ID (device IDs are not unique across users).
  266. if len(user_ids) > 1 and device_id is not None:
  267. raise AssertionError(
  268. "Programming error: 'device_id' cannot be supplied to "
  269. "_get_device_messages when >1 user_id has been provided"
  270. )
  271. # A limit can only be applied when querying for a single user ID / device ID tuple.
  272. # See the docstring of this function for more details.
  273. if limit is not None and device_id is None:
  274. raise AssertionError(
  275. "Programming error: _get_device_messages was passed 'limit' "
  276. "without a specific user_id/device_id"
  277. )
  278. user_ids_to_query: Set[str] = set()
  279. device_ids_to_query: Set[str] = set()
  280. # Note that a device ID could be an empty str
  281. if device_id is not None:
  282. # If a device ID was passed, use it to filter results.
  283. # Otherwise, device IDs will be derived from the given collection of user IDs.
  284. device_ids_to_query.add(device_id)
  285. # Determine which users have devices with pending messages
  286. for user_id in user_ids:
  287. if self._device_inbox_stream_cache.has_entity_changed(
  288. user_id, from_stream_id
  289. ):
  290. # This user has new messages sent to them. Query messages for them
  291. user_ids_to_query.add(user_id)
  292. if not user_ids_to_query:
  293. return {}, to_stream_id
  294. def get_device_messages_txn(
  295. txn: LoggingTransaction,
  296. ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
  297. # Build a query to select messages from any of the given devices that
  298. # are between the given stream id bounds.
  299. # If a list of device IDs was not provided, retrieve all devices IDs
  300. # for the given users. We explicitly do not query hidden devices, as
  301. # hidden devices should not receive to-device messages.
  302. # Note that this is more efficient than just dropping `device_id` from the query,
  303. # since device_inbox has an index on `(user_id, device_id, stream_id)`
  304. if not device_ids_to_query:
  305. user_device_dicts = self.db_pool.simple_select_many_txn(
  306. txn,
  307. table="devices",
  308. column="user_id",
  309. iterable=user_ids_to_query,
  310. keyvalues={"user_id": user_id, "hidden": False},
  311. retcols=("device_id",),
  312. )
  313. device_ids_to_query.update(
  314. {row["device_id"] for row in user_device_dicts}
  315. )
  316. if not device_ids_to_query:
  317. # We've ended up with no devices to query.
  318. return {}, to_stream_id
  319. # We include both user IDs and device IDs in this query, as we have an index
  320. # (device_inbox_user_stream_id) for them.
  321. user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
  322. self.database_engine, "user_id", user_ids_to_query
  323. )
  324. (
  325. device_id_many_clause_sql,
  326. device_id_many_clause_args,
  327. ) = make_in_list_sql_clause(
  328. self.database_engine, "device_id", device_ids_to_query
  329. )
  330. sql = f"""
  331. SELECT stream_id, user_id, device_id, message_json FROM device_inbox
  332. WHERE {user_id_many_clause_sql}
  333. AND {device_id_many_clause_sql}
  334. AND ? < stream_id AND stream_id <= ?
  335. ORDER BY stream_id ASC
  336. """
  337. sql_args = (
  338. *user_id_many_clause_args,
  339. *device_id_many_clause_args,
  340. from_stream_id,
  341. to_stream_id,
  342. )
  343. # If a limit was provided, limit the data retrieved from the database
  344. if limit is not None:
  345. sql += "LIMIT ?"
  346. sql_args += (limit,)
  347. txn.execute(sql, sql_args)
  348. # Create and fill a dictionary of (user ID, device ID) -> list of messages
  349. # intended for each device.
  350. last_processed_stream_pos = to_stream_id
  351. recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
  352. rowcount = 0
  353. for row in txn:
  354. rowcount += 1
  355. last_processed_stream_pos = row[0]
  356. recipient_user_id = row[1]
  357. recipient_device_id = row[2]
  358. message_dict = db_to_json(row[3])
  359. # Store the device details
  360. recipient_device_to_messages.setdefault(
  361. (recipient_user_id, recipient_device_id), []
  362. ).append(message_dict)
  363. # start a new span for each message, so that we can tag each separately
  364. with start_active_span("get_to_device_message"):
  365. set_tag(SynapseTags.TO_DEVICE_TYPE, message_dict["type"])
  366. set_tag(SynapseTags.TO_DEVICE_SENDER, message_dict["sender"])
  367. set_tag(SynapseTags.TO_DEVICE_RECIPIENT, recipient_user_id)
  368. set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, recipient_device_id)
  369. set_tag(
  370. SynapseTags.TO_DEVICE_MSGID,
  371. message_dict["content"].get(EventContentFields.TO_DEVICE_MSGID),
  372. )
  373. if limit is not None and rowcount == limit:
  374. # We ended up bumping up against the message limit. There may be more messages
  375. # to retrieve. Return what we have, as well as the last stream position that
  376. # was processed.
  377. #
  378. # The caller is expected to set this as the lower (exclusive) bound
  379. # for the next query of this device.
  380. return recipient_device_to_messages, last_processed_stream_pos
  381. # The limit was not reached, thus we know that recipient_device_to_messages
  382. # contains all to-device messages for the given device and stream id range.
  383. #
  384. # We return to_stream_id, which the caller should then provide as the lower
  385. # (exclusive) bound on the next query of this device.
  386. return recipient_device_to_messages, to_stream_id
  387. return await self.db_pool.runInteraction(
  388. "get_device_messages", get_device_messages_txn
  389. )
  390. @trace
  391. async def delete_messages_for_device(
  392. self, user_id: str, device_id: Optional[str], up_to_stream_id: int
  393. ) -> int:
  394. """
  395. Args:
  396. user_id: The recipient user_id.
  397. device_id: The recipient device_id.
  398. up_to_stream_id: Where to delete messages up to.
  399. Returns:
  400. The number of messages deleted.
  401. """
  402. # If we have cached the last stream id we've deleted up to, we can
  403. # check if there is likely to be anything that needs deleting
  404. last_deleted_stream_id = self._last_device_delete_cache.get(
  405. (user_id, device_id), None
  406. )
  407. set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
  408. if last_deleted_stream_id:
  409. has_changed = self._device_inbox_stream_cache.has_entity_changed(
  410. user_id, last_deleted_stream_id
  411. )
  412. if not has_changed:
  413. log_kv({"message": "No changes in cache since last check"})
  414. return 0
  415. def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
  416. sql = (
  417. "DELETE FROM device_inbox"
  418. " WHERE user_id = ? AND device_id = ?"
  419. " AND stream_id <= ?"
  420. )
  421. txn.execute(sql, (user_id, device_id, up_to_stream_id))
  422. return txn.rowcount
  423. count = await self.db_pool.runInteraction(
  424. "delete_messages_for_device", delete_messages_for_device_txn
  425. )
  426. log_kv({"message": f"deleted {count} messages for device", "count": count})
  427. # Update the cache, ensuring that we only ever increase the value
  428. updated_last_deleted_stream_id = self._last_device_delete_cache.get(
  429. (user_id, device_id), 0
  430. )
  431. self._last_device_delete_cache[(user_id, device_id)] = max(
  432. updated_last_deleted_stream_id, up_to_stream_id
  433. )
  434. return count
  435. @trace
  436. async def get_new_device_msgs_for_remote(
  437. self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
  438. ) -> Tuple[List[JsonDict], int]:
  439. """
  440. Args:
  441. destination: The name of the remote server.
  442. last_stream_id: The last position of the device message stream
  443. that the server sent up to.
  444. current_stream_id: The current position of the device message stream.
  445. Returns:
  446. A list of messages for the device and where in the stream the messages got to.
  447. """
  448. set_tag("destination", destination)
  449. set_tag("last_stream_id", last_stream_id)
  450. set_tag("current_stream_id", current_stream_id)
  451. set_tag("limit", limit)
  452. has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
  453. destination, last_stream_id
  454. )
  455. if not has_changed or last_stream_id == current_stream_id:
  456. log_kv({"message": "No new messages in stream"})
  457. return [], current_stream_id
  458. if limit <= 0:
  459. # This can happen if we run out of room for EDUs in the transaction.
  460. return [], last_stream_id
  461. @trace
  462. def get_new_messages_for_remote_destination_txn(
  463. txn: LoggingTransaction,
  464. ) -> Tuple[List[JsonDict], int]:
  465. sql = (
  466. "SELECT stream_id, messages_json FROM device_federation_outbox"
  467. " WHERE destination = ?"
  468. " AND ? < stream_id AND stream_id <= ?"
  469. " ORDER BY stream_id ASC"
  470. " LIMIT ?"
  471. )
  472. txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
  473. messages = []
  474. stream_pos = current_stream_id
  475. for row in txn:
  476. stream_pos = row[0]
  477. messages.append(db_to_json(row[1]))
  478. # If the limit was not reached we know that there's no more data for this
  479. # user/device pair up to current_stream_id.
  480. if len(messages) < limit:
  481. log_kv({"message": "Set stream position to current position"})
  482. stream_pos = current_stream_id
  483. return messages, stream_pos
  484. return await self.db_pool.runInteraction(
  485. "get_new_device_msgs_for_remote",
  486. get_new_messages_for_remote_destination_txn,
  487. )
  488. @trace
  489. async def delete_device_msgs_for_remote(
  490. self, destination: str, up_to_stream_id: int
  491. ) -> None:
  492. """Used to delete messages when the remote destination acknowledges
  493. their receipt.
  494. Args:
  495. destination: The destination server_name
  496. up_to_stream_id: Where to delete messages up to.
  497. """
  498. def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
  499. sql = (
  500. "DELETE FROM device_federation_outbox"
  501. " WHERE destination = ?"
  502. " AND stream_id <= ?"
  503. )
  504. txn.execute(sql, (destination, up_to_stream_id))
  505. await self.db_pool.runInteraction(
  506. "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
  507. )
  508. async def get_all_new_device_messages(
  509. self, instance_name: str, last_id: int, current_id: int, limit: int
  510. ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
  511. """Get updates for to device replication stream.
  512. Args:
  513. instance_name: The writer we want to fetch updates from. Unused
  514. here since there is only ever one writer.
  515. last_id: The token to fetch updates from. Exclusive.
  516. current_id: The token to fetch updates up to. Inclusive.
  517. limit: The requested limit for the number of rows to return. The
  518. function may return more or fewer rows.
  519. Returns:
  520. A tuple consisting of: the updates, a token to use to fetch
  521. subsequent updates, and whether we returned fewer rows than exists
  522. between the requested tokens due to the limit.
  523. The token returned can be used in a subsequent call to this
  524. function to get further updatees.
  525. The updates are a list of 2-tuples of stream ID and the row data
  526. """
  527. if last_id == current_id:
  528. return [], current_id, False
  529. def get_all_new_device_messages_txn(
  530. txn: LoggingTransaction,
  531. ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
  532. # We limit like this as we might have multiple rows per stream_id, and
  533. # we want to make sure we always get all entries for any stream_id
  534. # we return.
  535. upto_token = min(current_id, last_id + limit)
  536. sql = (
  537. "SELECT max(stream_id), user_id"
  538. " FROM device_inbox"
  539. " WHERE ? < stream_id AND stream_id <= ?"
  540. " GROUP BY user_id"
  541. )
  542. txn.execute(sql, (last_id, upto_token))
  543. updates = [(row[0], row[1:]) for row in txn]
  544. sql = (
  545. "SELECT max(stream_id), destination"
  546. " FROM device_federation_outbox"
  547. " WHERE ? < stream_id AND stream_id <= ?"
  548. " GROUP BY destination"
  549. )
  550. txn.execute(sql, (last_id, upto_token))
  551. updates.extend((row[0], row[1:]) for row in txn)
  552. # Order by ascending stream ordering
  553. updates.sort()
  554. return updates, upto_token, upto_token < current_id
  555. return await self.db_pool.runInteraction(
  556. "get_all_new_device_messages", get_all_new_device_messages_txn
  557. )
  558. @trace
  559. async def add_messages_to_device_inbox(
  560. self,
  561. local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
  562. remote_messages_by_destination: Dict[str, JsonDict],
  563. ) -> int:
  564. """Used to send messages from this server.
  565. Args:
  566. local_messages_by_user_then_device:
  567. Dictionary of recipient user_id to recipient device_id to message.
  568. remote_messages_by_destination:
  569. Dictionary of destination server_name to the EDU JSON to send.
  570. Returns:
  571. The new stream_id.
  572. """
  573. assert self._can_write_to_device
  574. def add_messages_txn(
  575. txn: LoggingTransaction, now_ms: int, stream_id: int
  576. ) -> None:
  577. # Add the local messages directly to the local inbox.
  578. self._add_messages_to_local_device_inbox_txn(
  579. txn, stream_id, local_messages_by_user_then_device
  580. )
  581. # Add the remote messages to the federation outbox.
  582. # We'll send them to a remote server when we next send a
  583. # federation transaction to that destination.
  584. self.db_pool.simple_insert_many_txn(
  585. txn,
  586. table="device_federation_outbox",
  587. keys=(
  588. "destination",
  589. "stream_id",
  590. "queued_ts",
  591. "messages_json",
  592. "instance_name",
  593. ),
  594. values=[
  595. (
  596. destination,
  597. stream_id,
  598. now_ms,
  599. json_encoder.encode(edu),
  600. self._instance_name,
  601. )
  602. for destination, edu in remote_messages_by_destination.items()
  603. ],
  604. )
  605. for destination, edu in remote_messages_by_destination.items():
  606. if issue9533_logger.isEnabledFor(logging.DEBUG):
  607. issue9533_logger.debug(
  608. "Queued outgoing to-device messages with "
  609. "stream_id %i, EDU message_id %s, type %s for %s: %s",
  610. stream_id,
  611. edu["message_id"],
  612. edu["type"],
  613. destination,
  614. [
  615. f"{user_id}/{device_id} (msgid "
  616. f"{msg.get(EventContentFields.TO_DEVICE_MSGID)})"
  617. for (user_id, messages_by_device) in edu["messages"].items()
  618. for (device_id, msg) in messages_by_device.items()
  619. ],
  620. )
  621. for user_id, messages_by_device in edu["messages"].items():
  622. for device_id, msg in messages_by_device.items():
  623. with start_active_span("store_outgoing_to_device_message"):
  624. set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["sender"])
  625. set_tag(SynapseTags.TO_DEVICE_EDU_ID, edu["message_id"])
  626. set_tag(SynapseTags.TO_DEVICE_TYPE, edu["type"])
  627. set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
  628. set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
  629. set_tag(
  630. SynapseTags.TO_DEVICE_MSGID,
  631. msg.get(EventContentFields.TO_DEVICE_MSGID),
  632. )
  633. async with self._device_inbox_id_gen.get_next() as stream_id:
  634. now_ms = self._clock.time_msec()
  635. await self.db_pool.runInteraction(
  636. "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
  637. )
  638. for user_id in local_messages_by_user_then_device.keys():
  639. self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
  640. for destination in remote_messages_by_destination.keys():
  641. self._device_federation_outbox_stream_cache.entity_has_changed(
  642. destination, stream_id
  643. )
  644. return self._device_inbox_id_gen.get_current_token()
  645. async def add_messages_from_remote_to_device_inbox(
  646. self,
  647. origin: str,
  648. message_id: str,
  649. local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
  650. ) -> int:
  651. assert self._can_write_to_device
  652. def add_messages_txn(
  653. txn: LoggingTransaction, now_ms: int, stream_id: int
  654. ) -> None:
  655. # Check if we've already inserted a matching message_id for that
  656. # origin. This can happen if the origin doesn't receive our
  657. # acknowledgement from the first time we received the message.
  658. already_inserted = self.db_pool.simple_select_one_txn(
  659. txn,
  660. table="device_federation_inbox",
  661. keyvalues={"origin": origin, "message_id": message_id},
  662. retcols=("message_id",),
  663. allow_none=True,
  664. )
  665. if already_inserted is not None:
  666. return
  667. # Add an entry for this message_id so that we know we've processed
  668. # it.
  669. self.db_pool.simple_insert_txn(
  670. txn,
  671. table="device_federation_inbox",
  672. values={
  673. "origin": origin,
  674. "message_id": message_id,
  675. "received_ts": now_ms,
  676. },
  677. )
  678. # Add the messages to the appropriate local device inboxes so that
  679. # they'll be sent to the devices when they next sync.
  680. self._add_messages_to_local_device_inbox_txn(
  681. txn, stream_id, local_messages_by_user_then_device
  682. )
  683. async with self._device_inbox_id_gen.get_next() as stream_id:
  684. now_ms = self._clock.time_msec()
  685. await self.db_pool.runInteraction(
  686. "add_messages_from_remote_to_device_inbox",
  687. add_messages_txn,
  688. now_ms,
  689. stream_id,
  690. )
  691. for user_id in local_messages_by_user_then_device.keys():
  692. self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
  693. return stream_id
  694. def _add_messages_to_local_device_inbox_txn(
  695. self,
  696. txn: LoggingTransaction,
  697. stream_id: int,
  698. messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
  699. ) -> None:
  700. assert self._can_write_to_device
  701. local_by_user_then_device = {}
  702. for user_id, messages_by_device in messages_by_user_then_device.items():
  703. messages_json_for_user = {}
  704. devices = list(messages_by_device.keys())
  705. if len(devices) == 1 and devices[0] == "*":
  706. # Handle wildcard device_ids.
  707. # We exclude hidden devices (such as cross-signing keys) here as they are
  708. # not expected to receive to-device messages.
  709. devices = self.db_pool.simple_select_onecol_txn(
  710. txn,
  711. table="devices",
  712. keyvalues={"user_id": user_id, "hidden": False},
  713. retcol="device_id",
  714. )
  715. message_json = json_encoder.encode(messages_by_device["*"])
  716. for device_id in devices:
  717. # Add the message for all devices for this user on this
  718. # server.
  719. messages_json_for_user[device_id] = message_json
  720. else:
  721. if not devices:
  722. continue
  723. # We exclude hidden devices (such as cross-signing keys) here as they are
  724. # not expected to receive to-device messages.
  725. rows = self.db_pool.simple_select_many_txn(
  726. txn,
  727. table="devices",
  728. keyvalues={"user_id": user_id, "hidden": False},
  729. column="device_id",
  730. iterable=devices,
  731. retcols=("device_id",),
  732. )
  733. for row in rows:
  734. # Only insert into the local inbox if the device exists on
  735. # this server
  736. device_id = row["device_id"]
  737. with start_active_span("serialise_to_device_message"):
  738. msg = messages_by_device[device_id]
  739. set_tag(SynapseTags.TO_DEVICE_TYPE, msg["type"])
  740. set_tag(SynapseTags.TO_DEVICE_SENDER, msg["sender"])
  741. set_tag(SynapseTags.TO_DEVICE_RECIPIENT, user_id)
  742. set_tag(SynapseTags.TO_DEVICE_RECIPIENT_DEVICE, device_id)
  743. set_tag(
  744. SynapseTags.TO_DEVICE_MSGID,
  745. msg["content"].get(EventContentFields.TO_DEVICE_MSGID),
  746. )
  747. message_json = json_encoder.encode(msg)
  748. messages_json_for_user[device_id] = message_json
  749. if messages_json_for_user:
  750. local_by_user_then_device[user_id] = messages_json_for_user
  751. if not local_by_user_then_device:
  752. return
  753. self.db_pool.simple_insert_many_txn(
  754. txn,
  755. table="device_inbox",
  756. keys=("user_id", "device_id", "stream_id", "message_json", "instance_name"),
  757. values=[
  758. (user_id, device_id, stream_id, message_json, self._instance_name)
  759. for user_id, messages_by_device in local_by_user_then_device.items()
  760. for device_id, message_json in messages_by_device.items()
  761. ],
  762. )
  763. if issue9533_logger.isEnabledFor(logging.DEBUG):
  764. issue9533_logger.debug(
  765. "Stored to-device messages with stream_id %i: %s",
  766. stream_id,
  767. [
  768. f"{user_id}/{device_id} (msgid "
  769. f"{msg['content'].get(EventContentFields.TO_DEVICE_MSGID)})"
  770. for (
  771. user_id,
  772. messages_by_device,
  773. ) in messages_by_user_then_device.items()
  774. for (device_id, msg) in messages_by_device.items()
  775. ],
  776. )
  777. class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
  778. DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
  779. REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
  780. def __init__(
  781. self,
  782. database: DatabasePool,
  783. db_conn: LoggingDatabaseConnection,
  784. hs: "HomeServer",
  785. ):
  786. super().__init__(database, db_conn, hs)
  787. self.db_pool.updates.register_background_index_update(
  788. "device_inbox_stream_index",
  789. index_name="device_inbox_stream_id_user_id",
  790. table="device_inbox",
  791. columns=["stream_id", "user_id"],
  792. )
  793. self.db_pool.updates.register_background_update_handler(
  794. self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
  795. )
  796. self.db_pool.updates.register_background_update_handler(
  797. self.REMOVE_DEAD_DEVICES_FROM_INBOX,
  798. self._remove_dead_devices_from_device_inbox,
  799. )
  800. async def _background_drop_index_device_inbox(
  801. self, progress: JsonDict, batch_size: int
  802. ) -> int:
  803. def reindex_txn(conn: LoggingDatabaseConnection) -> None:
  804. txn = conn.cursor()
  805. txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
  806. txn.close()
  807. await self.db_pool.runWithConnection(reindex_txn)
  808. await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
  809. return 1
  810. async def _remove_dead_devices_from_device_inbox(
  811. self,
  812. progress: JsonDict,
  813. batch_size: int,
  814. ) -> int:
  815. """A background update to remove devices that were either deleted or hidden from
  816. the device_inbox table.
  817. Args:
  818. progress: The update's progress dict.
  819. batch_size: The batch size for this update.
  820. Returns:
  821. The number of rows deleted.
  822. """
  823. def _remove_dead_devices_from_device_inbox_txn(
  824. txn: LoggingTransaction,
  825. ) -> Tuple[int, bool]:
  826. if "max_stream_id" in progress:
  827. max_stream_id = progress["max_stream_id"]
  828. else:
  829. txn.execute("SELECT max(stream_id) FROM device_inbox")
  830. # There's a type mismatch here between how we want to type the row and
  831. # what fetchone says it returns, but we silence it because we know that
  832. # res can't be None.
  833. res = cast(Tuple[Optional[int]], txn.fetchone())
  834. if res[0] is None:
  835. # this can only happen if the `device_inbox` table is empty, in which
  836. # case we have no work to do.
  837. return 0, True
  838. else:
  839. max_stream_id = res[0]
  840. start = progress.get("stream_id", 0)
  841. stop = start + batch_size
  842. # delete rows in `device_inbox` which do *not* correspond to a known,
  843. # unhidden device.
  844. sql = """
  845. DELETE FROM device_inbox
  846. WHERE
  847. stream_id >= ? AND stream_id < ?
  848. AND NOT EXISTS (
  849. SELECT * FROM devices d
  850. WHERE
  851. d.device_id=device_inbox.device_id
  852. AND d.user_id=device_inbox.user_id
  853. AND NOT hidden
  854. )
  855. """
  856. txn.execute(sql, (start, stop))
  857. self.db_pool.updates._background_update_progress_txn(
  858. txn,
  859. self.REMOVE_DEAD_DEVICES_FROM_INBOX,
  860. {
  861. "stream_id": stop,
  862. "max_stream_id": max_stream_id,
  863. },
  864. )
  865. return stop > max_stream_id
  866. finished = await self.db_pool.runInteraction(
  867. "_remove_devices_from_device_inbox_txn",
  868. _remove_dead_devices_from_device_inbox_txn,
  869. )
  870. if finished:
  871. await self.db_pool.updates._end_background_update(
  872. self.REMOVE_DEAD_DEVICES_FROM_INBOX,
  873. )
  874. return batch_size
  875. class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
  876. pass