user_directory.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2017 Vector Creations Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import re
  17. from typing import Any, Dict, Iterable, Optional, Set, Tuple
  18. from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules
  19. from synapse.storage.database import DatabasePool
  20. from synapse.storage.databases.main.state import StateFilter
  21. from synapse.storage.databases.main.state_deltas import StateDeltasStore
  22. from synapse.storage.engines import PostgresEngine, Sqlite3Engine
  23. from synapse.types import get_domain_from_id, get_localpart_from_id
  24. from synapse.util.caches.descriptors import cached
  25. logger = logging.getLogger(__name__)
  26. TEMP_TABLE = "_temp_populate_user_directory"
  27. class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
  28. # How many records do we calculate before sending it to
  29. # add_users_who_share_private_rooms?
  30. SHARE_PRIVATE_WORKING_SET = 500
  31. def __init__(self, database: DatabasePool, db_conn, hs):
  32. super().__init__(database, db_conn, hs)
  33. self.server_name = hs.hostname
  34. self.db_pool.updates.register_background_update_handler(
  35. "populate_user_directory_createtables",
  36. self._populate_user_directory_createtables,
  37. )
  38. self.db_pool.updates.register_background_update_handler(
  39. "populate_user_directory_process_rooms",
  40. self._populate_user_directory_process_rooms,
  41. )
  42. self.db_pool.updates.register_background_update_handler(
  43. "populate_user_directory_process_users",
  44. self._populate_user_directory_process_users,
  45. )
  46. self.db_pool.updates.register_background_update_handler(
  47. "populate_user_directory_cleanup", self._populate_user_directory_cleanup
  48. )
  49. async def _populate_user_directory_createtables(self, progress, batch_size):
  50. # Get all the rooms that we want to process.
  51. def _make_staging_area(txn):
  52. sql = (
  53. "CREATE TABLE IF NOT EXISTS "
  54. + TEMP_TABLE
  55. + "_rooms(room_id TEXT NOT NULL, events BIGINT NOT NULL)"
  56. )
  57. txn.execute(sql)
  58. sql = (
  59. "CREATE TABLE IF NOT EXISTS "
  60. + TEMP_TABLE
  61. + "_position(position TEXT NOT NULL)"
  62. )
  63. txn.execute(sql)
  64. # Get rooms we want to process from the database
  65. sql = """
  66. SELECT room_id, count(*) FROM current_state_events
  67. GROUP BY room_id
  68. """
  69. txn.execute(sql)
  70. rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
  71. self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
  72. del rooms
  73. # If search all users is on, get all the users we want to add.
  74. if self.hs.config.user_directory_search_all_users:
  75. sql = (
  76. "CREATE TABLE IF NOT EXISTS "
  77. + TEMP_TABLE
  78. + "_users(user_id TEXT NOT NULL)"
  79. )
  80. txn.execute(sql)
  81. txn.execute("SELECT name FROM users")
  82. users = [{"user_id": x[0]} for x in txn.fetchall()]
  83. self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
  84. new_pos = await self.get_max_stream_id_in_current_state_deltas()
  85. await self.db_pool.runInteraction(
  86. "populate_user_directory_temp_build", _make_staging_area
  87. )
  88. await self.db_pool.simple_insert(
  89. TEMP_TABLE + "_position", {"position": new_pos}
  90. )
  91. await self.db_pool.updates._end_background_update(
  92. "populate_user_directory_createtables"
  93. )
  94. return 1
  95. async def _populate_user_directory_cleanup(self, progress, batch_size):
  96. """
  97. Update the user directory stream position, then clean up the old tables.
  98. """
  99. position = await self.db_pool.simple_select_one_onecol(
  100. TEMP_TABLE + "_position", None, "position"
  101. )
  102. await self.update_user_directory_stream_pos(position)
  103. def _delete_staging_area(txn):
  104. txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
  105. txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
  106. txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
  107. await self.db_pool.runInteraction(
  108. "populate_user_directory_cleanup", _delete_staging_area
  109. )
  110. await self.db_pool.updates._end_background_update(
  111. "populate_user_directory_cleanup"
  112. )
  113. return 1
  114. async def _populate_user_directory_process_rooms(self, progress, batch_size):
  115. """
  116. Args:
  117. progress (dict)
  118. batch_size (int): Maximum number of state events to process
  119. per cycle.
  120. """
  121. state = self.hs.get_state_handler()
  122. # If we don't have progress filed, delete everything.
  123. if not progress:
  124. await self.delete_all_from_user_dir()
  125. def _get_next_batch(txn):
  126. # Only fetch 250 rooms, so we don't fetch too many at once, even
  127. # if those 250 rooms have less than batch_size state events.
  128. sql = """
  129. SELECT room_id, events FROM %s
  130. ORDER BY events DESC
  131. LIMIT 250
  132. """ % (
  133. TEMP_TABLE + "_rooms",
  134. )
  135. txn.execute(sql)
  136. rooms_to_work_on = txn.fetchall()
  137. if not rooms_to_work_on:
  138. return None
  139. # Get how many are left to process, so we can give status on how
  140. # far we are in processing
  141. txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
  142. progress["remaining"] = txn.fetchone()[0]
  143. return rooms_to_work_on
  144. rooms_to_work_on = await self.db_pool.runInteraction(
  145. "populate_user_directory_temp_read", _get_next_batch
  146. )
  147. # No more rooms -- complete the transaction.
  148. if not rooms_to_work_on:
  149. await self.db_pool.updates._end_background_update(
  150. "populate_user_directory_process_rooms"
  151. )
  152. return 1
  153. logger.debug(
  154. "Processing the next %d rooms of %d remaining"
  155. % (len(rooms_to_work_on), progress["remaining"])
  156. )
  157. processed_event_count = 0
  158. for room_id, event_count in rooms_to_work_on:
  159. is_in_room = await self.is_host_joined(room_id, self.server_name)
  160. if is_in_room:
  161. is_public = await self.is_room_world_readable_or_publicly_joinable(
  162. room_id
  163. )
  164. users_with_profile = await state.get_current_users_in_room(room_id)
  165. user_ids = set(users_with_profile)
  166. # Update each user in the user directory.
  167. for user_id, profile in users_with_profile.items():
  168. await self.update_profile_in_user_dir(
  169. user_id, profile.display_name, profile.avatar_url
  170. )
  171. to_insert = set()
  172. if is_public:
  173. for user_id in user_ids:
  174. if self.get_if_app_services_interested_in_user(user_id):
  175. continue
  176. to_insert.add(user_id)
  177. if to_insert:
  178. await self.add_users_in_public_rooms(room_id, to_insert)
  179. to_insert.clear()
  180. else:
  181. for user_id in user_ids:
  182. if not self.hs.is_mine_id(user_id):
  183. continue
  184. if self.get_if_app_services_interested_in_user(user_id):
  185. continue
  186. for other_user_id in user_ids:
  187. if user_id == other_user_id:
  188. continue
  189. user_set = (user_id, other_user_id)
  190. to_insert.add(user_set)
  191. # If it gets too big, stop and write to the database
  192. # to prevent storing too much in RAM.
  193. if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET:
  194. await self.add_users_who_share_private_room(
  195. room_id, to_insert
  196. )
  197. to_insert.clear()
  198. if to_insert:
  199. await self.add_users_who_share_private_room(room_id, to_insert)
  200. to_insert.clear()
  201. # We've finished a room. Delete it from the table.
  202. await self.db_pool.simple_delete_one(
  203. TEMP_TABLE + "_rooms", {"room_id": room_id}
  204. )
  205. # Update the remaining counter.
  206. progress["remaining"] -= 1
  207. await self.db_pool.runInteraction(
  208. "populate_user_directory",
  209. self.db_pool.updates._background_update_progress_txn,
  210. "populate_user_directory_process_rooms",
  211. progress,
  212. )
  213. processed_event_count += event_count
  214. if processed_event_count > batch_size:
  215. # Don't process any more rooms, we've hit our batch size.
  216. return processed_event_count
  217. return processed_event_count
  218. async def _populate_user_directory_process_users(self, progress, batch_size):
  219. """
  220. If search_all_users is enabled, add all of the users to the user directory.
  221. """
  222. if not self.hs.config.user_directory_search_all_users:
  223. await self.db_pool.updates._end_background_update(
  224. "populate_user_directory_process_users"
  225. )
  226. return 1
  227. def _get_next_batch(txn):
  228. sql = "SELECT user_id FROM %s LIMIT %s" % (
  229. TEMP_TABLE + "_users",
  230. str(batch_size),
  231. )
  232. txn.execute(sql)
  233. users_to_work_on = txn.fetchall()
  234. if not users_to_work_on:
  235. return None
  236. users_to_work_on = [x[0] for x in users_to_work_on]
  237. # Get how many are left to process, so we can give status on how
  238. # far we are in processing
  239. sql = "SELECT COUNT(*) FROM " + TEMP_TABLE + "_users"
  240. txn.execute(sql)
  241. progress["remaining"] = txn.fetchone()[0]
  242. return users_to_work_on
  243. users_to_work_on = await self.db_pool.runInteraction(
  244. "populate_user_directory_temp_read", _get_next_batch
  245. )
  246. # No more users -- complete the transaction.
  247. if not users_to_work_on:
  248. await self.db_pool.updates._end_background_update(
  249. "populate_user_directory_process_users"
  250. )
  251. return 1
  252. logger.debug(
  253. "Processing the next %d users of %d remaining"
  254. % (len(users_to_work_on), progress["remaining"])
  255. )
  256. for user_id in users_to_work_on:
  257. profile = await self.get_profileinfo(get_localpart_from_id(user_id))
  258. await self.update_profile_in_user_dir(
  259. user_id, profile.display_name, profile.avatar_url
  260. )
  261. # We've finished processing a user. Delete it from the table.
  262. await self.db_pool.simple_delete_one(
  263. TEMP_TABLE + "_users", {"user_id": user_id}
  264. )
  265. # Update the remaining counter.
  266. progress["remaining"] -= 1
  267. await self.db_pool.runInteraction(
  268. "populate_user_directory",
  269. self.db_pool.updates._background_update_progress_txn,
  270. "populate_user_directory_process_users",
  271. progress,
  272. )
  273. return len(users_to_work_on)
  274. async def is_room_world_readable_or_publicly_joinable(self, room_id):
  275. """Check if the room is either world_readable or publically joinable"""
  276. # Create a state filter that only queries join and history state event
  277. types_to_filter = (
  278. (EventTypes.JoinRules, ""),
  279. (EventTypes.RoomHistoryVisibility, ""),
  280. )
  281. current_state_ids = await self.get_filtered_current_state_ids(
  282. room_id, StateFilter.from_types(types_to_filter)
  283. )
  284. join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
  285. if join_rules_id:
  286. join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
  287. if join_rule_ev:
  288. if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
  289. return True
  290. hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
  291. if hist_vis_id:
  292. hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
  293. if hist_vis_ev:
  294. if (
  295. hist_vis_ev.content.get("history_visibility")
  296. == HistoryVisibility.WORLD_READABLE
  297. ):
  298. return True
  299. return False
  300. async def update_profile_in_user_dir(
  301. self, user_id: str, display_name: str, avatar_url: str
  302. ) -> None:
  303. """
  304. Update or add a user's profile in the user directory.
  305. """
  306. # If the display name or avatar URL are unexpected types, overwrite them.
  307. if not isinstance(display_name, str):
  308. display_name = None
  309. if not isinstance(avatar_url, str):
  310. avatar_url = None
  311. def _update_profile_in_user_dir_txn(txn):
  312. new_entry = self.db_pool.simple_upsert_txn(
  313. txn,
  314. table="user_directory",
  315. keyvalues={"user_id": user_id},
  316. values={"display_name": display_name, "avatar_url": avatar_url},
  317. lock=False, # We're only inserter
  318. )
  319. if isinstance(self.database_engine, PostgresEngine):
  320. # We weight the localpart most highly, then display name and finally
  321. # server name
  322. if self.database_engine.can_native_upsert:
  323. sql = """
  324. INSERT INTO user_directory_search(user_id, vector)
  325. VALUES (?,
  326. setweight(to_tsvector('simple', ?), 'A')
  327. || setweight(to_tsvector('simple', ?), 'D')
  328. || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
  329. ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
  330. """
  331. txn.execute(
  332. sql,
  333. (
  334. user_id,
  335. get_localpart_from_id(user_id),
  336. get_domain_from_id(user_id),
  337. display_name,
  338. ),
  339. )
  340. else:
  341. # TODO: Remove this code after we've bumped the minimum version
  342. # of postgres to always support upserts, so we can get rid of
  343. # `new_entry` usage
  344. if new_entry is True:
  345. sql = """
  346. INSERT INTO user_directory_search(user_id, vector)
  347. VALUES (?,
  348. setweight(to_tsvector('simple', ?), 'A')
  349. || setweight(to_tsvector('simple', ?), 'D')
  350. || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
  351. )
  352. """
  353. txn.execute(
  354. sql,
  355. (
  356. user_id,
  357. get_localpart_from_id(user_id),
  358. get_domain_from_id(user_id),
  359. display_name,
  360. ),
  361. )
  362. elif new_entry is False:
  363. sql = """
  364. UPDATE user_directory_search
  365. SET vector = setweight(to_tsvector('simple', ?), 'A')
  366. || setweight(to_tsvector('simple', ?), 'D')
  367. || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
  368. WHERE user_id = ?
  369. """
  370. txn.execute(
  371. sql,
  372. (
  373. get_localpart_from_id(user_id),
  374. get_domain_from_id(user_id),
  375. display_name,
  376. user_id,
  377. ),
  378. )
  379. else:
  380. raise RuntimeError(
  381. "upsert returned None when 'can_native_upsert' is False"
  382. )
  383. elif isinstance(self.database_engine, Sqlite3Engine):
  384. value = "%s %s" % (user_id, display_name) if display_name else user_id
  385. self.db_pool.simple_upsert_txn(
  386. txn,
  387. table="user_directory_search",
  388. keyvalues={"user_id": user_id},
  389. values={"value": value},
  390. lock=False, # We're only inserter
  391. )
  392. else:
  393. # This should be unreachable.
  394. raise Exception("Unrecognized database engine")
  395. txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
  396. await self.db_pool.runInteraction(
  397. "update_profile_in_user_dir", _update_profile_in_user_dir_txn
  398. )
  399. async def add_users_who_share_private_room(
  400. self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]
  401. ) -> None:
  402. """Insert entries into the users_who_share_private_rooms table. The first
  403. user should be a local user.
  404. Args:
  405. room_id
  406. user_id_tuples: iterable of 2-tuple of user IDs.
  407. """
  408. await self.db_pool.simple_upsert_many(
  409. table="users_who_share_private_rooms",
  410. key_names=["user_id", "other_user_id", "room_id"],
  411. key_values=[
  412. (user_id, other_user_id, room_id)
  413. for user_id, other_user_id in user_id_tuples
  414. ],
  415. value_names=(),
  416. value_values=None,
  417. desc="add_users_who_share_room",
  418. )
  419. async def add_users_in_public_rooms(
  420. self, room_id: str, user_ids: Iterable[str]
  421. ) -> None:
  422. """Insert entries into the users_in_public_rooms table.
  423. Args:
  424. room_id
  425. user_ids
  426. """
  427. await self.db_pool.simple_upsert_many(
  428. table="users_in_public_rooms",
  429. key_names=["user_id", "room_id"],
  430. key_values=[(user_id, room_id) for user_id in user_ids],
  431. value_names=(),
  432. value_values=None,
  433. desc="add_users_in_public_rooms",
  434. )
  435. async def delete_all_from_user_dir(self) -> None:
  436. """Delete the entire user directory"""
  437. def _delete_all_from_user_dir_txn(txn):
  438. txn.execute("DELETE FROM user_directory")
  439. txn.execute("DELETE FROM user_directory_search")
  440. txn.execute("DELETE FROM users_in_public_rooms")
  441. txn.execute("DELETE FROM users_who_share_private_rooms")
  442. txn.call_after(self.get_user_in_directory.invalidate_all)
  443. await self.db_pool.runInteraction(
  444. "delete_all_from_user_dir", _delete_all_from_user_dir_txn
  445. )
  446. @cached()
  447. async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
  448. return await self.db_pool.simple_select_one(
  449. table="user_directory",
  450. keyvalues={"user_id": user_id},
  451. retcols=("display_name", "avatar_url"),
  452. allow_none=True,
  453. desc="get_user_in_directory",
  454. )
  455. async def update_user_directory_stream_pos(self, stream_id: int) -> None:
  456. await self.db_pool.simple_update_one(
  457. table="user_directory_stream_pos",
  458. keyvalues={},
  459. updatevalues={"stream_id": stream_id},
  460. desc="update_user_directory_stream_pos",
  461. )
  462. class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
  463. # How many records do we calculate before sending it to
  464. # add_users_who_share_private_rooms?
  465. SHARE_PRIVATE_WORKING_SET = 500
  466. def __init__(self, database: DatabasePool, db_conn, hs):
  467. super().__init__(database, db_conn, hs)
  468. self._prefer_local_users_in_search = (
  469. hs.config.user_directory_search_prefer_local_users
  470. )
  471. self._server_name = hs.config.server_name
  472. async def remove_from_user_dir(self, user_id: str) -> None:
  473. def _remove_from_user_dir_txn(txn):
  474. self.db_pool.simple_delete_txn(
  475. txn, table="user_directory", keyvalues={"user_id": user_id}
  476. )
  477. self.db_pool.simple_delete_txn(
  478. txn, table="user_directory_search", keyvalues={"user_id": user_id}
  479. )
  480. self.db_pool.simple_delete_txn(
  481. txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
  482. )
  483. self.db_pool.simple_delete_txn(
  484. txn,
  485. table="users_who_share_private_rooms",
  486. keyvalues={"user_id": user_id},
  487. )
  488. self.db_pool.simple_delete_txn(
  489. txn,
  490. table="users_who_share_private_rooms",
  491. keyvalues={"other_user_id": user_id},
  492. )
  493. txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
  494. await self.db_pool.runInteraction(
  495. "remove_from_user_dir", _remove_from_user_dir_txn
  496. )
  497. async def get_users_in_dir_due_to_room(self, room_id):
  498. """Get all user_ids that are in the room directory because they're
  499. in the given room_id
  500. """
  501. user_ids_share_pub = await self.db_pool.simple_select_onecol(
  502. table="users_in_public_rooms",
  503. keyvalues={"room_id": room_id},
  504. retcol="user_id",
  505. desc="get_users_in_dir_due_to_room",
  506. )
  507. user_ids_share_priv = await self.db_pool.simple_select_onecol(
  508. table="users_who_share_private_rooms",
  509. keyvalues={"room_id": room_id},
  510. retcol="other_user_id",
  511. desc="get_users_in_dir_due_to_room",
  512. )
  513. user_ids = set(user_ids_share_pub)
  514. user_ids.update(user_ids_share_priv)
  515. return user_ids
  516. async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None:
  517. """
  518. Deletes entries in the users_who_share_*_rooms table. The first
  519. user should be a local user.
  520. Args:
  521. user_id
  522. room_id
  523. """
  524. def _remove_user_who_share_room_txn(txn):
  525. self.db_pool.simple_delete_txn(
  526. txn,
  527. table="users_who_share_private_rooms",
  528. keyvalues={"user_id": user_id, "room_id": room_id},
  529. )
  530. self.db_pool.simple_delete_txn(
  531. txn,
  532. table="users_who_share_private_rooms",
  533. keyvalues={"other_user_id": user_id, "room_id": room_id},
  534. )
  535. self.db_pool.simple_delete_txn(
  536. txn,
  537. table="users_in_public_rooms",
  538. keyvalues={"user_id": user_id, "room_id": room_id},
  539. )
  540. await self.db_pool.runInteraction(
  541. "remove_user_who_share_room", _remove_user_who_share_room_txn
  542. )
  543. async def get_user_dir_rooms_user_is_in(self, user_id):
  544. """
  545. Returns the rooms that a user is in.
  546. Args:
  547. user_id(str): Must be a local user
  548. Returns:
  549. list: user_id
  550. """
  551. rows = await self.db_pool.simple_select_onecol(
  552. table="users_who_share_private_rooms",
  553. keyvalues={"user_id": user_id},
  554. retcol="room_id",
  555. desc="get_rooms_user_is_in",
  556. )
  557. pub_rows = await self.db_pool.simple_select_onecol(
  558. table="users_in_public_rooms",
  559. keyvalues={"user_id": user_id},
  560. retcol="room_id",
  561. desc="get_rooms_user_is_in",
  562. )
  563. users = set(pub_rows)
  564. users.update(rows)
  565. return list(users)
  566. async def get_shared_rooms_for_users(
  567. self, user_id: str, other_user_id: str
  568. ) -> Set[str]:
  569. """
  570. Returns the rooms that a local user shares with another local or remote user.
  571. Args:
  572. user_id: The MXID of a local user
  573. other_user_id: The MXID of the other user
  574. Returns:
  575. A set of room ID's that the users share.
  576. """
  577. def _get_shared_rooms_for_users_txn(txn):
  578. txn.execute(
  579. """
  580. SELECT p1.room_id
  581. FROM users_in_public_rooms as p1
  582. INNER JOIN users_in_public_rooms as p2
  583. ON p1.room_id = p2.room_id
  584. AND p1.user_id = ?
  585. AND p2.user_id = ?
  586. UNION
  587. SELECT room_id
  588. FROM users_who_share_private_rooms
  589. WHERE
  590. user_id = ?
  591. AND other_user_id = ?
  592. """,
  593. (user_id, other_user_id, user_id, other_user_id),
  594. )
  595. rows = self.db_pool.cursor_to_dict(txn)
  596. return rows
  597. rows = await self.db_pool.runInteraction(
  598. "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn
  599. )
  600. return {row["room_id"] for row in rows}
  601. async def get_user_directory_stream_pos(self) -> Optional[int]:
  602. """
  603. Get the stream ID of the user directory stream.
  604. Returns:
  605. The stream token or None if the initial background update hasn't happened yet.
  606. """
  607. return await self.db_pool.simple_select_one_onecol(
  608. table="user_directory_stream_pos",
  609. keyvalues={},
  610. retcol="stream_id",
  611. desc="get_user_directory_stream_pos",
  612. )
  613. async def search_user_dir(self, user_id, search_term, limit):
  614. """Searches for users in directory
  615. Returns:
  616. dict of the form::
  617. {
  618. "limited": <bool>, # whether there were more results or not
  619. "results": [ # Ordered by best match first
  620. {
  621. "user_id": <user_id>,
  622. "display_name": <display_name>,
  623. "avatar_url": <avatar_url>
  624. }
  625. ]
  626. }
  627. """
  628. if self.hs.config.user_directory_search_all_users:
  629. join_args = (user_id,)
  630. where_clause = "user_id != ?"
  631. else:
  632. join_args = (user_id,)
  633. where_clause = """
  634. (
  635. EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id)
  636. OR EXISTS (
  637. SELECT 1 FROM users_who_share_private_rooms
  638. WHERE user_id = ? AND other_user_id = t.user_id
  639. )
  640. )
  641. """
  642. # We allow manipulating the ranking algorithm by injecting statements
  643. # based on config options.
  644. additional_ordering_statements = []
  645. ordering_arguments = ()
  646. if isinstance(self.database_engine, PostgresEngine):
  647. full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
  648. # If enabled, this config option will rank local users higher than those on
  649. # remote instances.
  650. if self._prefer_local_users_in_search:
  651. # This statement checks whether a given user's user ID contains a server name
  652. # that matches the local server
  653. statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)"
  654. additional_ordering_statements.append(statement)
  655. ordering_arguments += ("%:" + self._server_name,)
  656. # We order by rank and then if they have profile info
  657. # The ranking algorithm is hand tweaked for "best" results. Broadly
  658. # the idea is we give a higher weight to exact matches.
  659. # The array of numbers are the weights for the various part of the
  660. # search: (domain, _, display name, localpart)
  661. sql = """
  662. SELECT d.user_id AS user_id, display_name, avatar_url
  663. FROM user_directory_search as t
  664. INNER JOIN user_directory AS d USING (user_id)
  665. WHERE
  666. %(where_clause)s
  667. AND vector @@ to_tsquery('simple', ?)
  668. ORDER BY
  669. (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
  670. * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
  671. * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
  672. * (
  673. 3 * ts_rank_cd(
  674. '{0.1, 0.1, 0.9, 1.0}',
  675. vector,
  676. to_tsquery('simple', ?),
  677. 8
  678. )
  679. + ts_rank_cd(
  680. '{0.1, 0.1, 0.9, 1.0}',
  681. vector,
  682. to_tsquery('simple', ?),
  683. 8
  684. )
  685. )
  686. %(order_case_statements)s
  687. DESC,
  688. display_name IS NULL,
  689. avatar_url IS NULL
  690. LIMIT ?
  691. """ % {
  692. "where_clause": where_clause,
  693. "order_case_statements": " ".join(additional_ordering_statements),
  694. }
  695. args = (
  696. join_args
  697. + (full_query, exact_query, prefix_query)
  698. + ordering_arguments
  699. + (limit + 1,)
  700. )
  701. elif isinstance(self.database_engine, Sqlite3Engine):
  702. search_query = _parse_query_sqlite(search_term)
  703. # If enabled, this config option will rank local users higher than those on
  704. # remote instances.
  705. if self._prefer_local_users_in_search:
  706. # This statement checks whether a given user's user ID contains a server name
  707. # that matches the local server
  708. #
  709. # Note that we need to include a comma at the end for valid SQL
  710. statement = "user_id LIKE ? DESC,"
  711. additional_ordering_statements.append(statement)
  712. ordering_arguments += ("%:" + self._server_name,)
  713. sql = """
  714. SELECT d.user_id AS user_id, display_name, avatar_url
  715. FROM user_directory_search as t
  716. INNER JOIN user_directory AS d USING (user_id)
  717. WHERE
  718. %(where_clause)s
  719. AND value MATCH ?
  720. ORDER BY
  721. rank(matchinfo(user_directory_search)) DESC,
  722. %(order_statements)s
  723. display_name IS NULL,
  724. avatar_url IS NULL
  725. LIMIT ?
  726. """ % {
  727. "where_clause": where_clause,
  728. "order_statements": " ".join(additional_ordering_statements),
  729. }
  730. args = join_args + (search_query,) + ordering_arguments + (limit + 1,)
  731. else:
  732. # This should be unreachable.
  733. raise Exception("Unrecognized database engine")
  734. results = await self.db_pool.execute(
  735. "search_user_dir", self.db_pool.cursor_to_dict, sql, *args
  736. )
  737. limited = len(results) > limit
  738. return {"limited": limited, "results": results}
  739. def _parse_query_sqlite(search_term):
  740. """Takes a plain unicode string from the user and converts it into a form
  741. that can be passed to database.
  742. We use this so that we can add prefix matching, which isn't something
  743. that is supported by default.
  744. We specifically add both a prefix and non prefix matching term so that
  745. exact matches get ranked higher.
  746. """
  747. # Pull out the individual words, discarding any non-word characters.
  748. results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
  749. return " & ".join("(%s* OR %s)" % (result, result) for result in results)
  750. def _parse_query_postgres(search_term):
  751. """Takes a plain unicode string from the user and converts it into a form
  752. that can be passed to database.
  753. We use this so that we can add prefix matching, which isn't something
  754. that is supported by default.
  755. """
  756. # Pull out the individual words, discarding any non-word characters.
  757. results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
  758. both = " & ".join("(%s:* | %s)" % (result, result) for result in results)
  759. exact = " & ".join("%s" % (result,) for result in results)
  760. prefix = " & ".join("%s:*" % (result,) for result in results)
  761. return both, exact, prefix