end_to_end_keys.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet import defer
  16. from synapse.util.caches.descriptors import cached
  17. from canonicaljson import encode_canonical_json, json
  18. from ._base import SQLBaseStore
  19. from six import iteritems
  20. class EndToEndKeyStore(SQLBaseStore):
  21. def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
  22. """Stores device keys for a device. Returns whether there was a change
  23. or the keys were already in the database.
  24. """
  25. def _set_e2e_device_keys_txn(txn):
  26. old_key_json = self._simple_select_one_onecol_txn(
  27. txn,
  28. table="e2e_device_keys_json",
  29. keyvalues={
  30. "user_id": user_id,
  31. "device_id": device_id,
  32. },
  33. retcol="key_json",
  34. allow_none=True,
  35. )
  36. new_key_json = encode_canonical_json(device_keys)
  37. if old_key_json == new_key_json:
  38. return False
  39. self._simple_upsert_txn(
  40. txn,
  41. table="e2e_device_keys_json",
  42. keyvalues={
  43. "user_id": user_id,
  44. "device_id": device_id,
  45. },
  46. values={
  47. "ts_added_ms": time_now,
  48. "key_json": new_key_json,
  49. }
  50. )
  51. return True
  52. return self.runInteraction(
  53. "set_e2e_device_keys", _set_e2e_device_keys_txn
  54. )
  55. @defer.inlineCallbacks
  56. def get_e2e_device_keys(self, query_list, include_all_devices=False):
  57. """Fetch a list of device keys.
  58. Args:
  59. query_list(list): List of pairs of user_ids and device_ids.
  60. include_all_devices (bool): whether to include entries for devices
  61. that don't have device keys
  62. Returns:
  63. Dict mapping from user-id to dict mapping from device_id to
  64. dict containing "key_json", "device_display_name".
  65. """
  66. if not query_list:
  67. defer.returnValue({})
  68. results = yield self.runInteraction(
  69. "get_e2e_device_keys", self._get_e2e_device_keys_txn,
  70. query_list, include_all_devices,
  71. )
  72. for user_id, device_keys in iteritems(results):
  73. for device_id, device_info in iteritems(device_keys):
  74. device_info["keys"] = json.loads(device_info.pop("key_json"))
  75. defer.returnValue(results)
  76. def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
  77. query_clauses = []
  78. query_params = []
  79. for (user_id, device_id) in query_list:
  80. query_clause = "user_id = ?"
  81. query_params.append(user_id)
  82. if device_id is not None:
  83. query_clause += " AND device_id = ?"
  84. query_params.append(device_id)
  85. query_clauses.append(query_clause)
  86. sql = (
  87. "SELECT user_id, device_id, "
  88. " d.display_name AS device_display_name, "
  89. " k.key_json"
  90. " FROM devices d"
  91. " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
  92. " WHERE %s"
  93. ) % (
  94. "LEFT" if include_all_devices else "INNER",
  95. " OR ".join("(" + q + ")" for q in query_clauses)
  96. )
  97. txn.execute(sql, query_params)
  98. rows = self.cursor_to_dict(txn)
  99. result = {}
  100. for row in rows:
  101. result.setdefault(row["user_id"], {})[row["device_id"]] = row
  102. return result
  103. @defer.inlineCallbacks
  104. def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
  105. """Retrieve a number of one-time keys for a user
  106. Args:
  107. user_id(str): id of user to get keys for
  108. device_id(str): id of device to get keys for
  109. key_ids(list[str]): list of key ids (excluding algorithm) to
  110. retrieve
  111. Returns:
  112. deferred resolving to Dict[(str, str), str]: map from (algorithm,
  113. key_id) to json string for key
  114. """
  115. rows = yield self._simple_select_many_batch(
  116. table="e2e_one_time_keys_json",
  117. column="key_id",
  118. iterable=key_ids,
  119. retcols=("algorithm", "key_id", "key_json",),
  120. keyvalues={
  121. "user_id": user_id,
  122. "device_id": device_id,
  123. },
  124. desc="add_e2e_one_time_keys_check",
  125. )
  126. defer.returnValue({
  127. (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
  128. })
  129. @defer.inlineCallbacks
  130. def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
  131. """Insert some new one time keys for a device. Errors if any of the
  132. keys already exist.
  133. Args:
  134. user_id(str): id of user to get keys for
  135. device_id(str): id of device to get keys for
  136. time_now(long): insertion time to record (ms since epoch)
  137. new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
  138. (algorithm, key_id, key json)
  139. """
  140. def _add_e2e_one_time_keys(txn):
  141. # We are protected from race between lookup and insertion due to
  142. # a unique constraint. If there is a race of two calls to
  143. # `add_e2e_one_time_keys` then they'll conflict and we will only
  144. # insert one set.
  145. self._simple_insert_many_txn(
  146. txn, table="e2e_one_time_keys_json",
  147. values=[
  148. {
  149. "user_id": user_id,
  150. "device_id": device_id,
  151. "algorithm": algorithm,
  152. "key_id": key_id,
  153. "ts_added_ms": time_now,
  154. "key_json": json_bytes,
  155. }
  156. for algorithm, key_id, json_bytes in new_keys
  157. ],
  158. )
  159. self._invalidate_cache_and_stream(
  160. txn, self.count_e2e_one_time_keys, (user_id, device_id,)
  161. )
  162. yield self.runInteraction(
  163. "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
  164. )
  165. @cached(max_entries=10000)
  166. def count_e2e_one_time_keys(self, user_id, device_id):
  167. """ Count the number of one time keys the server has for a device
  168. Returns:
  169. Dict mapping from algorithm to number of keys for that algorithm.
  170. """
  171. def _count_e2e_one_time_keys(txn):
  172. sql = (
  173. "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
  174. " WHERE user_id = ? AND device_id = ?"
  175. " GROUP BY algorithm"
  176. )
  177. txn.execute(sql, (user_id, device_id))
  178. result = {}
  179. for algorithm, key_count in txn:
  180. result[algorithm] = key_count
  181. return result
  182. return self.runInteraction(
  183. "count_e2e_one_time_keys", _count_e2e_one_time_keys
  184. )
  185. def claim_e2e_one_time_keys(self, query_list):
  186. """Take a list of one time keys out of the database"""
  187. def _claim_e2e_one_time_keys(txn):
  188. sql = (
  189. "SELECT key_id, key_json FROM e2e_one_time_keys_json"
  190. " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
  191. " LIMIT 1"
  192. )
  193. result = {}
  194. delete = []
  195. for user_id, device_id, algorithm in query_list:
  196. user_result = result.setdefault(user_id, {})
  197. device_result = user_result.setdefault(device_id, {})
  198. txn.execute(sql, (user_id, device_id, algorithm))
  199. for key_id, key_json in txn:
  200. device_result[algorithm + ":" + key_id] = key_json
  201. delete.append((user_id, device_id, algorithm, key_id))
  202. sql = (
  203. "DELETE FROM e2e_one_time_keys_json"
  204. " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
  205. " AND key_id = ?"
  206. )
  207. for user_id, device_id, algorithm, key_id in delete:
  208. txn.execute(sql, (user_id, device_id, algorithm, key_id))
  209. self._invalidate_cache_and_stream(
  210. txn, self.count_e2e_one_time_keys, (user_id, device_id,)
  211. )
  212. return result
  213. return self.runInteraction(
  214. "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
  215. )
  216. def delete_e2e_keys_by_device(self, user_id, device_id):
  217. def delete_e2e_keys_by_device_txn(txn):
  218. self._simple_delete_txn(
  219. txn,
  220. table="e2e_device_keys_json",
  221. keyvalues={"user_id": user_id, "device_id": device_id},
  222. )
  223. self._simple_delete_txn(
  224. txn,
  225. table="e2e_one_time_keys_json",
  226. keyvalues={"user_id": user_id, "device_id": device_id},
  227. )
  228. self._invalidate_cache_and_stream(
  229. txn, self.count_e2e_one_time_keys, (user_id, device_id,)
  230. )
  231. return self.runInteraction(
  232. "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
  233. )