123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- # Copyright 2015, 2016 OpenMarket Ltd
- # Copyright 2021 The Matrix.org Foundation C.I.C.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Optional, Tuple, Union, cast
- from canonicaljson import encode_canonical_json
- from typing_extensions import TYPE_CHECKING
- from synapse.api.errors import Codes, StoreError, SynapseError
- from synapse.storage._base import SQLBaseStore, db_to_json
- from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
- )
- from synapse.storage.engines import PostgresEngine
- from synapse.types import JsonDict, UserID
- from synapse.util.caches.descriptors import cached
- if TYPE_CHECKING:
- from synapse.server import HomeServer
- class FilteringWorkerStore(SQLBaseStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
- self.server_name: str = hs.hostname
- self.database_engine = database.engine
- self.db_pool.updates.register_background_index_update(
- "full_users_filters_unique_idx",
- index_name="full_users_unique_idx",
- table="user_filters",
- columns=["full_user_id, filter_id"],
- unique=True,
- )
- self.db_pool.updates.register_background_update_handler(
- "populate_full_user_id_user_filters",
- self.populate_full_user_id_user_filters,
- )
- async def populate_full_user_id_user_filters(
- self, progress: JsonDict, batch_size: int
- ) -> int:
- """
- Background update to populate the column `full_user_id` of the table
- user_filters from entries in the column `user_local_part` of the same table
- """
- lower_bound_id = progress.get("lower_bound_id", "")
- def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
- sql = """
- SELECT user_id FROM user_filters
- WHERE user_id > ?
- ORDER BY user_id
- LIMIT 1 OFFSET 1000
- """
- txn.execute(sql, (lower_bound_id,))
- res = txn.fetchone()
- if res:
- upper_bound_id = res[0]
- return upper_bound_id
- else:
- return None
- def _process_batch(
- txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
- ) -> None:
- sql = """
- UPDATE user_filters
- SET full_user_id = '@' || user_id || ?
- WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
- """
- txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
- def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
- sql = """
- UPDATE user_filters
- SET full_user_id = '@' || user_id || ?
- WHERE ? < user_id AND full_user_id IS NULL
- """
- txn.execute(
- sql,
- (
- f":{self.server_name}",
- lower_bound_id,
- ),
- )
- if isinstance(self.database_engine, PostgresEngine):
- sql = """
- ALTER TABLE user_filters VALIDATE CONSTRAINT full_user_id_not_null
- """
- txn.execute(sql)
- upper_bound_id = await self.db_pool.runInteraction(
- "populate_full_user_id_user_filters", _get_last_id
- )
- if upper_bound_id is None:
- await self.db_pool.runInteraction(
- "populate_full_user_id_user_filters", _final_batch, lower_bound_id
- )
- await self.db_pool.updates._end_background_update(
- "populate_full_user_id_user_filters"
- )
- return 1
- await self.db_pool.runInteraction(
- "populate_full_user_id_user_filters",
- _process_batch,
- lower_bound_id,
- upper_bound_id,
- )
- progress["lower_bound_id"] = upper_bound_id
- await self.db_pool.runInteraction(
- "populate_full_user_id_user_filters",
- self.db_pool.updates._background_update_progress_txn,
- "populate_full_user_id_user_filters",
- progress,
- )
- return 50
- @cached(num_args=2)
- async def get_user_filter(
- self, user_id: UserID, filter_id: Union[int, str]
- ) -> JsonDict:
- # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
- # with a coherent error message rather than 500 M_UNKNOWN.
- try:
- int(filter_id)
- except ValueError:
- raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
- def_json = await self.db_pool.simple_select_one_onecol(
- table="user_filters",
- keyvalues={"full_user_id": user_id.to_string(), "filter_id": filter_id},
- retcol="filter_json",
- allow_none=False,
- desc="get_user_filter",
- )
- return db_to_json(def_json)
- async def add_user_filter(self, user_id: UserID, user_filter: JsonDict) -> int:
- def_json = encode_canonical_json(user_filter)
- # Need an atomic transaction to SELECT the maximal ID so far then
- # INSERT a new one
- def _do_txn(txn: LoggingTransaction) -> int:
- sql = (
- "SELECT filter_id FROM user_filters "
- "WHERE full_user_id = ? AND filter_json = ?"
- )
- txn.execute(sql, (user_id.to_string(), bytearray(def_json)))
- filter_id_response = txn.fetchone()
- if filter_id_response is not None:
- return filter_id_response[0]
- sql = "SELECT MAX(filter_id) FROM user_filters WHERE full_user_id = ?"
- txn.execute(sql, (user_id.to_string(),))
- max_id = cast(Tuple[Optional[int]], txn.fetchone())[0]
- if max_id is None:
- filter_id = 0
- else:
- filter_id = max_id + 1
- sql = (
- "INSERT INTO user_filters (full_user_id, user_id, filter_id, filter_json)"
- "VALUES(?, ?, ?, ?)"
- )
- txn.execute(
- sql,
- (
- user_id.to_string(),
- user_id.localpart,
- filter_id,
- bytearray(def_json),
- ),
- )
- return filter_id
- attempts = 0
- while True:
- # Try a few times.
- # This is technically needed if a user tries to create two filters at once,
- # leading to two concurrent transactions.
- # The failure case would be:
- # - SELECT filter_id ... filter_json = ? → both transactions return no rows
- # - SELECT MAX(filter_id) ... → both transactions return e.g. 5
- # - INSERT INTO ... → both transactions insert filter_id = 6
- # One of the transactions will commit. The other will get a unique key
- # constraint violation error (IntegrityError). This is not the same as a
- # serialisability violation, which would be automatically retried by
- # `runInteraction`.
- try:
- return await self.db_pool.runInteraction("add_user_filter", _do_txn)
- except self.db_pool.engine.module.IntegrityError:
- attempts += 1
- if attempts >= 5:
- raise StoreError(500, "Couldn't generate a filter ID.")
|