Browse Source

Add type hints to the crypto module. (#8999)

Patrick Cloke 3 years ago
parent
commit
1c9a850562

+ 1 - 0
changelog.d/8999.misc

@@ -0,0 +1 @@
+Add type hints to the crypto module.

+ 2 - 0
mypy.ini

@@ -17,6 +17,7 @@ files =
   synapse/api,
   synapse/appservice,
   synapse/config,
+  synapse/crypto,
   synapse/event_auth.py,
   synapse/events/builder.py,
   synapse/events/validator.py,
@@ -75,6 +76,7 @@ files =
   synapse/storage/background_updates.py,
   synapse/storage/databases/main/appservice.py,
   synapse/storage/databases/main/events.py,
+  synapse/storage/databases/main/keys.py,
   synapse/storage/databases/main/pusher.py,
   synapse/storage/databases/main/registration.py,
   synapse/storage/databases/main/stream.py,

+ 1 - 1
synapse/crypto/context_factory.py

@@ -227,7 +227,7 @@ class ConnectionVerifier:
 
     # This code is based on twisted.internet.ssl.ClientTLSOptions.
 
-    def __init__(self, hostname: bytes, verify_certs):
+    def __init__(self, hostname: bytes, verify_certs: bool):
         self._verify_certs = verify_certs
 
         _decoded = hostname.decode("ascii")

+ 18 - 11
synapse/crypto/event_signing.py

@@ -18,7 +18,7 @@
 import collections.abc
 import hashlib
 import logging
-from typing import Dict
+from typing import Any, Callable, Dict, Tuple
 
 from canonicaljson import encode_canonical_json
 from signedjson.sign import sign_json
@@ -27,13 +27,18 @@ from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
+from synapse.events import EventBase
 from synapse.events.utils import prune_event, prune_event_dict
 from synapse.types import JsonDict
 
 logger = logging.getLogger(__name__)
 
+Hasher = Callable[[bytes], "hashlib._Hash"]
 
-def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
+
+def check_event_content_hash(
+    event: EventBase, hash_algorithm: Hasher = hashlib.sha256
+) -> bool:
     """Check whether the hash for this PDU matches the contents"""
     name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
     logger.debug(
@@ -67,18 +72,19 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
     return message_hash_bytes == expected_hash
 
 
-def compute_content_hash(event_dict, hash_algorithm):
+def compute_content_hash(
+    event_dict: Dict[str, Any], hash_algorithm: Hasher
+) -> Tuple[str, bytes]:
     """Compute the content hash of an event, which is the hash of the
     unredacted event.
 
     Args:
-        event_dict (dict): The unredacted event as a dict
+        event_dict: The unredacted event as a dict
         hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
             to hash the event
 
     Returns:
-        tuple[str, bytes]: A tuple of the name of hash and the hash as raw
-        bytes.
+        A tuple of the name of hash and the hash as raw bytes.
     """
     event_dict = dict(event_dict)
     event_dict.pop("age_ts", None)
@@ -94,18 +100,19 @@ def compute_content_hash(event_dict, hash_algorithm):
     return hashed.name, hashed.digest()
 
 
-def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
+def compute_event_reference_hash(
+    event, hash_algorithm: Hasher = hashlib.sha256
+) -> Tuple[str, bytes]:
     """Computes the event reference hash. This is the hash of the redacted
     event.
 
     Args:
-        event (FrozenEvent)
+        event
         hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
             to hash the event
 
     Returns:
-        tuple[str, bytes]: A tuple of the name of hash and the hash as raw
-        bytes.
+        A tuple of the name of hash and the hash as raw bytes.
     """
     tmp_event = prune_event(event)
     event_dict = tmp_event.get_pdu_json()
@@ -156,7 +163,7 @@ def add_hashes_and_signatures(
     event_dict: JsonDict,
     signature_name: str,
     signing_key: SigningKey,
-):
+) -> None:
     """Add content hash and sign the event
 
     Args:

+ 120 - 86
synapse/crypto/keyring.py

@@ -14,9 +14,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import abc
 import logging
 import urllib
 from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 import attr
 from signedjson.key import (
@@ -40,6 +42,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.config.key import TrustedKeyServer
 from synapse.logging.context import (
     PreserveLoggingContext,
     make_deferred_yieldable,
@@ -47,11 +50,15 @@ from synapse.logging.context import (
     run_in_background,
 )
 from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.metrics import Measure
 from synapse.util.retryutils import NotRetryingDestination
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -61,16 +68,17 @@ class VerifyJsonRequest:
     A request to verify a JSON object.
 
     Attributes:
-        server_name(str): The name of the server to verify against.
-
-        key_ids(set[str]): The set of key_ids to that could be used to verify the
-            JSON object
+        server_name: The name of the server to verify against.
 
-        json_object(dict): The JSON object to verify.
+        json_object: The JSON object to verify.
 
-        minimum_valid_until_ts (int): time at which we require the signing key to
+        minimum_valid_until_ts: time at which we require the signing key to
             be valid. (0 implies we don't care)
 
+        request_name: The name of the request.
+
+        key_ids: The set of key_ids to that could be used to verify the JSON object
+
         key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
             A deferred (server_name, key_id, verify_key) tuple that resolves when
             a verify key has been fetched. The deferreds' callbacks are run with no
@@ -80,12 +88,12 @@ class VerifyJsonRequest:
             errbacks with an M_UNAUTHORIZED SynapseError.
     """
 
-    server_name = attr.ib()
-    json_object = attr.ib()
-    minimum_valid_until_ts = attr.ib()
-    request_name = attr.ib()
-    key_ids = attr.ib(init=False)
-    key_ready = attr.ib(default=attr.Factory(defer.Deferred))
+    server_name = attr.ib(type=str)
+    json_object = attr.ib(type=JsonDict)
+    minimum_valid_until_ts = attr.ib(type=int)
+    request_name = attr.ib(type=str)
+    key_ids = attr.ib(init=False, type=List[str])
+    key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
 
     def __attrs_post_init__(self):
         self.key_ids = signature_ids(self.json_object, self.server_name)
@@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
 
 
 class Keyring:
-    def __init__(self, hs, key_fetchers=None):
+    def __init__(
+        self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
+    ):
         self.clock = hs.get_clock()
 
         if key_fetchers is None:
@@ -112,22 +122,26 @@ class Keyring:
         # completes.
         #
         # These are regular, logcontext-agnostic Deferreds.
-        self.key_downloads = {}
+        self.key_downloads = {}  # type: Dict[str, defer.Deferred]
 
     def verify_json_for_server(
-        self, server_name, json_object, validity_time, request_name
-    ):
+        self,
+        server_name: str,
+        json_object: JsonDict,
+        validity_time: int,
+        request_name: str,
+    ) -> defer.Deferred:
         """Verify that a JSON object has been signed by a given server
 
         Args:
-            server_name (str): name of the server which must have signed this object
+            server_name: name of the server which must have signed this object
 
-            json_object (dict): object to be checked
+            json_object: object to be checked
 
-            validity_time (int): timestamp at which we require the signing key to
+            validity_time: timestamp at which we require the signing key to
                 be valid. (0 implies we don't care)
 
-            request_name (str): an identifier for this json object (eg, an event id)
+            request_name: an identifier for this json object (eg, an event id)
                 for logging.
 
         Returns:
@@ -138,12 +152,14 @@ class Keyring:
         requests = (req,)
         return make_deferred_yieldable(self._verify_objects(requests)[0])
 
-    def verify_json_objects_for_server(self, server_and_json):
+    def verify_json_objects_for_server(
+        self, server_and_json: Iterable[Tuple[str, dict, int, str]]
+    ) -> List[defer.Deferred]:
         """Bulk verifies signatures of json objects, bulk fetching keys as
         necessary.
 
         Args:
-            server_and_json (iterable[Tuple[str, dict, int, str]):
+            server_and_json:
                 Iterable of (server_name, json_object, validity_time, request_name)
                 tuples.
 
@@ -164,13 +180,14 @@ class Keyring:
             for server_name, json_object, validity_time, request_name in server_and_json
         )
 
-    def _verify_objects(self, verify_requests):
+    def _verify_objects(
+        self, verify_requests: Iterable[VerifyJsonRequest]
+    ) -> List[defer.Deferred]:
         """Does the work of verify_json_[objects_]for_server
 
 
         Args:
-            verify_requests (iterable[VerifyJsonRequest]):
-                Iterable of verification requests.
+            verify_requests: Iterable of verification requests.
 
         Returns:
             List<Deferred[None]>: for each input item, a deferred indicating success
@@ -182,7 +199,7 @@ class Keyring:
         key_lookups = []
         handle = preserve_fn(_handle_key_deferred)
 
-        def process(verify_request):
+        def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
             """Process an entry in the request list
 
             Adds a key request to key_lookups, and returns a deferred which
@@ -222,18 +239,20 @@ class Keyring:
 
         return results
 
-    async def _start_key_lookups(self, verify_requests):
+    async def _start_key_lookups(
+        self, verify_requests: List[VerifyJsonRequest]
+    ) -> None:
         """Sets off the key fetches for each verify request
 
         Once each fetch completes, verify_request.key_ready will be resolved.
 
         Args:
-            verify_requests (List[VerifyJsonRequest]):
+            verify_requests:
         """
 
         try:
             # map from server name to a set of outstanding request ids
-            server_to_request_ids = {}
+            server_to_request_ids = {}  # type: Dict[str, Set[int]]
 
             for verify_request in verify_requests:
                 server_name = verify_request.server_name
@@ -275,11 +294,11 @@ class Keyring:
         except Exception:
             logger.exception("Error starting key lookups")
 
-    async def wait_for_previous_lookups(self, server_names) -> None:
+    async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
         """Waits for any previous key lookups for the given servers to finish.
 
         Args:
-            server_names (Iterable[str]): list of servers which we want to look up
+            server_names: list of servers which we want to look up
 
         Returns:
             Resolves once all key lookups for the given servers have
@@ -304,7 +323,7 @@ class Keyring:
 
             loop_count += 1
 
-    def _get_server_verify_keys(self, verify_requests):
+    def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
         """Tries to find at least one key for each verify request
 
         For each verify_request, verify_request.key_ready is called back with
@@ -312,7 +331,7 @@ class Keyring:
         with a SynapseError if none of the keys are found.
 
         Args:
-            verify_requests (list[VerifyJsonRequest]): list of verify requests
+            verify_requests: list of verify requests
         """
 
         remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@@ -366,17 +385,19 @@ class Keyring:
 
         run_in_background(do_iterations)
 
-    async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+    async def _attempt_key_fetches_with_fetcher(
+        self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
+    ):
         """Use a key fetcher to attempt to satisfy some key requests
 
         Args:
-            fetcher (KeyFetcher): fetcher to use to fetch the keys
-            remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
+            fetcher: fetcher to use to fetch the keys
+            remaining_requests: outstanding key requests.
                 Any successfully-completed requests will be removed from the list.
         """
-        # dict[str, dict[str, int]]: keys to fetch.
+        # The keys to fetch.
         # server_name -> key_id -> min_valid_ts
-        missing_keys = defaultdict(dict)
+        missing_keys = defaultdict(dict)  # type: Dict[str, Dict[str, int]]
 
         for verify_request in remaining_requests:
             # any completed requests should already have been removed
@@ -438,16 +459,18 @@ class Keyring:
         remaining_requests.difference_update(completed)
 
 
-class KeyFetcher:
-    async def get_keys(self, keys_to_fetch):
+class KeyFetcher(metaclass=abc.ABCMeta):
+    @abc.abstractmethod
+    async def get_keys(
+        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
-            keys_to_fetch (dict[str, dict[str, int]]):
+            keys_to_fetch:
                 the keys to be fetched. server_name -> key_id -> min_valid_ts
 
         Returns:
-            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
-                map from server_name -> key_id -> FetchKeyResult
+            Map from server_name -> key_id -> FetchKeyResult
         """
         raise NotImplementedError
 
@@ -455,31 +478,35 @@ class KeyFetcher:
 class StoreKeyFetcher(KeyFetcher):
     """KeyFetcher impl which fetches keys from our data store"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-    async def get_keys(self, keys_to_fetch):
+    async def get_keys(
+        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """see KeyFetcher.get_keys"""
 
-        keys_to_fetch = (
+        key_ids_to_fetch = (
             (server_name, key_id)
             for server_name, keys_for_server in keys_to_fetch.items()
             for key_id in keys_for_server.keys()
         )
 
-        res = await self.store.get_server_verify_keys(keys_to_fetch)
-        keys = {}
+        res = await self.store.get_server_verify_keys(key_ids_to_fetch)
+        keys = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
         for (server_name, key_id), key in res.items():
             keys.setdefault(server_name, {})[key_id] = key
         return keys
 
 
-class BaseV2KeyFetcher:
-    def __init__(self, hs):
+class BaseV2KeyFetcher(KeyFetcher):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.config = hs.get_config()
 
-    async def process_v2_response(self, from_server, response_json, time_added_ms):
+    async def process_v2_response(
+        self, from_server: str, response_json: JsonDict, time_added_ms: int
+    ) -> Dict[str, FetchKeyResult]:
         """Parse a 'Server Keys' structure from the result of a /key request
 
         This is used to parse either the entirety of the response from
@@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
         to /_matrix/key/v2/query.
 
         Args:
-            from_server (str): the name of the server producing this result: either
+            from_server: the name of the server producing this result: either
                 the origin server for a /_matrix/key/v2/server request, or the notary
                 for a /_matrix/key/v2/query.
 
-            response_json (dict): the json-decoded Server Keys response object
+            response_json: the json-decoded Server Keys response object
 
-            time_added_ms (int): the timestamp to record in server_keys_json
+            time_added_ms: the timestamp to record in server_keys_json
 
         Returns:
-            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+            Map from key_id to result object
         """
         ts_valid_until_ms = response_json["valid_until_ts"]
 
@@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
 class PerspectivesKeyFetcher(BaseV2KeyFetcher):
     """KeyFetcher impl which fetches keys from the "perspectives" servers"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.clock = hs.get_clock()
         self.client = hs.get_federation_http_client()
         self.key_servers = self.config.key_servers
 
-    async def get_keys(self, keys_to_fetch):
+    async def get_keys(
+        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """see KeyFetcher.get_keys"""
 
-        async def get_key(key_server):
+        async def get_key(key_server: TrustedKeyServer) -> Dict:
             try:
-                result = await self.get_server_verify_key_v2_indirect(
+                return await self.get_server_verify_key_v2_indirect(
                     keys_to_fetch, key_server
                 )
-                return result
             except KeyLookupError as e:
                 logger.warning(
                     "Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             ).addErrback(unwrapFirstError)
         )
 
-        union_of_keys = {}
+        union_of_keys = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
         for result in results:
             for server_name, keys in result.items():
                 union_of_keys.setdefault(server_name, {}).update(keys)
 
         return union_of_keys
 
-    async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
+    async def get_server_verify_key_v2_indirect(
+        self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
-            keys_to_fetch (dict[str, dict[str, int]]):
+            keys_to_fetch:
                 the keys to be fetched. server_name -> key_id -> min_valid_ts
 
-            key_server (synapse.config.key.TrustedKeyServer): notary server to query for
-                the keys
+            key_server: notary server to query for the keys
 
         Returns:
-            dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
-                from server_name -> key_id -> FetchKeyResult
+            Map from server_name -> key_id -> FetchKeyResult
 
         Raises:
             KeyLookupError if there was an error processing the entire response from
@@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         except HttpResponseException as e:
             raise KeyLookupError("Remote server returned an error: %s" % (e,))
 
-        keys = {}
-        added_keys = []
+        keys = {}  # type: Dict[str, Dict[str, FetchKeyResult]]
+        added_keys = []  # type: List[Tuple[str, str, FetchKeyResult]]
 
         time_now_ms = self.clock.time_msec()
 
+        assert isinstance(query_response, dict)
         for response in query_response["server_keys"]:
             # do this first, so that we can give useful errors thereafter
             server_name = response.get("server_name")
@@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 
         return keys
 
-    def _validate_perspectives_response(self, key_server, response):
+    def _validate_perspectives_response(
+        self, key_server: TrustedKeyServer, response: JsonDict
+    ) -> None:
         """Optionally check the signature on the result of a /key/query request
 
         Args:
-            key_server (synapse.config.key.TrustedKeyServer): the notary server that
-                produced this result
+            key_server: the notary server that produced this result
 
-            response (dict): the json-decoded Server Keys response object
+            response: the json-decoded Server Keys response object
         """
         perspective_name = key_server.server_name
         perspective_keys = key_server.verify_keys
@@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 class ServerKeyFetcher(BaseV2KeyFetcher):
     """KeyFetcher impl which fetches keys from the origin servers"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.clock = hs.get_clock()
         self.client = hs.get_federation_http_client()
 
-    async def get_keys(self, keys_to_fetch):
+    async def get_keys(
+        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
-            keys_to_fetch (dict[str, iterable[str]]):
+            keys_to_fetch:
                 the keys to be fetched. server_name -> key_ids
 
         Returns:
-            dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
-                map from server_name -> key_id -> FetchKeyResult
+            Map from server_name -> key_id -> FetchKeyResult
         """
 
         results = {}
 
-        async def get_key(key_to_fetch_item):
+        async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
             server_name, key_ids = key_to_fetch_item
             try:
                 keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
@@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         await yieldable_gather_results(get_key, keys_to_fetch.items())
         return results
 
-    async def get_server_verify_key_v2_direct(self, server_name, key_ids):
+    async def get_server_verify_key_v2_direct(
+        self, server_name: str, key_ids: Iterable[str]
+    ) -> Dict[str, FetchKeyResult]:
         """
 
         Args:
-            server_name (str):
-            key_ids (iterable[str]):
+            server_name:
+            key_ids:
 
         Returns:
-            dict[str, FetchKeyResult]: map from key ID to lookup result
+            Map from key ID to lookup result
 
         Raises:
             KeyLookupError if there was a problem making the lookup
         """
-        keys = {}  # type: dict[str, FetchKeyResult]
+        keys = {}  # type: Dict[str, FetchKeyResult]
 
         for requested_key_id in key_ids:
             # we may have found this key as a side-effect of asking for another.
@@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             except HttpResponseException as e:
                 raise KeyLookupError("Remote server returned an error: %s" % (e,))
 
+            assert isinstance(response, dict)
             if response["server_name"] != server_name:
                 raise KeyLookupError(
                     "Expected a response for server %r not %r"
@@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         return keys
 
 
-async def _handle_key_deferred(verify_request) -> None:
+async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
     """Waits for the key to become available, and then performs a verification
 
     Args:
-        verify_request (VerifyJsonRequest):
+        verify_request:
 
     Raises:
         SynapseError if there was a problem performing the verification

+ 1 - 1
synapse/federation/transport/server.py

@@ -144,7 +144,7 @@ class Authenticator:
         ):
             raise FederationDeniedError(origin)
 
-        if not json_request["signatures"]:
+        if origin is None or not json_request["signatures"]:
             raise NoAuthenticationError(
                 401, "Missing Authorization headers", Codes.UNAUTHORIZED
             )

+ 5 - 4
synapse/rest/key/v2/remote_key_resource.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, Set
+from typing import Dict
 
 from signedjson.sign import sign_json
 
@@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
 
         time_now_ms = self.clock.time_msec()
 
-        cache_misses = {}  # type: Dict[str, Set[str]]
+        # Note that the value is unused.
+        cache_misses = {}  # type: Dict[str, Dict[str, int]]
         for (server_name, key_id, from_server), results in cached.items():
             results = [(result["ts_added_ms"], result) for result in results]
 
             if not results and key_id is not None:
-                cache_misses.setdefault(server_name, set()).add(key_id)
+                cache_misses.setdefault(server_name, {})[key_id] = 0
                 continue
 
             if key_id is not None:
@@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
                     )
 
                 if miss:
-                    cache_misses.setdefault(server_name, set()).add(key_id)
+                    cache_misses.setdefault(server_name, {})[key_id] = 0
                 # Cast to bytes since postgresql returns a memoryview.
                 json_results.add(bytes(most_recent_result["key_json"]))
             else:

+ 5 - 5
synapse/storage/databases/main/keys.py

@@ -22,6 +22,7 @@ from signedjson.key import decode_verify_key_bytes
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.keys import FetchKeyResult
+from synapse.storage.types import Cursor
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
 
@@ -44,7 +45,7 @@ class KeyStore(SQLBaseStore):
     )
     async def get_server_verify_keys(
         self, server_name_and_key_ids: Iterable[Tuple[str, str]]
-    ) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
+    ) -> Dict[Tuple[str, str], FetchKeyResult]:
         """
         Args:
             server_name_and_key_ids:
@@ -56,7 +57,7 @@ class KeyStore(SQLBaseStore):
         """
         keys = {}
 
-        def _get_keys(txn, batch):
+        def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str]]) -> None:
             """Processes a batch of keys to fetch, and adds the result to `keys`."""
 
             # batch_iter always returns tuples so it's safe to do len(batch)
@@ -77,13 +78,12 @@ class KeyStore(SQLBaseStore):
                     # `ts_valid_until_ms`.
                     ts_valid_until_ms = 0
 
-                res = FetchKeyResult(
+                keys[(server_name, key_id)] = FetchKeyResult(
                     verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
                     valid_until_ts=ts_valid_until_ms,
                 )
-                keys[(server_name, key_id)] = res
 
-        def _txn(txn):
+        def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
             for batch in batch_iter(server_name_and_key_ids, 50):
                 _get_keys(txn, batch)
             return keys

+ 5 - 5
tests/crypto/test_keyring.py

@@ -75,7 +75,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         return val
 
     def test_verify_json_objects_for_server_awaits_previous_requests(self):
-        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock()
         kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
 
@@ -195,7 +195,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """Tests that we correctly handle key requests for keys we've stored
         with a null `ts_valid_until_ms`
         """
-        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
 
         kr = keyring.Keyring(
@@ -249,7 +249,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 }
             }
 
-        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock(side_effect=get_keys)
         kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
 
@@ -288,9 +288,9 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 }
             }
 
-        mock_fetcher1 = keyring.KeyFetcher()
+        mock_fetcher1 = Mock()
         mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
-        mock_fetcher2 = keyring.KeyFetcher()
+        mock_fetcher2 = Mock()
         mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
         kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))