tags.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. # Copyright 2018 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. from canonicaljson import json
  18. from twisted.internet import defer
  19. from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
  20. from synapse.util.caches.descriptors import cached
  21. logger = logging.getLogger(__name__)
  22. class TagsWorkerStore(AccountDataWorkerStore):
  23. @cached()
  24. def get_tags_for_user(self, user_id):
  25. """Get all the tags for a user.
  26. Args:
  27. user_id(str): The user to get the tags for.
  28. Returns:
  29. A deferred dict mapping from room_id strings to dicts mapping from
  30. tag strings to tag content.
  31. """
  32. deferred = self.db.simple_select_list(
  33. "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
  34. )
  35. @deferred.addCallback
  36. def tags_by_room(rows):
  37. tags_by_room = {}
  38. for row in rows:
  39. room_tags = tags_by_room.setdefault(row["room_id"], {})
  40. room_tags[row["tag"]] = json.loads(row["content"])
  41. return tags_by_room
  42. return deferred
  43. @defer.inlineCallbacks
  44. def get_all_updated_tags(self, last_id, current_id, limit):
  45. """Get all the client tags that have changed on the server
  46. Args:
  47. last_id(int): The position to fetch from.
  48. current_id(int): The position to fetch up to.
  49. Returns:
  50. A deferred list of tuples of stream_id int, user_id string,
  51. room_id string, tag string and content string.
  52. """
  53. if last_id == current_id:
  54. return []
  55. def get_all_updated_tags_txn(txn):
  56. sql = (
  57. "SELECT stream_id, user_id, room_id"
  58. " FROM room_tags_revisions as r"
  59. " WHERE ? < stream_id AND stream_id <= ?"
  60. " ORDER BY stream_id ASC LIMIT ?"
  61. )
  62. txn.execute(sql, (last_id, current_id, limit))
  63. return txn.fetchall()
  64. tag_ids = yield self.db.runInteraction(
  65. "get_all_updated_tags", get_all_updated_tags_txn
  66. )
  67. def get_tag_content(txn, tag_ids):
  68. sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
  69. results = []
  70. for stream_id, user_id, room_id in tag_ids:
  71. txn.execute(sql, (user_id, room_id))
  72. tags = []
  73. for tag, content in txn:
  74. tags.append(json.dumps(tag) + ":" + content)
  75. tag_json = "{" + ",".join(tags) + "}"
  76. results.append((stream_id, user_id, room_id, tag_json))
  77. return results
  78. batch_size = 50
  79. results = []
  80. for i in range(0, len(tag_ids), batch_size):
  81. tags = yield self.db.runInteraction(
  82. "get_all_updated_tag_content",
  83. get_tag_content,
  84. tag_ids[i : i + batch_size],
  85. )
  86. results.extend(tags)
  87. return results
  88. @defer.inlineCallbacks
  89. def get_updated_tags(self, user_id, stream_id):
  90. """Get all the tags for the rooms where the tags have changed since the
  91. given version
  92. Args:
  93. user_id(str): The user to get the tags for.
  94. stream_id(int): The earliest update to get for the user.
  95. Returns:
  96. A deferred dict mapping from room_id strings to lists of tag
  97. strings for all the rooms that changed since the stream_id token.
  98. """
  99. def get_updated_tags_txn(txn):
  100. sql = (
  101. "SELECT room_id from room_tags_revisions"
  102. " WHERE user_id = ? AND stream_id > ?"
  103. )
  104. txn.execute(sql, (user_id, stream_id))
  105. room_ids = [row[0] for row in txn]
  106. return room_ids
  107. changed = self._account_data_stream_cache.has_entity_changed(
  108. user_id, int(stream_id)
  109. )
  110. if not changed:
  111. return {}
  112. room_ids = yield self.db.runInteraction(
  113. "get_updated_tags", get_updated_tags_txn
  114. )
  115. results = {}
  116. if room_ids:
  117. tags_by_room = yield self.get_tags_for_user(user_id)
  118. for room_id in room_ids:
  119. results[room_id] = tags_by_room.get(room_id, {})
  120. return results
  121. def get_tags_for_room(self, user_id, room_id):
  122. """Get all the tags for the given room
  123. Args:
  124. user_id(str): The user to get tags for
  125. room_id(str): The room to get tags for
  126. Returns:
  127. A deferred list of string tags.
  128. """
  129. return self.db.simple_select_list(
  130. table="room_tags",
  131. keyvalues={"user_id": user_id, "room_id": room_id},
  132. retcols=("tag", "content"),
  133. desc="get_tags_for_room",
  134. ).addCallback(
  135. lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
  136. )
  137. class TagsStore(TagsWorkerStore):
  138. @defer.inlineCallbacks
  139. def add_tag_to_room(self, user_id, room_id, tag, content):
  140. """Add a tag to a room for a user.
  141. Args:
  142. user_id(str): The user to add a tag for.
  143. room_id(str): The room to add a tag for.
  144. tag(str): The tag name to add.
  145. content(dict): A json object to associate with the tag.
  146. Returns:
  147. A deferred that completes once the tag has been added.
  148. """
  149. content_json = json.dumps(content)
  150. def add_tag_txn(txn, next_id):
  151. self.db.simple_upsert_txn(
  152. txn,
  153. table="room_tags",
  154. keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
  155. values={"content": content_json},
  156. )
  157. self._update_revision_txn(txn, user_id, room_id, next_id)
  158. with self._account_data_id_gen.get_next() as next_id:
  159. yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
  160. self.get_tags_for_user.invalidate((user_id,))
  161. result = self._account_data_id_gen.get_current_token()
  162. return result
  163. @defer.inlineCallbacks
  164. def remove_tag_from_room(self, user_id, room_id, tag):
  165. """Remove a tag from a room for a user.
  166. Returns:
  167. A deferred that completes once the tag has been removed
  168. """
  169. def remove_tag_txn(txn, next_id):
  170. sql = (
  171. "DELETE FROM room_tags "
  172. " WHERE user_id = ? AND room_id = ? AND tag = ?"
  173. )
  174. txn.execute(sql, (user_id, room_id, tag))
  175. self._update_revision_txn(txn, user_id, room_id, next_id)
  176. with self._account_data_id_gen.get_next() as next_id:
  177. yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
  178. self.get_tags_for_user.invalidate((user_id,))
  179. result = self._account_data_id_gen.get_current_token()
  180. return result
  181. def _update_revision_txn(self, txn, user_id, room_id, next_id):
  182. """Update the latest revision of the tags for the given user and room.
  183. Args:
  184. txn: The database cursor
  185. user_id(str): The ID of the user.
  186. room_id(str): The ID of the room.
  187. next_id(int): The the revision to advance to.
  188. """
  189. txn.call_after(
  190. self._account_data_stream_cache.entity_has_changed, user_id, next_id
  191. )
  192. # Note: This is only here for backwards compat to allow admins to
  193. # roll back to a previous Synapse version. Next time we update the
  194. # database version we can remove this table.
  195. update_max_id_sql = (
  196. "UPDATE account_data_max_stream_id"
  197. " SET stream_id = ?"
  198. " WHERE stream_id < ?"
  199. )
  200. txn.execute(update_max_id_sql, (next_id, next_id))
  201. update_sql = (
  202. "UPDATE room_tags_revisions"
  203. " SET stream_id = ?"
  204. " WHERE user_id = ?"
  205. " AND room_id = ?"
  206. )
  207. txn.execute(update_sql, (next_id, user_id, room_id))
  208. if txn.rowcount == 0:
  209. insert_sql = (
  210. "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)"
  211. " VALUES (?, ?, ?)"
  212. )
  213. try:
  214. txn.execute(insert_sql, (user_id, room_id, next_id))
  215. except self.database_engine.module.IntegrityError:
  216. # Ignore insertion errors. It doesn't matter if the row wasn't
  217. # inserted because if two updates happend concurrently the one
  218. # with the higher stream_id will not be reported to a client
  219. # unless the previous update has completed. It doesn't matter
  220. # which stream_id ends up in the table, as long as it is higher
  221. # than the id that the client has.
  222. pass