Browse Source

Refactor storing of server keys (#16261)

Erik Johnston 7 months ago
parent
commit
2b35626b6b

+ 1 - 0
changelog.d/16261.misc

@@ -0,0 +1 @@
+Simplify server key storage.

+ 7 - 28
synapse/crypto/keyring.py

@@ -23,12 +23,7 @@ from signedjson.key import (
     get_verify_key,
     is_signing_algorithm_supported,
 )
-from signedjson.sign import (
-    SignatureVerifyException,
-    encode_canonical_json,
-    signature_ids,
-    verify_signed_json,
-)
+from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json
 from signedjson.types import VerifyKey
 from unpaddedbase64 import decode_base64
 
@@ -596,24 +591,12 @@ class BaseV2KeyFetcher(KeyFetcher):
                     verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
                 )
 
-        key_json_bytes = encode_canonical_json(response_json)
-
-        await make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self.store.store_server_keys_json,
-                        server_name=server_name,
-                        key_id=key_id,
-                        from_server=from_server,
-                        ts_now_ms=time_added_ms,
-                        ts_expires_ms=ts_valid_until_ms,
-                        key_json_bytes=key_json_bytes,
-                    )
-                    for key_id in verify_keys
-                ],
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError)
+        await self.store.store_server_keys_response(
+            server_name=server_name,
+            from_server=from_server,
+            ts_added_ms=time_added_ms,
+            verify_keys=verify_keys,
+            response_json=response_json,
         )
 
         return verify_keys
@@ -775,10 +758,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 
             keys.setdefault(server_name, {}).update(processed_response)
 
-        await self.store.store_server_signature_keys(
-            perspective_name, time_now_ms, added_keys
-        )
-
         return keys
 
     def _validate_perspectives_response(

+ 72 - 147
synapse/storage/databases/main/keys.py

@@ -16,14 +16,17 @@
 import itertools
 import json
 import logging
-from typing import Dict, Iterable, Mapping, Optional, Tuple
+from typing import Dict, Iterable, Optional, Tuple
 
+from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
 from unpaddedbase64 import decode_base64
 
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
 from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
 
@@ -36,162 +39,84 @@ db_binary_type = memoryview
 class KeyStore(CacheInvalidationWorkerStore):
     """Persistence for signature verification keys"""
 
-    @cached()
-    def _get_server_signature_key(
-        self, server_name_and_key_id: Tuple[str, str]
-    ) -> FetchKeyResult:
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="_get_server_signature_key",
-        list_name="server_name_and_key_ids",
-    )
-    async def get_server_signature_keys(
-        self, server_name_and_key_ids: Iterable[Tuple[str, str]]
-    ) -> Dict[Tuple[str, str], FetchKeyResult]:
-        """
-        Args:
-            server_name_and_key_ids:
-                iterable of (server_name, key-id) tuples to fetch keys for
-
-        Returns:
-            A map from (server_name, key_id) -> FetchKeyResult, or None if the
-            key is unknown
-        """
-        keys = {}
-
-        def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
-            """Processes a batch of keys to fetch, and adds the result to `keys`."""
-
-            # batch_iter always returns tuples so it's safe to do len(batch)
-            sql = """
-            SELECT server_name, key_id, verify_key, ts_valid_until_ms
-            FROM server_signature_keys WHERE 1=0
-            """ + " OR (server_name=? AND key_id=?)" * len(
-                batch
-            )
-
-            txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
-
-            for row in txn:
-                server_name, key_id, key_bytes, ts_valid_until_ms = row
-
-                if ts_valid_until_ms is None:
-                    # Old keys may be stored with a ts_valid_until_ms of null,
-                    # in which case we treat this as if it was set to `0`, i.e.
-                    # it won't match key requests that define a minimum
-                    # `ts_valid_until_ms`.
-                    ts_valid_until_ms = 0
-
-                keys[(server_name, key_id)] = FetchKeyResult(
-                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
-                    valid_until_ts=ts_valid_until_ms,
-                )
-
-        def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
-            for batch in batch_iter(server_name_and_key_ids, 50):
-                _get_keys(txn, batch)
-            return keys
-
-        return await self.db_pool.runInteraction("get_server_signature_keys", _txn)
-
-    async def store_server_signature_keys(
+    async def store_server_keys_response(
         self,
+        server_name: str,
         from_server: str,
         ts_added_ms: int,
-        verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
+        verify_keys: Dict[str, FetchKeyResult],
+        response_json: JsonDict,
     ) -> None:
-        """Stores NACL verification keys for remote servers.
+        """Stores the keys for the given server that we got from `from_server`.
+
         Args:
-            from_server: Where the verification keys were looked up
-            ts_added_ms: The time to record that the key was added
-            verify_keys:
-                keys to be stored. Each entry is a triplet of
-                (server_name, key_id, key).
+            server_name: The owner of the keys
+            from_server: Which server we got the keys from
+            ts_added_ms: When we're adding the keys
+            verify_keys: The decoded keys
+            response_json: The full *signed* response JSON that contains the keys.
         """
-        key_values = []
-        value_values = []
-        invalidations = []
-        for (server_name, key_id), fetch_result in verify_keys.items():
-            key_values.append((server_name, key_id))
-            value_values.append(
-                (
-                    from_server,
-                    ts_added_ms,
-                    fetch_result.valid_until_ts,
-                    db_binary_type(fetch_result.verify_key.encode()),
-                )
-            )
-            # invalidate takes a tuple corresponding to the params of
-            # _get_server_signature_key. _get_server_signature_key only takes one
-            # param, which is itself the 2-tuple (server_name, key_id).
-            invalidations.append((server_name, key_id))
 
-        await self.db_pool.simple_upsert_many(
-            table="server_signature_keys",
-            key_names=("server_name", "key_id"),
-            key_values=key_values,
-            value_names=(
-                "from_server",
-                "ts_added_ms",
-                "ts_valid_until_ms",
-                "verify_key",
-            ),
-            value_values=value_values,
-            desc="store_server_signature_keys",
-        )
+        key_json_bytes = encode_canonical_json(response_json)
+
+        def store_server_keys_response_txn(txn: LoggingTransaction) -> None:
+            self.db_pool.simple_upsert_many_txn(
+                txn,
+                table="server_signature_keys",
+                key_names=("server_name", "key_id"),
+                key_values=[(server_name, key_id) for key_id in verify_keys],
+                value_names=(
+                    "from_server",
+                    "ts_added_ms",
+                    "ts_valid_until_ms",
+                    "verify_key",
+                ),
+                value_values=[
+                    (
+                        from_server,
+                        ts_added_ms,
+                        fetch_result.valid_until_ts,
+                        db_binary_type(fetch_result.verify_key.encode()),
+                    )
+                    for fetch_result in verify_keys.values()
+                ],
+            )
 
-        invalidate = self._get_server_signature_key.invalidate
-        for i in invalidations:
-            invalidate((i,))
+            self.db_pool.simple_upsert_many_txn(
+                txn,
+                table="server_keys_json",
+                key_names=("server_name", "key_id", "from_server"),
+                key_values=[
+                    (server_name, key_id, from_server) for key_id in verify_keys
+                ],
+                value_names=(
+                    "ts_added_ms",
+                    "ts_valid_until_ms",
+                    "key_json",
+                ),
+                value_values=[
+                    (
+                        ts_added_ms,
+                        fetch_result.valid_until_ts,
+                        db_binary_type(key_json_bytes),
+                    )
+                    for fetch_result in verify_keys.values()
+                ],
+            )
 
-    async def store_server_keys_json(
-        self,
-        server_name: str,
-        key_id: str,
-        from_server: str,
-        ts_now_ms: int,
-        ts_expires_ms: int,
-        key_json_bytes: bytes,
-    ) -> None:
-        """Stores the JSON bytes for a set of keys from a server
-        The JSON should be signed by the originating server, the intermediate
-        server, and by this server. Updates the value for the
-        (server_name, key_id, from_server) triplet if one already existed.
-        Args:
-            server_name: The name of the server.
-            key_id: The identifier of the key this JSON is for.
-            from_server: The server this JSON was fetched from.
-            ts_now_ms: The time now in milliseconds.
-            ts_valid_until_ms: The time when this json stops being valid.
-            key_json_bytes: The encoded JSON.
-        """
-        await self.db_pool.simple_upsert(
-            table="server_keys_json",
-            keyvalues={
-                "server_name": server_name,
-                "key_id": key_id,
-                "from_server": from_server,
-            },
-            values={
-                "server_name": server_name,
-                "key_id": key_id,
-                "from_server": from_server,
-                "ts_added_ms": ts_now_ms,
-                "ts_valid_until_ms": ts_expires_ms,
-                "key_json": db_binary_type(key_json_bytes),
-            },
-            desc="store_server_keys_json",
-        )
+            # invalidate takes a tuple corresponding to the params of
+            # _get_server_keys_json. _get_server_keys_json only takes one
+            # param, which is itself the 2-tuple (server_name, key_id).
+            for key_id in verify_keys:
+                self._invalidate_cache_and_stream(
+                    txn, self._get_server_keys_json, ((server_name, key_id),)
+                )
+                self._invalidate_cache_and_stream(
+                    txn, self.get_server_key_json_for_remote, (server_name, key_id)
+                )
 
-        # invalidate takes a tuple corresponding to the params of
-        # _get_server_keys_json. _get_server_keys_json only takes one
-        # param, which is itself the 2-tuple (server_name, key_id).
-        await self.invalidate_cache_and_stream(
-            "_get_server_keys_json", ((server_name, key_id),)
-        )
-        await self.invalidate_cache_and_stream(
-            "get_server_key_json_for_remote", (server_name, key_id)
+        await self.db_pool.runInteraction(
+            "store_server_keys_response", store_server_keys_response_txn
         )
 
     @cached()

+ 13 - 40
tests/crypto/test_keyring.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 import time
 from typing import Any, Dict, List, Optional, cast
-from unittest.mock import AsyncMock, Mock
+from unittest.mock import Mock
 
 import attr
 import canonicaljson
@@ -189,23 +189,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         kr = keyring.Keyring(self.hs)
 
         key1 = signedjson.key.generate_signing_key("1")
-        r = self.hs.get_datastores().main.store_server_keys_json(
+        r = self.hs.get_datastores().main.store_server_keys_response(
             "server9",
-            get_key_id(key1),
             from_server="test",
-            ts_now_ms=int(time.time() * 1000),
-            ts_expires_ms=1000,
+            ts_added_ms=int(time.time() * 1000),
+            verify_keys={
+                get_key_id(key1): FetchKeyResult(
+                    verify_key=get_verify_key(key1), valid_until_ts=1000
+                )
+            },
             # The entire response gets signed & stored, just include the bits we
             # care about.
-            key_json_bytes=canonicaljson.encode_canonical_json(
-                {
-                    "verify_keys": {
-                        get_key_id(key1): {
-                            "key": encode_verify_key_base64(get_verify_key(key1))
-                        }
+            response_json={
+                "verify_keys": {
+                    get_key_id(key1): {
+                        "key": encode_verify_key_base64(get_verify_key(key1))
                     }
                 }
-            ),
+            },
         )
         self.get_success(r)
 
@@ -285,34 +286,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
         self.get_success(d)
 
-    def test_verify_json_for_server_with_null_valid_until_ms(self) -> None:
-        """Tests that we correctly handle key requests for keys we've stored
-        with a null `ts_valid_until_ms`
-        """
-        mock_fetcher = Mock()
-        mock_fetcher.get_keys = AsyncMock(return_value={})
-
-        key1 = signedjson.key.generate_signing_key("1")
-        r = self.hs.get_datastores().main.store_server_signature_keys(
-            "server9",
-            int(time.time() * 1000),
-            # None is not a valid value in FetchKeyResult, but we're abusing this
-            # API to insert null values into the database. The nulls get converted
-            # to 0 when fetched in KeyStore.get_server_signature_keys.
-            {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)},  # type: ignore[arg-type]
-        )
-        self.get_success(r)
-
-        json1: JsonDict = {}
-        signedjson.sign.sign_json(json1, "server9", key1)
-
-        # should succeed on a signed object with a 0 minimum_valid_until_ms
-        d = self.hs.get_datastores().main.get_server_signature_keys(
-            [("server9", get_key_id(key1))]
-        )
-        result = self.get_success(d)
-        self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0)
-
     def test_verify_json_dedupes_key_requests(self) -> None:
         """Two requests for the same key should be deduped."""
         key1 = signedjson.key.generate_signing_key("1")

+ 0 - 137
tests/storage/test_keys.py

@@ -1,137 +0,0 @@
-# Copyright 2017 Vector Creations Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import signedjson.key
-import signedjson.types
-import unpaddedbase64
-
-from synapse.storage.keys import FetchKeyResult
-
-import tests.unittest
-
-
-def decode_verify_key_base64(
-    key_id: str, key_base64: str
-) -> signedjson.types.VerifyKey:
-    key_bytes = unpaddedbase64.decode_base64(key_base64)
-    return signedjson.key.decode_verify_key_bytes(key_id, key_bytes)
-
-
-KEY_1 = decode_verify_key_base64(
-    "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw"
-)
-KEY_2 = decode_verify_key_base64(
-    "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
-)
-
-
-class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
-    def test_get_server_signature_keys(self) -> None:
-        store = self.hs.get_datastores().main
-
-        key_id_1 = "ed25519:key1"
-        key_id_2 = "ed25519:KEY_ID_2"
-        self.get_success(
-            store.store_server_signature_keys(
-                "from_server",
-                10,
-                {
-                    ("server1", key_id_1): FetchKeyResult(KEY_1, 100),
-                    ("server1", key_id_2): FetchKeyResult(KEY_2, 200),
-                },
-            )
-        )
-
-        res = self.get_success(
-            store.get_server_signature_keys(
-                [
-                    ("server1", key_id_1),
-                    ("server1", key_id_2),
-                    ("server1", "ed25519:key3"),
-                ]
-            )
-        )
-
-        self.assertEqual(len(res.keys()), 3)
-        res1 = res[("server1", key_id_1)]
-        self.assertEqual(res1.verify_key, KEY_1)
-        self.assertEqual(res1.verify_key.version, "key1")
-        self.assertEqual(res1.valid_until_ts, 100)
-
-        res2 = res[("server1", key_id_2)]
-        self.assertEqual(res2.verify_key, KEY_2)
-        # version comes from the ID it was stored with
-        self.assertEqual(res2.verify_key.version, "KEY_ID_2")
-        self.assertEqual(res2.valid_until_ts, 200)
-
-        # non-existent result gives None
-        self.assertIsNone(res[("server1", "ed25519:key3")])
-
-    def test_cache(self) -> None:
-        """Check that updates correctly invalidate the cache."""
-
-        store = self.hs.get_datastores().main
-
-        key_id_1 = "ed25519:key1"
-        key_id_2 = "ed25519:key2"
-
-        self.get_success(
-            store.store_server_signature_keys(
-                "from_server",
-                0,
-                {
-                    ("srv1", key_id_1): FetchKeyResult(KEY_1, 100),
-                    ("srv1", key_id_2): FetchKeyResult(KEY_2, 200),
-                },
-            )
-        )
-
-        res = self.get_success(
-            store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
-        )
-        self.assertEqual(len(res.keys()), 2)
-
-        res1 = res[("srv1", key_id_1)]
-        self.assertEqual(res1.verify_key, KEY_1)
-        self.assertEqual(res1.valid_until_ts, 100)
-
-        res2 = res[("srv1", key_id_2)]
-        self.assertEqual(res2.verify_key, KEY_2)
-        self.assertEqual(res2.valid_until_ts, 200)
-
-        # we should be able to look up the same thing again without a db hit
-        res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)]))
-        self.assertEqual(len(res.keys()), 1)
-        self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
-
-        new_key_2 = signedjson.key.get_verify_key(
-            signedjson.key.generate_signing_key("key2")
-        )
-        d = store.store_server_signature_keys(
-            "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
-        )
-        self.get_success(d)
-
-        res = self.get_success(
-            store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)])
-        )
-        self.assertEqual(len(res.keys()), 2)
-
-        res1 = res[("srv1", key_id_1)]
-        self.assertEqual(res1.verify_key, KEY_1)
-        self.assertEqual(res1.valid_until_ts, 100)
-
-        res2 = res[("srv1", key_id_2)]
-        self.assertEqual(res2.verify_key, new_key_2)
-        self.assertEqual(res2.valid_until_ts, 300)

+ 13 - 13
tests/unittest.py

@@ -70,6 +70,7 @@ from synapse.logging.context import (
 )
 from synapse.rest import RegisterServletsFunc
 from synapse.server import HomeServer
+from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict, Requester, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
@@ -858,23 +859,22 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
         verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
 
         self.get_success(
-            hs.get_datastores().main.store_server_keys_json(
+            hs.get_datastores().main.store_server_keys_response(
                 self.OTHER_SERVER_NAME,
-                verify_key_id,
                 from_server=self.OTHER_SERVER_NAME,
-                ts_now_ms=clock.time_msec(),
-                ts_expires_ms=clock.time_msec() + 10000,
-                key_json_bytes=canonicaljson.encode_canonical_json(
-                    {
-                        "verify_keys": {
-                            verify_key_id: {
-                                "key": signedjson.key.encode_verify_key_base64(
-                                    verify_key
-                                )
-                            }
+                ts_added_ms=clock.time_msec(),
+                verify_keys={
+                    verify_key_id: FetchKeyResult(
+                        verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000
+                    ),
+                },
+                response_json={
+                    "verify_keys": {
+                        verify_key_id: {
+                            "key": signedjson.key.encode_verify_key_base64(verify_key)
                         }
                     }
-                ),
+                },
             )
         )