account_data.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import (
  17. TYPE_CHECKING,
  18. Any,
  19. Dict,
  20. FrozenSet,
  21. Iterable,
  22. List,
  23. Mapping,
  24. Optional,
  25. Tuple,
  26. cast,
  27. )
  28. from synapse.api.constants import AccountDataTypes
  29. from synapse.replication.tcp.streams import AccountDataStream
  30. from synapse.storage._base import db_to_json
  31. from synapse.storage.database import (
  32. DatabasePool,
  33. LoggingDatabaseConnection,
  34. LoggingTransaction,
  35. )
  36. from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
  37. from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
  38. from synapse.storage.engines import PostgresEngine
  39. from synapse.storage.util.id_generators import (
  40. AbstractStreamIdGenerator,
  41. MultiWriterIdGenerator,
  42. StreamIdGenerator,
  43. )
  44. from synapse.types import JsonDict, JsonMapping
  45. from synapse.util import json_encoder
  46. from synapse.util.caches.descriptors import cached
  47. from synapse.util.caches.stream_change_cache import StreamChangeCache
  48. if TYPE_CHECKING:
  49. from synapse.server import HomeServer
  50. logger = logging.getLogger(__name__)
  51. class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore):
  52. def __init__(
  53. self,
  54. database: DatabasePool,
  55. db_conn: LoggingDatabaseConnection,
  56. hs: "HomeServer",
  57. ):
  58. super().__init__(database, db_conn, hs)
  59. self._can_write_to_account_data = (
  60. self._instance_name in hs.config.worker.writers.account_data
  61. )
  62. self._account_data_id_gen: AbstractStreamIdGenerator
  63. if isinstance(database.engine, PostgresEngine):
  64. self._account_data_id_gen = MultiWriterIdGenerator(
  65. db_conn=db_conn,
  66. db=database,
  67. notifier=hs.get_replication_notifier(),
  68. stream_name="account_data",
  69. instance_name=self._instance_name,
  70. tables=[
  71. ("room_account_data", "instance_name", "stream_id"),
  72. ("room_tags_revisions", "instance_name", "stream_id"),
  73. ("account_data", "instance_name", "stream_id"),
  74. ],
  75. sequence_name="account_data_sequence",
  76. writers=hs.config.worker.writers.account_data,
  77. )
  78. else:
  79. # Multiple writers are not supported for SQLite.
  80. #
  81. # We shouldn't be running in worker mode with SQLite, but its useful
  82. # to support it for unit tests.
  83. self._account_data_id_gen = StreamIdGenerator(
  84. db_conn,
  85. hs.get_replication_notifier(),
  86. "room_account_data",
  87. "stream_id",
  88. extra_tables=[
  89. ("account_data", "stream_id"),
  90. ("room_tags_revisions", "stream_id"),
  91. ],
  92. is_writer=self._instance_name in hs.config.worker.writers.account_data,
  93. )
  94. account_max = self.get_max_account_data_stream_id()
  95. self._account_data_stream_cache = StreamChangeCache(
  96. "AccountDataAndTagsChangeCache", account_max
  97. )
  98. self.db_pool.updates.register_background_index_update(
  99. update_name="room_account_data_index_room_id",
  100. index_name="room_account_data_room_id",
  101. table="room_account_data",
  102. columns=("room_id",),
  103. )
  104. self.db_pool.updates.register_background_update_handler(
  105. "delete_account_data_for_deactivated_users",
  106. self._delete_account_data_for_deactivated_users,
  107. )
  108. def get_max_account_data_stream_id(self) -> int:
  109. """Get the current max stream ID for account data stream
  110. Returns:
  111. int
  112. """
  113. return self._account_data_id_gen.get_current_token()
  114. @cached()
  115. async def get_global_account_data_for_user(
  116. self, user_id: str
  117. ) -> Mapping[str, JsonMapping]:
  118. """
  119. Get all the global client account_data for a user.
  120. If experimental MSC3391 support is enabled, any entries with an empty
  121. content body are excluded; as this means they have been deleted.
  122. Args:
  123. user_id: The user to get the account_data for.
  124. Returns:
  125. The global account_data.
  126. """
  127. def get_global_account_data_for_user(
  128. txn: LoggingTransaction,
  129. ) -> Dict[str, JsonDict]:
  130. # The 'content != '{}' condition below prevents us from using
  131. # `simple_select_list_txn` here, as it doesn't support conditions
  132. # other than 'equals'.
  133. sql = """
  134. SELECT account_data_type, content FROM account_data
  135. WHERE user_id = ?
  136. """
  137. # If experimental MSC3391 support is enabled, then account data entries
  138. # with an empty content are considered "deleted". So skip adding them to
  139. # the results.
  140. if self.hs.config.experimental.msc3391_enabled:
  141. sql += " AND content != '{}'"
  142. txn.execute(sql, (user_id,))
  143. return {
  144. account_data_type: db_to_json(content)
  145. for account_data_type, content in txn
  146. }
  147. return await self.db_pool.runInteraction(
  148. "get_global_account_data_for_user", get_global_account_data_for_user
  149. )
  150. @cached()
  151. async def get_room_account_data_for_user(
  152. self, user_id: str
  153. ) -> Mapping[str, Mapping[str, JsonMapping]]:
  154. """
  155. Get all of the per-room client account_data for a user.
  156. If experimental MSC3391 support is enabled, any entries with an empty
  157. content body are excluded; as this means they have been deleted.
  158. Args:
  159. user_id: The user to get the account_data for.
  160. Returns:
  161. A dict mapping from room_id string to per-room account_data dicts.
  162. """
  163. def get_room_account_data_for_user_txn(
  164. txn: LoggingTransaction,
  165. ) -> Dict[str, Dict[str, JsonDict]]:
  166. # The 'content != '{}' condition below prevents us from using
  167. # `simple_select_list_txn` here, as it doesn't support conditions
  168. # other than 'equals'.
  169. sql = """
  170. SELECT room_id, account_data_type, content FROM room_account_data
  171. WHERE user_id = ?
  172. """
  173. # If experimental MSC3391 support is enabled, then account data entries
  174. # with an empty content are considered "deleted". So skip adding them to
  175. # the results.
  176. if self.hs.config.experimental.msc3391_enabled:
  177. sql += " AND content != '{}'"
  178. txn.execute(sql, (user_id,))
  179. by_room: Dict[str, Dict[str, JsonDict]] = {}
  180. for room_id, account_data_type, content in txn:
  181. room_data = by_room.setdefault(room_id, {})
  182. room_data[account_data_type] = db_to_json(content)
  183. return by_room
  184. return await self.db_pool.runInteraction(
  185. "get_room_account_data_for_user_txn", get_room_account_data_for_user_txn
  186. )
  187. @cached(num_args=2, max_entries=5000, tree=True)
  188. async def get_global_account_data_by_type_for_user(
  189. self, user_id: str, data_type: str
  190. ) -> Optional[JsonMapping]:
  191. """
  192. Returns:
  193. The account data.
  194. """
  195. result = await self.db_pool.simple_select_one_onecol(
  196. table="account_data",
  197. keyvalues={"user_id": user_id, "account_data_type": data_type},
  198. retcol="content",
  199. desc="get_global_account_data_by_type_for_user",
  200. allow_none=True,
  201. )
  202. if result:
  203. return db_to_json(result)
  204. else:
  205. return None
  206. async def get_latest_stream_id_for_global_account_data_by_type_for_user(
  207. self, user_id: str, data_type: str
  208. ) -> Optional[int]:
  209. """
  210. Returns:
  211. The stream ID of the account data,
  212. or None if there is no such account data.
  213. """
  214. def get_latest_stream_id_for_global_account_data_by_type_for_user_txn(
  215. txn: LoggingTransaction,
  216. ) -> Optional[int]:
  217. sql = """
  218. SELECT stream_id FROM account_data
  219. WHERE user_id = ? AND account_data_type = ?
  220. ORDER BY stream_id DESC
  221. LIMIT 1
  222. """
  223. txn.execute(sql, (user_id, data_type))
  224. row = txn.fetchone()
  225. if row:
  226. return row[0]
  227. else:
  228. return None
  229. return await self.db_pool.runInteraction(
  230. "get_latest_stream_id_for_global_account_data_by_type_for_user",
  231. get_latest_stream_id_for_global_account_data_by_type_for_user_txn,
  232. )
  233. @cached(num_args=2, tree=True)
  234. async def get_account_data_for_room(
  235. self, user_id: str, room_id: str
  236. ) -> Mapping[str, JsonMapping]:
  237. """Get all the client account_data for a user for a room.
  238. Args:
  239. user_id: The user to get the account_data for.
  240. room_id: The room to get the account_data for.
  241. Returns:
  242. A dict of the room account_data
  243. """
  244. def get_account_data_for_room_txn(
  245. txn: LoggingTransaction,
  246. ) -> Dict[str, JsonDict]:
  247. rows = self.db_pool.simple_select_list_txn(
  248. txn,
  249. "room_account_data",
  250. {"user_id": user_id, "room_id": room_id},
  251. ["account_data_type", "content"],
  252. )
  253. return {
  254. row["account_data_type"]: db_to_json(row["content"]) for row in rows
  255. }
  256. return await self.db_pool.runInteraction(
  257. "get_account_data_for_room", get_account_data_for_room_txn
  258. )
  259. @cached(num_args=3, max_entries=5000, tree=True)
  260. async def get_account_data_for_room_and_type(
  261. self, user_id: str, room_id: str, account_data_type: str
  262. ) -> Optional[JsonMapping]:
  263. """Get the client account_data of given type for a user for a room.
  264. Args:
  265. user_id: The user to get the account_data for.
  266. room_id: The room to get the account_data for.
  267. account_data_type: The account data type to get.
  268. Returns:
  269. The room account_data for that type, or None if there isn't any set.
  270. """
  271. def get_account_data_for_room_and_type_txn(
  272. txn: LoggingTransaction,
  273. ) -> Optional[JsonDict]:
  274. content_json = self.db_pool.simple_select_one_onecol_txn(
  275. txn,
  276. table="room_account_data",
  277. keyvalues={
  278. "user_id": user_id,
  279. "room_id": room_id,
  280. "account_data_type": account_data_type,
  281. },
  282. retcol="content",
  283. allow_none=True,
  284. )
  285. return db_to_json(content_json) if content_json else None
  286. return await self.db_pool.runInteraction(
  287. "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
  288. )
  289. async def get_updated_global_account_data(
  290. self, last_id: int, current_id: int, limit: int
  291. ) -> List[Tuple[int, str, str]]:
  292. """Get the global account_data that has changed, for the account_data stream
  293. Args:
  294. last_id: the last stream_id from the previous batch.
  295. current_id: the maximum stream_id to return up to
  296. limit: the maximum number of rows to return
  297. Returns:
  298. A list of tuples of stream_id int, user_id string,
  299. and type string.
  300. """
  301. if last_id == current_id:
  302. return []
  303. def get_updated_global_account_data_txn(
  304. txn: LoggingTransaction,
  305. ) -> List[Tuple[int, str, str]]:
  306. sql = (
  307. "SELECT stream_id, user_id, account_data_type"
  308. " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
  309. " ORDER BY stream_id ASC LIMIT ?"
  310. )
  311. txn.execute(sql, (last_id, current_id, limit))
  312. return cast(List[Tuple[int, str, str]], txn.fetchall())
  313. return await self.db_pool.runInteraction(
  314. "get_updated_global_account_data", get_updated_global_account_data_txn
  315. )
  316. async def get_updated_room_account_data(
  317. self, last_id: int, current_id: int, limit: int
  318. ) -> List[Tuple[int, str, str, str]]:
  319. """Get the global account_data that has changed, for the account_data stream
  320. Args:
  321. last_id: the last stream_id from the previous batch.
  322. current_id: the maximum stream_id to return up to
  323. limit: the maximum number of rows to return
  324. Returns:
  325. A list of tuples of stream_id int, user_id string,
  326. room_id string and type string.
  327. """
  328. if last_id == current_id:
  329. return []
  330. def get_updated_room_account_data_txn(
  331. txn: LoggingTransaction,
  332. ) -> List[Tuple[int, str, str, str]]:
  333. sql = (
  334. "SELECT stream_id, user_id, room_id, account_data_type"
  335. " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
  336. " ORDER BY stream_id ASC LIMIT ?"
  337. )
  338. txn.execute(sql, (last_id, current_id, limit))
  339. return cast(List[Tuple[int, str, str, str]], txn.fetchall())
  340. return await self.db_pool.runInteraction(
  341. "get_updated_room_account_data", get_updated_room_account_data_txn
  342. )
  343. async def get_updated_global_account_data_for_user(
  344. self, user_id: str, stream_id: int
  345. ) -> Mapping[str, JsonMapping]:
  346. """Get all the global account_data that's changed for a user.
  347. Args:
  348. user_id: The user to get the account_data for.
  349. stream_id: The point in the stream since which to get updates
  350. Returns:
  351. A dict of global account_data.
  352. """
  353. def get_updated_global_account_data_for_user(
  354. txn: LoggingTransaction,
  355. ) -> Dict[str, JsonDict]:
  356. sql = """
  357. SELECT account_data_type, content FROM account_data
  358. WHERE user_id = ? AND stream_id > ?
  359. """
  360. txn.execute(sql, (user_id, stream_id))
  361. return {row[0]: db_to_json(row[1]) for row in txn}
  362. changed = self._account_data_stream_cache.has_entity_changed(
  363. user_id, int(stream_id)
  364. )
  365. if not changed:
  366. return {}
  367. return await self.db_pool.runInteraction(
  368. "get_updated_global_account_data_for_user",
  369. get_updated_global_account_data_for_user,
  370. )
  371. async def get_updated_room_account_data_for_user(
  372. self, user_id: str, stream_id: int
  373. ) -> Dict[str, Dict[str, JsonDict]]:
  374. """Get all the room account_data that's changed for a user.
  375. Args:
  376. user_id: The user to get the account_data for.
  377. stream_id: The point in the stream since which to get updates
  378. Returns:
  379. A dict mapping from room_id string to per room account_data dicts.
  380. """
  381. def get_updated_room_account_data_for_user_txn(
  382. txn: LoggingTransaction,
  383. ) -> Dict[str, Dict[str, JsonDict]]:
  384. sql = """
  385. SELECT room_id, account_data_type, content FROM room_account_data
  386. WHERE user_id = ? AND stream_id > ?
  387. """
  388. txn.execute(sql, (user_id, stream_id))
  389. account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
  390. for row in txn:
  391. room_account_data = account_data_by_room.setdefault(row[0], {})
  392. room_account_data[row[1]] = db_to_json(row[2])
  393. return account_data_by_room
  394. changed = self._account_data_stream_cache.has_entity_changed(
  395. user_id, int(stream_id)
  396. )
  397. if not changed:
  398. return {}
  399. return await self.db_pool.runInteraction(
  400. "get_updated_room_account_data_for_user",
  401. get_updated_room_account_data_for_user_txn,
  402. )
  403. @cached(max_entries=5000, iterable=True)
  404. async def ignored_by(self, user_id: str) -> FrozenSet[str]:
  405. """
  406. Get users which ignore the given user.
  407. Params:
  408. user_id: The user ID which might be ignored.
  409. Return:
  410. The user IDs which ignore the given user.
  411. """
  412. return frozenset(
  413. await self.db_pool.simple_select_onecol(
  414. table="ignored_users",
  415. keyvalues={"ignored_user_id": user_id},
  416. retcol="ignorer_user_id",
  417. desc="ignored_by",
  418. )
  419. )
  420. @cached(max_entries=5000, iterable=True)
  421. async def ignored_users(self, user_id: str) -> FrozenSet[str]:
  422. """
  423. Get users which the given user ignores.
  424. Params:
  425. user_id: The user ID which is making the request.
  426. Return:
  427. The user IDs which are ignored by the given user.
  428. """
  429. return frozenset(
  430. await self.db_pool.simple_select_onecol(
  431. table="ignored_users",
  432. keyvalues={"ignorer_user_id": user_id},
  433. retcol="ignored_user_id",
  434. desc="ignored_users",
  435. )
  436. )
  437. def process_replication_rows(
  438. self,
  439. stream_name: str,
  440. instance_name: str,
  441. token: int,
  442. rows: Iterable[Any],
  443. ) -> None:
  444. if stream_name == AccountDataStream.NAME:
  445. for row in rows:
  446. if not row.room_id:
  447. self.get_global_account_data_by_type_for_user.invalidate(
  448. (row.user_id, row.data_type)
  449. )
  450. self.get_global_account_data_for_user.invalidate((row.user_id,))
  451. self.get_room_account_data_for_user.invalidate((row.user_id,))
  452. self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
  453. self.get_account_data_for_room_and_type.invalidate(
  454. (row.user_id, row.room_id, row.data_type)
  455. )
  456. self._account_data_stream_cache.entity_has_changed(row.user_id, token)
  457. super().process_replication_rows(stream_name, instance_name, token, rows)
  458. def process_replication_position(
  459. self, stream_name: str, instance_name: str, token: int
  460. ) -> None:
  461. if stream_name == AccountDataStream.NAME:
  462. self._account_data_id_gen.advance(instance_name, token)
  463. super().process_replication_position(stream_name, instance_name, token)
  464. async def add_account_data_to_room(
  465. self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
  466. ) -> int:
  467. """Add some account_data to a room for a user.
  468. Args:
  469. user_id: The user to add a tag for.
  470. room_id: The room to add a tag for.
  471. account_data_type: The type of account_data to add.
  472. content: A json object to associate with the tag.
  473. Returns:
  474. The maximum stream ID.
  475. """
  476. assert self._can_write_to_account_data
  477. content_json = json_encoder.encode(content)
  478. async with self._account_data_id_gen.get_next() as next_id:
  479. await self.db_pool.simple_upsert(
  480. desc="add_room_account_data",
  481. table="room_account_data",
  482. keyvalues={
  483. "user_id": user_id,
  484. "room_id": room_id,
  485. "account_data_type": account_data_type,
  486. },
  487. values={"stream_id": next_id, "content": content_json},
  488. )
  489. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  490. self.get_room_account_data_for_user.invalidate((user_id,))
  491. self.get_account_data_for_room.invalidate((user_id, room_id))
  492. self.get_account_data_for_room_and_type.prefill(
  493. (user_id, room_id, account_data_type), content
  494. )
  495. return self._account_data_id_gen.get_current_token()
  496. async def remove_account_data_for_room(
  497. self, user_id: str, room_id: str, account_data_type: str
  498. ) -> int:
  499. """Delete the room account data for the user of a given type.
  500. Args:
  501. user_id: The user to remove account_data for.
  502. room_id: The room ID to scope the request to.
  503. account_data_type: The account data type to delete.
  504. Returns:
  505. The maximum stream position, or None if there was no matching room account
  506. data to delete.
  507. """
  508. assert self._can_write_to_account_data
  509. def _remove_account_data_for_room_txn(
  510. txn: LoggingTransaction, next_id: int
  511. ) -> bool:
  512. """
  513. Args:
  514. txn: The transaction object.
  515. next_id: The stream_id to update any existing rows to.
  516. Returns:
  517. True if an entry in room_account_data had its content set to '{}',
  518. otherwise False. This informs callers of whether there actually was an
  519. existing room account data entry to delete, or if the call was a no-op.
  520. """
  521. # We can't use `simple_update` as it doesn't have the ability to specify
  522. # where clauses other than '=', which we need for `content != '{}'` below.
  523. sql = """
  524. UPDATE room_account_data
  525. SET stream_id = ?, content = '{}'
  526. WHERE user_id = ?
  527. AND room_id = ?
  528. AND account_data_type = ?
  529. AND content != '{}'
  530. """
  531. txn.execute(
  532. sql,
  533. (next_id, user_id, room_id, account_data_type),
  534. )
  535. # Return true if any rows were updated.
  536. return txn.rowcount != 0
  537. async with self._account_data_id_gen.get_next() as next_id:
  538. row_updated = await self.db_pool.runInteraction(
  539. "remove_account_data_for_room",
  540. _remove_account_data_for_room_txn,
  541. next_id,
  542. )
  543. if row_updated:
  544. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  545. self.get_room_account_data_for_user.invalidate((user_id,))
  546. self.get_account_data_for_room.invalidate((user_id, room_id))
  547. self.get_account_data_for_room_and_type.prefill(
  548. (user_id, room_id, account_data_type), {}
  549. )
  550. return self._account_data_id_gen.get_current_token()
  551. async def add_account_data_for_user(
  552. self, user_id: str, account_data_type: str, content: JsonDict
  553. ) -> int:
  554. """Add some global account_data for a user.
  555. Args:
  556. user_id: The user to add a tag for.
  557. account_data_type: The type of account_data to add.
  558. content: A json object to associate with the tag.
  559. Returns:
  560. The maximum stream ID.
  561. """
  562. assert self._can_write_to_account_data
  563. async with self._account_data_id_gen.get_next() as next_id:
  564. await self.db_pool.runInteraction(
  565. "add_user_account_data",
  566. self._add_account_data_for_user,
  567. next_id,
  568. user_id,
  569. account_data_type,
  570. content,
  571. )
  572. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  573. self.get_global_account_data_for_user.invalidate((user_id,))
  574. self.get_global_account_data_by_type_for_user.invalidate(
  575. (user_id, account_data_type)
  576. )
  577. return self._account_data_id_gen.get_current_token()
  578. def _add_account_data_for_user(
  579. self,
  580. txn: LoggingTransaction,
  581. next_id: int,
  582. user_id: str,
  583. account_data_type: str,
  584. content: JsonDict,
  585. ) -> None:
  586. content_json = json_encoder.encode(content)
  587. self.db_pool.simple_upsert_txn(
  588. txn,
  589. table="account_data",
  590. keyvalues={"user_id": user_id, "account_data_type": account_data_type},
  591. values={"stream_id": next_id, "content": content_json},
  592. )
  593. # Ignored users get denormalized into a separate table as an optimisation.
  594. if account_data_type != AccountDataTypes.IGNORED_USER_LIST:
  595. return
  596. # Insert / delete to sync the list of ignored users.
  597. previously_ignored_users = set(
  598. self.db_pool.simple_select_onecol_txn(
  599. txn,
  600. table="ignored_users",
  601. keyvalues={"ignorer_user_id": user_id},
  602. retcol="ignored_user_id",
  603. )
  604. )
  605. # If the data is invalid, no one is ignored.
  606. ignored_users_content = content.get("ignored_users", {})
  607. if isinstance(ignored_users_content, dict):
  608. currently_ignored_users = set(ignored_users_content)
  609. else:
  610. currently_ignored_users = set()
  611. # If the data has not changed, nothing to do.
  612. if previously_ignored_users == currently_ignored_users:
  613. return
  614. # Delete entries which are no longer ignored.
  615. self.db_pool.simple_delete_many_txn(
  616. txn,
  617. table="ignored_users",
  618. column="ignored_user_id",
  619. values=previously_ignored_users - currently_ignored_users,
  620. keyvalues={"ignorer_user_id": user_id},
  621. )
  622. # Add entries which are newly ignored.
  623. self.db_pool.simple_insert_many_txn(
  624. txn,
  625. table="ignored_users",
  626. keys=("ignorer_user_id", "ignored_user_id"),
  627. values=[
  628. (user_id, u) for u in currently_ignored_users - previously_ignored_users
  629. ],
  630. )
  631. # Invalidate the cache for any ignored users which were added or removed.
  632. for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
  633. self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
  634. self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
  635. async def remove_account_data_for_user(
  636. self,
  637. user_id: str,
  638. account_data_type: str,
  639. ) -> int:
  640. """
  641. Delete a single piece of user account data by type.
  642. A "delete" is performed by updating a potentially existing row in the
  643. "account_data" database table for (user_id, account_data_type) and
  644. setting its content to "{}".
  645. Args:
  646. user_id: The user ID to modify the account data of.
  647. account_data_type: The type to remove.
  648. Returns:
  649. The maximum stream position, or None if there was no matching account data
  650. to delete.
  651. """
  652. assert self._can_write_to_account_data
  653. def _remove_account_data_for_user_txn(
  654. txn: LoggingTransaction, next_id: int
  655. ) -> bool:
  656. """
  657. Args:
  658. txn: The transaction object.
  659. next_id: The stream_id to update any existing rows to.
  660. Returns:
  661. True if an entry in account_data had its content set to '{}', otherwise
  662. False. This informs callers of whether there actually was an existing
  663. account data entry to delete, or if the call was a no-op.
  664. """
  665. # We can't use `simple_update` as it doesn't have the ability to specify
  666. # where clauses other than '=', which we need for `content != '{}'` below.
  667. sql = """
  668. UPDATE account_data
  669. SET stream_id = ?, content = '{}'
  670. WHERE user_id = ?
  671. AND account_data_type = ?
  672. AND content != '{}'
  673. """
  674. txn.execute(sql, (next_id, user_id, account_data_type))
  675. if txn.rowcount == 0:
  676. # We didn't update any rows. This means that there was no matching room
  677. # account data entry to delete in the first place.
  678. return False
  679. # Ignored users get denormalized into a separate table as an optimisation.
  680. if account_data_type == AccountDataTypes.IGNORED_USER_LIST:
  681. # If this method was called with the ignored users account data type, we
  682. # simply delete all ignored users.
  683. # First pull all the users that this user ignores.
  684. previously_ignored_users = set(
  685. self.db_pool.simple_select_onecol_txn(
  686. txn,
  687. table="ignored_users",
  688. keyvalues={"ignorer_user_id": user_id},
  689. retcol="ignored_user_id",
  690. )
  691. )
  692. # Then delete them from the database.
  693. self.db_pool.simple_delete_txn(
  694. txn,
  695. table="ignored_users",
  696. keyvalues={"ignorer_user_id": user_id},
  697. )
  698. # Invalidate the cache for ignored users which were removed.
  699. for ignored_user_id in previously_ignored_users:
  700. self._invalidate_cache_and_stream(
  701. txn, self.ignored_by, (ignored_user_id,)
  702. )
  703. # Invalidate for this user the cache tracking ignored users.
  704. self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
  705. return True
  706. async with self._account_data_id_gen.get_next() as next_id:
  707. row_updated = await self.db_pool.runInteraction(
  708. "remove_account_data_for_user",
  709. _remove_account_data_for_user_txn,
  710. next_id,
  711. )
  712. if row_updated:
  713. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  714. self.get_global_account_data_for_user.invalidate((user_id,))
  715. self.get_global_account_data_by_type_for_user.prefill(
  716. (user_id, account_data_type), {}
  717. )
  718. return self._account_data_id_gen.get_current_token()
  719. async def purge_account_data_for_user(self, user_id: str) -> None:
  720. """
  721. Removes ALL the account data for a user.
  722. Intended to be used upon user deactivation.
  723. Also purges the user from the ignored_users cache table
  724. and the push_rules cache tables.
  725. """
  726. await self.db_pool.runInteraction(
  727. "purge_account_data_for_user_txn",
  728. self._purge_account_data_for_user_txn,
  729. user_id,
  730. )
  731. def _purge_account_data_for_user_txn(
  732. self, txn: LoggingTransaction, user_id: str
  733. ) -> None:
  734. """
  735. See `purge_account_data_for_user`.
  736. """
  737. # Purge from the primary account_data tables.
  738. self.db_pool.simple_delete_txn(
  739. txn, table="account_data", keyvalues={"user_id": user_id}
  740. )
  741. self.db_pool.simple_delete_txn(
  742. txn, table="room_account_data", keyvalues={"user_id": user_id}
  743. )
  744. # Purge from ignored_users where this user is the ignorer.
  745. # N.B. We don't purge where this user is the ignoree, because that
  746. # interferes with other users' account data.
  747. # It's also not this user's data to delete!
  748. self.db_pool.simple_delete_txn(
  749. txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
  750. )
  751. # Remove the push rules
  752. self.db_pool.simple_delete_txn(
  753. txn, table="push_rules", keyvalues={"user_name": user_id}
  754. )
  755. self.db_pool.simple_delete_txn(
  756. txn, table="push_rules_enable", keyvalues={"user_name": user_id}
  757. )
  758. self.db_pool.simple_delete_txn(
  759. txn, table="push_rules_stream", keyvalues={"user_id": user_id}
  760. )
  761. # Invalidate caches as appropriate
  762. self._invalidate_cache_and_stream(
  763. txn, self.get_account_data_for_room_and_type, (user_id,)
  764. )
  765. self._invalidate_cache_and_stream(
  766. txn, self.get_global_account_data_for_user, (user_id,)
  767. )
  768. self._invalidate_cache_and_stream(
  769. txn, self.get_room_account_data_for_user, (user_id,)
  770. )
  771. self._invalidate_cache_and_stream(
  772. txn, self.get_global_account_data_by_type_for_user, (user_id,)
  773. )
  774. self._invalidate_cache_and_stream(
  775. txn, self.get_account_data_for_room, (user_id,)
  776. )
  777. self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
  778. # This user might be contained in the ignored_by cache for other users,
  779. # so we have to invalidate it all.
  780. self._invalidate_all_cache_and_stream(txn, self.ignored_by)
  781. async def _delete_account_data_for_deactivated_users(
  782. self, progress: dict, batch_size: int
  783. ) -> int:
  784. """
  785. Retroactively purges account data for users that have already been deactivated.
  786. Gets run as a background update caused by a schema delta.
  787. """
  788. last_user: str = progress.get("last_user", "")
  789. def _delete_account_data_for_deactivated_users_txn(
  790. txn: LoggingTransaction,
  791. ) -> int:
  792. sql = """
  793. SELECT name FROM users
  794. WHERE deactivated = ? and name > ?
  795. ORDER BY name ASC
  796. LIMIT ?
  797. """
  798. txn.execute(sql, (1, last_user, batch_size))
  799. users = [row[0] for row in txn]
  800. for user in users:
  801. self._purge_account_data_for_user_txn(txn, user_id=user)
  802. if users:
  803. self.db_pool.updates._background_update_progress_txn(
  804. txn,
  805. "delete_account_data_for_deactivated_users",
  806. {"last_user": users[-1]},
  807. )
  808. return len(users)
  809. number_deleted = await self.db_pool.runInteraction(
  810. "_delete_account_data_for_deactivated_users",
  811. _delete_account_data_for_deactivated_users_txn,
  812. )
  813. if number_deleted < batch_size:
  814. await self.db_pool.updates._end_background_update(
  815. "delete_account_data_for_deactivated_users"
  816. )
  817. return number_deleted
  818. class AccountDataStore(AccountDataWorkerStore):
  819. pass