account_data.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import (
  17. TYPE_CHECKING,
  18. Any,
  19. Dict,
  20. FrozenSet,
  21. Iterable,
  22. List,
  23. Optional,
  24. Tuple,
  25. cast,
  26. )
  27. from synapse.api.constants import AccountDataTypes
  28. from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
  29. from synapse.storage._base import db_to_json
  30. from synapse.storage.database import (
  31. DatabasePool,
  32. LoggingDatabaseConnection,
  33. LoggingTransaction,
  34. )
  35. from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
  36. from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
  37. from synapse.storage.engines import PostgresEngine
  38. from synapse.storage.util.id_generators import (
  39. AbstractStreamIdGenerator,
  40. AbstractStreamIdTracker,
  41. MultiWriterIdGenerator,
  42. StreamIdGenerator,
  43. )
  44. from synapse.types import JsonDict
  45. from synapse.util import json_encoder
  46. from synapse.util.caches.descriptors import cached
  47. from synapse.util.caches.stream_change_cache import StreamChangeCache
  48. if TYPE_CHECKING:
  49. from synapse.server import HomeServer
  50. logger = logging.getLogger(__name__)
  51. class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore):
  52. def __init__(
  53. self,
  54. database: DatabasePool,
  55. db_conn: LoggingDatabaseConnection,
  56. hs: "HomeServer",
  57. ):
  58. super().__init__(database, db_conn, hs)
  59. # `_can_write_to_account_data` indicates whether the current worker is allowed
  60. # to write account data. A value of `True` implies that `_account_data_id_gen`
  61. # is an `AbstractStreamIdGenerator` and not just a tracker.
  62. self._account_data_id_gen: AbstractStreamIdTracker
  63. self._can_write_to_account_data = (
  64. self._instance_name in hs.config.worker.writers.account_data
  65. )
  66. if isinstance(database.engine, PostgresEngine):
  67. self._account_data_id_gen = MultiWriterIdGenerator(
  68. db_conn=db_conn,
  69. db=database,
  70. stream_name="account_data",
  71. instance_name=self._instance_name,
  72. tables=[
  73. ("room_account_data", "instance_name", "stream_id"),
  74. ("room_tags_revisions", "instance_name", "stream_id"),
  75. ("account_data", "instance_name", "stream_id"),
  76. ],
  77. sequence_name="account_data_sequence",
  78. writers=hs.config.worker.writers.account_data,
  79. )
  80. else:
  81. # We shouldn't be running in worker mode with SQLite, but its useful
  82. # to support it for unit tests.
  83. #
  84. # If this process is the writer than we need to use
  85. # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
  86. # updated over replication. (Multiple writers are not supported for
  87. # SQLite).
  88. self._account_data_id_gen = StreamIdGenerator(
  89. db_conn,
  90. "room_account_data",
  91. "stream_id",
  92. extra_tables=[("room_tags_revisions", "stream_id")],
  93. is_writer=self._instance_name in hs.config.worker.writers.account_data,
  94. )
  95. account_max = self.get_max_account_data_stream_id()
  96. self._account_data_stream_cache = StreamChangeCache(
  97. "AccountDataAndTagsChangeCache", account_max
  98. )
  99. self.db_pool.updates.register_background_update_handler(
  100. "delete_account_data_for_deactivated_users",
  101. self._delete_account_data_for_deactivated_users,
  102. )
  103. def get_max_account_data_stream_id(self) -> int:
  104. """Get the current max stream ID for account data stream
  105. Returns:
  106. int
  107. """
  108. return self._account_data_id_gen.get_current_token()
  109. @cached()
  110. async def get_account_data_for_user(
  111. self, user_id: str
  112. ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
  113. """
  114. Get all the client account_data for a user.
  115. If experimental MSC3391 support is enabled, any entries with an empty
  116. content body are excluded; as this means they have been deleted.
  117. Args:
  118. user_id: The user to get the account_data for.
  119. Returns:
  120. A 2-tuple of a dict of global account_data and a dict mapping from
  121. room_id string to per room account_data dicts.
  122. """
  123. def get_account_data_for_user_txn(
  124. txn: LoggingTransaction,
  125. ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
  126. # The 'content != '{}' condition below prevents us from using
  127. # `simple_select_list_txn` here, as it doesn't support conditions
  128. # other than 'equals'.
  129. sql = """
  130. SELECT account_data_type, content FROM account_data
  131. WHERE user_id = ?
  132. """
  133. # If experimental MSC3391 support is enabled, then account data entries
  134. # with an empty content are considered "deleted". So skip adding them to
  135. # the results.
  136. if self.hs.config.experimental.msc3391_enabled:
  137. sql += " AND content != '{}'"
  138. txn.execute(sql, (user_id,))
  139. rows = self.db_pool.cursor_to_dict(txn)
  140. global_account_data = {
  141. row["account_data_type"]: db_to_json(row["content"]) for row in rows
  142. }
  143. # The 'content != '{}' condition below prevents us from using
  144. # `simple_select_list_txn` here, as it doesn't support conditions
  145. # other than 'equals'.
  146. sql = """
  147. SELECT room_id, account_data_type, content FROM room_account_data
  148. WHERE user_id = ?
  149. """
  150. # If experimental MSC3391 support is enabled, then account data entries
  151. # with an empty content are considered "deleted". So skip adding them to
  152. # the results.
  153. if self.hs.config.experimental.msc3391_enabled:
  154. sql += " AND content != '{}'"
  155. txn.execute(sql, (user_id,))
  156. rows = self.db_pool.cursor_to_dict(txn)
  157. by_room: Dict[str, Dict[str, JsonDict]] = {}
  158. for row in rows:
  159. room_data = by_room.setdefault(row["room_id"], {})
  160. room_data[row["account_data_type"]] = db_to_json(row["content"])
  161. return global_account_data, by_room
  162. return await self.db_pool.runInteraction(
  163. "get_account_data_for_user", get_account_data_for_user_txn
  164. )
  165. @cached(num_args=2, max_entries=5000, tree=True)
  166. async def get_global_account_data_by_type_for_user(
  167. self, user_id: str, data_type: str
  168. ) -> Optional[JsonDict]:
  169. """
  170. Returns:
  171. The account data.
  172. """
  173. result = await self.db_pool.simple_select_one_onecol(
  174. table="account_data",
  175. keyvalues={"user_id": user_id, "account_data_type": data_type},
  176. retcol="content",
  177. desc="get_global_account_data_by_type_for_user",
  178. allow_none=True,
  179. )
  180. if result:
  181. return db_to_json(result)
  182. else:
  183. return None
  184. @cached(num_args=2, tree=True)
  185. async def get_account_data_for_room(
  186. self, user_id: str, room_id: str
  187. ) -> Dict[str, JsonDict]:
  188. """Get all the client account_data for a user for a room.
  189. Args:
  190. user_id: The user to get the account_data for.
  191. room_id: The room to get the account_data for.
  192. Returns:
  193. A dict of the room account_data
  194. """
  195. def get_account_data_for_room_txn(
  196. txn: LoggingTransaction,
  197. ) -> Dict[str, JsonDict]:
  198. rows = self.db_pool.simple_select_list_txn(
  199. txn,
  200. "room_account_data",
  201. {"user_id": user_id, "room_id": room_id},
  202. ["account_data_type", "content"],
  203. )
  204. return {
  205. row["account_data_type"]: db_to_json(row["content"]) for row in rows
  206. }
  207. return await self.db_pool.runInteraction(
  208. "get_account_data_for_room", get_account_data_for_room_txn
  209. )
  210. @cached(num_args=3, max_entries=5000, tree=True)
  211. async def get_account_data_for_room_and_type(
  212. self, user_id: str, room_id: str, account_data_type: str
  213. ) -> Optional[JsonDict]:
  214. """Get the client account_data of given type for a user for a room.
  215. Args:
  216. user_id: The user to get the account_data for.
  217. room_id: The room to get the account_data for.
  218. account_data_type: The account data type to get.
  219. Returns:
  220. The room account_data for that type, or None if there isn't any set.
  221. """
  222. def get_account_data_for_room_and_type_txn(
  223. txn: LoggingTransaction,
  224. ) -> Optional[JsonDict]:
  225. content_json = self.db_pool.simple_select_one_onecol_txn(
  226. txn,
  227. table="room_account_data",
  228. keyvalues={
  229. "user_id": user_id,
  230. "room_id": room_id,
  231. "account_data_type": account_data_type,
  232. },
  233. retcol="content",
  234. allow_none=True,
  235. )
  236. return db_to_json(content_json) if content_json else None
  237. return await self.db_pool.runInteraction(
  238. "get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
  239. )
  240. async def get_updated_global_account_data(
  241. self, last_id: int, current_id: int, limit: int
  242. ) -> List[Tuple[int, str, str]]:
  243. """Get the global account_data that has changed, for the account_data stream
  244. Args:
  245. last_id: the last stream_id from the previous batch.
  246. current_id: the maximum stream_id to return up to
  247. limit: the maximum number of rows to return
  248. Returns:
  249. A list of tuples of stream_id int, user_id string,
  250. and type string.
  251. """
  252. if last_id == current_id:
  253. return []
  254. def get_updated_global_account_data_txn(
  255. txn: LoggingTransaction,
  256. ) -> List[Tuple[int, str, str]]:
  257. sql = (
  258. "SELECT stream_id, user_id, account_data_type"
  259. " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
  260. " ORDER BY stream_id ASC LIMIT ?"
  261. )
  262. txn.execute(sql, (last_id, current_id, limit))
  263. return cast(List[Tuple[int, str, str]], txn.fetchall())
  264. return await self.db_pool.runInteraction(
  265. "get_updated_global_account_data", get_updated_global_account_data_txn
  266. )
  267. async def get_updated_room_account_data(
  268. self, last_id: int, current_id: int, limit: int
  269. ) -> List[Tuple[int, str, str, str]]:
  270. """Get the global account_data that has changed, for the account_data stream
  271. Args:
  272. last_id: the last stream_id from the previous batch.
  273. current_id: the maximum stream_id to return up to
  274. limit: the maximum number of rows to return
  275. Returns:
  276. A list of tuples of stream_id int, user_id string,
  277. room_id string and type string.
  278. """
  279. if last_id == current_id:
  280. return []
  281. def get_updated_room_account_data_txn(
  282. txn: LoggingTransaction,
  283. ) -> List[Tuple[int, str, str, str]]:
  284. sql = (
  285. "SELECT stream_id, user_id, room_id, account_data_type"
  286. " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
  287. " ORDER BY stream_id ASC LIMIT ?"
  288. )
  289. txn.execute(sql, (last_id, current_id, limit))
  290. return cast(List[Tuple[int, str, str, str]], txn.fetchall())
  291. return await self.db_pool.runInteraction(
  292. "get_updated_room_account_data", get_updated_room_account_data_txn
  293. )
  294. async def get_updated_account_data_for_user(
  295. self, user_id: str, stream_id: int
  296. ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
  297. """Get all the client account_data for a that's changed for a user
  298. Args:
  299. user_id: The user to get the account_data for.
  300. stream_id: The point in the stream since which to get updates
  301. Returns:
  302. A deferred pair of a dict of global account_data and a dict
  303. mapping from room_id string to per room account_data dicts.
  304. """
  305. def get_updated_account_data_for_user_txn(
  306. txn: LoggingTransaction,
  307. ) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]:
  308. sql = (
  309. "SELECT account_data_type, content FROM account_data"
  310. " WHERE user_id = ? AND stream_id > ?"
  311. )
  312. txn.execute(sql, (user_id, stream_id))
  313. global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
  314. sql = (
  315. "SELECT room_id, account_data_type, content FROM room_account_data"
  316. " WHERE user_id = ? AND stream_id > ?"
  317. )
  318. txn.execute(sql, (user_id, stream_id))
  319. account_data_by_room: Dict[str, Dict[str, JsonDict]] = {}
  320. for row in txn:
  321. room_account_data = account_data_by_room.setdefault(row[0], {})
  322. room_account_data[row[1]] = db_to_json(row[2])
  323. return global_account_data, account_data_by_room
  324. changed = self._account_data_stream_cache.has_entity_changed(
  325. user_id, int(stream_id)
  326. )
  327. if not changed:
  328. return {}, {}
  329. return await self.db_pool.runInteraction(
  330. "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
  331. )
  332. @cached(max_entries=5000, iterable=True)
  333. async def ignored_by(self, user_id: str) -> FrozenSet[str]:
  334. """
  335. Get users which ignore the given user.
  336. Params:
  337. user_id: The user ID which might be ignored.
  338. Return:
  339. The user IDs which ignore the given user.
  340. """
  341. return frozenset(
  342. await self.db_pool.simple_select_onecol(
  343. table="ignored_users",
  344. keyvalues={"ignored_user_id": user_id},
  345. retcol="ignorer_user_id",
  346. desc="ignored_by",
  347. )
  348. )
  349. @cached(max_entries=5000, iterable=True)
  350. async def ignored_users(self, user_id: str) -> FrozenSet[str]:
  351. """
  352. Get users which the given user ignores.
  353. Params:
  354. user_id: The user ID which is making the request.
  355. Return:
  356. The user IDs which are ignored by the given user.
  357. """
  358. return frozenset(
  359. await self.db_pool.simple_select_onecol(
  360. table="ignored_users",
  361. keyvalues={"ignorer_user_id": user_id},
  362. retcol="ignored_user_id",
  363. desc="ignored_users",
  364. )
  365. )
  366. def process_replication_rows(
  367. self,
  368. stream_name: str,
  369. instance_name: str,
  370. token: int,
  371. rows: Iterable[Any],
  372. ) -> None:
  373. if stream_name == TagAccountDataStream.NAME:
  374. self._account_data_id_gen.advance(instance_name, token)
  375. elif stream_name == AccountDataStream.NAME:
  376. self._account_data_id_gen.advance(instance_name, token)
  377. for row in rows:
  378. if not row.room_id:
  379. self.get_global_account_data_by_type_for_user.invalidate(
  380. (row.user_id, row.data_type)
  381. )
  382. self.get_account_data_for_user.invalidate((row.user_id,))
  383. self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
  384. self.get_account_data_for_room_and_type.invalidate(
  385. (row.user_id, row.room_id, row.data_type)
  386. )
  387. self._account_data_stream_cache.entity_has_changed(row.user_id, token)
  388. super().process_replication_rows(stream_name, instance_name, token, rows)
  389. async def add_account_data_to_room(
  390. self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
  391. ) -> int:
  392. """Add some account_data to a room for a user.
  393. Args:
  394. user_id: The user to add a tag for.
  395. room_id: The room to add a tag for.
  396. account_data_type: The type of account_data to add.
  397. content: A json object to associate with the tag.
  398. Returns:
  399. The maximum stream ID.
  400. """
  401. assert self._can_write_to_account_data
  402. assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
  403. content_json = json_encoder.encode(content)
  404. async with self._account_data_id_gen.get_next() as next_id:
  405. await self.db_pool.simple_upsert(
  406. desc="add_room_account_data",
  407. table="room_account_data",
  408. keyvalues={
  409. "user_id": user_id,
  410. "room_id": room_id,
  411. "account_data_type": account_data_type,
  412. },
  413. values={"stream_id": next_id, "content": content_json},
  414. )
  415. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  416. self.get_account_data_for_user.invalidate((user_id,))
  417. self.get_account_data_for_room.invalidate((user_id, room_id))
  418. self.get_account_data_for_room_and_type.prefill(
  419. (user_id, room_id, account_data_type), content
  420. )
  421. return self._account_data_id_gen.get_current_token()
  422. async def remove_account_data_for_room(
  423. self, user_id: str, room_id: str, account_data_type: str
  424. ) -> Optional[int]:
  425. """Delete the room account data for the user of a given type.
  426. Args:
  427. user_id: The user to remove account_data for.
  428. room_id: The room ID to scope the request to.
  429. account_data_type: The account data type to delete.
  430. Returns:
  431. The maximum stream position, or None if there was no matching room account
  432. data to delete.
  433. """
  434. assert self._can_write_to_account_data
  435. assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
  436. def _remove_account_data_for_room_txn(
  437. txn: LoggingTransaction, next_id: int
  438. ) -> bool:
  439. """
  440. Args:
  441. txn: The transaction object.
  442. next_id: The stream_id to update any existing rows to.
  443. Returns:
  444. True if an entry in room_account_data had its content set to '{}',
  445. otherwise False. This informs callers of whether there actually was an
  446. existing room account data entry to delete, or if the call was a no-op.
  447. """
  448. # We can't use `simple_update` as it doesn't have the ability to specify
  449. # where clauses other than '=', which we need for `content != '{}'` below.
  450. sql = """
  451. UPDATE room_account_data
  452. SET stream_id = ?, content = '{}'
  453. WHERE user_id = ?
  454. AND room_id = ?
  455. AND account_data_type = ?
  456. AND content != '{}'
  457. """
  458. txn.execute(
  459. sql,
  460. (next_id, user_id, room_id, account_data_type),
  461. )
  462. # Return true if any rows were updated.
  463. return txn.rowcount != 0
  464. async with self._account_data_id_gen.get_next() as next_id:
  465. row_updated = await self.db_pool.runInteraction(
  466. "remove_account_data_for_room",
  467. _remove_account_data_for_room_txn,
  468. next_id,
  469. )
  470. if not row_updated:
  471. return None
  472. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  473. self.get_account_data_for_user.invalidate((user_id,))
  474. self.get_account_data_for_room.invalidate((user_id, room_id))
  475. self.get_account_data_for_room_and_type.prefill(
  476. (user_id, room_id, account_data_type), {}
  477. )
  478. return self._account_data_id_gen.get_current_token()
  479. async def add_account_data_for_user(
  480. self, user_id: str, account_data_type: str, content: JsonDict
  481. ) -> int:
  482. """Add some global account_data for a user.
  483. Args:
  484. user_id: The user to add a tag for.
  485. account_data_type: The type of account_data to add.
  486. content: A json object to associate with the tag.
  487. Returns:
  488. The maximum stream ID.
  489. """
  490. assert self._can_write_to_account_data
  491. assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
  492. async with self._account_data_id_gen.get_next() as next_id:
  493. await self.db_pool.runInteraction(
  494. "add_user_account_data",
  495. self._add_account_data_for_user,
  496. next_id,
  497. user_id,
  498. account_data_type,
  499. content,
  500. )
  501. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  502. self.get_account_data_for_user.invalidate((user_id,))
  503. self.get_global_account_data_by_type_for_user.invalidate(
  504. (user_id, account_data_type)
  505. )
  506. return self._account_data_id_gen.get_current_token()
  507. def _add_account_data_for_user(
  508. self,
  509. txn: LoggingTransaction,
  510. next_id: int,
  511. user_id: str,
  512. account_data_type: str,
  513. content: JsonDict,
  514. ) -> None:
  515. content_json = json_encoder.encode(content)
  516. self.db_pool.simple_upsert_txn(
  517. txn,
  518. table="account_data",
  519. keyvalues={"user_id": user_id, "account_data_type": account_data_type},
  520. values={"stream_id": next_id, "content": content_json},
  521. )
  522. # Ignored users get denormalized into a separate table as an optimisation.
  523. if account_data_type != AccountDataTypes.IGNORED_USER_LIST:
  524. return
  525. # Insert / delete to sync the list of ignored users.
  526. previously_ignored_users = set(
  527. self.db_pool.simple_select_onecol_txn(
  528. txn,
  529. table="ignored_users",
  530. keyvalues={"ignorer_user_id": user_id},
  531. retcol="ignored_user_id",
  532. )
  533. )
  534. # If the data is invalid, no one is ignored.
  535. ignored_users_content = content.get("ignored_users", {})
  536. if isinstance(ignored_users_content, dict):
  537. currently_ignored_users = set(ignored_users_content)
  538. else:
  539. currently_ignored_users = set()
  540. # If the data has not changed, nothing to do.
  541. if previously_ignored_users == currently_ignored_users:
  542. return
  543. # Delete entries which are no longer ignored.
  544. self.db_pool.simple_delete_many_txn(
  545. txn,
  546. table="ignored_users",
  547. column="ignored_user_id",
  548. values=previously_ignored_users - currently_ignored_users,
  549. keyvalues={"ignorer_user_id": user_id},
  550. )
  551. # Add entries which are newly ignored.
  552. self.db_pool.simple_insert_many_txn(
  553. txn,
  554. table="ignored_users",
  555. keys=("ignorer_user_id", "ignored_user_id"),
  556. values=[
  557. (user_id, u) for u in currently_ignored_users - previously_ignored_users
  558. ],
  559. )
  560. # Invalidate the cache for any ignored users which were added or removed.
  561. for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
  562. self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
  563. self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
  564. async def remove_account_data_for_user(
  565. self,
  566. user_id: str,
  567. account_data_type: str,
  568. ) -> Optional[int]:
  569. """
  570. Delete a single piece of user account data by type.
  571. A "delete" is performed by updating a potentially existing row in the
  572. "account_data" database table for (user_id, account_data_type) and
  573. setting its content to "{}".
  574. Args:
  575. user_id: The user ID to modify the account data of.
  576. account_data_type: The type to remove.
  577. Returns:
  578. The maximum stream position, or None if there was no matching account data
  579. to delete.
  580. """
  581. assert self._can_write_to_account_data
  582. assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator)
  583. def _remove_account_data_for_user_txn(
  584. txn: LoggingTransaction, next_id: int
  585. ) -> bool:
  586. """
  587. Args:
  588. txn: The transaction object.
  589. next_id: The stream_id to update any existing rows to.
  590. Returns:
  591. True if an entry in account_data had its content set to '{}', otherwise
  592. False. This informs callers of whether there actually was an existing
  593. account data entry to delete, or if the call was a no-op.
  594. """
  595. # We can't use `simple_update` as it doesn't have the ability to specify
  596. # where clauses other than '=', which we need for `content != '{}'` below.
  597. sql = """
  598. UPDATE account_data
  599. SET stream_id = ?, content = '{}'
  600. WHERE user_id = ?
  601. AND account_data_type = ?
  602. AND content != '{}'
  603. """
  604. txn.execute(sql, (next_id, user_id, account_data_type))
  605. if txn.rowcount == 0:
  606. # We didn't update any rows. This means that there was no matching room
  607. # account data entry to delete in the first place.
  608. return False
  609. # Ignored users get denormalized into a separate table as an optimisation.
  610. if account_data_type == AccountDataTypes.IGNORED_USER_LIST:
  611. # If this method was called with the ignored users account data type, we
  612. # simply delete all ignored users.
  613. # First pull all the users that this user ignores.
  614. previously_ignored_users = set(
  615. self.db_pool.simple_select_onecol_txn(
  616. txn,
  617. table="ignored_users",
  618. keyvalues={"ignorer_user_id": user_id},
  619. retcol="ignored_user_id",
  620. )
  621. )
  622. # Then delete them from the database.
  623. self.db_pool.simple_delete_txn(
  624. txn,
  625. table="ignored_users",
  626. keyvalues={"ignorer_user_id": user_id},
  627. )
  628. # Invalidate the cache for ignored users which were removed.
  629. for ignored_user_id in previously_ignored_users:
  630. self._invalidate_cache_and_stream(
  631. txn, self.ignored_by, (ignored_user_id,)
  632. )
  633. # Invalidate for this user the cache tracking ignored users.
  634. self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,))
  635. return True
  636. async with self._account_data_id_gen.get_next() as next_id:
  637. row_updated = await self.db_pool.runInteraction(
  638. "remove_account_data_for_user",
  639. _remove_account_data_for_user_txn,
  640. next_id,
  641. )
  642. if not row_updated:
  643. return None
  644. self._account_data_stream_cache.entity_has_changed(user_id, next_id)
  645. self.get_account_data_for_user.invalidate((user_id,))
  646. self.get_global_account_data_by_type_for_user.prefill(
  647. (user_id, account_data_type), {}
  648. )
  649. return self._account_data_id_gen.get_current_token()
  650. async def purge_account_data_for_user(self, user_id: str) -> None:
  651. """
  652. Removes ALL the account data for a user.
  653. Intended to be used upon user deactivation.
  654. Also purges the user from the ignored_users cache table
  655. and the push_rules cache tables.
  656. """
  657. await self.db_pool.runInteraction(
  658. "purge_account_data_for_user_txn",
  659. self._purge_account_data_for_user_txn,
  660. user_id,
  661. )
  662. def _purge_account_data_for_user_txn(
  663. self, txn: LoggingTransaction, user_id: str
  664. ) -> None:
  665. """
  666. See `purge_account_data_for_user`.
  667. """
  668. # Purge from the primary account_data tables.
  669. self.db_pool.simple_delete_txn(
  670. txn, table="account_data", keyvalues={"user_id": user_id}
  671. )
  672. self.db_pool.simple_delete_txn(
  673. txn, table="room_account_data", keyvalues={"user_id": user_id}
  674. )
  675. # Purge from ignored_users where this user is the ignorer.
  676. # N.B. We don't purge where this user is the ignoree, because that
  677. # interferes with other users' account data.
  678. # It's also not this user's data to delete!
  679. self.db_pool.simple_delete_txn(
  680. txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
  681. )
  682. # Remove the push rules
  683. self.db_pool.simple_delete_txn(
  684. txn, table="push_rules", keyvalues={"user_name": user_id}
  685. )
  686. self.db_pool.simple_delete_txn(
  687. txn, table="push_rules_enable", keyvalues={"user_name": user_id}
  688. )
  689. self.db_pool.simple_delete_txn(
  690. txn, table="push_rules_stream", keyvalues={"user_id": user_id}
  691. )
  692. # Invalidate caches as appropriate
  693. self._invalidate_cache_and_stream(
  694. txn, self.get_account_data_for_room_and_type, (user_id,)
  695. )
  696. self._invalidate_cache_and_stream(
  697. txn, self.get_account_data_for_user, (user_id,)
  698. )
  699. self._invalidate_cache_and_stream(
  700. txn, self.get_global_account_data_by_type_for_user, (user_id,)
  701. )
  702. self._invalidate_cache_and_stream(
  703. txn, self.get_account_data_for_room, (user_id,)
  704. )
  705. self._invalidate_cache_and_stream(txn, self.get_push_rules_for_user, (user_id,))
  706. # This user might be contained in the ignored_by cache for other users,
  707. # so we have to invalidate it all.
  708. self._invalidate_all_cache_and_stream(txn, self.ignored_by)
  709. async def _delete_account_data_for_deactivated_users(
  710. self, progress: dict, batch_size: int
  711. ) -> int:
  712. """
  713. Retroactively purges account data for users that have already been deactivated.
  714. Gets run as a background update caused by a schema delta.
  715. """
  716. last_user: str = progress.get("last_user", "")
  717. def _delete_account_data_for_deactivated_users_txn(
  718. txn: LoggingTransaction,
  719. ) -> int:
  720. sql = """
  721. SELECT name FROM users
  722. WHERE deactivated = ? and name > ?
  723. ORDER BY name ASC
  724. LIMIT ?
  725. """
  726. txn.execute(sql, (1, last_user, batch_size))
  727. users = [row[0] for row in txn]
  728. for user in users:
  729. self._purge_account_data_for_user_txn(txn, user_id=user)
  730. if users:
  731. self.db_pool.updates._background_update_progress_txn(
  732. txn,
  733. "delete_account_data_for_deactivated_users",
  734. {"last_user": users[-1]},
  735. )
  736. return len(users)
  737. number_deleted = await self.db_pool.runInteraction(
  738. "_delete_account_data_for_deactivated_users",
  739. _delete_account_data_for_deactivated_users_txn,
  740. )
  741. if number_deleted < batch_size:
  742. await self.db_pool.updates._end_background_update(
  743. "delete_account_data_for_deactivated_users"
  744. )
  745. return number_deleted
  746. class AccountDataStore(AccountDataWorkerStore):
  747. pass