end_to_end_keys.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet import defer
  16. from synapse.util.caches.descriptors import cached
  17. from canonicaljson import encode_canonical_json
  18. import simplejson as json
  19. from ._base import SQLBaseStore
  20. from six import iteritems
  21. class EndToEndKeyStore(SQLBaseStore):
  22. def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
  23. """Stores device keys for a device. Returns whether there was a change
  24. or the keys were already in the database.
  25. """
  26. def _set_e2e_device_keys_txn(txn):
  27. old_key_json = self._simple_select_one_onecol_txn(
  28. txn,
  29. table="e2e_device_keys_json",
  30. keyvalues={
  31. "user_id": user_id,
  32. "device_id": device_id,
  33. },
  34. retcol="key_json",
  35. allow_none=True,
  36. )
  37. new_key_json = encode_canonical_json(device_keys)
  38. if old_key_json == new_key_json:
  39. return False
  40. self._simple_upsert_txn(
  41. txn,
  42. table="e2e_device_keys_json",
  43. keyvalues={
  44. "user_id": user_id,
  45. "device_id": device_id,
  46. },
  47. values={
  48. "ts_added_ms": time_now,
  49. "key_json": new_key_json,
  50. }
  51. )
  52. return True
  53. return self.runInteraction(
  54. "set_e2e_device_keys", _set_e2e_device_keys_txn
  55. )
  56. @defer.inlineCallbacks
  57. def get_e2e_device_keys(self, query_list, include_all_devices=False):
  58. """Fetch a list of device keys.
  59. Args:
  60. query_list(list): List of pairs of user_ids and device_ids.
  61. include_all_devices (bool): whether to include entries for devices
  62. that don't have device keys
  63. Returns:
  64. Dict mapping from user-id to dict mapping from device_id to
  65. dict containing "key_json", "device_display_name".
  66. """
  67. if not query_list:
  68. defer.returnValue({})
  69. results = yield self.runInteraction(
  70. "get_e2e_device_keys", self._get_e2e_device_keys_txn,
  71. query_list, include_all_devices,
  72. )
  73. for user_id, device_keys in iteritems(results):
  74. for device_id, device_info in iteritems(device_keys):
  75. device_info["keys"] = json.loads(device_info.pop("key_json"))
  76. defer.returnValue(results)
  77. def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
  78. query_clauses = []
  79. query_params = []
  80. for (user_id, device_id) in query_list:
  81. query_clause = "user_id = ?"
  82. query_params.append(user_id)
  83. if device_id is not None:
  84. query_clause += " AND device_id = ?"
  85. query_params.append(device_id)
  86. query_clauses.append(query_clause)
  87. sql = (
  88. "SELECT user_id, device_id, "
  89. " d.display_name AS device_display_name, "
  90. " k.key_json"
  91. " FROM devices d"
  92. " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
  93. " WHERE %s"
  94. ) % (
  95. "LEFT" if include_all_devices else "INNER",
  96. " OR ".join("(" + q + ")" for q in query_clauses)
  97. )
  98. txn.execute(sql, query_params)
  99. rows = self.cursor_to_dict(txn)
  100. result = {}
  101. for row in rows:
  102. result.setdefault(row["user_id"], {})[row["device_id"]] = row
  103. return result
  104. @defer.inlineCallbacks
  105. def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
  106. """Retrieve a number of one-time keys for a user
  107. Args:
  108. user_id(str): id of user to get keys for
  109. device_id(str): id of device to get keys for
  110. key_ids(list[str]): list of key ids (excluding algorithm) to
  111. retrieve
  112. Returns:
  113. deferred resolving to Dict[(str, str), str]: map from (algorithm,
  114. key_id) to json string for key
  115. """
  116. rows = yield self._simple_select_many_batch(
  117. table="e2e_one_time_keys_json",
  118. column="key_id",
  119. iterable=key_ids,
  120. retcols=("algorithm", "key_id", "key_json",),
  121. keyvalues={
  122. "user_id": user_id,
  123. "device_id": device_id,
  124. },
  125. desc="add_e2e_one_time_keys_check",
  126. )
  127. defer.returnValue({
  128. (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
  129. })
  130. @defer.inlineCallbacks
  131. def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
  132. """Insert some new one time keys for a device. Errors if any of the
  133. keys already exist.
  134. Args:
  135. user_id(str): id of user to get keys for
  136. device_id(str): id of device to get keys for
  137. time_now(long): insertion time to record (ms since epoch)
  138. new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
  139. (algorithm, key_id, key json)
  140. """
  141. def _add_e2e_one_time_keys(txn):
  142. # We are protected from race between lookup and insertion due to
  143. # a unique constraint. If there is a race of two calls to
  144. # `add_e2e_one_time_keys` then they'll conflict and we will only
  145. # insert one set.
  146. self._simple_insert_many_txn(
  147. txn, table="e2e_one_time_keys_json",
  148. values=[
  149. {
  150. "user_id": user_id,
  151. "device_id": device_id,
  152. "algorithm": algorithm,
  153. "key_id": key_id,
  154. "ts_added_ms": time_now,
  155. "key_json": json_bytes,
  156. }
  157. for algorithm, key_id, json_bytes in new_keys
  158. ],
  159. )
  160. self._invalidate_cache_and_stream(
  161. txn, self.count_e2e_one_time_keys, (user_id, device_id,)
  162. )
  163. yield self.runInteraction(
  164. "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
  165. )
  166. @cached(max_entries=10000)
  167. def count_e2e_one_time_keys(self, user_id, device_id):
  168. """ Count the number of one time keys the server has for a device
  169. Returns:
  170. Dict mapping from algorithm to number of keys for that algorithm.
  171. """
  172. def _count_e2e_one_time_keys(txn):
  173. sql = (
  174. "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
  175. " WHERE user_id = ? AND device_id = ?"
  176. " GROUP BY algorithm"
  177. )
  178. txn.execute(sql, (user_id, device_id))
  179. result = {}
  180. for algorithm, key_count in txn:
  181. result[algorithm] = key_count
  182. return result
  183. return self.runInteraction(
  184. "count_e2e_one_time_keys", _count_e2e_one_time_keys
  185. )
  186. def claim_e2e_one_time_keys(self, query_list):
  187. """Take a list of one time keys out of the database"""
  188. def _claim_e2e_one_time_keys(txn):
  189. sql = (
  190. "SELECT key_id, key_json FROM e2e_one_time_keys_json"
  191. " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
  192. " LIMIT 1"
  193. )
  194. result = {}
  195. delete = []
  196. for user_id, device_id, algorithm in query_list:
  197. user_result = result.setdefault(user_id, {})
  198. device_result = user_result.setdefault(device_id, {})
  199. txn.execute(sql, (user_id, device_id, algorithm))
  200. for key_id, key_json in txn:
  201. device_result[algorithm + ":" + key_id] = key_json
  202. delete.append((user_id, device_id, algorithm, key_id))
  203. sql = (
  204. "DELETE FROM e2e_one_time_keys_json"
  205. " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
  206. " AND key_id = ?"
  207. )
  208. for user_id, device_id, algorithm, key_id in delete:
  209. txn.execute(sql, (user_id, device_id, algorithm, key_id))
  210. self._invalidate_cache_and_stream(
  211. txn, self.count_e2e_one_time_keys, (user_id, device_id,)
  212. )
  213. return result
  214. return self.runInteraction(
  215. "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
  216. )
  217. def delete_e2e_keys_by_device(self, user_id, device_id):
  218. def delete_e2e_keys_by_device_txn(txn):
  219. self._simple_delete_txn(
  220. txn,
  221. table="e2e_device_keys_json",
  222. keyvalues={"user_id": user_id, "device_id": device_id},
  223. )
  224. self._simple_delete_txn(
  225. txn,
  226. table="e2e_one_time_keys_json",
  227. keyvalues={"user_id": user_id, "device_id": device_id},
  228. )
  229. self._invalidate_cache_and_stream(
  230. txn, self.count_e2e_one_time_keys, (user_id, device_id,)
  231. )
  232. return self.runInteraction(
  233. "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
  234. )