123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # -*- coding: utf-8 -*-
- # Copyright 2015, 2016 OpenMarket Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from twisted.internet import defer
- from synapse.util.caches.descriptors import cached
- from canonicaljson import encode_canonical_json, json
- from ._base import SQLBaseStore
- from six import iteritems
- class EndToEndKeyStore(SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
- """Stores device keys for a device. Returns whether there was a change
- or the keys were already in the database.
- """
- def _set_e2e_device_keys_txn(txn):
- old_key_json = self._simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- retcol="key_json",
- allow_none=True,
- )
- new_key_json = encode_canonical_json(device_keys)
- if old_key_json == new_key_json:
- return False
- self._simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- values={
- "ts_added_ms": time_now,
- "key_json": new_key_json,
- }
- )
- return True
- return self.runInteraction(
- "set_e2e_device_keys", _set_e2e_device_keys_txn
- )
- @defer.inlineCallbacks
- def get_e2e_device_keys(self, query_list, include_all_devices=False):
- """Fetch a list of device keys.
- Args:
- query_list(list): List of pairs of user_ids and device_ids.
- include_all_devices (bool): whether to include entries for devices
- that don't have device keys
- Returns:
- Dict mapping from user-id to dict mapping from device_id to
- dict containing "key_json", "device_display_name".
- """
- if not query_list:
- defer.returnValue({})
- results = yield self.runInteraction(
- "get_e2e_device_keys", self._get_e2e_device_keys_txn,
- query_list, include_all_devices,
- )
- for user_id, device_keys in iteritems(results):
- for device_id, device_info in iteritems(device_keys):
- device_info["keys"] = json.loads(device_info.pop("key_json"))
- defer.returnValue(results)
- def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
- query_clauses = []
- query_params = []
- for (user_id, device_id) in query_list:
- query_clause = "user_id = ?"
- query_params.append(user_id)
- if device_id is not None:
- query_clause += " AND device_id = ?"
- query_params.append(device_id)
- query_clauses.append(query_clause)
- sql = (
- "SELECT user_id, device_id, "
- " d.display_name AS device_display_name, "
- " k.key_json"
- " FROM devices d"
- " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
- " WHERE %s"
- ) % (
- "LEFT" if include_all_devices else "INNER",
- " OR ".join("(" + q + ")" for q in query_clauses)
- )
- txn.execute(sql, query_params)
- rows = self.cursor_to_dict(txn)
- result = {}
- for row in rows:
- result.setdefault(row["user_id"], {})[row["device_id"]] = row
- return result
- @defer.inlineCallbacks
- def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
- """Retrieve a number of one-time keys for a user
- Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- key_ids(list[str]): list of key ids (excluding algorithm) to
- retrieve
- Returns:
- deferred resolving to Dict[(str, str), str]: map from (algorithm,
- key_id) to json string for key
- """
- rows = yield self._simple_select_many_batch(
- table="e2e_one_time_keys_json",
- column="key_id",
- iterable=key_ids,
- retcols=("algorithm", "key_id", "key_json",),
- keyvalues={
- "user_id": user_id,
- "device_id": device_id,
- },
- desc="add_e2e_one_time_keys_check",
- )
- defer.returnValue({
- (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
- })
- @defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
- """Insert some new one time keys for a device. Errors if any of the
- keys already exist.
- Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- time_now(long): insertion time to record (ms since epoch)
- new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
- (algorithm, key_id, key json)
- """
- def _add_e2e_one_time_keys(txn):
- # We are protected from race between lookup and insertion due to
- # a unique constraint. If there is a race of two calls to
- # `add_e2e_one_time_keys` then they'll conflict and we will only
- # insert one set.
- self._simple_insert_many_txn(
- txn, table="e2e_one_time_keys_json",
- values=[
- {
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- "key_id": key_id,
- "ts_added_ms": time_now,
- "key_json": json_bytes,
- }
- for algorithm, key_id, json_bytes in new_keys
- ],
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id,)
- )
- yield self.runInteraction(
- "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
- )
- @cached(max_entries=10000)
- def count_e2e_one_time_keys(self, user_id, device_id):
- """ Count the number of one time keys the server has for a device
- Returns:
- Dict mapping from algorithm to number of keys for that algorithm.
- """
- def _count_e2e_one_time_keys(txn):
- sql = (
- "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ?"
- " GROUP BY algorithm"
- )
- txn.execute(sql, (user_id, device_id))
- result = {}
- for algorithm, key_count in txn:
- result[algorithm] = key_count
- return result
- return self.runInteraction(
- "count_e2e_one_time_keys", _count_e2e_one_time_keys
- )
- def claim_e2e_one_time_keys(self, query_list):
- """Take a list of one time keys out of the database"""
- def _claim_e2e_one_time_keys(txn):
- sql = (
- "SELECT key_id, key_json FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
- " LIMIT 1"
- )
- result = {}
- delete = []
- for user_id, device_id, algorithm in query_list:
- user_result = result.setdefault(user_id, {})
- device_result = user_result.setdefault(device_id, {})
- txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn:
- device_result[algorithm + ":" + key_id] = key_json
- delete.append((user_id, device_id, algorithm, key_id))
- sql = (
- "DELETE FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
- " AND key_id = ?"
- )
- for user_id, device_id, algorithm, key_id in delete:
- txn.execute(sql, (user_id, device_id, algorithm, key_id))
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id,)
- )
- return result
- return self.runInteraction(
- "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
- )
- def delete_e2e_keys_by_device(self, user_id, device_id):
- def delete_e2e_keys_by_device_txn(txn):
- self._simple_delete_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- )
- self._simple_delete_txn(
- txn,
- table="e2e_one_time_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id,)
- )
- return self.runInteraction(
- "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
- )
|