filtering.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. # Copyright 2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Optional, Tuple, Union, cast
  16. from canonicaljson import encode_canonical_json
  17. from typing_extensions import TYPE_CHECKING
  18. from synapse.api.errors import Codes, StoreError, SynapseError
  19. from synapse.storage._base import SQLBaseStore, db_to_json
  20. from synapse.storage.database import (
  21. DatabasePool,
  22. LoggingDatabaseConnection,
  23. LoggingTransaction,
  24. )
  25. from synapse.storage.engines import PostgresEngine
  26. from synapse.types import JsonDict, UserID
  27. from synapse.util.caches.descriptors import cached
  28. if TYPE_CHECKING:
  29. from synapse.server import HomeServer
  30. class FilteringWorkerStore(SQLBaseStore):
  31. def __init__(
  32. self,
  33. database: DatabasePool,
  34. db_conn: LoggingDatabaseConnection,
  35. hs: "HomeServer",
  36. ):
  37. super().__init__(database, db_conn, hs)
  38. self.server_name: str = hs.hostname
  39. self.database_engine = database.engine
  40. self.db_pool.updates.register_background_index_update(
  41. "full_users_filters_unique_idx",
  42. index_name="full_users_unique_idx",
  43. table="user_filters",
  44. columns=["full_user_id, filter_id"],
  45. unique=True,
  46. )
  47. self.db_pool.updates.register_background_update_handler(
  48. "populate_full_user_id_user_filters",
  49. self.populate_full_user_id_user_filters,
  50. )
  51. async def populate_full_user_id_user_filters(
  52. self, progress: JsonDict, batch_size: int
  53. ) -> int:
  54. """
  55. Background update to populate the column `full_user_id` of the table
  56. user_filters from entries in the column `user_local_part` of the same table
  57. """
  58. lower_bound_id = progress.get("lower_bound_id", "")
  59. def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
  60. sql = """
  61. SELECT user_id FROM user_filters
  62. WHERE user_id > ?
  63. ORDER BY user_id
  64. LIMIT 1 OFFSET 1000
  65. """
  66. txn.execute(sql, (lower_bound_id,))
  67. res = txn.fetchone()
  68. if res:
  69. upper_bound_id = res[0]
  70. return upper_bound_id
  71. else:
  72. return None
  73. def _process_batch(
  74. txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
  75. ) -> None:
  76. sql = """
  77. UPDATE user_filters
  78. SET full_user_id = '@' || user_id || ?
  79. WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
  80. """
  81. txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
  82. def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
  83. sql = """
  84. UPDATE user_filters
  85. SET full_user_id = '@' || user_id || ?
  86. WHERE ? < user_id AND full_user_id IS NULL
  87. """
  88. txn.execute(
  89. sql,
  90. (
  91. f":{self.server_name}",
  92. lower_bound_id,
  93. ),
  94. )
  95. if isinstance(self.database_engine, PostgresEngine):
  96. sql = """
  97. ALTER TABLE user_filters VALIDATE CONSTRAINT full_user_id_not_null
  98. """
  99. txn.execute(sql)
  100. upper_bound_id = await self.db_pool.runInteraction(
  101. "populate_full_user_id_user_filters", _get_last_id
  102. )
  103. if upper_bound_id is None:
  104. await self.db_pool.runInteraction(
  105. "populate_full_user_id_user_filters", _final_batch, lower_bound_id
  106. )
  107. await self.db_pool.updates._end_background_update(
  108. "populate_full_user_id_user_filters"
  109. )
  110. return 1
  111. await self.db_pool.runInteraction(
  112. "populate_full_user_id_user_filters",
  113. _process_batch,
  114. lower_bound_id,
  115. upper_bound_id,
  116. )
  117. progress["lower_bound_id"] = upper_bound_id
  118. await self.db_pool.runInteraction(
  119. "populate_full_user_id_user_filters",
  120. self.db_pool.updates._background_update_progress_txn,
  121. "populate_full_user_id_user_filters",
  122. progress,
  123. )
  124. return 50
  125. @cached(num_args=2)
  126. async def get_user_filter(
  127. self, user_id: UserID, filter_id: Union[int, str]
  128. ) -> JsonDict:
  129. # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
  130. # with a coherent error message rather than 500 M_UNKNOWN.
  131. try:
  132. int(filter_id)
  133. except ValueError:
  134. raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
  135. def_json = await self.db_pool.simple_select_one_onecol(
  136. table="user_filters",
  137. keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id},
  138. retcol="filter_json",
  139. allow_none=False,
  140. desc="get_user_filter",
  141. )
  142. return db_to_json(def_json)
  143. async def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> int:
  144. def_json = encode_canonical_json(user_filter)
  145. # Need an atomic transaction to SELECT the maximal ID so far then
  146. # INSERT a new one
  147. def _do_txn(txn: LoggingTransaction) -> int:
  148. sql = (
  149. "SELECT filter_id FROM user_filters "
  150. "WHERE full_user_id = ? AND filter_json = ?"
  151. )
  152. txn.execute(sql, (user_id.to_string(), bytearray(def_json)))
  153. filter_id_response = txn.fetchone()
  154. if filter_id_response is not None:
  155. return filter_id_response[0]
  156. sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?"
  157. txn.execute(sql, (user_id.to_string(),))
  158. max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
  159. if max_id is None:
  160. filter_id = 0
  161. else:
  162. filter_id = max_id + 1
  163. sql = (
  164. "INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)"
  165. "VALUES(?, ?, ?, ?)"
  166. )
  167. txn.execute(
  168. sql,
  169. (
  170. user_id.to_string(),
  171. user_id.localpart,
  172. filter_id,
  173. bytearray(def_json),
  174. ),
  175. )
  176. return filter_id
  177. attempts = 0
  178. while True:
  179. # Try a few times.
  180. # This is technically needed if a user tries to create two filters at once,
  181. # leading to two concurrent transactions.
  182. # The failure case would be:
  183. # - SELECT filter_id ... filter_json = ? → both transactions return no rows
  184. # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
  185. # - INSERT INTO ... → both transactions insert filter_id = 6
  186. # One of the transactions will commit. The other will get a unique key
  187. # constraint violation error (IntegrityError). This is not the same as a
  188. # serialisability violation, which would be automatically retried by
  189. # `runInteraction`.
  190. try:
  191. return await self.db_pool.runInteraction("add_user_filter", _do_txn)
  192. except self.db_pool.engine.module.IntegrityError:
  193. attempts += 1
  194. if attempts >= 5:
  195. raise StoreError(500, "Couldn't generate a filter ID.")