account_data.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from ._base import SQLBaseStore
  16. from twisted.internet import defer
  17. import ujson as json
  18. import logging
  19. logger = logging.getLogger(__name__)
  20. class AccountDataStore(SQLBaseStore):
  21. def get_account_data_for_user(self, user_id):
  22. """Get all the client account_data for a user.
  23. Args:
  24. user_id(str): The user to get the account_data for.
  25. Returns:
  26. A deferred pair of a dict of global account_data and a dict
  27. mapping from room_id string to per room account_data dicts.
  28. """
  29. def get_account_data_for_user_txn(txn):
  30. rows = self._simple_select_list_txn(
  31. txn, "account_data", {"user_id": user_id},
  32. ["account_data_type", "content"]
  33. )
  34. global_account_data = {
  35. row["account_data_type"]: json.loads(row["content"]) for row in rows
  36. }
  37. rows = self._simple_select_list_txn(
  38. txn, "room_account_data", {"user_id": user_id},
  39. ["room_id", "account_data_type", "content"]
  40. )
  41. by_room = {}
  42. for row in rows:
  43. room_data = by_room.setdefault(row["room_id"], {})
  44. room_data[row["account_data_type"]] = json.loads(row["content"])
  45. return (global_account_data, by_room)
  46. return self.runInteraction(
  47. "get_account_data_for_user", get_account_data_for_user_txn
  48. )
  49. def get_account_data_for_room(self, user_id, room_id):
  50. """Get all the client account_data for a user for a room.
  51. Args:
  52. user_id(str): The user to get the account_data for.
  53. room_id(str): The room to get the account_data for.
  54. Returns:
  55. A deferred dict of the room account_data
  56. """
  57. def get_account_data_for_room_txn(txn):
  58. rows = self._simple_select_list_txn(
  59. txn, "room_account_data", {"user_id": user_id, "room_id": room_id},
  60. ["account_data_type", "content"]
  61. )
  62. return {
  63. row["account_data_type"]: json.loads(row["content"]) for row in rows
  64. }
  65. return self.runInteraction(
  66. "get_account_data_for_room", get_account_data_for_room_txn
  67. )
  68. def get_all_updated_account_data(self, last_global_id, last_room_id,
  69. current_id, limit):
  70. """Get all the client account_data that has changed on the server
  71. Args:
  72. last_global_id(int): The position to fetch from for top level data
  73. last_room_id(int): The position to fetch from for per room data
  74. current_id(int): The position to fetch up to.
  75. Returns:
  76. A deferred pair of lists of tuples of stream_id int, user_id string,
  77. room_id string, type string, and content string.
  78. """
  79. def get_updated_account_data_txn(txn):
  80. sql = (
  81. "SELECT stream_id, user_id, account_data_type, content"
  82. " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
  83. " ORDER BY stream_id ASC LIMIT ?"
  84. )
  85. txn.execute(sql, (last_global_id, current_id, limit))
  86. global_results = txn.fetchall()
  87. sql = (
  88. "SELECT stream_id, user_id, room_id, account_data_type, content"
  89. " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
  90. " ORDER BY stream_id ASC LIMIT ?"
  91. )
  92. txn.execute(sql, (last_room_id, current_id, limit))
  93. room_results = txn.fetchall()
  94. return (global_results, room_results)
  95. return self.runInteraction(
  96. "get_all_updated_account_data_txn", get_updated_account_data_txn
  97. )
  98. def get_updated_account_data_for_user(self, user_id, stream_id):
  99. """Get all the client account_data for a that's changed for a user
  100. Args:
  101. user_id(str): The user to get the account_data for.
  102. stream_id(int): The point in the stream since which to get updates
  103. Returns:
  104. A deferred pair of a dict of global account_data and a dict
  105. mapping from room_id string to per room account_data dicts.
  106. """
  107. def get_updated_account_data_for_user_txn(txn):
  108. sql = (
  109. "SELECT account_data_type, content FROM account_data"
  110. " WHERE user_id = ? AND stream_id > ?"
  111. )
  112. txn.execute(sql, (user_id, stream_id))
  113. global_account_data = {
  114. row[0]: json.loads(row[1]) for row in txn.fetchall()
  115. }
  116. sql = (
  117. "SELECT room_id, account_data_type, content FROM room_account_data"
  118. " WHERE user_id = ? AND stream_id > ?"
  119. )
  120. txn.execute(sql, (user_id, stream_id))
  121. account_data_by_room = {}
  122. for row in txn.fetchall():
  123. room_account_data = account_data_by_room.setdefault(row[0], {})
  124. room_account_data[row[1]] = json.loads(row[2])
  125. return (global_account_data, account_data_by_room)
  126. changed = self._account_data_stream_cache.has_entity_changed(
  127. user_id, int(stream_id)
  128. )
  129. if not changed:
  130. return ({}, {})
  131. return self.runInteraction(
  132. "get_updated_account_data_for_user", get_updated_account_data_for_user_txn
  133. )
  134. @defer.inlineCallbacks
  135. def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
  136. """Add some account_data to a room for a user.
  137. Args:
  138. user_id(str): The user to add a tag for.
  139. room_id(str): The room to add a tag for.
  140. account_data_type(str): The type of account_data to add.
  141. content(dict): A json object to associate with the tag.
  142. Returns:
  143. A deferred that completes once the account_data has been added.
  144. """
  145. content_json = json.dumps(content)
  146. def add_account_data_txn(txn, next_id):
  147. self._simple_upsert_txn(
  148. txn,
  149. table="room_account_data",
  150. keyvalues={
  151. "user_id": user_id,
  152. "room_id": room_id,
  153. "account_data_type": account_data_type,
  154. },
  155. values={
  156. "stream_id": next_id,
  157. "content": content_json,
  158. }
  159. )
  160. txn.call_after(
  161. self._account_data_stream_cache.entity_has_changed,
  162. user_id, next_id,
  163. )
  164. self._update_max_stream_id(txn, next_id)
  165. with self._account_data_id_gen.get_next() as next_id:
  166. yield self.runInteraction(
  167. "add_room_account_data", add_account_data_txn, next_id
  168. )
  169. result = self._account_data_id_gen.get_current_token()
  170. defer.returnValue(result)
  171. @defer.inlineCallbacks
  172. def add_account_data_for_user(self, user_id, account_data_type, content):
  173. """Add some account_data to a room for a user.
  174. Args:
  175. user_id(str): The user to add a tag for.
  176. account_data_type(str): The type of account_data to add.
  177. content(dict): A json object to associate with the tag.
  178. Returns:
  179. A deferred that completes once the account_data has been added.
  180. """
  181. content_json = json.dumps(content)
  182. def add_account_data_txn(txn, next_id):
  183. self._simple_upsert_txn(
  184. txn,
  185. table="account_data",
  186. keyvalues={
  187. "user_id": user_id,
  188. "account_data_type": account_data_type,
  189. },
  190. values={
  191. "stream_id": next_id,
  192. "content": content_json,
  193. }
  194. )
  195. txn.call_after(
  196. self._account_data_stream_cache.entity_has_changed,
  197. user_id, next_id,
  198. )
  199. self._update_max_stream_id(txn, next_id)
  200. with self._account_data_id_gen.get_next() as next_id:
  201. yield self.runInteraction(
  202. "add_user_account_data", add_account_data_txn, next_id
  203. )
  204. result = self._account_data_id_gen.get_current_token()
  205. defer.returnValue(result)
  206. def _update_max_stream_id(self, txn, next_id):
  207. """Update the max stream_id
  208. Args:
  209. txn: The database cursor
  210. next_id(int): The the revision to advance to.
  211. """
  212. update_max_id_sql = (
  213. "UPDATE account_data_max_stream_id"
  214. " SET stream_id = ?"
  215. " WHERE stream_id < ?"
  216. )
  217. txn.execute(update_max_id_sql, (next_id, next_id))