devices.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253
  1. # Copyright 2016 OpenMarket Ltd
  2. # Copyright 2019 New Vector Ltd
  3. # Copyright 2019,2020 The Matrix.org Foundation C.I.C.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. from typing import (
  18. TYPE_CHECKING,
  19. Any,
  20. Collection,
  21. Dict,
  22. Iterable,
  23. List,
  24. Mapping,
  25. Optional,
  26. Set,
  27. Tuple,
  28. cast,
  29. )
  30. from typing_extensions import Literal
  31. from synapse.api.constants import EduTypes
  32. from synapse.api.errors import Codes, StoreError
  33. from synapse.logging.opentracing import (
  34. get_active_span_text_map,
  35. set_tag,
  36. trace,
  37. whitelisted_homeserver,
  38. )
  39. from synapse.metrics.background_process_metrics import wrap_as_background_process
  40. from synapse.replication.tcp.streams._base import DeviceListsStream
  41. from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
  42. from synapse.storage.database import (
  43. DatabasePool,
  44. LoggingDatabaseConnection,
  45. LoggingTransaction,
  46. make_tuple_comparison_clause,
  47. )
  48. from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
  49. from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
  50. from synapse.storage.types import Cursor
  51. from synapse.storage.util.id_generators import (
  52. AbstractStreamIdGenerator,
  53. StreamIdGenerator,
  54. )
  55. from synapse.types import JsonDict, StrCollection, get_verify_key_from_cross_signing_key
  56. from synapse.util import json_decoder, json_encoder
  57. from synapse.util.caches.descriptors import cached, cachedList
  58. from synapse.util.caches.lrucache import LruCache
  59. from synapse.util.caches.stream_change_cache import (
  60. AllEntitiesChangedResult,
  61. StreamChangeCache,
  62. )
  63. from synapse.util.cancellation import cancellable
  64. from synapse.util.iterutils import batch_iter
  65. from synapse.util.stringutils import shortstr
  66. if TYPE_CHECKING:
  67. from synapse.server import HomeServer
  68. logger = logging.getLogger(__name__)
  69. issue_8631_logger = logging.getLogger("synapse.8631_debug")
  70. DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
  71. "drop_device_list_streams_non_unique_indexes"
  72. )
  73. BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
  74. class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
  75. def __init__(
  76. self,
  77. database: DatabasePool,
  78. db_conn: LoggingDatabaseConnection,
  79. hs: "HomeServer",
  80. ):
  81. super().__init__(database, db_conn, hs)
  82. # In the worker store this is an ID tracker which we overwrite in the non-worker
  83. # class below that is used on the main process.
  84. self._device_list_id_gen = StreamIdGenerator(
  85. db_conn,
  86. hs.get_replication_notifier(),
  87. "device_lists_stream",
  88. "stream_id",
  89. extra_tables=[
  90. ("user_signature_stream", "stream_id"),
  91. ("device_lists_outbound_pokes", "stream_id"),
  92. ("device_lists_changes_in_room", "stream_id"),
  93. ("device_lists_remote_pending", "stream_id"),
  94. ("device_lists_changes_converted_stream_position", "stream_id"),
  95. ],
  96. is_writer=hs.config.worker.worker_app is None,
  97. )
  98. # Type-ignore: _device_list_id_gen is mixed in from either DataStore (as a
  99. # StreamIdGenerator) or SlavedDataStore (as a SlavedIdTracker).
  100. device_list_max = self._device_list_id_gen.get_current_token()
  101. device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
  102. db_conn,
  103. "device_lists_stream",
  104. entity_column="user_id",
  105. stream_column="stream_id",
  106. max_value=device_list_max,
  107. limit=10000,
  108. )
  109. self._device_list_stream_cache = StreamChangeCache(
  110. "DeviceListStreamChangeCache",
  111. min_device_list_id,
  112. prefilled_cache=device_list_prefill,
  113. )
  114. (
  115. user_signature_stream_prefill,
  116. user_signature_stream_list_id,
  117. ) = self.db_pool.get_cache_dict(
  118. db_conn,
  119. "user_signature_stream",
  120. entity_column="from_user_id",
  121. stream_column="stream_id",
  122. max_value=device_list_max,
  123. limit=1000,
  124. )
  125. self._user_signature_stream_cache = StreamChangeCache(
  126. "UserSignatureStreamChangeCache",
  127. user_signature_stream_list_id,
  128. prefilled_cache=user_signature_stream_prefill,
  129. )
  130. (
  131. device_list_federation_prefill,
  132. device_list_federation_list_id,
  133. ) = self.db_pool.get_cache_dict(
  134. db_conn,
  135. "device_lists_outbound_pokes",
  136. entity_column="destination",
  137. stream_column="stream_id",
  138. max_value=device_list_max,
  139. limit=10000,
  140. )
  141. self._device_list_federation_stream_cache = StreamChangeCache(
  142. "DeviceListFederationStreamChangeCache",
  143. device_list_federation_list_id,
  144. prefilled_cache=device_list_federation_prefill,
  145. )
  146. if hs.config.worker.run_background_tasks:
  147. self._clock.looping_call(
  148. self._prune_old_outbound_device_pokes, 60 * 60 * 1000
  149. )
  150. def process_replication_rows(
  151. self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
  152. ) -> None:
  153. if stream_name == DeviceListsStream.NAME:
  154. self._invalidate_caches_for_devices(token, rows)
  155. return super().process_replication_rows(stream_name, instance_name, token, rows)
  156. def process_replication_position(
  157. self, stream_name: str, instance_name: str, token: int
  158. ) -> None:
  159. if stream_name == DeviceListsStream.NAME:
  160. self._device_list_id_gen.advance(instance_name, token)
  161. super().process_replication_position(stream_name, instance_name, token)
  162. def _invalidate_caches_for_devices(
  163. self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
  164. ) -> None:
  165. for row in rows:
  166. if row.is_signature:
  167. self._user_signature_stream_cache.entity_has_changed(row.entity, token)
  168. continue
  169. # The entities are either user IDs (starting with '@') whose devices
  170. # have changed, or remote servers that we need to tell about
  171. # changes.
  172. if row.entity.startswith("@"):
  173. self._device_list_stream_cache.entity_has_changed(row.entity, token)
  174. self.get_cached_devices_for_user.invalidate((row.entity,))
  175. self._get_cached_user_device.invalidate((row.entity,))
  176. self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
  177. else:
  178. self._device_list_federation_stream_cache.entity_has_changed(
  179. row.entity, token
  180. )
  181. def get_device_stream_token(self) -> int:
  182. return self._device_list_id_gen.get_current_token()
  183. async def count_devices_by_users(
  184. self, user_ids: Optional[Collection[str]] = None
  185. ) -> int:
  186. """Retrieve number of all devices of given users.
  187. Only returns number of devices that are not marked as hidden.
  188. Args:
  189. user_ids: The IDs of the users which owns devices
  190. Returns:
  191. Number of devices of this users.
  192. """
  193. def count_devices_by_users_txn(
  194. txn: LoggingTransaction, user_ids: Collection[str]
  195. ) -> int:
  196. sql = """
  197. SELECT count(*)
  198. FROM devices
  199. WHERE
  200. hidden = '0' AND
  201. """
  202. clause, args = make_in_list_sql_clause(
  203. txn.database_engine, "user_id", user_ids
  204. )
  205. txn.execute(sql + clause, args)
  206. return cast(Tuple[int], txn.fetchone())[0]
  207. if not user_ids:
  208. return 0
  209. return await self.db_pool.runInteraction(
  210. "count_devices_by_users", count_devices_by_users_txn, user_ids
  211. )
  212. async def get_device(
  213. self, user_id: str, device_id: str
  214. ) -> Optional[Dict[str, Any]]:
  215. """Retrieve a device. Only returns devices that are not marked as
  216. hidden.
  217. Args:
  218. user_id: The ID of the user which owns the device
  219. device_id: The ID of the device to retrieve
  220. Returns:
  221. A dict containing the device information, or `None` if the device does not
  222. exist.
  223. """
  224. return await self.db_pool.simple_select_one(
  225. table="devices",
  226. keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
  227. retcols=("user_id", "device_id", "display_name"),
  228. desc="get_device",
  229. allow_none=True,
  230. )
  231. async def get_device_opt(
  232. self, user_id: str, device_id: str
  233. ) -> Optional[Dict[str, Any]]:
  234. """Retrieve a device. Only returns devices that are not marked as
  235. hidden.
  236. Args:
  237. user_id: The ID of the user which owns the device
  238. device_id: The ID of the device to retrieve
  239. Returns:
  240. A dict containing the device information, or None if the device does not exist.
  241. """
  242. return await self.db_pool.simple_select_one(
  243. table="devices",
  244. keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
  245. retcols=("user_id", "device_id", "display_name"),
  246. desc="get_device",
  247. allow_none=True,
  248. )
  249. async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
  250. """Retrieve all of a user's registered devices. Only returns devices
  251. that are not marked as hidden.
  252. Args:
  253. user_id:
  254. Returns:
  255. A mapping from device_id to a dict containing "device_id", "user_id"
  256. and "display_name" for each device.
  257. """
  258. devices = await self.db_pool.simple_select_list(
  259. table="devices",
  260. keyvalues={"user_id": user_id, "hidden": False},
  261. retcols=("user_id", "device_id", "display_name"),
  262. desc="get_devices_by_user",
  263. )
  264. return {d["device_id"]: d for d in devices}
  265. async def get_devices_by_auth_provider_session_id(
  266. self, auth_provider_id: str, auth_provider_session_id: str
  267. ) -> List[Dict[str, Any]]:
  268. """Retrieve the list of devices associated with a SSO IdP session ID.
  269. Args:
  270. auth_provider_id: The SSO IdP ID as defined in the server config
  271. auth_provider_session_id: The session ID within the IdP
  272. Returns:
  273. A list of dicts containing the device_id and the user_id of each device
  274. """
  275. return await self.db_pool.simple_select_list(
  276. table="device_auth_providers",
  277. keyvalues={
  278. "auth_provider_id": auth_provider_id,
  279. "auth_provider_session_id": auth_provider_session_id,
  280. },
  281. retcols=("user_id", "device_id"),
  282. desc="get_devices_by_auth_provider_session_id",
  283. )
  284. @trace
  285. async def get_device_updates_by_remote(
  286. self, destination: str, from_stream_id: int, limit: int
  287. ) -> Tuple[int, List[Tuple[str, JsonDict]]]:
  288. """Get a stream of device updates to send to the given remote server.
  289. Args:
  290. destination: The host the device updates are intended for
  291. from_stream_id: The minimum stream_id to filter updates by, exclusive
  292. limit: Maximum number of device updates to return
  293. Returns:
  294. - The current stream id (i.e. the stream id of the last update included
  295. in the response); and
  296. - The list of updates, where each update is a pair of EDU type and
  297. EDU contents.
  298. """
  299. now_stream_id = self.get_device_stream_token()
  300. has_changed = self._device_list_federation_stream_cache.has_entity_changed(
  301. destination, int(from_stream_id)
  302. )
  303. if not has_changed:
  304. # debugging for https://github.com/matrix-org/synapse/issues/14251
  305. issue_8631_logger.debug(
  306. "%s: no change between %i and %i",
  307. destination,
  308. from_stream_id,
  309. now_stream_id,
  310. )
  311. return now_stream_id, []
  312. updates = await self.db_pool.runInteraction(
  313. "get_device_updates_by_remote",
  314. self._get_device_updates_by_remote_txn,
  315. destination,
  316. from_stream_id,
  317. now_stream_id,
  318. limit,
  319. )
  320. # We need to ensure `updates` doesn't grow too big.
  321. # Currently: `len(updates) <= limit`.
  322. # Return an empty list if there are no updates
  323. if not updates:
  324. return now_stream_id, []
  325. if issue_8631_logger.isEnabledFor(logging.DEBUG):
  326. data = {(user, device): stream_id for user, device, stream_id, _ in updates}
  327. issue_8631_logger.debug(
  328. "device updates need to be sent to %s: %s", destination, data
  329. )
  330. # get the cross-signing keys of the users in the list, so that we can
  331. # determine which of the device changes were cross-signing keys
  332. users = {r[0] for r in updates}
  333. master_key_by_user = {}
  334. self_signing_key_by_user = {}
  335. for user in users:
  336. cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
  337. if cross_signing_key:
  338. key_id, verify_key = get_verify_key_from_cross_signing_key(
  339. cross_signing_key
  340. )
  341. # verify_key is a VerifyKey from signedjson, which uses
  342. # .version to denote the portion of the key ID after the
  343. # algorithm and colon, which is the device ID
  344. master_key_by_user[user] = {
  345. "key_info": cross_signing_key,
  346. "device_id": verify_key.version,
  347. }
  348. cross_signing_key = await self.get_e2e_cross_signing_key(
  349. user, "self_signing"
  350. )
  351. if cross_signing_key:
  352. key_id, verify_key = get_verify_key_from_cross_signing_key(
  353. cross_signing_key
  354. )
  355. self_signing_key_by_user[user] = {
  356. "key_info": cross_signing_key,
  357. "device_id": verify_key.version,
  358. }
  359. # Perform the equivalent of a GROUP BY
  360. #
  361. # Iterate through the updates list and copy non-duplicate
  362. # (user_id, device_id) entries into a map, with the value being
  363. # the max stream_id across each set of duplicate entries
  364. #
  365. # maps (user_id, device_id) -> (stream_id, opentracing_context)
  366. #
  367. # opentracing_context contains the opentracing metadata for the request
  368. # that created the poke
  369. #
  370. # The most recent request's opentracing_context is used as the
  371. # context which created the Edu.
  372. # This is the stream ID that we will return for the consumer to resume
  373. # following this stream later.
  374. last_processed_stream_id = from_stream_id
  375. # A map of (user ID, device ID) to (stream ID, context).
  376. query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]] = {}
  377. cross_signing_keys_by_user: Dict[str, Dict[str, object]] = {}
  378. for user_id, device_id, update_stream_id, update_context in updates:
  379. # Calculate the remaining length budget.
  380. # Note that, for now, each entry in `cross_signing_keys_by_user`
  381. # gives rise to two device updates in the result, so those cost twice
  382. # as much (and are the whole reason we need to separately calculate
  383. # the budget; we know len(updates) <= limit otherwise!)
  384. # N.B. len() on dicts is cheap since they store their size.
  385. remaining_length_budget = limit - (
  386. len(query_map) + 2 * len(cross_signing_keys_by_user)
  387. )
  388. assert remaining_length_budget >= 0
  389. is_master_key_update = (
  390. user_id in master_key_by_user
  391. and device_id == master_key_by_user[user_id]["device_id"]
  392. )
  393. is_self_signing_key_update = (
  394. user_id in self_signing_key_by_user
  395. and device_id == self_signing_key_by_user[user_id]["device_id"]
  396. )
  397. is_cross_signing_key_update = (
  398. is_master_key_update or is_self_signing_key_update
  399. )
  400. if (
  401. is_cross_signing_key_update
  402. and user_id not in cross_signing_keys_by_user
  403. ):
  404. # This will give rise to 2 device updates.
  405. # If we don't have the budget, stop here!
  406. if remaining_length_budget < 2:
  407. break
  408. if is_master_key_update:
  409. result = cross_signing_keys_by_user.setdefault(user_id, {})
  410. result["master_key"] = master_key_by_user[user_id]["key_info"]
  411. elif is_self_signing_key_update:
  412. result = cross_signing_keys_by_user.setdefault(user_id, {})
  413. result["self_signing_key"] = self_signing_key_by_user[user_id][
  414. "key_info"
  415. ]
  416. else:
  417. key = (user_id, device_id)
  418. if key not in query_map and remaining_length_budget < 1:
  419. # We don't have space for a new entry
  420. break
  421. previous_update_stream_id, _ = query_map.get(key, (0, None))
  422. if update_stream_id > previous_update_stream_id:
  423. # FIXME If this overwrites an older update, this discards the
  424. # previous OpenTracing context.
  425. # It might make it harder to track down issues using OpenTracing.
  426. # If there's a good reason why it doesn't matter, a comment here
  427. # about that would not hurt.
  428. query_map[key] = (update_stream_id, update_context)
  429. # As this update has been added to the response, advance the stream
  430. # position.
  431. last_processed_stream_id = update_stream_id
  432. # In the worst case scenario, each update is for a distinct user and is
  433. # added either to the query_map or to cross_signing_keys_by_user,
  434. # but not both:
  435. # len(query_map) + len(cross_signing_keys_by_user) <= len(updates) here,
  436. # so len(query_map) + len(cross_signing_keys_by_user) <= limit.
  437. results = await self._get_device_update_edus_by_remote(
  438. destination, from_stream_id, query_map
  439. )
  440. # len(results) <= len(query_map) here,
  441. # so len(results) + len(cross_signing_keys_by_user) <= limit.
  442. # Add the updated cross-signing keys to the results list
  443. for user_id, result in cross_signing_keys_by_user.items():
  444. result["user_id"] = user_id
  445. results.append((EduTypes.SIGNING_KEY_UPDATE, result))
  446. # also send the unstable version
  447. # FIXME: remove this when enough servers have upgraded
  448. # and remove the length budgeting above.
  449. results.append(("org.matrix.signing_key_update", result))
  450. if issue_8631_logger.isEnabledFor(logging.DEBUG):
  451. for user_id, edu in results:
  452. issue_8631_logger.debug(
  453. "device update to %s for %s from %s to %s: %s",
  454. destination,
  455. user_id,
  456. from_stream_id,
  457. last_processed_stream_id,
  458. edu,
  459. )
  460. return last_processed_stream_id, results
  461. def _get_device_updates_by_remote_txn(
  462. self,
  463. txn: LoggingTransaction,
  464. destination: str,
  465. from_stream_id: int,
  466. now_stream_id: int,
  467. limit: int,
  468. ) -> List[Tuple[str, str, int, Optional[str]]]:
  469. """Return device update information for a given remote destination
  470. Args:
  471. txn: The transaction to execute
  472. destination: The host the device updates are intended for
  473. from_stream_id: The minimum stream_id to filter updates by, exclusive
  474. now_stream_id: The maximum stream_id to filter updates by, inclusive
  475. limit: Maximum number of device updates to return
  476. Returns:
  477. List of device update tuples:
  478. - user_id
  479. - device_id
  480. - stream_id
  481. - opentracing_context
  482. """
  483. # get the list of device updates that need to be sent
  484. sql = """
  485. SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
  486. WHERE destination = ? AND ? < stream_id AND stream_id <= ?
  487. ORDER BY stream_id
  488. LIMIT ?
  489. """
  490. txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
  491. return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
  492. async def _get_device_update_edus_by_remote(
  493. self,
  494. destination: str,
  495. from_stream_id: int,
  496. query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
  497. ) -> List[Tuple[str, dict]]:
  498. """Returns a list of device update EDUs as well as E2EE keys
  499. Args:
  500. destination: The host the device updates are intended for
  501. from_stream_id: The minimum stream_id to filter updates by, exclusive
  502. query_map: Dictionary mapping (user_id, device_id) to
  503. (update stream_id, the relevant json-encoded opentracing context)
  504. Returns:
  505. List of objects representing a device update EDU.
  506. Postconditions:
  507. The returned list has a length not exceeding that of the query_map:
  508. len(result) <= len(query_map)
  509. """
  510. devices = (
  511. await self.get_e2e_device_keys_and_signatures(
  512. # Because these are (user_id, device_id) tuples with all
  513. # device_ids not being None, the returned list's length will not
  514. # exceed that of query_map.
  515. query_map.keys(),
  516. include_all_devices=True,
  517. include_deleted_devices=True,
  518. )
  519. if query_map
  520. else {}
  521. )
  522. results = []
  523. for user_id, user_devices in devices.items():
  524. # The prev_id for the first row is always the last row before
  525. # `from_stream_id`
  526. prev_id = await self._get_last_device_update_for_remote_user(
  527. destination, user_id, from_stream_id
  528. )
  529. # make sure we go through the devices in stream order
  530. device_ids = sorted(
  531. user_devices.keys(),
  532. key=lambda i: query_map[(user_id, i)][0],
  533. )
  534. for device_id in device_ids:
  535. device = user_devices[device_id]
  536. stream_id, opentracing_context = query_map[(user_id, device_id)]
  537. result = {
  538. "user_id": user_id,
  539. "device_id": device_id,
  540. "prev_id": [prev_id] if prev_id else [],
  541. "stream_id": stream_id,
  542. }
  543. if opentracing_context != "{}":
  544. result["org.matrix.opentracing_context"] = opentracing_context
  545. prev_id = stream_id
  546. if device is not None:
  547. keys = device.keys
  548. if keys:
  549. result["keys"] = keys
  550. device_display_name = None
  551. if (
  552. self.hs.config.federation.allow_device_name_lookup_over_federation
  553. ):
  554. device_display_name = device.display_name
  555. if device_display_name:
  556. result["device_display_name"] = device_display_name
  557. else:
  558. result["deleted"] = True
  559. results.append((EduTypes.DEVICE_LIST_UPDATE, result))
  560. return results
  561. async def _get_last_device_update_for_remote_user(
  562. self, destination: str, user_id: str, from_stream_id: int
  563. ) -> int:
  564. def f(txn: LoggingTransaction) -> int:
  565. prev_sent_id_sql = """
  566. SELECT coalesce(max(stream_id), 0) as stream_id
  567. FROM device_lists_outbound_last_success
  568. WHERE destination = ? AND user_id = ? AND stream_id <= ?
  569. """
  570. txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
  571. rows = txn.fetchall()
  572. return rows[0][0]
  573. return await self.db_pool.runInteraction(
  574. "get_last_device_update_for_remote_user", f
  575. )
  576. async def mark_as_sent_devices_by_remote(
  577. self, destination: str, stream_id: int
  578. ) -> None:
  579. """Mark that updates have successfully been sent to the destination."""
  580. await self.db_pool.runInteraction(
  581. "mark_as_sent_devices_by_remote",
  582. self._mark_as_sent_devices_by_remote_txn,
  583. destination,
  584. stream_id,
  585. )
  586. def _mark_as_sent_devices_by_remote_txn(
  587. self, txn: LoggingTransaction, destination: str, stream_id: int
  588. ) -> None:
  589. # We update the device_lists_outbound_last_success with the successfully
  590. # poked users.
  591. sql = """
  592. SELECT user_id, coalesce(max(o.stream_id), 0)
  593. FROM device_lists_outbound_pokes as o
  594. WHERE destination = ? AND o.stream_id <= ?
  595. GROUP BY user_id
  596. """
  597. txn.execute(sql, (destination, stream_id))
  598. rows = txn.fetchall()
  599. self.db_pool.simple_upsert_many_txn(
  600. txn=txn,
  601. table="device_lists_outbound_last_success",
  602. key_names=("destination", "user_id"),
  603. key_values=[(destination, user_id) for user_id, _ in rows],
  604. value_names=("stream_id",),
  605. value_values=((stream_id,) for _, stream_id in rows),
  606. )
  607. # Delete all sent outbound pokes
  608. sql = """
  609. DELETE FROM device_lists_outbound_pokes
  610. WHERE destination = ? AND stream_id <= ?
  611. """
  612. txn.execute(sql, (destination, stream_id))
  613. async def add_user_signature_change_to_streams(
  614. self, from_user_id: str, user_ids: List[str]
  615. ) -> int:
  616. """Persist that a user has made new signatures
  617. Args:
  618. from_user_id: the user who made the signatures
  619. user_ids: the users who were signed
  620. Returns:
  621. The new stream ID.
  622. """
  623. async with self._device_list_id_gen.get_next() as stream_id:
  624. await self.db_pool.runInteraction(
  625. "add_user_sig_change_to_streams",
  626. self._add_user_signature_change_txn,
  627. from_user_id,
  628. user_ids,
  629. stream_id,
  630. )
  631. return stream_id
  632. def _add_user_signature_change_txn(
  633. self,
  634. txn: LoggingTransaction,
  635. from_user_id: str,
  636. user_ids: List[str],
  637. stream_id: int,
  638. ) -> None:
  639. txn.call_after(
  640. self._user_signature_stream_cache.entity_has_changed,
  641. from_user_id,
  642. stream_id,
  643. )
  644. self.db_pool.simple_insert_txn(
  645. txn,
  646. "user_signature_stream",
  647. values={
  648. "stream_id": stream_id,
  649. "from_user_id": from_user_id,
  650. "user_ids": json_encoder.encode(user_ids),
  651. },
  652. )
  653. @trace
  654. @cancellable
  655. async def get_user_devices_from_cache(
  656. self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
  657. ) -> Tuple[Set[str], Dict[str, Mapping[str, JsonDict]]]:
  658. """Get the devices (and keys if any) for remote users from the cache.
  659. Args:
  660. user_ids: users which should have all device IDs returned
  661. user_and_device_ids: List of (user_id, device_ids)
  662. Returns:
  663. A tuple of (user_ids_not_in_cache, results_map), where
  664. user_ids_not_in_cache is a set of user_ids and results_map is a
  665. mapping of user_id -> device_id -> device_info.
  666. """
  667. unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
  668. user_map = await self.get_device_list_last_stream_id_for_remotes(
  669. list(unique_user_ids)
  670. )
  671. # We go and check if any of the users need to have their device lists
  672. # resynced. If they do then we remove them from the cached list.
  673. users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
  674. unique_user_ids
  675. )
  676. user_ids_in_cache = {
  677. user_id for user_id, stream_id in user_map.items() if stream_id
  678. } - users_needing_resync
  679. user_ids_not_in_cache = unique_user_ids - user_ids_in_cache
  680. # First fetch all the users which all devices are to be returned.
  681. results: Dict[str, Mapping[str, JsonDict]] = {}
  682. for user_id in user_ids:
  683. if user_id in user_ids_in_cache:
  684. results[user_id] = await self.get_cached_devices_for_user(user_id)
  685. # Then fetch all device-specific requests, but skip users we've already
  686. # fetched all devices for.
  687. device_specific_results: Dict[str, Dict[str, JsonDict]] = {}
  688. for user_id, device_id in user_and_device_ids:
  689. if user_id in user_ids_in_cache and user_id not in user_ids:
  690. device = await self._get_cached_user_device(user_id, device_id)
  691. device_specific_results.setdefault(user_id, {})[device_id] = device
  692. results.update(device_specific_results)
  693. set_tag("in_cache", str(results))
  694. set_tag("not_in_cache", str(user_ids_not_in_cache))
  695. return user_ids_not_in_cache, results
  696. @cached(num_args=2, tree=True)
  697. async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
  698. content = await self.db_pool.simple_select_one_onecol(
  699. table="device_lists_remote_cache",
  700. keyvalues={"user_id": user_id, "device_id": device_id},
  701. retcol="content",
  702. desc="_get_cached_user_device",
  703. )
  704. return db_to_json(content)
  705. @cached()
  706. async def get_cached_devices_for_user(self, user_id: str) -> Mapping[str, JsonDict]:
  707. devices = await self.db_pool.simple_select_list(
  708. table="device_lists_remote_cache",
  709. keyvalues={"user_id": user_id},
  710. retcols=("device_id", "content"),
  711. desc="get_cached_devices_for_user",
  712. )
  713. return {
  714. device["device_id"]: db_to_json(device["content"]) for device in devices
  715. }
  716. def get_cached_device_list_changes(
  717. self,
  718. from_key: int,
  719. ) -> AllEntitiesChangedResult:
  720. """Get set of users whose devices have changed since `from_key`, or None
  721. if that information is not in our cache.
  722. """
  723. return self._device_list_stream_cache.get_all_entities_changed(from_key)
  724. @cancellable
  725. async def get_all_devices_changed(
  726. self,
  727. from_key: int,
  728. to_key: int,
  729. ) -> Set[str]:
  730. """Get all users whose devices have changed in the given range.
  731. Args:
  732. from_key: The minimum device lists stream token to query device list
  733. changes for, exclusive.
  734. to_key: The maximum device lists stream token to query device list
  735. changes for, inclusive.
  736. Returns:
  737. The set of user_ids whose devices have changed since `from_key`
  738. (exclusive) until `to_key` (inclusive).
  739. """
  740. result = self._device_list_stream_cache.get_all_entities_changed(from_key)
  741. if result.hit:
  742. # We know which users might have changed devices.
  743. if not result.entities:
  744. # If no users then we can return early.
  745. return set()
  746. # Otherwise we need to filter down the list
  747. return await self.get_users_whose_devices_changed(
  748. from_key, result.entities, to_key
  749. )
  750. # If the cache didn't tell us anything, we just need to query the full
  751. # range.
  752. sql = """
  753. SELECT DISTINCT user_id FROM device_lists_stream
  754. WHERE ? < stream_id AND stream_id <= ?
  755. """
  756. rows = await self.db_pool.execute(
  757. "get_all_devices_changed",
  758. None,
  759. sql,
  760. from_key,
  761. to_key,
  762. )
  763. return {u for u, in rows}
  764. @cancellable
  765. async def get_users_whose_devices_changed(
  766. self,
  767. from_key: int,
  768. user_ids: Collection[str],
  769. to_key: Optional[int] = None,
  770. ) -> Set[str]:
  771. """Get set of users whose devices have changed since `from_key` that
  772. are in the given list of user_ids.
  773. Args:
  774. from_key: The minimum device lists stream token to query device list changes for,
  775. exclusive.
  776. user_ids: If provided, only check if these users have changed their device lists.
  777. Otherwise changes from all users are returned.
  778. to_key: The maximum device lists stream token to query device list changes for,
  779. inclusive.
  780. Returns:
  781. The set of user_ids whose devices have changed since `from_key` (exclusive)
  782. until `to_key` (inclusive).
  783. """
  784. # Get set of users who *may* have changed. Users not in the returned
  785. # list have definitely not changed.
  786. user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
  787. user_ids, from_key
  788. )
  789. # If an empty set was returned, there's nothing to do.
  790. if not user_ids_to_check:
  791. return set()
  792. if to_key is None:
  793. to_key = self._device_list_id_gen.get_current_token()
  794. def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
  795. sql = """
  796. SELECT DISTINCT user_id FROM device_lists_stream
  797. WHERE ? < stream_id AND stream_id <= ? AND %s
  798. """
  799. changes: Set[str] = set()
  800. # Query device changes with a batch of users at a time
  801. for chunk in batch_iter(user_ids_to_check, 100):
  802. clause, args = make_in_list_sql_clause(
  803. txn.database_engine, "user_id", chunk
  804. )
  805. txn.execute(sql % (clause,), [from_key, to_key] + args)
  806. changes.update(user_id for user_id, in txn)
  807. return changes
  808. return await self.db_pool.runInteraction(
  809. "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
  810. )
  811. async def get_users_whose_signatures_changed(
  812. self, user_id: str, from_key: int
  813. ) -> Set[str]:
  814. """Get the users who have new cross-signing signatures made by `user_id` since
  815. `from_key`.
  816. Args:
  817. user_id: the user who made the signatures
  818. from_key: The device lists stream token
  819. Returns:
  820. A set of user IDs with updated signatures.
  821. """
  822. if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
  823. sql = """
  824. SELECT DISTINCT user_ids FROM user_signature_stream
  825. WHERE from_user_id = ? AND stream_id > ?
  826. """
  827. rows = await self.db_pool.execute(
  828. "get_users_whose_signatures_changed", None, sql, user_id, from_key
  829. )
  830. return {user for row in rows for user in db_to_json(row[0])}
  831. else:
  832. return set()
  833. async def get_all_device_list_changes_for_remotes(
  834. self, instance_name: str, last_id: int, current_id: int, limit: int
  835. ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
  836. """Get updates for device lists replication stream.
  837. Args:
  838. instance_name: The writer we want to fetch updates from. Unused
  839. here since there is only ever one writer.
  840. last_id: The token to fetch updates from. Exclusive.
  841. current_id: The token to fetch updates up to. Inclusive.
  842. limit: The requested limit for the number of rows to return. The
  843. function may return more or fewer rows.
  844. Returns:
  845. A tuple consisting of: the updates, a token to use to fetch
  846. subsequent updates, and whether we returned fewer rows than exists
  847. between the requested tokens due to the limit.
  848. The token returned can be used in a subsequent call to this
  849. function to get further updates.
  850. The updates are a list of 2-tuples of stream ID and the row data
  851. """
  852. if last_id == current_id:
  853. return [], current_id, False
  854. def _get_all_device_list_changes_for_remotes(
  855. txn: Cursor,
  856. ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
  857. # This query Does The Right Thing where it'll correctly apply the
  858. # bounds to the inner queries.
  859. sql = """
  860. SELECT stream_id, entity FROM (
  861. SELECT stream_id, user_id AS entity FROM device_lists_stream
  862. UNION ALL
  863. SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
  864. ) AS e
  865. WHERE ? < stream_id AND stream_id <= ?
  866. ORDER BY stream_id ASC
  867. LIMIT ?
  868. """
  869. txn.execute(sql, (last_id, current_id, limit))
  870. updates = [(row[0], row[1:]) for row in txn]
  871. limited = False
  872. upto_token = current_id
  873. if len(updates) >= limit:
  874. upto_token = updates[-1][0]
  875. limited = True
  876. return updates, upto_token, limited
  877. return await self.db_pool.runInteraction(
  878. "get_all_device_list_changes_for_remotes",
  879. _get_all_device_list_changes_for_remotes,
  880. )
  881. @cached(max_entries=10000)
  882. async def get_device_list_last_stream_id_for_remote(
  883. self, user_id: str
  884. ) -> Optional[str]:
  885. """Get the last stream_id we got for a user. May be None if we haven't
  886. got any information for them.
  887. """
  888. return await self.db_pool.simple_select_one_onecol(
  889. table="device_lists_remote_extremeties",
  890. keyvalues={"user_id": user_id},
  891. retcol="stream_id",
  892. desc="get_device_list_last_stream_id_for_remote",
  893. allow_none=True,
  894. )
  895. @cachedList(
  896. cached_method_name="get_device_list_last_stream_id_for_remote",
  897. list_name="user_ids",
  898. )
  899. async def get_device_list_last_stream_id_for_remotes(
  900. self, user_ids: Iterable[str]
  901. ) -> Dict[str, Optional[str]]:
  902. rows = await self.db_pool.simple_select_many_batch(
  903. table="device_lists_remote_extremeties",
  904. column="user_id",
  905. iterable=user_ids,
  906. retcols=("user_id", "stream_id"),
  907. desc="get_device_list_last_stream_id_for_remotes",
  908. )
  909. results: Dict[str, Optional[str]] = {user_id: None for user_id in user_ids}
  910. results.update({row["user_id"]: row["stream_id"] for row in rows})
  911. return results
  912. async def get_user_ids_requiring_device_list_resync(
  913. self,
  914. user_ids: Optional[Collection[str]] = None,
  915. ) -> Set[str]:
  916. """Given a list of remote users return the list of users that we
  917. should resync the device lists for. If None is given instead of a list,
  918. return every user that we should resync the device lists for.
  919. Returns:
  920. The IDs of users whose device lists need resync.
  921. """
  922. if user_ids:
  923. rows = await self.db_pool.simple_select_many_batch(
  924. table="device_lists_remote_resync",
  925. column="user_id",
  926. iterable=user_ids,
  927. retcols=("user_id",),
  928. desc="get_user_ids_requiring_device_list_resync_with_iterable",
  929. )
  930. else:
  931. rows = await self.db_pool.simple_select_list(
  932. table="device_lists_remote_resync",
  933. keyvalues=None,
  934. retcols=("user_id",),
  935. desc="get_user_ids_requiring_device_list_resync",
  936. )
  937. return {row["user_id"] for row in rows}
  938. async def mark_remote_users_device_caches_as_stale(
  939. self, user_ids: StrCollection
  940. ) -> None:
  941. """Records that the server has reason to believe the cache of the devices
  942. for the remote users is out of date.
  943. """
  944. def _mark_remote_users_device_caches_as_stale_txn(
  945. txn: LoggingTransaction,
  946. ) -> None:
  947. # TODO add insertion_values support to simple_upsert_many and use
  948. # that!
  949. for user_id in user_ids:
  950. self.db_pool.simple_upsert_txn(
  951. txn,
  952. table="device_lists_remote_resync",
  953. keyvalues={"user_id": user_id},
  954. values={},
  955. insertion_values={"added_ts": self._clock.time_msec()},
  956. )
  957. await self.db_pool.runInteraction(
  958. "mark_remote_users_device_caches_as_stale",
  959. _mark_remote_users_device_caches_as_stale_txn,
  960. )
  961. async def mark_remote_user_device_cache_as_valid(self, user_id: str) -> None:
  962. # Remove the database entry that says we need to resync devices, after a resync
  963. await self.db_pool.simple_delete(
  964. table="device_lists_remote_resync",
  965. keyvalues={"user_id": user_id},
  966. desc="mark_remote_user_device_cache_as_valid",
  967. )
  968. async def handle_potentially_left_users(self, user_ids: Set[str]) -> None:
  969. """Given a set of remote users check if the server still shares a room with
  970. them. If not then mark those users' device cache as stale.
  971. """
  972. if not user_ids:
  973. return
  974. await self.db_pool.runInteraction(
  975. "_handle_potentially_left_users",
  976. self.handle_potentially_left_users_txn,
  977. user_ids,
  978. )
  979. def handle_potentially_left_users_txn(
  980. self,
  981. txn: LoggingTransaction,
  982. user_ids: Set[str],
  983. ) -> None:
  984. """Given a set of remote users check if the server still shares a room with
  985. them. If not then mark those users' device cache as stale.
  986. """
  987. if not user_ids:
  988. return
  989. joined_users = self.get_users_server_still_shares_room_with_txn(txn, user_ids)
  990. left_users = user_ids - joined_users
  991. for user_id in left_users:
  992. self.mark_remote_user_device_list_as_unsubscribed_txn(txn, user_id)
  993. async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
  994. """Mark that we no longer track device lists for remote user."""
  995. await self.db_pool.runInteraction(
  996. "mark_remote_user_device_list_as_unsubscribed",
  997. self.mark_remote_user_device_list_as_unsubscribed_txn,
  998. user_id,
  999. )
  1000. def mark_remote_user_device_list_as_unsubscribed_txn(
  1001. self,
  1002. txn: LoggingTransaction,
  1003. user_id: str,
  1004. ) -> None:
  1005. self.db_pool.simple_delete_txn(
  1006. txn,
  1007. table="device_lists_remote_extremeties",
  1008. keyvalues={"user_id": user_id},
  1009. )
  1010. self._invalidate_cache_and_stream(
  1011. txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
  1012. )
  1013. async def get_dehydrated_device(
  1014. self, user_id: str
  1015. ) -> Optional[Tuple[str, JsonDict]]:
  1016. """Retrieve the information for a dehydrated device.
  1017. Args:
  1018. user_id: the user whose dehydrated device we are looking for
  1019. Returns:
  1020. a tuple whose first item is the device ID, and the second item is
  1021. the dehydrated device information
  1022. """
  1023. # FIXME: make sure device ID still exists in devices table
  1024. row = await self.db_pool.simple_select_one(
  1025. table="dehydrated_devices",
  1026. keyvalues={"user_id": user_id},
  1027. retcols=["device_id", "device_data"],
  1028. allow_none=True,
  1029. )
  1030. return (
  1031. (row["device_id"], json_decoder.decode(row["device_data"])) if row else None
  1032. )
  1033. def _store_dehydrated_device_txn(
  1034. self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
  1035. ) -> Optional[str]:
  1036. old_device_id = self.db_pool.simple_select_one_onecol_txn(
  1037. txn,
  1038. table="dehydrated_devices",
  1039. keyvalues={"user_id": user_id},
  1040. retcol="device_id",
  1041. allow_none=True,
  1042. )
  1043. self.db_pool.simple_upsert_txn(
  1044. txn,
  1045. table="dehydrated_devices",
  1046. keyvalues={"user_id": user_id},
  1047. values={"device_id": device_id, "device_data": device_data},
  1048. )
  1049. return old_device_id
  1050. async def store_dehydrated_device(
  1051. self, user_id: str, device_id: str, device_data: JsonDict
  1052. ) -> Optional[str]:
  1053. """Store a dehydrated device for a user.
  1054. Args:
  1055. user_id: the user that we are storing the device for
  1056. device_id: the ID of the dehydrated device
  1057. device_data: the dehydrated device information
  1058. Returns:
  1059. device id of the user's previous dehydrated device, if any
  1060. """
  1061. return await self.db_pool.runInteraction(
  1062. "store_dehydrated_device_txn",
  1063. self._store_dehydrated_device_txn,
  1064. user_id,
  1065. device_id,
  1066. json_encoder.encode(device_data),
  1067. )
  1068. async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
  1069. """Remove a dehydrated device.
  1070. Args:
  1071. user_id: the user that the dehydrated device belongs to
  1072. device_id: the ID of the dehydrated device
  1073. """
  1074. count = await self.db_pool.simple_delete(
  1075. "dehydrated_devices",
  1076. {"user_id": user_id, "device_id": device_id},
  1077. desc="remove_dehydrated_device",
  1078. )
  1079. return count >= 1
  1080. @wrap_as_background_process("prune_old_outbound_device_pokes")
  1081. async def _prune_old_outbound_device_pokes(
  1082. self, prune_age: int = 24 * 60 * 60 * 1000
  1083. ) -> None:
  1084. """Delete old entries out of the device_lists_outbound_pokes to ensure
  1085. that we don't fill up due to dead servers.
  1086. Normally, we try to send device updates as a delta since a previous known point:
  1087. this is done by setting the prev_id in the m.device_list_update EDU. However,
  1088. for that to work, we have to have a complete record of each change to
  1089. each device, which can add up to quite a lot of data.
  1090. An alternative mechanism is that, if the remote server sees that it has missed
  1091. an entry in the stream_id sequence for a given user, it will request a full
  1092. list of that user's devices. Hence, we can reduce the amount of data we have to
  1093. store (and transmit in some future transaction), by clearing almost everything
  1094. for a given destination out of the database, and having the remote server
  1095. resync.
  1096. All we need to do is make sure we keep at least one row for each
  1097. (user, destination) pair, to remind us to send a m.device_list_update EDU for
  1098. that user when the destination comes back. It doesn't matter which device
  1099. we keep.
  1100. """
  1101. yesterday = self._clock.time_msec() - prune_age
  1102. def _prune_txn(txn: LoggingTransaction) -> None:
  1103. # look for (user, destination) pairs which have an update older than
  1104. # the cutoff.
  1105. #
  1106. # For each pair, we also need to know the most recent stream_id, and
  1107. # an arbitrary device_id at that stream_id.
  1108. select_sql = """
  1109. SELECT
  1110. dlop1.destination,
  1111. dlop1.user_id,
  1112. MAX(dlop1.stream_id) AS stream_id,
  1113. (SELECT MIN(dlop2.device_id) AS device_id FROM
  1114. device_lists_outbound_pokes dlop2
  1115. WHERE dlop2.destination = dlop1.destination AND
  1116. dlop2.user_id=dlop1.user_id AND
  1117. dlop2.stream_id=MAX(dlop1.stream_id)
  1118. )
  1119. FROM device_lists_outbound_pokes dlop1
  1120. GROUP BY destination, user_id
  1121. HAVING min(ts) < ? AND count(*) > 1
  1122. """
  1123. txn.execute(select_sql, (yesterday,))
  1124. rows = txn.fetchall()
  1125. if not rows:
  1126. return
  1127. logger.info(
  1128. "Pruning old outbound device list updates for %i users/destinations: %s",
  1129. len(rows),
  1130. shortstr((row[0], row[1]) for row in rows),
  1131. )
  1132. # we want to keep the update with the highest stream_id for each user.
  1133. #
  1134. # there might be more than one update (with different device_ids) with the
  1135. # same stream_id, so we also delete all but one rows with the max stream id.
  1136. delete_sql = """
  1137. DELETE FROM device_lists_outbound_pokes
  1138. WHERE destination = ? AND user_id = ? AND (
  1139. stream_id < ? OR
  1140. (stream_id = ? AND device_id != ?)
  1141. )
  1142. """
  1143. count = 0
  1144. for destination, user_id, stream_id, device_id in rows:
  1145. txn.execute(
  1146. delete_sql, (destination, user_id, stream_id, stream_id, device_id)
  1147. )
  1148. count += txn.rowcount
  1149. # Since we've deleted unsent deltas, we need to remove the entry
  1150. # of last successful sent so that the prev_ids are correctly set.
  1151. sql = """
  1152. DELETE FROM device_lists_outbound_last_success
  1153. WHERE destination = ? AND user_id = ?
  1154. """
  1155. txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
  1156. logger.info("Pruned %d device list outbound pokes", count)
  1157. await self.db_pool.runInteraction(
  1158. "_prune_old_outbound_device_pokes",
  1159. _prune_txn,
  1160. )
  1161. async def get_local_devices_not_accessed_since(
  1162. self, since_ms: int
  1163. ) -> Dict[str, List[str]]:
  1164. """Retrieves local devices that haven't been accessed since a given date.
  1165. Args:
  1166. since_ms: the timestamp to select on, every device with a last access date
  1167. from before that time is returned.
  1168. Returns:
  1169. A dictionary with an entry for each user with at least one device matching
  1170. the request, which value is a list of the device ID(s) for the corresponding
  1171. device(s).
  1172. """
  1173. def get_devices_not_accessed_since_txn(
  1174. txn: LoggingTransaction,
  1175. ) -> List[Dict[str, str]]:
  1176. sql = """
  1177. SELECT user_id, device_id
  1178. FROM devices WHERE last_seen < ? AND hidden = FALSE
  1179. """
  1180. txn.execute(sql, (since_ms,))
  1181. return self.db_pool.cursor_to_dict(txn)
  1182. rows = await self.db_pool.runInteraction(
  1183. "get_devices_not_accessed_since",
  1184. get_devices_not_accessed_since_txn,
  1185. )
  1186. devices: Dict[str, List[str]] = {}
  1187. for row in rows:
  1188. # Remote devices are never stale from our point of view.
  1189. if self.hs.is_mine_id(row["user_id"]):
  1190. user_devices = devices.setdefault(row["user_id"], [])
  1191. user_devices.append(row["device_id"])
  1192. return devices
  1193. @cached()
  1194. async def _get_min_device_lists_changes_in_room(self) -> int:
  1195. """Returns the minimum stream ID that we have entries for
  1196. `device_lists_changes_in_room`
  1197. """
  1198. return await self.db_pool.simple_select_one_onecol(
  1199. table="device_lists_changes_in_room",
  1200. keyvalues={},
  1201. retcol="COALESCE(MIN(stream_id), 0)",
  1202. desc="get_min_device_lists_changes_in_room",
  1203. )
  1204. @cancellable
  1205. async def get_device_list_changes_in_rooms(
  1206. self, room_ids: Collection[str], from_id: int
  1207. ) -> Optional[Set[str]]:
  1208. """Return the set of users whose devices have changed in the given rooms
  1209. since the given stream ID.
  1210. Returns None if the given stream ID is too old.
  1211. """
  1212. if not room_ids:
  1213. return set()
  1214. min_stream_id = await self._get_min_device_lists_changes_in_room()
  1215. if min_stream_id > from_id:
  1216. return None
  1217. sql = """
  1218. SELECT DISTINCT user_id FROM device_lists_changes_in_room
  1219. WHERE {clause} AND stream_id >= ?
  1220. """
  1221. def _get_device_list_changes_in_rooms_txn(
  1222. txn: LoggingTransaction,
  1223. clause: str,
  1224. args: List[Any],
  1225. ) -> Set[str]:
  1226. txn.execute(sql.format(clause=clause), args)
  1227. return {user_id for user_id, in txn}
  1228. changes = set()
  1229. for chunk in batch_iter(room_ids, 1000):
  1230. clause, args = make_in_list_sql_clause(
  1231. self.database_engine, "room_id", chunk
  1232. )
  1233. args.append(from_id)
  1234. changes |= await self.db_pool.runInteraction(
  1235. "get_device_list_changes_in_rooms",
  1236. _get_device_list_changes_in_rooms_txn,
  1237. clause,
  1238. args,
  1239. )
  1240. return changes
  1241. async def get_device_list_changes_in_room(
  1242. self, room_id: str, min_stream_id: int
  1243. ) -> Collection[Tuple[str, str]]:
  1244. """Get all device list changes that happened in the room since the given
  1245. stream ID.
  1246. Returns:
  1247. Collection of user ID/device ID tuples of all devices that have
  1248. changed
  1249. """
  1250. sql = """
  1251. SELECT DISTINCT user_id, device_id FROM device_lists_changes_in_room
  1252. WHERE room_id = ? AND stream_id > ?
  1253. """
  1254. def get_device_list_changes_in_room_txn(
  1255. txn: LoggingTransaction,
  1256. ) -> Collection[Tuple[str, str]]:
  1257. txn.execute(sql, (room_id, min_stream_id))
  1258. return cast(Collection[Tuple[str, str]], txn.fetchall())
  1259. return await self.db_pool.runInteraction(
  1260. "get_device_list_changes_in_room",
  1261. get_device_list_changes_in_room_txn,
  1262. )
  1263. class DeviceBackgroundUpdateStore(SQLBaseStore):
  1264. def __init__(
  1265. self,
  1266. database: DatabasePool,
  1267. db_conn: LoggingDatabaseConnection,
  1268. hs: "HomeServer",
  1269. ):
  1270. super().__init__(database, db_conn, hs)
  1271. self.db_pool.updates.register_background_index_update(
  1272. "device_lists_stream_idx",
  1273. index_name="device_lists_stream_user_id",
  1274. table="device_lists_stream",
  1275. columns=["user_id", "device_id"],
  1276. )
  1277. # create a unique index on device_lists_remote_cache
  1278. self.db_pool.updates.register_background_index_update(
  1279. "device_lists_remote_cache_unique_idx",
  1280. index_name="device_lists_remote_cache_unique_id",
  1281. table="device_lists_remote_cache",
  1282. columns=["user_id", "device_id"],
  1283. unique=True,
  1284. )
  1285. # And one on device_lists_remote_extremeties
  1286. self.db_pool.updates.register_background_index_update(
  1287. "device_lists_remote_extremeties_unique_idx",
  1288. index_name="device_lists_remote_extremeties_unique_idx",
  1289. table="device_lists_remote_extremeties",
  1290. columns=["user_id"],
  1291. unique=True,
  1292. )
  1293. # once they complete, we can remove the old non-unique indexes.
  1294. self.db_pool.updates.register_background_update_handler(
  1295. DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
  1296. self._drop_device_list_streams_non_unique_indexes,
  1297. )
  1298. # clear out duplicate device list outbound pokes
  1299. self.db_pool.updates.register_background_update_handler(
  1300. BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
  1301. self._remove_duplicate_outbound_pokes,
  1302. )
  1303. self.db_pool.updates.register_background_index_update(
  1304. "device_lists_changes_in_room_by_room_index",
  1305. index_name="device_lists_changes_in_room_by_room_idx",
  1306. table="device_lists_changes_in_room",
  1307. columns=["room_id", "stream_id"],
  1308. )
  1309. async def _drop_device_list_streams_non_unique_indexes(
  1310. self, progress: JsonDict, batch_size: int
  1311. ) -> int:
  1312. def f(conn: LoggingDatabaseConnection) -> None:
  1313. txn = conn.cursor()
  1314. txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
  1315. txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
  1316. txn.close()
  1317. await self.db_pool.runWithConnection(f)
  1318. await self.db_pool.updates._end_background_update(
  1319. DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
  1320. )
  1321. return 1
  1322. async def _remove_duplicate_outbound_pokes(
  1323. self, progress: JsonDict, batch_size: int
  1324. ) -> int:
  1325. # for some reason, we have accumulated duplicate entries in
  1326. # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
  1327. # efficient.
  1328. #
  1329. # For each duplicate, we delete all the existing rows and put one back.
  1330. KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
  1331. last_row = progress.get(
  1332. "last_row",
  1333. {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
  1334. )
  1335. def _txn(txn: LoggingTransaction) -> int:
  1336. clause, args = make_tuple_comparison_clause(
  1337. [(x, last_row[x]) for x in KEY_COLS]
  1338. )
  1339. sql = """
  1340. SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
  1341. FROM device_lists_outbound_pokes
  1342. WHERE %s
  1343. GROUP BY %s
  1344. HAVING count(*) > 1
  1345. ORDER BY %s
  1346. LIMIT ?
  1347. """ % (
  1348. clause, # WHERE
  1349. ",".join(KEY_COLS), # GROUP BY
  1350. ",".join(KEY_COLS), # ORDER BY
  1351. )
  1352. txn.execute(sql, args + [batch_size])
  1353. rows = self.db_pool.cursor_to_dict(txn)
  1354. row = None
  1355. for row in rows:
  1356. self.db_pool.simple_delete_txn(
  1357. txn,
  1358. "device_lists_outbound_pokes",
  1359. {x: row[x] for x in KEY_COLS},
  1360. )
  1361. row["sent"] = False
  1362. self.db_pool.simple_insert_txn(
  1363. txn,
  1364. "device_lists_outbound_pokes",
  1365. row,
  1366. )
  1367. if row:
  1368. self.db_pool.updates._background_update_progress_txn(
  1369. txn,
  1370. BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES,
  1371. {"last_row": row},
  1372. )
  1373. return len(rows)
  1374. rows = await self.db_pool.runInteraction(
  1375. BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn
  1376. )
  1377. if not rows:
  1378. await self.db_pool.updates._end_background_update(
  1379. BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
  1380. )
  1381. return rows
  1382. class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
  1383. # Because we have write access, this will be a StreamIdGenerator
  1384. # (see DeviceWorkerStore.__init__)
  1385. _device_list_id_gen: AbstractStreamIdGenerator
  1386. def __init__(
  1387. self,
  1388. database: DatabasePool,
  1389. db_conn: LoggingDatabaseConnection,
  1390. hs: "HomeServer",
  1391. ):
  1392. super().__init__(database, db_conn, hs)
  1393. # Map of (user_id, device_id) -> bool. If there is an entry that implies
  1394. # the device exists.
  1395. self.device_id_exists_cache: LruCache[
  1396. Tuple[str, str], Literal[True]
  1397. ] = LruCache(cache_name="device_id_exists", max_size=10000)
  1398. async def store_device(
  1399. self,
  1400. user_id: str,
  1401. device_id: str,
  1402. initial_device_display_name: Optional[str],
  1403. auth_provider_id: Optional[str] = None,
  1404. auth_provider_session_id: Optional[str] = None,
  1405. ) -> bool:
  1406. """Ensure the given device is known; add it to the store if not
  1407. Args:
  1408. user_id: id of user associated with the device
  1409. device_id: id of device
  1410. initial_device_display_name: initial displayname of the device.
  1411. Ignored if device exists.
  1412. auth_provider_id: The SSO IdP the user used, if any.
  1413. auth_provider_session_id: The session ID (sid) got from a OIDC login.
  1414. Returns:
  1415. Whether the device was inserted or an existing device existed with that ID.
  1416. Raises:
  1417. StoreError: if the device is already in use
  1418. """
  1419. key = (user_id, device_id)
  1420. if self.device_id_exists_cache.get(key, None):
  1421. return False
  1422. try:
  1423. inserted = await self.db_pool.simple_upsert(
  1424. "devices",
  1425. keyvalues={
  1426. "user_id": user_id,
  1427. "device_id": device_id,
  1428. },
  1429. values={},
  1430. insertion_values={
  1431. "display_name": initial_device_display_name,
  1432. "hidden": False,
  1433. },
  1434. desc="store_device",
  1435. )
  1436. if not inserted:
  1437. # if the device already exists, check if it's a real device, or
  1438. # if the device ID is reserved by something else
  1439. hidden = await self.db_pool.simple_select_one_onecol(
  1440. "devices",
  1441. keyvalues={"user_id": user_id, "device_id": device_id},
  1442. retcol="hidden",
  1443. )
  1444. if hidden:
  1445. raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
  1446. if auth_provider_id and auth_provider_session_id:
  1447. await self.db_pool.simple_insert(
  1448. "device_auth_providers",
  1449. values={
  1450. "user_id": user_id,
  1451. "device_id": device_id,
  1452. "auth_provider_id": auth_provider_id,
  1453. "auth_provider_session_id": auth_provider_session_id,
  1454. },
  1455. desc="store_device_auth_provider",
  1456. )
  1457. self.device_id_exists_cache.set(key, True)
  1458. return inserted
  1459. except StoreError:
  1460. raise
  1461. except Exception as e:
  1462. logger.error(
  1463. "store_device with device_id=%s(%r) user_id=%s(%r)"
  1464. " display_name=%s(%r) failed: %s",
  1465. type(device_id).__name__,
  1466. device_id,
  1467. type(user_id).__name__,
  1468. user_id,
  1469. type(initial_device_display_name).__name__,
  1470. initial_device_display_name,
  1471. e,
  1472. )
  1473. raise StoreError(500, "Problem storing device.")
  1474. async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
  1475. """Deletes several devices.
  1476. Args:
  1477. user_id: The ID of the user which owns the devices
  1478. device_ids: The IDs of the devices to delete
  1479. """
  1480. def _delete_devices_txn(txn: LoggingTransaction) -> None:
  1481. self.db_pool.simple_delete_many_txn(
  1482. txn,
  1483. table="devices",
  1484. column="device_id",
  1485. values=device_ids,
  1486. keyvalues={"user_id": user_id, "hidden": False},
  1487. )
  1488. self.db_pool.simple_delete_many_txn(
  1489. txn,
  1490. table="device_inbox",
  1491. column="device_id",
  1492. values=device_ids,
  1493. keyvalues={"user_id": user_id},
  1494. )
  1495. self.db_pool.simple_delete_many_txn(
  1496. txn,
  1497. table="device_auth_providers",
  1498. column="device_id",
  1499. values=device_ids,
  1500. keyvalues={"user_id": user_id},
  1501. )
  1502. await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
  1503. for device_id in device_ids:
  1504. self.device_id_exists_cache.invalidate((user_id, device_id))
  1505. async def update_device(
  1506. self, user_id: str, device_id: str, new_display_name: Optional[str] = None
  1507. ) -> None:
  1508. """Update a device. Only updates the device if it is not marked as
  1509. hidden.
  1510. Args:
  1511. user_id: The ID of the user which owns the device
  1512. device_id: The ID of the device to update
  1513. new_display_name: new displayname for device; None to leave unchanged
  1514. Raises:
  1515. StoreError: if the device is not found
  1516. """
  1517. updates = {}
  1518. if new_display_name is not None:
  1519. updates["display_name"] = new_display_name
  1520. if not updates:
  1521. return None
  1522. await self.db_pool.simple_update_one(
  1523. table="devices",
  1524. keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
  1525. updatevalues=updates,
  1526. desc="update_device",
  1527. )
  1528. async def update_remote_device_list_cache_entry(
  1529. self, user_id: str, device_id: str, content: JsonDict, stream_id: str
  1530. ) -> None:
  1531. """Updates a single device in the cache of a remote user's devicelist.
  1532. Note: assumes that we are the only thread that can be updating this user's
  1533. device list.
  1534. Args:
  1535. user_id: User to update device list for
  1536. device_id: ID of decivice being updated
  1537. content: new data on this device
  1538. stream_id: the version of the device list
  1539. """
  1540. await self.db_pool.runInteraction(
  1541. "update_remote_device_list_cache_entry",
  1542. self._update_remote_device_list_cache_entry_txn,
  1543. user_id,
  1544. device_id,
  1545. content,
  1546. stream_id,
  1547. )
  1548. def _update_remote_device_list_cache_entry_txn(
  1549. self,
  1550. txn: LoggingTransaction,
  1551. user_id: str,
  1552. device_id: str,
  1553. content: JsonDict,
  1554. stream_id: str,
  1555. ) -> None:
  1556. """Delete, update or insert a cache entry for this (user, device) pair."""
  1557. if content.get("deleted"):
  1558. self.db_pool.simple_delete_txn(
  1559. txn,
  1560. table="device_lists_remote_cache",
  1561. keyvalues={"user_id": user_id, "device_id": device_id},
  1562. )
  1563. txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
  1564. else:
  1565. self.db_pool.simple_upsert_txn(
  1566. txn,
  1567. table="device_lists_remote_cache",
  1568. keyvalues={"user_id": user_id, "device_id": device_id},
  1569. values={"content": json_encoder.encode(content)},
  1570. )
  1571. txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
  1572. txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
  1573. txn.call_after(
  1574. self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
  1575. )
  1576. self.db_pool.simple_upsert_txn(
  1577. txn,
  1578. table="device_lists_remote_extremeties",
  1579. keyvalues={"user_id": user_id},
  1580. values={"stream_id": stream_id},
  1581. )
  1582. async def update_remote_device_list_cache(
  1583. self, user_id: str, devices: List[dict], stream_id: int
  1584. ) -> None:
  1585. """Replace the entire cache of the remote user's devices.
  1586. Note: assumes that we are the only thread that can be updating this user's
  1587. device list.
  1588. Args:
  1589. user_id: User to update device list for
  1590. devices: list of device objects supplied over federation
  1591. stream_id: the version of the device list
  1592. """
  1593. await self.db_pool.runInteraction(
  1594. "update_remote_device_list_cache",
  1595. self._update_remote_device_list_cache_txn,
  1596. user_id,
  1597. devices,
  1598. stream_id,
  1599. )
  1600. def _update_remote_device_list_cache_txn(
  1601. self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
  1602. ) -> None:
  1603. """Replace the list of cached devices for this user with the given list."""
  1604. self.db_pool.simple_delete_txn(
  1605. txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
  1606. )
  1607. self.db_pool.simple_insert_many_txn(
  1608. txn,
  1609. table="device_lists_remote_cache",
  1610. keys=("user_id", "device_id", "content"),
  1611. values=[
  1612. (user_id, content["device_id"], json_encoder.encode(content))
  1613. for content in devices
  1614. ],
  1615. )
  1616. txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
  1617. txn.call_after(self._get_cached_user_device.invalidate, (user_id,))
  1618. txn.call_after(
  1619. self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
  1620. )
  1621. self.db_pool.simple_upsert_txn(
  1622. txn,
  1623. table="device_lists_remote_extremeties",
  1624. keyvalues={"user_id": user_id},
  1625. values={"stream_id": stream_id},
  1626. )
  1627. async def add_device_change_to_streams(
  1628. self,
  1629. user_id: str,
  1630. device_ids: Collection[str],
  1631. room_ids: Collection[str],
  1632. ) -> Optional[int]:
  1633. """Persist that a user's devices have been updated, and which hosts
  1634. (if any) should be poked.
  1635. Args:
  1636. user_id: The ID of the user whose device changed.
  1637. device_ids: The IDs of any changed devices. If empty, this function will
  1638. return None.
  1639. room_ids: The rooms that the user is in
  1640. Returns:
  1641. The maximum stream ID of device list updates that were added to the database, or
  1642. None if no updates were added.
  1643. """
  1644. if not device_ids:
  1645. return None
  1646. context = get_active_span_text_map()
  1647. def add_device_changes_txn(
  1648. txn: LoggingTransaction, stream_ids: List[int]
  1649. ) -> None:
  1650. self._add_device_change_to_stream_txn(
  1651. txn,
  1652. user_id,
  1653. device_ids,
  1654. stream_ids,
  1655. )
  1656. self._add_device_outbound_room_poke_txn(
  1657. txn,
  1658. user_id,
  1659. device_ids,
  1660. room_ids,
  1661. stream_ids,
  1662. context,
  1663. )
  1664. async with self._device_list_id_gen.get_next_mult(
  1665. len(device_ids)
  1666. ) as stream_ids:
  1667. await self.db_pool.runInteraction(
  1668. "add_device_change_to_stream",
  1669. add_device_changes_txn,
  1670. stream_ids,
  1671. )
  1672. return stream_ids[-1]
  1673. def _add_device_change_to_stream_txn(
  1674. self,
  1675. txn: LoggingTransaction,
  1676. user_id: str,
  1677. device_ids: Collection[str],
  1678. stream_ids: List[int],
  1679. ) -> None:
  1680. txn.call_after(
  1681. self._device_list_stream_cache.entity_has_changed,
  1682. user_id,
  1683. stream_ids[-1],
  1684. )
  1685. min_stream_id = stream_ids[0]
  1686. # Delete older entries in the table, as we really only care about
  1687. # when the latest change happened.
  1688. txn.execute_batch(
  1689. """
  1690. DELETE FROM device_lists_stream
  1691. WHERE user_id = ? AND device_id = ? AND stream_id < ?
  1692. """,
  1693. [(user_id, device_id, min_stream_id) for device_id in device_ids],
  1694. )
  1695. self.db_pool.simple_insert_many_txn(
  1696. txn,
  1697. table="device_lists_stream",
  1698. keys=("stream_id", "user_id", "device_id"),
  1699. values=[
  1700. (stream_id, user_id, device_id)
  1701. for stream_id, device_id in zip(stream_ids, device_ids)
  1702. ],
  1703. )
  1704. def _add_device_outbound_poke_to_stream_txn(
  1705. self,
  1706. txn: LoggingTransaction,
  1707. user_id: str,
  1708. device_id: str,
  1709. hosts: Collection[str],
  1710. stream_ids: List[int],
  1711. context: Optional[Dict[str, str]],
  1712. ) -> None:
  1713. for host in hosts:
  1714. txn.call_after(
  1715. self._device_list_federation_stream_cache.entity_has_changed,
  1716. host,
  1717. stream_ids[-1],
  1718. )
  1719. now = self._clock.time_msec()
  1720. stream_id_iterator = iter(stream_ids)
  1721. encoded_context = json_encoder.encode(context)
  1722. mark_sent = not self.hs.is_mine_id(user_id)
  1723. values = [
  1724. (
  1725. destination,
  1726. next(stream_id_iterator),
  1727. user_id,
  1728. device_id,
  1729. mark_sent,
  1730. now,
  1731. encoded_context if whitelisted_homeserver(destination) else "{}",
  1732. )
  1733. for destination in hosts
  1734. ]
  1735. self.db_pool.simple_insert_many_txn(
  1736. txn,
  1737. table="device_lists_outbound_pokes",
  1738. keys=(
  1739. "destination",
  1740. "stream_id",
  1741. "user_id",
  1742. "device_id",
  1743. "sent",
  1744. "ts",
  1745. "opentracing_context",
  1746. ),
  1747. values=values,
  1748. )
  1749. # debugging for https://github.com/matrix-org/synapse/issues/14251
  1750. if issue_8631_logger.isEnabledFor(logging.DEBUG):
  1751. issue_8631_logger.debug(
  1752. "Recorded outbound pokes for %s:%s with device stream ids %s",
  1753. user_id,
  1754. device_id,
  1755. {
  1756. stream_id: destination
  1757. for (destination, stream_id, _, _, _, _, _) in values
  1758. },
  1759. )
  1760. def _add_device_outbound_room_poke_txn(
  1761. self,
  1762. txn: LoggingTransaction,
  1763. user_id: str,
  1764. device_ids: Iterable[str],
  1765. room_ids: Collection[str],
  1766. stream_ids: List[int],
  1767. context: Dict[str, str],
  1768. ) -> None:
  1769. """Record the user in the room has updated their device."""
  1770. encoded_context = json_encoder.encode(context)
  1771. # The `device_lists_changes_in_room.stream_id` column matches the
  1772. # corresponding `stream_id` of the update in the `device_lists_stream`
  1773. # table, i.e. all rows persisted for the same device update will have
  1774. # the same `stream_id` (but different room IDs).
  1775. self.db_pool.simple_insert_many_txn(
  1776. txn,
  1777. table="device_lists_changes_in_room",
  1778. keys=(
  1779. "user_id",
  1780. "device_id",
  1781. "room_id",
  1782. "stream_id",
  1783. "converted_to_destinations",
  1784. "opentracing_context",
  1785. ),
  1786. values=[
  1787. (
  1788. user_id,
  1789. device_id,
  1790. room_id,
  1791. stream_id,
  1792. # We only need to calculate outbound pokes for local users
  1793. not self.hs.is_mine_id(user_id),
  1794. encoded_context,
  1795. )
  1796. for room_id in room_ids
  1797. for device_id, stream_id in zip(device_ids, stream_ids)
  1798. ],
  1799. )
  1800. async def get_uncoverted_outbound_room_pokes(
  1801. self, start_stream_id: int, start_room_id: str, limit: int = 10
  1802. ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
  1803. """Get device list changes by room that have not yet been handled and
  1804. written to `device_lists_outbound_pokes`.
  1805. Args:
  1806. start_stream_id: Together with `start_room_id`, indicates the position after
  1807. which to return device list changes.
  1808. start_room_id: Together with `start_stream_id`, indicates the position after
  1809. which to return device list changes.
  1810. limit: The maximum number of device list changes to return.
  1811. Returns:
  1812. A list of user ID, device ID, room ID, stream ID and optional opentracing
  1813. context, in order of ascending (stream ID, room ID).
  1814. """
  1815. sql = """
  1816. SELECT user_id, device_id, room_id, stream_id, opentracing_context
  1817. FROM device_lists_changes_in_room
  1818. WHERE
  1819. (stream_id, room_id) > (?, ?) AND
  1820. stream_id <= ? AND
  1821. NOT converted_to_destinations
  1822. ORDER BY stream_id ASC, room_id ASC
  1823. LIMIT ?
  1824. """
  1825. def get_uncoverted_outbound_room_pokes_txn(
  1826. txn: LoggingTransaction,
  1827. ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
  1828. txn.execute(
  1829. sql,
  1830. (
  1831. start_stream_id,
  1832. start_room_id,
  1833. # Avoid returning rows if there may be uncommitted device list
  1834. # changes with smaller stream IDs.
  1835. self._device_list_id_gen.get_current_token(),
  1836. limit,
  1837. ),
  1838. )
  1839. return [
  1840. (
  1841. user_id,
  1842. device_id,
  1843. room_id,
  1844. stream_id,
  1845. db_to_json(opentracing_context),
  1846. )
  1847. for user_id, device_id, room_id, stream_id, opentracing_context in txn
  1848. ]
  1849. return await self.db_pool.runInteraction(
  1850. "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
  1851. )
  1852. async def add_device_list_outbound_pokes(
  1853. self,
  1854. user_id: str,
  1855. device_id: str,
  1856. room_id: str,
  1857. hosts: Collection[str],
  1858. context: Optional[Dict[str, str]],
  1859. ) -> None:
  1860. """Queue the device update to be sent to the given set of hosts,
  1861. calculated from the room ID.
  1862. """
  1863. if not hosts:
  1864. return
  1865. def add_device_list_outbound_pokes_txn(
  1866. txn: LoggingTransaction, stream_ids: List[int]
  1867. ) -> None:
  1868. self._add_device_outbound_poke_to_stream_txn(
  1869. txn,
  1870. user_id=user_id,
  1871. device_id=device_id,
  1872. hosts=hosts,
  1873. stream_ids=stream_ids,
  1874. context=context,
  1875. )
  1876. async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
  1877. return await self.db_pool.runInteraction(
  1878. "add_device_list_outbound_pokes",
  1879. add_device_list_outbound_pokes_txn,
  1880. stream_ids,
  1881. )
  1882. async def add_remote_device_list_to_pending(
  1883. self, user_id: str, device_id: str
  1884. ) -> None:
  1885. """Add a device list update to the table tracking remote device list
  1886. updates during partial joins.
  1887. """
  1888. async with self._device_list_id_gen.get_next() as stream_id:
  1889. await self.db_pool.simple_upsert(
  1890. table="device_lists_remote_pending",
  1891. keyvalues={
  1892. "user_id": user_id,
  1893. "device_id": device_id,
  1894. },
  1895. values={"stream_id": stream_id},
  1896. desc="add_remote_device_list_to_pending",
  1897. )
  1898. async def get_pending_remote_device_list_updates_for_room(
  1899. self, room_id: str
  1900. ) -> Collection[Tuple[str, str]]:
  1901. """Get the set of remote device list updates from the pending table for
  1902. the room.
  1903. """
  1904. min_device_stream_id = await self.db_pool.simple_select_one_onecol(
  1905. table="partial_state_rooms",
  1906. keyvalues={
  1907. "room_id": room_id,
  1908. },
  1909. retcol="device_lists_stream_id",
  1910. desc="get_pending_remote_device_list_updates_for_room_device",
  1911. )
  1912. sql = """
  1913. SELECT user_id, device_id FROM device_lists_remote_pending AS d
  1914. INNER JOIN current_state_events AS c ON
  1915. type = 'm.room.member'
  1916. AND state_key = user_id
  1917. AND membership = 'join'
  1918. WHERE
  1919. room_id = ? AND stream_id > ?
  1920. """
  1921. def get_pending_remote_device_list_updates_for_room_txn(
  1922. txn: LoggingTransaction,
  1923. ) -> Collection[Tuple[str, str]]:
  1924. txn.execute(sql, (room_id, min_device_stream_id))
  1925. return cast(Collection[Tuple[str, str]], txn.fetchall())
  1926. return await self.db_pool.runInteraction(
  1927. "get_pending_remote_device_list_updates_for_room",
  1928. get_pending_remote_device_list_updates_for_room_txn,
  1929. )
  1930. async def get_device_change_last_converted_pos(self) -> Tuple[int, str]:
  1931. """
  1932. Get the position of the last row in `device_list_changes_in_room` that has been
  1933. converted to `device_lists_outbound_pokes`.
  1934. Rows with a strictly greater position where `converted_to_destinations` is
  1935. `FALSE` have not been converted.
  1936. """
  1937. row = await self.db_pool.simple_select_one(
  1938. table="device_lists_changes_converted_stream_position",
  1939. keyvalues={},
  1940. retcols=["stream_id", "room_id"],
  1941. desc="get_device_change_last_converted_pos",
  1942. )
  1943. return row["stream_id"], row["room_id"]
  1944. async def set_device_change_last_converted_pos(
  1945. self,
  1946. stream_id: int,
  1947. room_id: str,
  1948. ) -> None:
  1949. """
  1950. Set the position of the last row in `device_list_changes_in_room` that has been
  1951. converted to `device_lists_outbound_pokes`.
  1952. """
  1953. await self.db_pool.simple_update_one(
  1954. table="device_lists_changes_converted_stream_position",
  1955. keyvalues={},
  1956. updatevalues={"stream_id": stream_id, "room_id": room_id},
  1957. desc="set_device_change_last_converted_pos",
  1958. )