devices.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import ujson as json
  17. from twisted.internet import defer
  18. from synapse.api.errors import StoreError
  19. from ._base import SQLBaseStore
  20. from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
  21. logger = logging.getLogger(__name__)
  22. class DeviceStore(SQLBaseStore):
  23. def __init__(self, hs):
  24. super(DeviceStore, self).__init__(hs)
  25. self._clock.looping_call(
  26. self._prune_old_outbound_device_pokes, 60 * 60 * 1000
  27. )
  28. self.register_background_index_update(
  29. "device_lists_stream_idx",
  30. index_name="device_lists_stream_user_id",
  31. table="device_lists_stream",
  32. columns=["user_id", "device_id"],
  33. )
  34. @defer.inlineCallbacks
  35. def store_device(self, user_id, device_id,
  36. initial_device_display_name):
  37. """Ensure the given device is known; add it to the store if not
  38. Args:
  39. user_id (str): id of user associated with the device
  40. device_id (str): id of device
  41. initial_device_display_name (str): initial displayname of the
  42. device. Ignored if device exists.
  43. Returns:
  44. defer.Deferred: boolean whether the device was inserted or an
  45. existing device existed with that ID.
  46. """
  47. try:
  48. inserted = yield self._simple_insert(
  49. "devices",
  50. values={
  51. "user_id": user_id,
  52. "device_id": device_id,
  53. "display_name": initial_device_display_name
  54. },
  55. desc="store_device",
  56. or_ignore=True,
  57. )
  58. defer.returnValue(inserted)
  59. except Exception as e:
  60. logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
  61. " display_name=%s(%r) failed: %s",
  62. type(device_id).__name__, device_id,
  63. type(user_id).__name__, user_id,
  64. type(initial_device_display_name).__name__,
  65. initial_device_display_name, e)
  66. raise StoreError(500, "Problem storing device.")
  67. def get_device(self, user_id, device_id):
  68. """Retrieve a device.
  69. Args:
  70. user_id (str): The ID of the user which owns the device
  71. device_id (str): The ID of the device to retrieve
  72. Returns:
  73. defer.Deferred for a dict containing the device information
  74. Raises:
  75. StoreError: if the device is not found
  76. """
  77. return self._simple_select_one(
  78. table="devices",
  79. keyvalues={"user_id": user_id, "device_id": device_id},
  80. retcols=("user_id", "device_id", "display_name"),
  81. desc="get_device",
  82. )
  83. def delete_device(self, user_id, device_id):
  84. """Delete a device.
  85. Args:
  86. user_id (str): The ID of the user which owns the device
  87. device_id (str): The ID of the device to delete
  88. Returns:
  89. defer.Deferred
  90. """
  91. return self._simple_delete_one(
  92. table="devices",
  93. keyvalues={"user_id": user_id, "device_id": device_id},
  94. desc="delete_device",
  95. )
  96. def delete_devices(self, user_id, device_ids):
  97. """Deletes several devices.
  98. Args:
  99. user_id (str): The ID of the user which owns the devices
  100. device_ids (list): The IDs of the devices to delete
  101. Returns:
  102. defer.Deferred
  103. """
  104. return self._simple_delete_many(
  105. table="devices",
  106. column="device_id",
  107. iterable=device_ids,
  108. keyvalues={"user_id": user_id},
  109. desc="delete_devices",
  110. )
  111. def update_device(self, user_id, device_id, new_display_name=None):
  112. """Update a device.
  113. Args:
  114. user_id (str): The ID of the user which owns the device
  115. device_id (str): The ID of the device to update
  116. new_display_name (str|None): new displayname for device; None
  117. to leave unchanged
  118. Raises:
  119. StoreError: if the device is not found
  120. Returns:
  121. defer.Deferred
  122. """
  123. updates = {}
  124. if new_display_name is not None:
  125. updates["display_name"] = new_display_name
  126. if not updates:
  127. return defer.succeed(None)
  128. return self._simple_update_one(
  129. table="devices",
  130. keyvalues={"user_id": user_id, "device_id": device_id},
  131. updatevalues=updates,
  132. desc="update_device",
  133. )
  134. @defer.inlineCallbacks
  135. def get_devices_by_user(self, user_id):
  136. """Retrieve all of a user's registered devices.
  137. Args:
  138. user_id (str):
  139. Returns:
  140. defer.Deferred: resolves to a dict from device_id to a dict
  141. containing "device_id", "user_id" and "display_name" for each
  142. device.
  143. """
  144. devices = yield self._simple_select_list(
  145. table="devices",
  146. keyvalues={"user_id": user_id},
  147. retcols=("user_id", "device_id", "display_name"),
  148. desc="get_devices_by_user"
  149. )
  150. defer.returnValue({d["device_id"]: d for d in devices})
  151. @cached(max_entries=10000)
  152. def get_device_list_last_stream_id_for_remote(self, user_id):
  153. """Get the last stream_id we got for a user. May be None if we haven't
  154. got any information for them.
  155. """
  156. return self._simple_select_one_onecol(
  157. table="device_lists_remote_extremeties",
  158. keyvalues={"user_id": user_id},
  159. retcol="stream_id",
  160. desc="get_device_list_remote_extremity",
  161. allow_none=True,
  162. )
  163. @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
  164. list_name="user_ids", inlineCallbacks=True)
  165. def get_device_list_last_stream_id_for_remotes(self, user_ids):
  166. rows = yield self._simple_select_many_batch(
  167. table="device_lists_remote_extremeties",
  168. column="user_id",
  169. iterable=user_ids,
  170. retcols=("user_id", "stream_id",),
  171. desc="get_user_devices_from_cache",
  172. )
  173. results = {user_id: None for user_id in user_ids}
  174. results.update({
  175. row["user_id"]: row["stream_id"] for row in rows
  176. })
  177. defer.returnValue(results)
  178. @defer.inlineCallbacks
  179. def mark_remote_user_device_list_as_unsubscribed(self, user_id):
  180. """Mark that we no longer track device lists for remote user.
  181. """
  182. yield self._simple_delete(
  183. table="device_lists_remote_extremeties",
  184. keyvalues={
  185. "user_id": user_id,
  186. },
  187. desc="mark_remote_user_device_list_as_unsubscribed",
  188. )
  189. self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
  190. def update_remote_device_list_cache_entry(self, user_id, device_id, content,
  191. stream_id):
  192. """Updates a single user's device in the cache.
  193. """
  194. return self.runInteraction(
  195. "update_remote_device_list_cache_entry",
  196. self._update_remote_device_list_cache_entry_txn,
  197. user_id, device_id, content, stream_id,
  198. )
  199. def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
  200. content, stream_id):
  201. self._simple_upsert_txn(
  202. txn,
  203. table="device_lists_remote_cache",
  204. keyvalues={
  205. "user_id": user_id,
  206. "device_id": device_id,
  207. },
  208. values={
  209. "content": json.dumps(content),
  210. }
  211. )
  212. txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
  213. txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
  214. txn.call_after(
  215. self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
  216. )
  217. self._simple_upsert_txn(
  218. txn,
  219. table="device_lists_remote_extremeties",
  220. keyvalues={
  221. "user_id": user_id,
  222. },
  223. values={
  224. "stream_id": stream_id,
  225. }
  226. )
  227. def update_remote_device_list_cache(self, user_id, devices, stream_id):
  228. """Replace the cache of the remote user's devices.
  229. """
  230. return self.runInteraction(
  231. "update_remote_device_list_cache",
  232. self._update_remote_device_list_cache_txn,
  233. user_id, devices, stream_id,
  234. )
  235. def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
  236. stream_id):
  237. self._simple_delete_txn(
  238. txn,
  239. table="device_lists_remote_cache",
  240. keyvalues={
  241. "user_id": user_id,
  242. },
  243. )
  244. self._simple_insert_many_txn(
  245. txn,
  246. table="device_lists_remote_cache",
  247. values=[
  248. {
  249. "user_id": user_id,
  250. "device_id": content["device_id"],
  251. "content": json.dumps(content),
  252. }
  253. for content in devices
  254. ]
  255. )
  256. txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
  257. txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
  258. txn.call_after(
  259. self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
  260. )
  261. self._simple_upsert_txn(
  262. txn,
  263. table="device_lists_remote_extremeties",
  264. keyvalues={
  265. "user_id": user_id,
  266. },
  267. values={
  268. "stream_id": stream_id,
  269. }
  270. )
  271. def get_devices_by_remote(self, destination, from_stream_id):
  272. """Get stream of updates to send to remote servers
  273. Returns:
  274. (int, list[dict]): current stream id and list of updates
  275. """
  276. now_stream_id = self._device_list_id_gen.get_current_token()
  277. has_changed = self._device_list_federation_stream_cache.has_entity_changed(
  278. destination, int(from_stream_id)
  279. )
  280. if not has_changed:
  281. return (now_stream_id, [])
  282. return self.runInteraction(
  283. "get_devices_by_remote", self._get_devices_by_remote_txn,
  284. destination, from_stream_id, now_stream_id,
  285. )
  286. def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
  287. now_stream_id):
  288. sql = """
  289. SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
  290. WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
  291. GROUP BY user_id, device_id
  292. """
  293. txn.execute(
  294. sql, (destination, from_stream_id, now_stream_id, False)
  295. )
  296. rows = txn.fetchall()
  297. if not rows:
  298. return (now_stream_id, [])
  299. # maps (user_id, device_id) -> stream_id
  300. query_map = {(r[0], r[1]): r[2] for r in rows}
  301. devices = self._get_e2e_device_keys_txn(
  302. txn, query_map.keys(), include_all_devices=True
  303. )
  304. prev_sent_id_sql = """
  305. SELECT coalesce(max(stream_id), 0) as stream_id
  306. FROM device_lists_outbound_pokes
  307. WHERE destination = ? AND user_id = ? AND stream_id <= ?
  308. """
  309. results = []
  310. for user_id, user_devices in devices.iteritems():
  311. # The prev_id for the first row is always the last row before
  312. # `from_stream_id`
  313. txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
  314. rows = txn.fetchall()
  315. prev_id = rows[0][0]
  316. for device_id, device in user_devices.iteritems():
  317. stream_id = query_map[(user_id, device_id)]
  318. result = {
  319. "user_id": user_id,
  320. "device_id": device_id,
  321. "prev_id": [prev_id] if prev_id else [],
  322. "stream_id": stream_id,
  323. }
  324. prev_id = stream_id
  325. key_json = device.get("key_json", None)
  326. if key_json:
  327. result["keys"] = json.loads(key_json)
  328. device_display_name = device.get("device_display_name", None)
  329. if device_display_name:
  330. result["device_display_name"] = device_display_name
  331. results.append(result)
  332. return (now_stream_id, results)
  333. @defer.inlineCallbacks
  334. def get_user_devices_from_cache(self, query_list):
  335. """Get the devices (and keys if any) for remote users from the cache.
  336. Args:
  337. query_list(list): List of (user_id, device_ids), if device_ids is
  338. falsey then return all device ids for that user.
  339. Returns:
  340. (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
  341. a set of user_ids and results_map is a mapping of
  342. user_id -> device_id -> device_info
  343. """
  344. user_ids = set(user_id for user_id, _ in query_list)
  345. user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
  346. user_ids_in_cache = set(
  347. user_id for user_id, stream_id in user_map.items() if stream_id
  348. )
  349. user_ids_not_in_cache = user_ids - user_ids_in_cache
  350. results = {}
  351. for user_id, device_id in query_list:
  352. if user_id not in user_ids_in_cache:
  353. continue
  354. if device_id:
  355. device = yield self._get_cached_user_device(user_id, device_id)
  356. results.setdefault(user_id, {})[device_id] = device
  357. else:
  358. results[user_id] = yield self._get_cached_devices_for_user(user_id)
  359. defer.returnValue((user_ids_not_in_cache, results))
  360. @cachedInlineCallbacks(num_args=2, tree=True)
  361. def _get_cached_user_device(self, user_id, device_id):
  362. content = yield self._simple_select_one_onecol(
  363. table="device_lists_remote_cache",
  364. keyvalues={
  365. "user_id": user_id,
  366. "device_id": device_id,
  367. },
  368. retcol="content",
  369. desc="_get_cached_user_device",
  370. )
  371. defer.returnValue(json.loads(content))
  372. @cachedInlineCallbacks()
  373. def _get_cached_devices_for_user(self, user_id):
  374. devices = yield self._simple_select_list(
  375. table="device_lists_remote_cache",
  376. keyvalues={
  377. "user_id": user_id,
  378. },
  379. retcols=("device_id", "content"),
  380. desc="_get_cached_devices_for_user",
  381. )
  382. defer.returnValue({
  383. device["device_id"]: json.loads(device["content"])
  384. for device in devices
  385. })
  386. def get_devices_with_keys_by_user(self, user_id):
  387. """Get all devices (with any device keys) for a user
  388. Returns:
  389. (stream_id, devices)
  390. """
  391. return self.runInteraction(
  392. "get_devices_with_keys_by_user",
  393. self._get_devices_with_keys_by_user_txn, user_id,
  394. )
  395. def _get_devices_with_keys_by_user_txn(self, txn, user_id):
  396. now_stream_id = self._device_list_id_gen.get_current_token()
  397. devices = self._get_e2e_device_keys_txn(
  398. txn, [(user_id, None)], include_all_devices=True
  399. )
  400. if devices:
  401. user_devices = devices[user_id]
  402. results = []
  403. for device_id, device in user_devices.iteritems():
  404. result = {
  405. "device_id": device_id,
  406. }
  407. key_json = device.get("key_json", None)
  408. if key_json:
  409. result["keys"] = json.loads(key_json)
  410. device_display_name = device.get("device_display_name", None)
  411. if device_display_name:
  412. result["device_display_name"] = device_display_name
  413. results.append(result)
  414. return now_stream_id, results
  415. return now_stream_id, []
  416. def mark_as_sent_devices_by_remote(self, destination, stream_id):
  417. """Mark that updates have successfully been sent to the destination.
  418. """
  419. return self.runInteraction(
  420. "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
  421. destination, stream_id,
  422. )
  423. def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
  424. # First we DELETE all rows such that only the latest row for each
  425. # (destination, user_id is left. We do this by selecting first and
  426. # deleting.
  427. sql = """
  428. SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
  429. WHERE destination = ? AND stream_id <= ?
  430. GROUP BY user_id
  431. HAVING count(*) > 1
  432. """
  433. txn.execute(sql, (destination, stream_id,))
  434. rows = txn.fetchall()
  435. sql = """
  436. DELETE FROM device_lists_outbound_pokes
  437. WHERE destination = ? AND user_id = ? AND stream_id < ?
  438. """
  439. txn.executemany(
  440. sql, ((destination, row[0], row[1],) for row in rows)
  441. )
  442. # Mark everything that is left as sent
  443. sql = """
  444. UPDATE device_lists_outbound_pokes SET sent = ?
  445. WHERE destination = ? AND stream_id <= ?
  446. """
  447. txn.execute(sql, (True, destination, stream_id,))
  448. @defer.inlineCallbacks
  449. def get_user_whose_devices_changed(self, from_key):
  450. """Get set of users whose devices have changed since `from_key`.
  451. """
  452. from_key = int(from_key)
  453. changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
  454. if changed is not None:
  455. defer.returnValue(set(changed))
  456. sql = """
  457. SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
  458. """
  459. rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
  460. defer.returnValue(set(row[0] for row in rows))
  461. def get_all_device_list_changes_for_remotes(self, from_key):
  462. """Return a list of `(stream_id, user_id, destination)` which is the
  463. combined list of changes to devices, and which destinations need to be
  464. poked. `destination` may be None if no destinations need to be poked.
  465. """
  466. sql = """
  467. SELECT stream_id, user_id, destination FROM device_lists_stream
  468. LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
  469. WHERE stream_id > ?
  470. """
  471. return self._execute(
  472. "get_all_device_list_changes_for_remotes", None,
  473. sql, from_key,
  474. )
  475. @defer.inlineCallbacks
  476. def add_device_change_to_streams(self, user_id, device_ids, hosts):
  477. """Persist that a user's devices have been updated, and which hosts
  478. (if any) should be poked.
  479. """
  480. with self._device_list_id_gen.get_next() as stream_id:
  481. yield self.runInteraction(
  482. "add_device_change_to_streams", self._add_device_change_txn,
  483. user_id, device_ids, hosts, stream_id,
  484. )
  485. defer.returnValue(stream_id)
  486. def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
  487. now = self._clock.time_msec()
  488. txn.call_after(
  489. self._device_list_stream_cache.entity_has_changed,
  490. user_id, stream_id,
  491. )
  492. for host in hosts:
  493. txn.call_after(
  494. self._device_list_federation_stream_cache.entity_has_changed,
  495. host, stream_id,
  496. )
  497. # Delete older entries in the table, as we really only care about
  498. # when the latest change happened.
  499. txn.executemany(
  500. """
  501. DELETE FROM device_lists_stream
  502. WHERE user_id = ? AND device_id = ? AND stream_id < ?
  503. """,
  504. [(user_id, device_id, stream_id) for device_id in device_ids]
  505. )
  506. self._simple_insert_many_txn(
  507. txn,
  508. table="device_lists_stream",
  509. values=[
  510. {
  511. "stream_id": stream_id,
  512. "user_id": user_id,
  513. "device_id": device_id,
  514. }
  515. for device_id in device_ids
  516. ]
  517. )
  518. self._simple_insert_many_txn(
  519. txn,
  520. table="device_lists_outbound_pokes",
  521. values=[
  522. {
  523. "destination": destination,
  524. "stream_id": stream_id,
  525. "user_id": user_id,
  526. "device_id": device_id,
  527. "sent": False,
  528. "ts": now,
  529. }
  530. for destination in hosts
  531. for device_id in device_ids
  532. ]
  533. )
  534. def get_device_stream_token(self):
  535. return self._device_list_id_gen.get_current_token()
  536. def _prune_old_outbound_device_pokes(self):
  537. """Delete old entries out of the device_lists_outbound_pokes to ensure
  538. that we don't fill up due to dead servers. We keep one entry per
  539. (destination, user_id) tuple to ensure that the prev_ids remain correct
  540. if the server does come back.
  541. """
  542. yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
  543. def _prune_txn(txn):
  544. select_sql = """
  545. SELECT destination, user_id, max(stream_id) as stream_id
  546. FROM device_lists_outbound_pokes
  547. GROUP BY destination, user_id
  548. HAVING min(ts) < ? AND count(*) > 1
  549. """
  550. txn.execute(select_sql, (yesterday,))
  551. rows = txn.fetchall()
  552. if not rows:
  553. return
  554. delete_sql = """
  555. DELETE FROM device_lists_outbound_pokes
  556. WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
  557. """
  558. txn.executemany(
  559. delete_sql,
  560. (
  561. (yesterday, row[0], row[1], row[2])
  562. for row in rows
  563. )
  564. )
  565. logger.info("Pruned %d device list outbound pokes", txn.rowcount)
  566. return self.runInteraction(
  567. "_prune_old_outbound_device_pokes", _prune_txn
  568. )