tags.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. # Copyright 2018 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. from six.moves import range
  18. from canonicaljson import json
  19. from twisted.internet import defer
  20. from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
  21. from synapse.util.caches.descriptors import cached
  22. logger = logging.getLogger(__name__)
  23. class TagsWorkerStore(AccountDataWorkerStore):
  24. @cached()
  25. def get_tags_for_user(self, user_id):
  26. """Get all the tags for a user.
  27. Args:
  28. user_id(str): The user to get the tags for.
  29. Returns:
  30. A deferred dict mapping from room_id strings to dicts mapping from
  31. tag strings to tag content.
  32. """
  33. deferred = self.db.simple_select_list(
  34. "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
  35. )
  36. @deferred.addCallback
  37. def tags_by_room(rows):
  38. tags_by_room = {}
  39. for row in rows:
  40. room_tags = tags_by_room.setdefault(row["room_id"], {})
  41. room_tags[row["tag"]] = json.loads(row["content"])
  42. return tags_by_room
  43. return deferred
  44. @defer.inlineCallbacks
  45. def get_all_updated_tags(self, last_id, current_id, limit):
  46. """Get all the client tags that have changed on the server
  47. Args:
  48. last_id(int): The position to fetch from.
  49. current_id(int): The position to fetch up to.
  50. Returns:
  51. A deferred list of tuples of stream_id int, user_id string,
  52. room_id string, tag string and content string.
  53. """
  54. if last_id == current_id:
  55. return []
  56. def get_all_updated_tags_txn(txn):
  57. sql = (
  58. "SELECT stream_id, user_id, room_id"
  59. " FROM room_tags_revisions as r"
  60. " WHERE ? < stream_id AND stream_id <= ?"
  61. " ORDER BY stream_id ASC LIMIT ?"
  62. )
  63. txn.execute(sql, (last_id, current_id, limit))
  64. return txn.fetchall()
  65. tag_ids = yield self.db.runInteraction(
  66. "get_all_updated_tags", get_all_updated_tags_txn
  67. )
  68. def get_tag_content(txn, tag_ids):
  69. sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
  70. results = []
  71. for stream_id, user_id, room_id in tag_ids:
  72. txn.execute(sql, (user_id, room_id))
  73. tags = []
  74. for tag, content in txn:
  75. tags.append(json.dumps(tag) + ":" + content)
  76. tag_json = "{" + ",".join(tags) + "}"
  77. results.append((stream_id, user_id, room_id, tag_json))
  78. return results
  79. batch_size = 50
  80. results = []
  81. for i in range(0, len(tag_ids), batch_size):
  82. tags = yield self.db.runInteraction(
  83. "get_all_updated_tag_content",
  84. get_tag_content,
  85. tag_ids[i : i + batch_size],
  86. )
  87. results.extend(tags)
  88. return results
  89. @defer.inlineCallbacks
  90. def get_updated_tags(self, user_id, stream_id):
  91. """Get all the tags for the rooms where the tags have changed since the
  92. given version
  93. Args:
  94. user_id(str): The user to get the tags for.
  95. stream_id(int): The earliest update to get for the user.
  96. Returns:
  97. A deferred dict mapping from room_id strings to lists of tag
  98. strings for all the rooms that changed since the stream_id token.
  99. """
  100. def get_updated_tags_txn(txn):
  101. sql = (
  102. "SELECT room_id from room_tags_revisions"
  103. " WHERE user_id = ? AND stream_id > ?"
  104. )
  105. txn.execute(sql, (user_id, stream_id))
  106. room_ids = [row[0] for row in txn]
  107. return room_ids
  108. changed = self._account_data_stream_cache.has_entity_changed(
  109. user_id, int(stream_id)
  110. )
  111. if not changed:
  112. return {}
  113. room_ids = yield self.db.runInteraction(
  114. "get_updated_tags", get_updated_tags_txn
  115. )
  116. results = {}
  117. if room_ids:
  118. tags_by_room = yield self.get_tags_for_user(user_id)
  119. for room_id in room_ids:
  120. results[room_id] = tags_by_room.get(room_id, {})
  121. return results
  122. def get_tags_for_room(self, user_id, room_id):
  123. """Get all the tags for the given room
  124. Args:
  125. user_id(str): The user to get tags for
  126. room_id(str): The room to get tags for
  127. Returns:
  128. A deferred list of string tags.
  129. """
  130. return self.db.simple_select_list(
  131. table="room_tags",
  132. keyvalues={"user_id": user_id, "room_id": room_id},
  133. retcols=("tag", "content"),
  134. desc="get_tags_for_room",
  135. ).addCallback(
  136. lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
  137. )
  138. class TagsStore(TagsWorkerStore):
  139. @defer.inlineCallbacks
  140. def add_tag_to_room(self, user_id, room_id, tag, content):
  141. """Add a tag to a room for a user.
  142. Args:
  143. user_id(str): The user to add a tag for.
  144. room_id(str): The room to add a tag for.
  145. tag(str): The tag name to add.
  146. content(dict): A json object to associate with the tag.
  147. Returns:
  148. A deferred that completes once the tag has been added.
  149. """
  150. content_json = json.dumps(content)
  151. def add_tag_txn(txn, next_id):
  152. self.db.simple_upsert_txn(
  153. txn,
  154. table="room_tags",
  155. keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
  156. values={"content": content_json},
  157. )
  158. self._update_revision_txn(txn, user_id, room_id, next_id)
  159. with self._account_data_id_gen.get_next() as next_id:
  160. yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
  161. self.get_tags_for_user.invalidate((user_id,))
  162. result = self._account_data_id_gen.get_current_token()
  163. return result
  164. @defer.inlineCallbacks
  165. def remove_tag_from_room(self, user_id, room_id, tag):
  166. """Remove a tag from a room for a user.
  167. Returns:
  168. A deferred that completes once the tag has been removed
  169. """
  170. def remove_tag_txn(txn, next_id):
  171. sql = (
  172. "DELETE FROM room_tags "
  173. " WHERE user_id = ? AND room_id = ? AND tag = ?"
  174. )
  175. txn.execute(sql, (user_id, room_id, tag))
  176. self._update_revision_txn(txn, user_id, room_id, next_id)
  177. with self._account_data_id_gen.get_next() as next_id:
  178. yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
  179. self.get_tags_for_user.invalidate((user_id,))
  180. result = self._account_data_id_gen.get_current_token()
  181. return result
  182. def _update_revision_txn(self, txn, user_id, room_id, next_id):
  183. """Update the latest revision of the tags for the given user and room.
  184. Args:
  185. txn: The database cursor
  186. user_id(str): The ID of the user.
  187. room_id(str): The ID of the room.
  188. next_id(int): The the revision to advance to.
  189. """
  190. txn.call_after(
  191. self._account_data_stream_cache.entity_has_changed, user_id, next_id
  192. )
  193. update_max_id_sql = (
  194. "UPDATE account_data_max_stream_id"
  195. " SET stream_id = ?"
  196. " WHERE stream_id < ?"
  197. )
  198. txn.execute(update_max_id_sql, (next_id, next_id))
  199. update_sql = (
  200. "UPDATE room_tags_revisions"
  201. " SET stream_id = ?"
  202. " WHERE user_id = ?"
  203. " AND room_id = ?"
  204. )
  205. txn.execute(update_sql, (next_id, user_id, room_id))
  206. if txn.rowcount == 0:
  207. insert_sql = (
  208. "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)"
  209. " VALUES (?, ?, ?)"
  210. )
  211. try:
  212. txn.execute(insert_sql, (user_id, room_id, next_id))
  213. except self.database_engine.module.IntegrityError:
  214. # Ignore insertion errors. It doesn't matter if the row wasn't
  215. # inserted because if two updates happend concurrently the one
  216. # with the higher stream_id will not be reported to a client
  217. # unless the previous update has completed. It doesn't matter
  218. # which stream_id ends up in the table, as long as it is higher
  219. # than the id that the client has.
  220. pass