|
@@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
|
|
|
from synapse.util.retryutils import get_retry_limiter
|
|
|
from synapse.util import unwrapFirstError
|
|
|
from synapse.util.async import ObservableDeferred
|
|
|
+from synapse.util.logcontext import (
|
|
|
+ preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
|
|
|
+ preserve_fn
|
|
|
+)
|
|
|
|
|
|
from twisted.internet import defer
|
|
|
|
|
@@ -142,40 +146,43 @@ class Keyring(object):
|
|
|
for server_name, _ in server_and_json
|
|
|
}
|
|
|
|
|
|
- # We want to wait for any previous lookups to complete before
|
|
|
- # proceeding.
|
|
|
- wait_on_deferred = self.wait_for_previous_lookups(
|
|
|
- [server_name for server_name, _ in server_and_json],
|
|
|
- server_to_deferred,
|
|
|
- )
|
|
|
+ with PreserveLoggingContext():
|
|
|
|
|
|
- # Actually start fetching keys.
|
|
|
- wait_on_deferred.addBoth(
|
|
|
- lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
|
|
- )
|
|
|
+ # We want to wait for any previous lookups to complete before
|
|
|
+ # proceeding.
|
|
|
+ wait_on_deferred = self.wait_for_previous_lookups(
|
|
|
+ [server_name for server_name, _ in server_and_json],
|
|
|
+ server_to_deferred,
|
|
|
+ )
|
|
|
|
|
|
- # When we've finished fetching all the keys for a given server_name,
|
|
|
- # resolve the deferred passed to `wait_for_previous_lookups` so that
|
|
|
- # any lookups waiting will proceed.
|
|
|
- server_to_gids = {}
|
|
|
+ # Actually start fetching keys.
|
|
|
+ wait_on_deferred.addBoth(
|
|
|
+ lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
|
|
+ )
|
|
|
+
|
|
|
+ # When we've finished fetching all the keys for a given server_name,
|
|
|
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
|
|
|
+ # any lookups waiting will proceed.
|
|
|
+ server_to_gids = {}
|
|
|
|
|
|
- def remove_deferreds(res, server_name, group_id):
|
|
|
- server_to_gids[server_name].discard(group_id)
|
|
|
- if not server_to_gids[server_name]:
|
|
|
- d = server_to_deferred.pop(server_name, None)
|
|
|
- if d:
|
|
|
- d.callback(None)
|
|
|
- return res
|
|
|
+ def remove_deferreds(res, server_name, group_id):
|
|
|
+ server_to_gids[server_name].discard(group_id)
|
|
|
+ if not server_to_gids[server_name]:
|
|
|
+ d = server_to_deferred.pop(server_name, None)
|
|
|
+ if d:
|
|
|
+ d.callback(None)
|
|
|
+ return res
|
|
|
|
|
|
- for g_id, deferred in deferreds.items():
|
|
|
- server_name = group_id_to_group[g_id].server_name
|
|
|
- server_to_gids.setdefault(server_name, set()).add(g_id)
|
|
|
- deferred.addBoth(remove_deferreds, server_name, g_id)
|
|
|
+ for g_id, deferred in deferreds.items():
|
|
|
+ server_name = group_id_to_group[g_id].server_name
|
|
|
+ server_to_gids.setdefault(server_name, set()).add(g_id)
|
|
|
+ deferred.addBoth(remove_deferreds, server_name, g_id)
|
|
|
|
|
|
# Pass those keys to handle_key_deferred so that the json object
|
|
|
# signatures can be verified
|
|
|
return [
|
|
|
- handle_key_deferred(
|
|
|
+ preserve_context_over_fn(
|
|
|
+ handle_key_deferred,
|
|
|
group_id_to_group[g_id],
|
|
|
deferreds[g_id],
|
|
|
)
|
|
@@ -198,12 +205,13 @@ class Keyring(object):
|
|
|
if server_name in self.key_downloads
|
|
|
]
|
|
|
if wait_on:
|
|
|
- yield defer.DeferredList(wait_on)
|
|
|
+ with PreserveLoggingContext():
|
|
|
+ yield defer.DeferredList(wait_on)
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
for server_name, deferred in server_to_deferred.items():
|
|
|
- d = ObservableDeferred(deferred)
|
|
|
+ d = ObservableDeferred(preserve_context_over_deferred(deferred))
|
|
|
self.key_downloads[server_name] = d
|
|
|
|
|
|
def rm(r, server_name):
|
|
@@ -244,12 +252,13 @@ class Keyring(object):
|
|
|
for group in group_id_to_group.values():
|
|
|
for key_id in group.key_ids:
|
|
|
if key_id in merged_results[group.server_name]:
|
|
|
- group_id_to_deferred[group.group_id].callback((
|
|
|
- group.group_id,
|
|
|
- group.server_name,
|
|
|
- key_id,
|
|
|
- merged_results[group.server_name][key_id],
|
|
|
- ))
|
|
|
+ with PreserveLoggingContext():
|
|
|
+ group_id_to_deferred[group.group_id].callback((
|
|
|
+ group.group_id,
|
|
|
+ group.server_name,
|
|
|
+ key_id,
|
|
|
+ merged_results[group.server_name][key_id],
|
|
|
+ ))
|
|
|
break
|
|
|
else:
|
|
|
missing_groups.setdefault(
|
|
@@ -504,7 +513,7 @@ class Keyring(object):
|
|
|
|
|
|
yield defer.gatherResults(
|
|
|
[
|
|
|
- self.store_keys(
|
|
|
+ preserve_fn(self.store_keys)(
|
|
|
server_name=key_server_name,
|
|
|
from_server=server_name,
|
|
|
verify_keys=verify_keys,
|
|
@@ -573,7 +582,7 @@ class Keyring(object):
|
|
|
|
|
|
yield defer.gatherResults(
|
|
|
[
|
|
|
- self.store.store_server_keys_json(
|
|
|
+ preserve_fn(self.store.store_server_keys_json)(
|
|
|
server_name=server_name,
|
|
|
key_id=key_id,
|
|
|
from_server=server_name,
|
|
@@ -675,7 +684,7 @@ class Keyring(object):
|
|
|
# TODO(markjh): Store whether the keys have expired.
|
|
|
yield defer.gatherResults(
|
|
|
[
|
|
|
- self.store.store_server_verify_key(
|
|
|
+ preserve_fn(self.store.store_server_verify_key)(
|
|
|
server_name, server_name, key.time_added, key
|
|
|
)
|
|
|
for key_id, key in verify_keys.items()
|