tags.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from ._base import SQLBaseStore
  16. from synapse.util.caches.descriptors import cached
  17. from twisted.internet import defer
  18. import ujson as json
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. class TagsStore(SQLBaseStore):
  22. def get_max_account_data_stream_id(self):
  23. """Get the current max stream id for the private user data stream
  24. Returns:
  25. A deferred int.
  26. """
  27. return self._account_data_id_gen.get_current_token()
  28. @cached()
  29. def get_tags_for_user(self, user_id):
  30. """Get all the tags for a user.
  31. Args:
  32. user_id(str): The user to get the tags for.
  33. Returns:
  34. A deferred dict mapping from room_id strings to dicts mapping from
  35. tag strings to tag content.
  36. """
  37. deferred = self._simple_select_list(
  38. "room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
  39. )
  40. @deferred.addCallback
  41. def tags_by_room(rows):
  42. tags_by_room = {}
  43. for row in rows:
  44. room_tags = tags_by_room.setdefault(row["room_id"], {})
  45. room_tags[row["tag"]] = json.loads(row["content"])
  46. return tags_by_room
  47. return deferred
  48. @defer.inlineCallbacks
  49. def get_all_updated_tags(self, last_id, current_id, limit):
  50. """Get all the client tags that have changed on the server
  51. Args:
  52. last_id(int): The position to fetch from.
  53. current_id(int): The position to fetch up to.
  54. Returns:
  55. A deferred list of tuples of stream_id int, user_id string,
  56. room_id string, tag string and content string.
  57. """
  58. def get_all_updated_tags_txn(txn):
  59. sql = (
  60. "SELECT stream_id, user_id, room_id"
  61. " FROM room_tags_revisions as r"
  62. " WHERE ? < stream_id AND stream_id <= ?"
  63. " ORDER BY stream_id ASC LIMIT ?"
  64. )
  65. txn.execute(sql, (last_id, current_id, limit))
  66. return txn.fetchall()
  67. tag_ids = yield self.runInteraction(
  68. "get_all_updated_tags", get_all_updated_tags_txn
  69. )
  70. def get_tag_content(txn, tag_ids):
  71. sql = (
  72. "SELECT tag, content"
  73. " FROM room_tags"
  74. " WHERE user_id=? AND room_id=?"
  75. )
  76. results = []
  77. for stream_id, user_id, room_id in tag_ids:
  78. txn.execute(sql, (user_id, room_id))
  79. tags = []
  80. for tag, content in txn.fetchall():
  81. tags.append(json.dumps(tag) + ":" + content)
  82. tag_json = "{" + ",".join(tags) + "}"
  83. results.append((stream_id, user_id, room_id, tag_json))
  84. return results
  85. batch_size = 50
  86. results = []
  87. for i in xrange(0, len(tag_ids), batch_size):
  88. tags = yield self.runInteraction(
  89. "get_all_updated_tag_content",
  90. get_tag_content,
  91. tag_ids[i:i + batch_size],
  92. )
  93. results.extend(tags)
  94. defer.returnValue(results)
  95. @defer.inlineCallbacks
  96. def get_updated_tags(self, user_id, stream_id):
  97. """Get all the tags for the rooms where the tags have changed since the
  98. given version
  99. Args:
  100. user_id(str): The user to get the tags for.
  101. stream_id(int): The earliest update to get for the user.
  102. Returns:
  103. A deferred dict mapping from room_id strings to lists of tag
  104. strings for all the rooms that changed since the stream_id token.
  105. """
  106. def get_updated_tags_txn(txn):
  107. sql = (
  108. "SELECT room_id from room_tags_revisions"
  109. " WHERE user_id = ? AND stream_id > ?"
  110. )
  111. txn.execute(sql, (user_id, stream_id))
  112. room_ids = [row[0] for row in txn.fetchall()]
  113. return room_ids
  114. changed = self._account_data_stream_cache.has_entity_changed(
  115. user_id, int(stream_id)
  116. )
  117. if not changed:
  118. defer.returnValue({})
  119. room_ids = yield self.runInteraction(
  120. "get_updated_tags", get_updated_tags_txn
  121. )
  122. results = {}
  123. if room_ids:
  124. tags_by_room = yield self.get_tags_for_user(user_id)
  125. for room_id in room_ids:
  126. results[room_id] = tags_by_room.get(room_id, {})
  127. defer.returnValue(results)
  128. def get_tags_for_room(self, user_id, room_id):
  129. """Get all the tags for the given room
  130. Args:
  131. user_id(str): The user to get tags for
  132. room_id(str): The room to get tags for
  133. Returns:
  134. A deferred list of string tags.
  135. """
  136. return self._simple_select_list(
  137. table="room_tags",
  138. keyvalues={"user_id": user_id, "room_id": room_id},
  139. retcols=("tag", "content"),
  140. desc="get_tags_for_room",
  141. ).addCallback(lambda rows: {
  142. row["tag"]: json.loads(row["content"]) for row in rows
  143. })
  144. @defer.inlineCallbacks
  145. def add_tag_to_room(self, user_id, room_id, tag, content):
  146. """Add a tag to a room for a user.
  147. Args:
  148. user_id(str): The user to add a tag for.
  149. room_id(str): The room to add a tag for.
  150. tag(str): The tag name to add.
  151. content(dict): A json object to associate with the tag.
  152. Returns:
  153. A deferred that completes once the tag has been added.
  154. """
  155. content_json = json.dumps(content)
  156. def add_tag_txn(txn, next_id):
  157. self._simple_upsert_txn(
  158. txn,
  159. table="room_tags",
  160. keyvalues={
  161. "user_id": user_id,
  162. "room_id": room_id,
  163. "tag": tag,
  164. },
  165. values={
  166. "content": content_json,
  167. }
  168. )
  169. self._update_revision_txn(txn, user_id, room_id, next_id)
  170. with self._account_data_id_gen.get_next() as next_id:
  171. yield self.runInteraction("add_tag", add_tag_txn, next_id)
  172. self.get_tags_for_user.invalidate((user_id,))
  173. result = self._account_data_id_gen.get_current_token()
  174. defer.returnValue(result)
  175. @defer.inlineCallbacks
  176. def remove_tag_from_room(self, user_id, room_id, tag):
  177. """Remove a tag from a room for a user.
  178. Returns:
  179. A deferred that completes once the tag has been removed
  180. """
  181. def remove_tag_txn(txn, next_id):
  182. sql = (
  183. "DELETE FROM room_tags "
  184. " WHERE user_id = ? AND room_id = ? AND tag = ?"
  185. )
  186. txn.execute(sql, (user_id, room_id, tag))
  187. self._update_revision_txn(txn, user_id, room_id, next_id)
  188. with self._account_data_id_gen.get_next() as next_id:
  189. yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
  190. self.get_tags_for_user.invalidate((user_id,))
  191. result = self._account_data_id_gen.get_current_token()
  192. defer.returnValue(result)
  193. def _update_revision_txn(self, txn, user_id, room_id, next_id):
  194. """Update the latest revision of the tags for the given user and room.
  195. Args:
  196. txn: The database cursor
  197. user_id(str): The ID of the user.
  198. room_id(str): The ID of the room.
  199. next_id(int): The the revision to advance to.
  200. """
  201. txn.call_after(
  202. self._account_data_stream_cache.entity_has_changed,
  203. user_id, next_id
  204. )
  205. update_max_id_sql = (
  206. "UPDATE account_data_max_stream_id"
  207. " SET stream_id = ?"
  208. " WHERE stream_id < ?"
  209. )
  210. txn.execute(update_max_id_sql, (next_id, next_id))
  211. update_sql = (
  212. "UPDATE room_tags_revisions"
  213. " SET stream_id = ?"
  214. " WHERE user_id = ?"
  215. " AND room_id = ?"
  216. )
  217. txn.execute(update_sql, (next_id, user_id, room_id))
  218. if txn.rowcount == 0:
  219. insert_sql = (
  220. "INSERT INTO room_tags_revisions (user_id, room_id, stream_id)"
  221. " VALUES (?, ?, ?)"
  222. )
  223. try:
  224. txn.execute(insert_sql, (user_id, room_id, next_id))
  225. except self.database_engine.module.IntegrityError:
  226. # Ignore insertion errors. It doesn't matter if the row wasn't
  227. # inserted because if two updates happend concurrently the one
  228. # with the higher stream_id will not be reported to a client
  229. # unless the previous update has completed. It doesn't matter
  230. # which stream_id ends up in the table, as long as it is higher
  231. # than the id that the client has.
  232. pass