tags.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from ._base import SQLBaseStore
  16. from synapse.util.caches.descriptors import cached
  17. from twisted.internet import defer
  18. import ujson as json
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. class TagsStore(SQLBaseStore):
  22. def get_max_account_data_stream_id(self):
  23. """Get the current max stream id for the private user data stream
  24. Returns:
  25. A deferred int.
  26. """
  27. return self._account_data_id_gen.get_current_token()
  28. @cached()
  29. def get_tags_for_user(self, user_id):
  30. """Get all the tags for a user.
  31. Args:
  32. user_id(str): The user to get the tags for.
  33. Returns:
  34. A deferred dict mapping from room_id strings to dicts mapping from
  35. tag strings to tag content.
  36. """
  37. deferred = self._simple_select_list(
  38. "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
  39. )
  40. @deferred.addCallback
  41. def tags_by_room(rows):
  42. tags_by_room = {}
  43. for row in rows:
  44. room_tags = tags_by_room.setdefault(row["room_id"], {})
  45. room_tags[row["tag"]] = json.loads(row["content"])
  46. return tags_by_room
  47. return deferred
  48. @defer.inlineCallbacks
  49. def get_all_updated_tags(self, last_id, current_id, limit):
  50. """Get all the client tags that have changed on the server
  51. Args:
  52. last_id(int): The position to fetch from.
  53. current_id(int): The position to fetch up to.
  54. Returns:
  55. A deferred list of tuples of stream_id int, user_id string,
  56. room_id string, tag string and content string.
  57. """
  58. if last_id == current_id:
  59. defer.returnValue([])
  60. def get_all_updated_tags_txn(txn):
  61. sql = (
  62. "SELECT stream_id, user_id, room_id"
  63. " FROM room_tags_revisions as r"
  64. " WHERE ? < stream_id AND stream_id <= ?"
  65. " ORDER BY stream_id ASC LIMIT ?"
  66. )
  67. txn.execute(sql, (last_id, current_id, limit))
  68. return txn.fetchall()
  69. tag_ids = yield self.runInteraction(
  70. "get_all_updated_tags", get_all_updated_tags_txn
  71. )
  72. def get_tag_content(txn, tag_ids):
  73. sql = (
  74. "SELECT tag, content"
  75. " FROM room_tags"
  76. " WHERE user_id=? AND room_id=?"
  77. )
  78. results = []
  79. for stream_id, user_id, room_id in tag_ids:
  80. txn.execute(sql, (user_id, room_id))
  81. tags = []
  82. for tag, content in txn.fetchall():
  83. tags.append(json.dumps(tag) + ":" + content)
  84. tag_json = "{" + ",".join(tags) + "}"
  85. results.append((stream_id, user_id, room_id, tag_json))
  86. return results
  87. batch_size = 50
  88. results = []
  89. for i in xrange(0, len(tag_ids), batch_size):
  90. tags = yield self.runInteraction(
  91. "get_all_updated_tag_content",
  92. get_tag_content,
  93. tag_ids[i:i + batch_size],
  94. )
  95. results.extend(tags)
  96. defer.returnValue(results)
  97. @defer.inlineCallbacks
  98. def get_updated_tags(self, user_id, stream_id):
  99. """Get all the tags for the rooms where the tags have changed since the
  100. given version
  101. Args:
  102. user_id(str): The user to get the tags for.
  103. stream_id(int): The earliest update to get for the user.
  104. Returns:
  105. A deferred dict mapping from room_id strings to lists of tag
  106. strings for all the rooms that changed since the stream_id token.
  107. """
  108. def get_updated_tags_txn(txn):
  109. sql = (
  110. "SELECT room_id from room_tags_revisions"
  111. " WHERE user_id = ? AND stream_id > ?"
  112. )
  113. txn.execute(sql, (user_id, stream_id))
  114. room_ids = [row[0] for row in txn.fetchall()]
  115. return room_ids
  116. changed = self._account_data_stream_cache.has_entity_changed(
  117. user_id, int(stream_id)
  118. )
  119. if not changed:
  120. defer.returnValue({})
  121. room_ids = yield self.runInteraction(
  122. "get_updated_tags", get_updated_tags_txn
  123. )
  124. results = {}
  125. if room_ids:
  126. tags_by_room = yield self.get_tags_for_user(user_id)
  127. for room_id in room_ids:
  128. results[room_id] = tags_by_room.get(room_id, {})
  129. defer.returnValue(results)
  130. def get_tags_for_room(self, user_id, room_id):
  131. """Get all the tags for the given room
  132. Args:
  133. user_id(str): The user to get tags for
  134. room_id(str): The room to get tags for
  135. Returns:
  136. A deferred list of string tags.
  137. """
  138. return self._simple_select_list(
  139. table="room_tags",
  140. keyvalues={"user_id": user_id, "room_id": room_id},
  141. retcols=("tag", "content"),
  142. desc="get_tags_for_room",
  143. ).addCallback(lambda rows: {
  144. row["tag"]: json.loads(row["content"]) for row in rows
  145. })
  146. @defer.inlineCallbacks
  147. def add_tag_to_room(self, user_id, room_id, tag, content):
  148. """Add a tag to a room for a user.
  149. Args:
  150. user_id(str): The user to add a tag for.
  151. room_id(str): The room to add a tag for.
  152. tag(str): The tag name to add.
  153. content(dict): A json object to associate with the tag.
  154. Returns:
  155. A deferred that completes once the tag has been added.
  156. """
  157. content_json = json.dumps(content)
  158. def add_tag_txn(txn, next_id):
  159. self._simple_upsert_txn(
  160. txn,
  161. table="room_tags",
  162. keyvalues={
  163. "user_id": user_id,
  164. "room_id": room_id,
  165. "tag": tag,
  166. },
  167. values={
  168. "content": content_json,
  169. }
  170. )
  171. self._update_revision_txn(txn, user_id, room_id, next_id)
  172. with self._account_data_id_gen.get_next() as next_id:
  173. yield self.runInteraction("add_tag", add_tag_txn, next_id)
  174. self.get_tags_for_user.invalidate((user_id,))
  175. result = self._account_data_id_gen.get_current_token()
  176. defer.returnValue(result)
  177. @defer.inlineCallbacks
  178. def remove_tag_from_room(self, user_id, room_id, tag):
  179. """Remove a tag from a room for a user.
  180. Returns:
  181. A deferred that completes once the tag has been removed
  182. """
  183. def remove_tag_txn(txn, next_id):
  184. sql = (
  185. "DELETE FROM room_tags "
  186. " WHERE user_id = ? AND room_id = ? AND tag = ?"
  187. )
  188. txn.execute(sql, (user_id, room_id, tag))
  189. self._update_revision_txn(txn, user_id, room_id, next_id)
  190. with self._account_data_id_gen.get_next() as next_id:
  191. yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
  192. self.get_tags_for_user.invalidate((user_id,))
  193. result = self._account_data_id_gen.get_current_token()
  194. defer.returnValue(result)
  195. def _update_revision_txn(self, txn, user_id, room_id, next_id):
  196. """Update the latest revision of the tags for the given user and room.
  197. Args:
  198. txn: The database cursor
  199. user_id(str): The ID of the user.
  200. room_id(str): The ID of the room.
  201. next_id(int): The the revision to advance to.
  202. """
  203. txn.call_after(
  204. self._account_data_stream_cache.entity_has_changed,
  205. user_id, next_id
  206. )
  207. update_max_id_sql = (
  208. "UPDATE account_data_max_stream_id"
  209. " SET stream_id = ?"
  210. " WHERE stream_id < ?"
  211. )
  212. txn.execute(update_max_id_sql, (next_id, next_id))
  213. update_sql = (
  214. "UPDATE room_tags_revisions"
  215. " SET stream_id = ?"
  216. " WHERE user_id = ?"
  217. " AND room_id = ?"
  218. )
  219. txn.execute(update_sql, (next_id, user_id, room_id))
  220. if txn.rowcount == 0:
  221. insert_sql = (
  222. "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)"
  223. " VALUES (?, ?, ?)"
  224. )
  225. try:
  226. txn.execute(insert_sql, (user_id, room_id, next_id))
  227. except self.database_engine.module.IntegrityError:
  228. # Ignore insertion errors. It doesn't matter if the row wasn't
  229. # inserted because if two updates happend concurrently the one
  230. # with the higher stream_id will not be reported to a client
  231. # unless the previous update has completed. It doesn't matter
  232. # which stream_id ends up in the table, as long as it is higher
  233. # than the id that the client has.
  234. pass