account_data.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. # Copyright 2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import random
  17. from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple
  18. from synapse.replication.http.account_data import (
  19. ReplicationAddTagRestServlet,
  20. ReplicationRemoveTagRestServlet,
  21. ReplicationRoomAccountDataRestServlet,
  22. ReplicationUserAccountDataRestServlet,
  23. )
  24. from synapse.streams import EventSource
  25. from synapse.types import JsonDict, StreamKeyType, UserID
  26. if TYPE_CHECKING:
  27. from synapse.server import HomeServer
  28. logger = logging.getLogger(__name__)
  29. ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
  30. [str, Optional[str], str, JsonDict], Awaitable
  31. ]
  32. class AccountDataHandler:
  33. def __init__(self, hs: "HomeServer"):
  34. self._store = hs.get_datastores().main
  35. self._instance_name = hs.get_instance_name()
  36. self._notifier = hs.get_notifier()
  37. self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs)
  38. self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs)
  39. self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs)
  40. self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
  41. self._account_data_writers = hs.config.worker.writers.account_data
  42. self._on_account_data_updated_callbacks: List[
  43. ON_ACCOUNT_DATA_UPDATED_CALLBACK
  44. ] = []
  45. def register_module_callbacks(
  46. self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
  47. ) -> None:
  48. """Register callbacks from modules."""
  49. if on_account_data_updated is not None:
  50. self._on_account_data_updated_callbacks.append(on_account_data_updated)
  51. async def _notify_modules(
  52. self,
  53. user_id: str,
  54. room_id: Optional[str],
  55. account_data_type: str,
  56. content: JsonDict,
  57. ) -> None:
  58. """Notifies modules about new account data changes.
  59. A change can be either a new account data type being added, or the content
  60. associated with a type being changed. Account data for a given type is removed by
  61. changing the associated content to an empty dictionary.
  62. Note that this is not called when the tags associated with a room change.
  63. Args:
  64. user_id: The user whose account data is changing.
  65. room_id: The ID of the room the account data change concerns, if any.
  66. account_data_type: The type of the account data.
  67. content: The content that is now associated with this type.
  68. """
  69. for callback in self._on_account_data_updated_callbacks:
  70. try:
  71. await callback(user_id, room_id, account_data_type, content)
  72. except Exception as e:
  73. logger.exception("Failed to run module callback %s: %s", callback, e)
  74. async def add_account_data_to_room(
  75. self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
  76. ) -> int:
  77. """Add some account_data to a room for a user.
  78. Args:
  79. user_id: The user to add a tag for.
  80. room_id: The room to add a tag for.
  81. account_data_type: The type of account_data to add.
  82. content: A json object to associate with the tag.
  83. Returns:
  84. The maximum stream ID.
  85. """
  86. if self._instance_name in self._account_data_writers:
  87. max_stream_id = await self._store.add_account_data_to_room(
  88. user_id, room_id, account_data_type, content
  89. )
  90. self._notifier.on_new_event(
  91. StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
  92. )
  93. await self._notify_modules(user_id, room_id, account_data_type, content)
  94. return max_stream_id
  95. else:
  96. response = await self._room_data_client(
  97. instance_name=random.choice(self._account_data_writers),
  98. user_id=user_id,
  99. room_id=room_id,
  100. account_data_type=account_data_type,
  101. content=content,
  102. )
  103. return response["max_stream_id"]
  104. async def add_account_data_for_user(
  105. self, user_id: str, account_data_type: str, content: JsonDict
  106. ) -> int:
  107. """Add some global account_data for a user.
  108. Args:
  109. user_id: The user to add a tag for.
  110. account_data_type: The type of account_data to add.
  111. content: A json object to associate with the tag.
  112. Returns:
  113. The maximum stream ID.
  114. """
  115. if self._instance_name in self._account_data_writers:
  116. max_stream_id = await self._store.add_account_data_for_user(
  117. user_id, account_data_type, content
  118. )
  119. self._notifier.on_new_event(
  120. StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
  121. )
  122. await self._notify_modules(user_id, None, account_data_type, content)
  123. return max_stream_id
  124. else:
  125. response = await self._user_data_client(
  126. instance_name=random.choice(self._account_data_writers),
  127. user_id=user_id,
  128. account_data_type=account_data_type,
  129. content=content,
  130. )
  131. return response["max_stream_id"]
  132. async def add_tag_to_room(
  133. self, user_id: str, room_id: str, tag: str, content: JsonDict
  134. ) -> int:
  135. """Add a tag to a room for a user.
  136. Args:
  137. user_id: The user to add a tag for.
  138. room_id: The room to add a tag for.
  139. tag: The tag name to add.
  140. content: A json object to associate with the tag.
  141. Returns:
  142. The next account data ID.
  143. """
  144. if self._instance_name in self._account_data_writers:
  145. max_stream_id = await self._store.add_tag_to_room(
  146. user_id, room_id, tag, content
  147. )
  148. self._notifier.on_new_event(
  149. StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
  150. )
  151. return max_stream_id
  152. else:
  153. response = await self._add_tag_client(
  154. instance_name=random.choice(self._account_data_writers),
  155. user_id=user_id,
  156. room_id=room_id,
  157. tag=tag,
  158. content=content,
  159. )
  160. return response["max_stream_id"]
  161. async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int:
  162. """Remove a tag from a room for a user.
  163. Returns:
  164. The next account data ID.
  165. """
  166. if self._instance_name in self._account_data_writers:
  167. max_stream_id = await self._store.remove_tag_from_room(
  168. user_id, room_id, tag
  169. )
  170. self._notifier.on_new_event(
  171. StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
  172. )
  173. return max_stream_id
  174. else:
  175. response = await self._remove_tag_client(
  176. instance_name=random.choice(self._account_data_writers),
  177. user_id=user_id,
  178. room_id=room_id,
  179. tag=tag,
  180. )
  181. return response["max_stream_id"]
  182. class AccountDataEventSource(EventSource[int, JsonDict]):
  183. def __init__(self, hs: "HomeServer"):
  184. self.store = hs.get_datastores().main
  185. def get_current_key(self, direction: str = "f") -> int:
  186. return self.store.get_max_account_data_stream_id()
  187. async def get_new_events(
  188. self,
  189. user: UserID,
  190. from_key: int,
  191. limit: Optional[int],
  192. room_ids: Collection[str],
  193. is_guest: bool,
  194. explicit_room_id: Optional[str] = None,
  195. ) -> Tuple[List[JsonDict], int]:
  196. user_id = user.to_string()
  197. last_stream_id = from_key
  198. current_stream_id = self.store.get_max_account_data_stream_id()
  199. results = []
  200. tags = await self.store.get_updated_tags(user_id, last_stream_id)
  201. for room_id, room_tags in tags.items():
  202. results.append(
  203. {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
  204. )
  205. (
  206. account_data,
  207. room_account_data,
  208. ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
  209. for account_data_type, content in account_data.items():
  210. results.append({"type": account_data_type, "content": content})
  211. for room_id, account_data in room_account_data.items():
  212. for account_data_type, content in account_data.items():
  213. results.append(
  214. {"type": account_data_type, "content": content, "room_id": room_id}
  215. )
  216. return results, current_stream_id