123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164 |
- # Copyright 2017 Vector Creations Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- import re
- import unicodedata
- from typing import (
- TYPE_CHECKING,
- Iterable,
- List,
- Mapping,
- Optional,
- Sequence,
- Set,
- Tuple,
- cast,
- )
- try:
- # Figure out if ICU support is available for searching users.
- import icu
- USE_ICU = True
- except ModuleNotFoundError:
- USE_ICU = False
- from typing_extensions import TypedDict
- from synapse.api.errors import StoreError
- from synapse.util.stringutils import non_null_str_or_none
- if TYPE_CHECKING:
- from synapse.server import HomeServer
- from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
- from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
- )
- from synapse.storage.databases.main.state import StateFilter
- from synapse.storage.databases.main.state_deltas import StateDeltasStore
- from synapse.storage.engines import PostgresEngine, Sqlite3Engine
- from synapse.types import (
- JsonDict,
- UserID,
- UserProfile,
- get_domain_from_id,
- get_localpart_from_id,
- )
- from synapse.util.caches.descriptors import cached
- logger = logging.getLogger(__name__)
- TEMP_TABLE = "_temp_populate_user_directory"
- class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
- # How many records do we calculate before sending it to
- # add_users_who_share_private_rooms?
- SHARE_PRIVATE_WORKING_SET = 500
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ) -> None:
- super().__init__(database, db_conn, hs)
- self.server_name: str = hs.hostname
- self.db_pool.updates.register_background_update_handler(
- "populate_user_directory_createtables",
- self._populate_user_directory_createtables,
- )
- self.db_pool.updates.register_background_update_handler(
- "populate_user_directory_process_rooms",
- self._populate_user_directory_process_rooms,
- )
- self.db_pool.updates.register_background_update_handler(
- "populate_user_directory_process_users",
- self._populate_user_directory_process_users,
- )
- self.db_pool.updates.register_background_update_handler(
- "populate_user_directory_cleanup", self._populate_user_directory_cleanup
- )
- async def _populate_user_directory_createtables(
- self, progress: JsonDict, batch_size: int
- ) -> int:
- # Get all the rooms that we want to process.
- def _make_staging_area(txn: LoggingTransaction) -> None:
- sql = (
- "CREATE TABLE IF NOT EXISTS "
- + TEMP_TABLE
- + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)"
- )
- txn.execute(sql)
- sql = (
- "CREATE TABLE IF NOT EXISTS "
- + TEMP_TABLE
- + "_position(position TEXT NOT NULL)"
- )
- txn.execute(sql)
- # Get rooms we want to process from the database
- sql = """
- SELECT room_id, count(*) FROM current_state_events
- GROUP BY room_id
- """
- txn.execute(sql)
- rooms = list(txn.fetchall())
- self.db_pool.simple_insert_many_txn(
- txn, TEMP_TABLE + "_rooms", keys=("room_id", "events"), values=rooms
- )
- del rooms
- sql = (
- "CREATE TABLE IF NOT EXISTS "
- + TEMP_TABLE
- + "_users(user_id TEXT NOT NULL)"
- )
- txn.execute(sql)
- txn.execute("SELECT name FROM users")
- users = list(txn.fetchall())
- self.db_pool.simple_insert_many_txn(
- txn, TEMP_TABLE + "_users", keys=("user_id",), values=users
- )
- new_pos = await self.get_max_stream_id_in_current_state_deltas()
- await self.db_pool.runInteraction(
- "populate_user_directory_temp_build", _make_staging_area
- )
- await self.db_pool.simple_insert(
- TEMP_TABLE + "_position", {"position": new_pos}
- )
- await self.db_pool.updates._end_background_update(
- "populate_user_directory_createtables"
- )
- return 1
- async def _populate_user_directory_cleanup(
- self,
- progress: JsonDict,
- batch_size: int,
- ) -> int:
- """
- Update the user directory stream position, then clean up the old tables.
- """
- position = await self.db_pool.simple_select_one_onecol(
- TEMP_TABLE + "_position", {}, "position"
- )
- await self.update_user_directory_stream_pos(position)
- def _delete_staging_area(txn: LoggingTransaction) -> None:
- txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
- txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
- txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
- await self.db_pool.runInteraction(
- "populate_user_directory_cleanup", _delete_staging_area
- )
- await self.db_pool.updates._end_background_update(
- "populate_user_directory_cleanup"
- )
- return 1
- async def _populate_user_directory_process_rooms(
- self, progress: JsonDict, batch_size: int
- ) -> int:
- """
- Rescan the state of all rooms so we can track
- - who's in a public room;
- - which local users share a private room with other users (local
- and remote); and
- - who should be in the user_directory.
- Args:
- progress
- batch_size: Maximum number of state events to process per cycle.
- Returns:
- number of events processed.
- """
- # If we don't have progress filed, delete everything.
- if not progress:
- await self.delete_all_from_user_dir()
- def _get_next_batch(
- txn: LoggingTransaction,
- ) -> Optional[Sequence[Tuple[str, int]]]:
- # Only fetch 250 rooms, so we don't fetch too many at once, even
- # if those 250 rooms have less than batch_size state events.
- sql = """
- SELECT room_id, events FROM %s
- ORDER BY events DESC
- LIMIT 250
- """ % (
- TEMP_TABLE + "_rooms",
- )
- txn.execute(sql)
- rooms_to_work_on = cast(List[Tuple[str, int]], txn.fetchall())
- if not rooms_to_work_on:
- return None
- # Get how many are left to process, so we can give status on how
- # far we are in processing
- txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
- result = txn.fetchone()
- assert result is not None
- progress["remaining"] = result[0]
- return rooms_to_work_on
- rooms_to_work_on = await self.db_pool.runInteraction(
- "populate_user_directory_temp_read", _get_next_batch
- )
- # No more rooms -- complete the transaction.
- if not rooms_to_work_on:
- await self.db_pool.updates._end_background_update(
- "populate_user_directory_process_rooms"
- )
- return 1
- logger.debug(
- "Processing the next %d rooms of %d remaining"
- % (len(rooms_to_work_on), progress["remaining"])
- )
- processed_event_count = 0
- for room_id, event_count in rooms_to_work_on:
- is_in_room = await self.is_host_joined(room_id, self.server_name) # type: ignore[attr-defined]
- if is_in_room:
- users_with_profile = await self.get_users_in_room_with_profiles(room_id) # type: ignore[attr-defined]
- # Throw away users excluded from the directory.
- users_with_profile = {
- user_id: profile
- for user_id, profile in users_with_profile.items()
- if not self.hs.is_mine_id(user_id)
- or await self.should_include_local_user_in_dir(user_id)
- }
- # Upsert a user_directory record for each remote user we see.
- for user_id, profile in users_with_profile.items():
- # Local users are processed separately in
- # `_populate_user_directory_users`; there we can read from
- # the `profiles` table to ensure we don't leak their per-room
- # profiles. It also means we write local users to this table
- # exactly once, rather than once for every room they're in.
- if self.hs.is_mine_id(user_id):
- continue
- # TODO `users_with_profile` above reads from the `user_directory`
- # table, meaning that `profile` is bespoke to this room.
- # and this leaks remote users' per-room profiles to the user directory.
- await self.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url
- )
- # Now update the room sharing tables to include this room.
- is_public = await self.is_room_world_readable_or_publicly_joinable(
- room_id
- )
- if is_public:
- if users_with_profile:
- await self.add_users_in_public_rooms(
- room_id, users_with_profile.keys()
- )
- else:
- to_insert = set()
- for user_id in users_with_profile:
- # We want the set of pairs (L, M) where L and M are
- # in `users_with_profile` and L is local.
- # Do so by looking for the local user L first.
- if not self.hs.is_mine_id(user_id):
- continue
- for other_user_id in users_with_profile:
- if user_id == other_user_id:
- continue
- user_set = (user_id, other_user_id)
- to_insert.add(user_set)
- # If it gets too big, stop and write to the database
- # to prevent storing too much in RAM.
- if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET:
- await self.add_users_who_share_private_room(
- room_id, to_insert
- )
- to_insert.clear()
- if to_insert:
- await self.add_users_who_share_private_room(room_id, to_insert)
- to_insert.clear()
- # We've finished a room. Delete it from the table.
- await self.db_pool.simple_delete_one(
- TEMP_TABLE + "_rooms", {"room_id": room_id}
- )
- # Update the remaining counter.
- progress["remaining"] -= 1
- await self.db_pool.runInteraction(
- "populate_user_directory",
- self.db_pool.updates._background_update_progress_txn,
- "populate_user_directory_process_rooms",
- progress,
- )
- processed_event_count += event_count
- if processed_event_count > batch_size:
- # Don't process any more rooms, we've hit our batch size.
- return processed_event_count
- return processed_event_count
- async def _populate_user_directory_process_users(
- self, progress: JsonDict, batch_size: int
- ) -> int:
- """
- Add all local users to the user directory.
- """
- def _get_next_batch(txn: LoggingTransaction) -> Optional[List[str]]:
- sql = "SELECT user_id FROM %s LIMIT %s" % (
- TEMP_TABLE + "_users",
- str(batch_size),
- )
- txn.execute(sql)
- user_result = cast(List[Tuple[str]], txn.fetchall())
- if not user_result:
- return None
- users_to_work_on = [x[0] for x in user_result]
- # Get how many are left to process, so we can give status on how
- # far we are in processing
- sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
- txn.execute(sql)
- count_result = txn.fetchone()
- assert count_result is not None
- progress["remaining"] = count_result[0]
- return users_to_work_on
- users_to_work_on = await self.db_pool.runInteraction(
- "populate_user_directory_temp_read", _get_next_batch
- )
- # No more users -- complete the transaction.
- if not users_to_work_on:
- await self.db_pool.updates._end_background_update(
- "populate_user_directory_process_users"
- )
- return 1
- logger.debug(
- "Processing the next %d users of %d remaining"
- % (len(users_to_work_on), progress["remaining"])
- )
- for user_id in users_to_work_on:
- if await self.should_include_local_user_in_dir(user_id):
- profile = await self.get_profileinfo(get_localpart_from_id(user_id)) # type: ignore[attr-defined]
- await self.update_profile_in_user_dir(
- user_id, profile.display_name, profile.avatar_url
- )
- # We've finished processing a user. Delete it from the table.
- await self.db_pool.simple_delete_one(
- TEMP_TABLE + "_users", {"user_id": user_id}
- )
- # Update the remaining counter.
- progress["remaining"] -= 1
- await self.db_pool.runInteraction(
- "populate_user_directory",
- self.db_pool.updates._background_update_progress_txn,
- "populate_user_directory_process_users",
- progress,
- )
- return len(users_to_work_on)
- async def should_include_local_user_in_dir(self, user: str) -> bool:
- """Certain classes of local user are omitted from the user directory.
- Is this user one of them?
- """
- # We're opting to exclude the appservice sender (user defined by the
- # `sender_localpart` in the appservice registration) even though
- # technically it could be DM-able. In the future, this could potentially
- # be configurable per-appservice whether the appservice sender can be
- # contacted.
- if self.get_app_service_by_user_id(user) is not None: # type: ignore[attr-defined]
- return False
- # We're opting to exclude appservice users (anyone matching the user
- # namespace regex in the appservice registration) even though technically
- # they could be DM-able. In the future, this could potentially
- # be configurable per-appservice whether the appservice users can be
- # contacted.
- if self.get_if_app_services_interested_in_user(user): # type: ignore[attr-defined]
- # TODO we might want to make this configurable for each app service
- return False
- # Support users are for diagnostics and should not appear in the user directory.
- if await self.is_support_user(user): # type: ignore[attr-defined]
- return False
- # Deactivated users aren't contactable, so should not appear in the user directory.
- try:
- if await self.get_user_deactivated_status(user): # type: ignore[attr-defined]
- return False
- except StoreError:
- # No such user in the users table. No need to do this when calling
- # is_support_user---that returns False if the user is missing.
- return False
- return True
- async def is_room_world_readable_or_publicly_joinable(self, room_id: str) -> bool:
- """Check if the room is either world_readable or publically joinable"""
- # Create a state filter that only queries join and history state event
- types_to_filter = (
- (EventTypes.JoinRules, ""),
- (EventTypes.RoomHistoryVisibility, ""),
- )
- # Getting the partial state is fine, as we're not looking at membership
- # events.
- current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined]
- room_id, StateFilter.from_types(types_to_filter)
- )
- join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
- if join_rules_id:
- join_rule_ev = await self.get_event(join_rules_id, allow_none=True) # type: ignore[attr-defined]
- if join_rule_ev:
- if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
- return True
- hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
- if hist_vis_id:
- hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True) # type: ignore[attr-defined]
- if hist_vis_ev:
- if (
- hist_vis_ev.content.get("history_visibility")
- == HistoryVisibility.WORLD_READABLE
- ):
- return True
- return False
- async def set_remote_user_profile_in_user_dir_stale(
- self, user_id: str, next_try_at_ms: int, retry_counter: int
- ) -> None:
- """
- Marks a remote user as having a possibly-stale user directory profile.
- Args:
- user_id: the remote user who may have a stale profile on this server.
- next_try_at_ms: timestamp in ms after which the user directory profile can be
- refreshed.
- retry_counter: number of failures in refreshing the profile so far. Used for
- exponential backoff calculations.
- """
- assert not self.hs.is_mine_id(
- user_id
- ), "Can't mark a local user as a stale remote user."
- server_name = UserID.from_string(user_id).domain
- await self.db_pool.simple_upsert(
- table="user_directory_stale_remote_users",
- keyvalues={"user_id": user_id},
- values={
- "next_try_at_ts": next_try_at_ms,
- "retry_counter": retry_counter,
- "user_server_name": server_name,
- },
- desc="set_remote_user_profile_in_user_dir_stale",
- )
- async def clear_remote_user_profile_in_user_dir_stale(self, user_id: str) -> None:
- """
- Marks a remote user as no longer having a possibly-stale user directory profile.
- Args:
- user_id: the remote user who no longer has a stale profile on this server.
- """
- await self.db_pool.simple_delete(
- table="user_directory_stale_remote_users",
- keyvalues={"user_id": user_id},
- desc="clear_remote_user_profile_in_user_dir_stale",
- )
- async def get_remote_servers_with_profiles_to_refresh(
- self, now_ts: int, limit: int
- ) -> List[str]:
- """
- Get a list of up to `limit` server names which have users whose
- locally-cached profiles we believe to be stale
- and are refreshable given the current time `now_ts` in milliseconds.
- """
- def _get_remote_servers_with_refreshable_profiles_txn(
- txn: LoggingTransaction,
- ) -> List[str]:
- sql = """
- SELECT user_server_name
- FROM user_directory_stale_remote_users
- WHERE next_try_at_ts < ?
- GROUP BY user_server_name
- ORDER BY MIN(next_try_at_ts), user_server_name
- LIMIT ?
- """
- txn.execute(sql, (now_ts, limit))
- return [row[0] for row in txn]
- return await self.db_pool.runInteraction(
- "get_remote_servers_with_profiles_to_refresh",
- _get_remote_servers_with_refreshable_profiles_txn,
- )
- async def get_remote_users_to_refresh_on_server(
- self, server_name: str, now_ts: int, limit: int
- ) -> List[Tuple[str, int, int]]:
- """
- Get a list of up to `limit` user IDs from the server `server_name`
- whose locally-cached profiles we believe to be stale
- and are refreshable given the current time `now_ts` in milliseconds.
- Returns:
- tuple of:
- - User ID
- - Retry counter (number of failures so far)
- - Time the retry is scheduled for, in milliseconds
- """
- def _get_remote_users_to_refresh_on_server_txn(
- txn: LoggingTransaction,
- ) -> List[Tuple[str, int, int]]:
- sql = """
- SELECT user_id, retry_counter, next_try_at_ts
- FROM user_directory_stale_remote_users
- WHERE user_server_name = ? AND next_try_at_ts < ?
- ORDER BY next_try_at_ts
- LIMIT ?
- """
- txn.execute(sql, (server_name, now_ts, limit))
- return cast(List[Tuple[str, int, int]], txn.fetchall())
- return await self.db_pool.runInteraction(
- "get_remote_users_to_refresh_on_server",
- _get_remote_users_to_refresh_on_server_txn,
- )
- async def update_profile_in_user_dir(
- self, user_id: str, display_name: Optional[str], avatar_url: Optional[str]
- ) -> None:
- """
- Update or add a user's profile in the user directory.
- If the user is remote, the profile will be marked as not stale.
- """
- # If the display name or avatar URL are unexpected types, replace with None.
- display_name = non_null_str_or_none(display_name)
- avatar_url = non_null_str_or_none(avatar_url)
- def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_upsert_txn(
- txn,
- table="user_directory",
- keyvalues={"user_id": user_id},
- values={"display_name": display_name, "avatar_url": avatar_url},
- )
- if not self.hs.is_mine_id(user_id):
- # Remote users: Make sure the profile is not marked as stale anymore.
- self.db_pool.simple_delete_txn(
- txn,
- table="user_directory_stale_remote_users",
- keyvalues={"user_id": user_id},
- )
- # The display name that goes into the database index.
- index_display_name = display_name
- if index_display_name is not None:
- index_display_name = _filter_text_for_index(index_display_name)
- if isinstance(self.database_engine, PostgresEngine):
- # We weight the localpart most highly, then display name and finally
- # server name
- sql = """
- INSERT INTO user_directory_search(user_id, vector)
- VALUES (?,
- setweight(to_tsvector('simple', ?), 'A')
- || setweight(to_tsvector('simple', ?), 'D')
- || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
- ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
- """
- txn.execute(
- sql,
- (
- user_id,
- get_localpart_from_id(user_id),
- get_domain_from_id(user_id),
- index_display_name,
- ),
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- value = (
- "%s %s" % (user_id, index_display_name)
- if index_display_name
- else user_id
- )
- self.db_pool.simple_upsert_txn(
- txn,
- table="user_directory_search",
- keyvalues={"user_id": user_id},
- values={"value": value},
- )
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
- txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- await self.db_pool.runInteraction(
- "update_profile_in_user_dir", _update_profile_in_user_dir_txn
- )
- async def add_users_who_share_private_room(
- self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
- ) -> None:
- """Insert entries into the users_who_share_private_rooms table. The first
- user should be a local user.
- Args:
- room_id
- user_id_tuples: iterable of 2-tuple of user IDs.
- """
- await self.db_pool.simple_upsert_many(
- table="users_who_share_private_rooms",
- key_names=["user_id", "other_user_id", "room_id"],
- key_values=[
- (user_id, other_user_id, room_id)
- for user_id, other_user_id in user_id_tuples
- ],
- value_names=(),
- value_values=(),
- desc="add_users_who_share_room",
- )
- async def add_users_in_public_rooms(
- self, room_id: str, user_ids: Iterable[str]
- ) -> None:
- """Insert entries into the users_in_public_rooms table.
- Args:
- room_id
- user_ids
- """
- await self.db_pool.simple_upsert_many(
- table="users_in_public_rooms",
- key_names=["user_id", "room_id"],
- key_values=[(user_id, room_id) for user_id in user_ids],
- value_names=(),
- value_values=(),
- desc="add_users_in_public_rooms",
- )
- async def delete_all_from_user_dir(self) -> None:
- """Delete the entire user directory"""
- def _delete_all_from_user_dir_txn(txn: LoggingTransaction) -> None:
- # SQLite doesn't support TRUNCATE.
- # On Postgres, DELETE FROM does a table scan but TRUNCATE is more efficient.
- truncate = (
- "DELETE FROM"
- if isinstance(self.database_engine, Sqlite3Engine)
- else "TRUNCATE"
- )
- txn.execute(f"{truncate} user_directory")
- txn.execute(f"{truncate} user_directory_search")
- txn.execute(f"{truncate} users_in_public_rooms")
- txn.execute(f"{truncate} users_who_share_private_rooms")
- txn.call_after(self.get_user_in_directory.invalidate_all)
- await self.db_pool.runInteraction(
- "delete_all_from_user_dir", _delete_all_from_user_dir_txn
- )
- @cached()
- async def get_user_in_directory(self, user_id: str) -> Optional[Mapping[str, str]]:
- return await self.db_pool.simple_select_one(
- table="user_directory",
- keyvalues={"user_id": user_id},
- retcols=("display_name", "avatar_url"),
- allow_none=True,
- desc="get_user_in_directory",
- )
- async def update_user_directory_stream_pos(self, stream_id: Optional[int]) -> None:
- await self.db_pool.simple_update_one(
- table="user_directory_stream_pos",
- keyvalues={},
- updatevalues={"stream_id": stream_id},
- desc="update_user_directory_stream_pos",
- )
- class SearchResult(TypedDict):
- limited: bool
- results: List[UserProfile]
- class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
- # How many records do we calculate before sending it to
- # add_users_who_share_private_rooms?
- SHARE_PRIVATE_WORKING_SET = 500
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ) -> None:
- super().__init__(database, db_conn, hs)
- self._prefer_local_users_in_search = (
- hs.config.userdirectory.user_directory_search_prefer_local_users
- )
- self._server_name = hs.config.server.server_name
- async def remove_from_user_dir(self, user_id: str) -> None:
- def _remove_from_user_dir_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_delete_txn(
- txn, table="user_directory", keyvalues={"user_id": user_id}
- )
- self.db_pool.simple_delete_txn(
- txn, table="user_directory_search", keyvalues={"user_id": user_id}
- )
- self.db_pool.simple_delete_txn(
- txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"user_id": user_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"other_user_id": user_id},
- )
- txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- await self.db_pool.runInteraction(
- "remove_from_user_dir", _remove_from_user_dir_txn
- )
- async def get_users_in_dir_due_to_room(self, room_id: str) -> Set[str]:
- """Get all user_ids that are in the room directory because they're
- in the given room_id
- """
- user_ids_share_pub = await self.db_pool.simple_select_onecol(
- table="users_in_public_rooms",
- keyvalues={"room_id": room_id},
- retcol="user_id",
- desc="get_users_in_dir_due_to_room",
- )
- user_ids_share_priv = await self.db_pool.simple_select_onecol(
- table="users_who_share_private_rooms",
- keyvalues={"room_id": room_id},
- retcol="other_user_id",
- desc="get_users_in_dir_due_to_room",
- )
- user_ids = set(user_ids_share_pub)
- user_ids.update(user_ids_share_priv)
- return user_ids
- async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
- """
- Deletes entries in the users_who_share_*_rooms table. The first
- user should be a local user.
- Args:
- user_id
- room_id
- """
- def _remove_user_who_share_room_txn(txn: LoggingTransaction) -> None:
- self.db_pool.simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"user_id": user_id, "room_id": room_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"other_user_id": user_id, "room_id": room_id},
- )
- self.db_pool.simple_delete_txn(
- txn,
- table="users_in_public_rooms",
- keyvalues={"user_id": user_id, "room_id": room_id},
- )
- await self.db_pool.runInteraction(
- "remove_user_who_share_room", _remove_user_who_share_room_txn
- )
- async def get_user_dir_rooms_user_is_in(self, user_id: str) -> List[str]:
- """
- Returns the rooms that a user is in.
- Args:
- user_id: Must be a local user
- Returns:
- List of room IDs
- """
- rows = await self.db_pool.simple_select_onecol(
- table="users_who_share_private_rooms",
- keyvalues={"user_id": user_id},
- retcol="room_id",
- desc="get_rooms_user_is_in",
- )
- pub_rows = await self.db_pool.simple_select_onecol(
- table="users_in_public_rooms",
- keyvalues={"user_id": user_id},
- retcol="room_id",
- desc="get_rooms_user_is_in",
- )
- users = set(pub_rows)
- users.update(rows)
- return list(users)
- async def get_user_directory_stream_pos(self) -> Optional[int]:
- """
- Get the stream ID of the user directory stream.
- Returns:
- The stream token or None if the initial background update hasn't happened yet.
- """
- return await self.db_pool.simple_select_one_onecol(
- table="user_directory_stream_pos",
- keyvalues={},
- retcol="stream_id",
- desc="get_user_directory_stream_pos",
- )
- async def search_user_dir(
- self, user_id: str, search_term: str, limit: int
- ) -> SearchResult:
- """Searches for users in directory
- Returns:
- dict of the form::
- {
- "limited": <bool>, # whether there were more results or not
- "results": [ # Ordered by best match first
- {
- "user_id": <user_id>,
- "display_name": <display_name>,
- "avatar_url": <avatar_url>
- }
- ]
- }
- """
- if self.hs.config.userdirectory.user_directory_search_all_users:
- join_args = (user_id,)
- where_clause = "user_id != ?"
- else:
- join_args = (user_id,)
- where_clause = """
- (
- EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id)
- OR EXISTS (
- SELECT 1 FROM users_who_share_private_rooms
- WHERE user_id = ? AND other_user_id = t.user_id
- )
- )
- """
- # We allow manipulating the ranking algorithm by injecting statements
- # based on config options.
- additional_ordering_statements = []
- ordering_arguments: Tuple[str, ...] = ()
- if isinstance(self.database_engine, PostgresEngine):
- full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
- # If enabled, this config option will rank local users higher than those on
- # remote instances.
- if self._prefer_local_users_in_search:
- # This statement checks whether a given user's user ID contains a server name
- # that matches the local server
- statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)"
- additional_ordering_statements.append(statement)
- ordering_arguments += ("%:" + self._server_name,)
- # We order by rank and then if they have profile info
- # The ranking algorithm is hand tweaked for "best" results. Broadly
- # the idea is we give a higher weight to exact matches.
- # The array of numbers are the weights for the various part of the
- # search: (domain, _, display name, localpart)
- sql = """
- SELECT d.user_id AS user_id, display_name, avatar_url
- FROM user_directory_search as t
- INNER JOIN user_directory AS d USING (user_id)
- WHERE
- %(where_clause)s
- AND vector @@ to_tsquery('simple', ?)
- ORDER BY
- (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
- * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
- * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
- * (
- 3 * ts_rank_cd(
- '{0.1, 0.1, 0.9, 1.0}',
- vector,
- to_tsquery('simple', ?),
- 8
- )
- + ts_rank_cd(
- '{0.1, 0.1, 0.9, 1.0}',
- vector,
- to_tsquery('simple', ?),
- 8
- )
- )
- %(order_case_statements)s
- DESC,
- display_name IS NULL,
- avatar_url IS NULL
- LIMIT ?
- """ % {
- "where_clause": where_clause,
- "order_case_statements": " ".join(additional_ordering_statements),
- }
- args = (
- join_args
- + (full_query, exact_query, prefix_query)
- + ordering_arguments
- + (limit + 1,)
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- search_query = _parse_query_sqlite(search_term)
- # If enabled, this config option will rank local users higher than those on
- # remote instances.
- if self._prefer_local_users_in_search:
- # This statement checks whether a given user's user ID contains a server name
- # that matches the local server
- #
- # Note that we need to include a comma at the end for valid SQL
- statement = "user_id LIKE ? DESC,"
- additional_ordering_statements.append(statement)
- ordering_arguments += ("%:" + self._server_name,)
- sql = """
- SELECT d.user_id AS user_id, display_name, avatar_url
- FROM user_directory_search as t
- INNER JOIN user_directory AS d USING (user_id)
- WHERE
- %(where_clause)s
- AND value MATCH ?
- ORDER BY
- rank(matchinfo(user_directory_search)) DESC,
- %(order_statements)s
- display_name IS NULL,
- avatar_url IS NULL
- LIMIT ?
- """ % {
- "where_clause": where_clause,
- "order_statements": " ".join(additional_ordering_statements),
- }
- args = join_args + (search_query,) + ordering_arguments + (limit + 1,)
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
- results = cast(
- List[UserProfile],
- await self.db_pool.execute(
- "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
- ),
- )
- limited = len(results) > limit
- return {"limited": limited, "results": results[0:limit]}
- def _filter_text_for_index(text: str) -> str:
- """Transforms text before it is inserted into the user directory index, or searched
- for in the user directory index.
- Note that the user directory search table needs to be rebuilt whenever this function
- changes.
- """
- # Lowercase the text, to make searches case-insensitive.
- # This is necessary for both PostgreSQL and SQLite. PostgreSQL's
- # `to_tsquery/to_tsvector` functions don't lowercase non-ASCII characters when using
- # the "C" collation, while SQLite just doesn't lowercase non-ASCII characters at
- # all.
- text = text.lower()
- # Normalize the text. NFKC normalization has two effects:
- # 1. It canonicalizes the text, ie. maps all visually identical strings to the same
- # string. For example, ["e", "◌́"] is mapped to ["é"].
- # 2. It maps strings that are roughly equivalent to the same string.
- # For example, ["dž"] is mapped to ["d", "ž"], ["①"] to ["1"] and ["i⁹"] to
- # ["i", "9"].
- text = unicodedata.normalize("NFKC", text)
- # Note that nothing is done to make searches accent-insensitive.
- # That could be achieved by converting to NFKD form instead (with combining accents
- # split out) and filtering out combining accents using `unicodedata.combining(c)`.
- # The downside of this may be noisier search results, since search terms with
- # explicit accents will match characters with no accents, or completely different
- # accents.
- #
- # text = unicodedata.normalize("NFKD", text)
- # text = "".join([c for c in text if not unicodedata.combining(c)])
- return text
- def _parse_query_sqlite(search_term: str) -> str:
- """Takes a plain unicode string from the user and converts it into a form
- that can be passed to database.
- We use this so that we can add prefix matching, which isn't something
- that is supported by default.
- We specifically add both a prefix and non prefix matching term so that
- exact matches get ranked higher.
- """
- search_term = _filter_text_for_index(search_term)
- # Pull out the individual words, discarding any non-word characters.
- results = _parse_words(search_term)
- return " & ".join("(%s* OR %s)" % (result, result) for result in results)
- def _parse_query_postgres(search_term: str) -> Tuple[str, str, str]:
- """Takes a plain unicode string from the user and converts it into a form
- that can be passed to database.
- We use this so that we can add prefix matching, which isn't something
- that is supported by default.
- """
- search_term = _filter_text_for_index(search_term)
- escaped_words = []
- for word in _parse_words(search_term):
- # Postgres tsvector and tsquery quoting rules:
- # words potentially containing punctuation should be quoted
- # and then existing quotes and backslashes should be doubled
- # See: https://www.postgresql.org/docs/current/datatype-textsearch.html#DATATYPE-TSQUERY
- quoted_word = word.replace("'", "''").replace("\\", "\\\\")
- escaped_words.append(f"'{quoted_word}'")
- both = " & ".join("(%s:* | %s)" % (word, word) for word in escaped_words)
- exact = " & ".join("%s" % (word,) for word in escaped_words)
- prefix = " & ".join("%s:*" % (word,) for word in escaped_words)
- return both, exact, prefix
- def _parse_words(search_term: str) -> List[str]:
- """Split the provided search string into a list of its words.
- If support for ICU (International Components for Unicode) is available, use it.
- Otherwise, fall back to using a regex to detect word boundaries. This latter
- solution works well enough for most latin-based languages, but doesn't work as well
- with other languages.
- Args:
- search_term: The search string.
- Returns:
- A list of the words in the search string.
- """
- if USE_ICU:
- return _parse_words_with_icu(search_term)
- return _parse_words_with_regex(search_term)
- def _parse_words_with_regex(search_term: str) -> List[str]:
- """
- Break down search term into words, when we don't have ICU available.
- See: `_parse_words`
- """
- return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
- def _parse_words_with_icu(search_term: str) -> List[str]:
- """Break down the provided search string into its individual words using ICU
- (International Components for Unicode).
- Args:
- search_term: The search string.
- Returns:
- A list of the words in the search string.
- """
- results = []
- breaker = icu.BreakIterator.createWordInstance(icu.Locale.getDefault())
- breaker.setText(search_term)
- i = 0
- while True:
- j = breaker.nextBoundary()
- if j < 0:
- break
- result = search_term[i:j]
- # libicu considers spaces and punctuation between words as words, but we don't
- # want to include those in results as they would result in syntax errors in SQL
- # queries (e.g. "foo bar" would result in the search query including "foo & &
- # bar").
- if len(re.findall(r"([\w\-]+)", result, re.UNICODE)):
- results.append(result)
- i = j
- return results
|