123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226 |
- # Copyright 2014-2016 OpenMarket Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import TYPE_CHECKING, Optional
- from synapse.api.errors import StoreError
- from synapse.storage._base import SQLBaseStore
- from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
- )
- from synapse.storage.databases.main.roommember import ProfileInfo
- from synapse.storage.engines import PostgresEngine
- from synapse.types import JsonDict, UserID
- if TYPE_CHECKING:
- from synapse.server import HomeServer
- class ProfileWorkerStore(SQLBaseStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
- self.server_name: str = hs.hostname
- self.database_engine = database.engine
- self.db_pool.updates.register_background_index_update(
- "profiles_full_user_id_key_idx",
- index_name="profiles_full_user_id_key",
- table="profiles",
- columns=["full_user_id"],
- unique=True,
- )
- self.db_pool.updates.register_background_update_handler(
- "populate_full_user_id_profiles", self.populate_full_user_id_profiles
- )
- async def populate_full_user_id_profiles(
- self, progress: JsonDict, batch_size: int
- ) -> int:
- """
- Background update to populate the column `full_user_id` of the table
- profiles from entries in the column `user_local_part` of the same table
- """
- lower_bound_id = progress.get("lower_bound_id", "")
- def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
- sql = """
- SELECT user_id FROM profiles
- WHERE user_id > ?
- ORDER BY user_id
- LIMIT 1 OFFSET 1000
- """
- txn.execute(sql, (lower_bound_id,))
- res = txn.fetchone()
- if res:
- upper_bound_id = res[0]
- return upper_bound_id
- else:
- return None
- def _process_batch(
- txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
- ) -> None:
- sql = """
- UPDATE profiles
- SET full_user_id = '@' || user_id || ?
- WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
- """
- txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
- def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
- sql = """
- UPDATE profiles
- SET full_user_id = '@' || user_id || ?
- WHERE ? < user_id AND full_user_id IS NULL
- """
- txn.execute(
- sql,
- (
- f":{self.server_name}",
- lower_bound_id,
- ),
- )
- if isinstance(self.database_engine, PostgresEngine):
- sql = """
- ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null
- """
- txn.execute(sql)
- upper_bound_id = await self.db_pool.runInteraction(
- "populate_full_user_id_profiles", _get_last_id
- )
- if upper_bound_id is None:
- await self.db_pool.runInteraction(
- "populate_full_user_id_profiles", _final_batch, lower_bound_id
- )
- await self.db_pool.updates._end_background_update(
- "populate_full_user_id_profiles"
- )
- return 1
- await self.db_pool.runInteraction(
- "populate_full_user_id_profiles",
- _process_batch,
- lower_bound_id,
- upper_bound_id,
- )
- progress["lower_bound_id"] = upper_bound_id
- await self.db_pool.runInteraction(
- "populate_full_user_id_profiles",
- self.db_pool.updates._background_update_progress_txn,
- "populate_full_user_id_profiles",
- progress,
- )
- return 50
- async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
- try:
- profile = await self.db_pool.simple_select_one(
- table="profiles",
- keyvalues={"full_user_id": user_id.to_string()},
- retcols=("displayname", "avatar_url"),
- desc="get_profileinfo",
- )
- except StoreError as e:
- if e.code == 404:
- # no match
- return ProfileInfo(None, None)
- else:
- raise
- return ProfileInfo(
- avatar_url=profile["avatar_url"], display_name=profile["displayname"]
- )
- async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
- return await self.db_pool.simple_select_one_onecol(
- table="profiles",
- keyvalues={"full_user_id": user_id.to_string()},
- retcol="displayname",
- desc="get_profile_displayname",
- )
- async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]:
- return await self.db_pool.simple_select_one_onecol(
- table="profiles",
- keyvalues={"full_user_id": user_id.to_string()},
- retcol="avatar_url",
- desc="get_profile_avatar_url",
- )
- async def create_profile(self, user_id: UserID) -> None:
- user_localpart = user_id.localpart
- await self.db_pool.simple_insert(
- table="profiles",
- values={"user_id": user_localpart, "full_user_id": user_id.to_string()},
- desc="create_profile",
- )
- async def set_profile_displayname(
- self, user_id: UserID, new_displayname: Optional[str]
- ) -> None:
- """
- Set the display name of a user.
- Args:
- user_id: The user's ID.
- new_displayname: The new display name. If this is None, the user's display
- name is removed.
- """
- user_localpart = user_id.localpart
- await self.db_pool.simple_upsert(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- values={
- "displayname": new_displayname,
- "full_user_id": user_id.to_string(),
- },
- desc="set_profile_displayname",
- )
- async def set_profile_avatar_url(
- self, user_id: UserID, new_avatar_url: Optional[str]
- ) -> None:
- """
- Set the avatar of a user.
- Args:
- user_id: The user's ID.
- new_avatar_url: The new avatar URL. If this is None, the user's avatar is
- removed.
- """
- user_localpart = user_id.localpart
- await self.db_pool.simple_upsert(
- table="profiles",
- keyvalues={"user_id": user_localpart},
- values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
- desc="set_profile_avatar_url",
- )
- class ProfileStore(ProfileWorkerStore):
- pass
|