profile.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Optional
  15. from synapse.api.errors import StoreError
  16. from synapse.storage._base import SQLBaseStore
  17. from synapse.storage.database import (
  18. DatabasePool,
  19. LoggingDatabaseConnection,
  20. LoggingTransaction,
  21. )
  22. from synapse.storage.databases.main.roommember import ProfileInfo
  23. from synapse.storage.engines import PostgresEngine
  24. from synapse.types import JsonDict, UserID
  25. if TYPE_CHECKING:
  26. from synapse.server import HomeServer
  27. class ProfileWorkerStore(SQLBaseStore):
  28. def __init__(
  29. self,
  30. database: DatabasePool,
  31. db_conn: LoggingDatabaseConnection,
  32. hs: "HomeServer",
  33. ):
  34. super().__init__(database, db_conn, hs)
  35. self.server_name: str = hs.hostname
  36. self.database_engine = database.engine
  37. self.db_pool.updates.register_background_index_update(
  38. "profiles_full_user_id_key_idx",
  39. index_name="profiles_full_user_id_key",
  40. table="profiles",
  41. columns=["full_user_id"],
  42. unique=True,
  43. )
  44. self.db_pool.updates.register_background_update_handler(
  45. "populate_full_user_id_profiles", self.populate_full_user_id_profiles
  46. )
  47. async def populate_full_user_id_profiles(
  48. self, progress: JsonDict, batch_size: int
  49. ) -> int:
  50. """
  51. Background update to populate the column `full_user_id` of the table
  52. profiles from entries in the column `user_local_part` of the same table
  53. """
  54. lower_bound_id = progress.get("lower_bound_id", "")
  55. def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
  56. sql = """
  57. SELECT user_id FROM profiles
  58. WHERE user_id > ?
  59. ORDER BY user_id
  60. LIMIT 1 OFFSET 1000
  61. """
  62. txn.execute(sql, (lower_bound_id,))
  63. res = txn.fetchone()
  64. if res:
  65. upper_bound_id = res[0]
  66. return upper_bound_id
  67. else:
  68. return None
  69. def _process_batch(
  70. txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
  71. ) -> None:
  72. sql = """
  73. UPDATE profiles
  74. SET full_user_id = '@' || user_id || ?
  75. WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
  76. """
  77. txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
  78. def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
  79. sql = """
  80. UPDATE profiles
  81. SET full_user_id = '@' || user_id || ?
  82. WHERE ? < user_id AND full_user_id IS NULL
  83. """
  84. txn.execute(
  85. sql,
  86. (
  87. f":{self.server_name}",
  88. lower_bound_id,
  89. ),
  90. )
  91. if isinstance(self.database_engine, PostgresEngine):
  92. sql = """
  93. ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null
  94. """
  95. txn.execute(sql)
  96. upper_bound_id = await self.db_pool.runInteraction(
  97. "populate_full_user_id_profiles", _get_last_id
  98. )
  99. if upper_bound_id is None:
  100. await self.db_pool.runInteraction(
  101. "populate_full_user_id_profiles", _final_batch, lower_bound_id
  102. )
  103. await self.db_pool.updates._end_background_update(
  104. "populate_full_user_id_profiles"
  105. )
  106. return 1
  107. await self.db_pool.runInteraction(
  108. "populate_full_user_id_profiles",
  109. _process_batch,
  110. lower_bound_id,
  111. upper_bound_id,
  112. )
  113. progress["lower_bound_id"] = upper_bound_id
  114. await self.db_pool.runInteraction(
  115. "populate_full_user_id_profiles",
  116. self.db_pool.updates._background_update_progress_txn,
  117. "populate_full_user_id_profiles",
  118. progress,
  119. )
  120. return 50
  121. async def get_profileinfo(self, user_id: UserID) -> ProfileInfo:
  122. try:
  123. profile = await self.db_pool.simple_select_one(
  124. table="profiles",
  125. keyvalues={"full_user_id": user_id.to_string()},
  126. retcols=("displayname", "avatar_url"),
  127. desc="get_profileinfo",
  128. )
  129. except StoreError as e:
  130. if e.code == 404:
  131. # no match
  132. return ProfileInfo(None, None)
  133. else:
  134. raise
  135. return ProfileInfo(
  136. avatar_url=profile["avatar_url"], display_name=profile["displayname"]
  137. )
  138. async def get_profile_displayname(self, user_id: UserID) -> Optional[str]:
  139. return await self.db_pool.simple_select_one_onecol(
  140. table="profiles",
  141. keyvalues={"full_user_id": user_id.to_string()},
  142. retcol="displayname",
  143. desc="get_profile_displayname",
  144. )
  145. async def get_profile_avatar_url(self, user_id: UserID) -> Optional[str]:
  146. return await self.db_pool.simple_select_one_onecol(
  147. table="profiles",
  148. keyvalues={"full_user_id": user_id.to_string()},
  149. retcol="avatar_url",
  150. desc="get_profile_avatar_url",
  151. )
  152. async def create_profile(self, user_id: UserID) -> None:
  153. user_localpart = user_id.localpart
  154. await self.db_pool.simple_insert(
  155. table="profiles",
  156. values={"user_id": user_localpart, "full_user_id": user_id.to_string()},
  157. desc="create_profile",
  158. )
  159. async def set_profile_displayname(
  160. self, user_id: UserID, new_displayname: Optional[str]
  161. ) -> None:
  162. """
  163. Set the display name of a user.
  164. Args:
  165. user_id: The user's ID.
  166. new_displayname: The new display name. If this is None, the user's display
  167. name is removed.
  168. """
  169. user_localpart = user_id.localpart
  170. await self.db_pool.simple_upsert(
  171. table="profiles",
  172. keyvalues={"user_id": user_localpart},
  173. values={
  174. "displayname": new_displayname,
  175. "full_user_id": user_id.to_string(),
  176. },
  177. desc="set_profile_displayname",
  178. )
  179. async def set_profile_avatar_url(
  180. self, user_id: UserID, new_avatar_url: Optional[str]
  181. ) -> None:
  182. """
  183. Set the avatar of a user.
  184. Args:
  185. user_id: The user's ID.
  186. new_avatar_url: The new avatar URL. If this is None, the user's avatar is
  187. removed.
  188. """
  189. user_localpart = user_id.localpart
  190. await self.db_pool.simple_upsert(
  191. table="profiles",
  192. keyvalues={"user_id": user_localpart},
  193. values={"avatar_url": new_avatar_url, "full_user_id": user_id.to_string()},
  194. desc="set_profile_avatar_url",
  195. )
  196. class ProfileStore(ProfileWorkerStore):
  197. pass