|
@@ -16,25 +16,29 @@
|
|
|
import binascii
|
|
|
import json
|
|
|
import logging
|
|
|
-from typing import TYPE_CHECKING, Any, Dict, Optional
|
|
|
+from abc import abstractmethod
|
|
|
+from typing import TYPE_CHECKING, Dict, Generic, Optional, Sequence, TypeVar
|
|
|
|
|
|
import signedjson.key
|
|
|
import signedjson.sign
|
|
|
from twisted.internet import defer
|
|
|
from twisted.internet.defer import Deferred
|
|
|
+from twisted.python.failure import Failure
|
|
|
from twisted.web.client import readBody
|
|
|
from twisted.web.iweb import IResponse
|
|
|
from unpaddedbase64 import decode_base64
|
|
|
|
|
|
from sydent.config import ConfigError
|
|
|
from sydent.db.hashing_metadata import HashingMetadataStore
|
|
|
-from sydent.db.threepid_associations import GlobalAssociationStore
|
|
|
+from sydent.db.threepid_associations import GlobalAssociationStore, SignedAssociations
|
|
|
from sydent.threepid import threePidAssocFromDict
|
|
|
from sydent.types import JsonDict
|
|
|
from sydent.util import json_decoder
|
|
|
from sydent.util.hash import sha256_and_url_safe_base64
|
|
|
from sydent.util.stringutils import normalise_address
|
|
|
|
|
|
+PushUpdateReturn = TypeVar("PushUpdateReturn")
|
|
|
+
|
|
|
if TYPE_CHECKING:
|
|
|
from sydent.sydent import Sydent
|
|
|
|
|
@@ -43,21 +47,27 @@ logger = logging.getLogger(__name__)
|
|
|
SIGNING_KEY_ALGORITHM = "ed25519"
|
|
|
|
|
|
|
|
|
-class Peer:
|
|
|
- def __init__(self, servername, pubkeys):
|
|
|
+class Peer(Generic[PushUpdateReturn]):
|
|
|
+ def __init__(self, servername: str, pubkeys: Dict[str, str]):
|
|
|
+ """
|
|
|
+ :param server_name: The peer's server name.
|
|
|
+ :param pubkeys: The peer's public keys in a Dict[key_id, key_b64]
|
|
|
+ """
|
|
|
self.servername = servername
|
|
|
self.pubkeys = pubkeys
|
|
|
self.is_being_pushed_to = False
|
|
|
|
|
|
- def pushUpdates(self, sgAssocs) -> "Deferred":
|
|
|
+ @abstractmethod
|
|
|
+ def pushUpdates(self, sgAssocs: SignedAssociations) -> "Deferred[PushUpdateReturn]":
|
|
|
"""
|
|
|
- :param sgAssocs: Sequence of (originId, sgAssoc) tuples where originId is the id on the creating server and
|
|
|
- sgAssoc is the json object of the signed association
|
|
|
+ :param sgAssocs: Map from originId to sgAssoc, where originId is the id
|
|
|
+ on the creating server and sgAssoc is the json object
|
|
|
+ of the signed association
|
|
|
"""
|
|
|
- pass
|
|
|
+ ...
|
|
|
|
|
|
|
|
|
-class LocalPeer(Peer):
|
|
|
+class LocalPeer(Peer[bool]):
|
|
|
"""
|
|
|
The local peer (ourselves: essentially copying from the local associations table to the global one)
|
|
|
"""
|
|
@@ -71,14 +81,14 @@ class LocalPeer(Peer):
|
|
|
lastId = globalAssocStore.lastIdFromServer(self.servername)
|
|
|
self.lastId = lastId if lastId is not None else -1
|
|
|
|
|
|
- def pushUpdates(self, sgAssocs: Dict[int, Dict[str, Any]]) -> "Deferred":
|
|
|
+ def pushUpdates(self, sgAssocs: SignedAssociations) -> "Deferred[bool]":
|
|
|
"""
|
|
|
Saves the given associations in the global associations store. Only stores an
|
|
|
association if its ID is greater than the last seen ID.
|
|
|
|
|
|
:param sgAssocs: The associations to save.
|
|
|
|
|
|
- :return: True
|
|
|
+ :return: A deferred that succeeds with the value `True`.
|
|
|
"""
|
|
|
globalAssocStore = GlobalAssociationStore(self.sydent)
|
|
|
for localId in sgAssocs:
|
|
@@ -90,11 +100,14 @@ class LocalPeer(Peer):
|
|
|
|
|
|
if assocObj.mxid is not None:
|
|
|
# Assign a lookup_hash to this association
|
|
|
+ pepper = self.hashing_store.get_lookup_pepper()
|
|
|
+ if not pepper:
|
|
|
+ raise RuntimeError("No lookup_pepper in the database.")
|
|
|
str_to_hash = " ".join(
|
|
|
[
|
|
|
assocObj.address,
|
|
|
assocObj.medium,
|
|
|
- self.hashing_store.get_lookup_pepper(),
|
|
|
+ pepper,
|
|
|
],
|
|
|
)
|
|
|
assocObj.lookup_hash = sha256_and_url_safe_base64(str_to_hash)
|
|
@@ -116,7 +129,7 @@ class LocalPeer(Peer):
|
|
|
return d
|
|
|
|
|
|
|
|
|
-class RemotePeer(Peer):
|
|
|
+class RemotePeer(Peer[IResponse]):
|
|
|
def __init__(
|
|
|
self,
|
|
|
sydent: "Sydent",
|
|
@@ -199,15 +212,16 @@ class RemotePeer(Peer):
|
|
|
or len(key_ids) == 0
|
|
|
or not key_ids[0].startswith(SIGNING_KEY_ALGORITHM + ":")
|
|
|
):
|
|
|
- e = NoMatchingSignatureException()
|
|
|
- e.foundSigs = assoc["signatures"].keys()
|
|
|
- e.requiredServername = self.servername
|
|
|
+ e = NoMatchingSignatureException(
|
|
|
+ foundSigs=assoc["signatures"].keys(),
|
|
|
+ requiredServername=self.servername,
|
|
|
+ )
|
|
|
raise e
|
|
|
|
|
|
# Verify the JSON
|
|
|
signedjson.sign.verify_signed_json(assoc, self.servername, self.verify_key)
|
|
|
|
|
|
- def pushUpdates(self, sgAssocs: Dict[int, Dict[str, Any]]) -> "Deferred":
|
|
|
+ def pushUpdates(self, sgAssocs: SignedAssociations) -> "Deferred[IResponse]":
|
|
|
"""
|
|
|
Pushes the given associations to the peer.
|
|
|
|
|
@@ -220,13 +234,15 @@ class RemotePeer(Peer):
|
|
|
reqDeferred = self.sydent.replicationHttpsClient.postJson(
|
|
|
self.replication_url, body
|
|
|
)
|
|
|
+ if reqDeferred is None:
|
|
|
+ raise RuntimeError(f"Unable to push sgAssocs to {self.replication_url}")
|
|
|
|
|
|
# XXX: We'll also need to prune the deleted associations out of the
|
|
|
# local associations table once they've been replicated to all peers
|
|
|
# (ie. remove the record we kept in order to propagate the deletion to
|
|
|
# other peers).
|
|
|
|
|
|
- updateDeferred = defer.Deferred()
|
|
|
+ updateDeferred: "Deferred[IResponse]" = defer.Deferred()
|
|
|
|
|
|
reqDeferred.addCallback(self._pushSuccess, updateDeferred=updateDeferred)
|
|
|
reqDeferred.addErrback(self._pushFailed, updateDeferred=updateDeferred)
|
|
@@ -236,7 +252,7 @@ class RemotePeer(Peer):
|
|
|
def _pushSuccess(
|
|
|
self,
|
|
|
result: "IResponse",
|
|
|
- updateDeferred: "Deferred",
|
|
|
+ updateDeferred: "Deferred[IResponse]",
|
|
|
) -> None:
|
|
|
"""
|
|
|
Processes a successful push request. If the request resulted in a status code
|
|
@@ -253,7 +269,9 @@ class RemotePeer(Peer):
|
|
|
d.addCallback(self._failedPushBodyRead, updateDeferred=updateDeferred)
|
|
|
d.addErrback(self._pushFailed, updateDeferred=updateDeferred)
|
|
|
|
|
|
- def _failedPushBodyRead(self, body: bytes, updateDeferred: "Deferred") -> None:
|
|
|
+ def _failedPushBodyRead(
|
|
|
+ self, body: bytes, updateDeferred: "Deferred[IResponse]"
|
|
|
+ ) -> None:
|
|
|
"""
|
|
|
Processes a response body from a failed push request, then calls the error
|
|
|
callback of the provided deferred.
|
|
@@ -262,21 +280,19 @@ class RemotePeer(Peer):
|
|
|
:param updateDeferred: The deferred to call the error callback of.
|
|
|
"""
|
|
|
errObj = json_decoder.decode(body.decode("utf8"))
|
|
|
- e = RemotePeerError()
|
|
|
- e.errorDict = errObj
|
|
|
+ e = RemotePeerError(errObj)
|
|
|
updateDeferred.errback(e)
|
|
|
|
|
|
def _pushFailed(
|
|
|
self,
|
|
|
- failure,
|
|
|
- updateDeferred: "Deferred",
|
|
|
+ failure: Failure,
|
|
|
+ updateDeferred: "Deferred[object]",
|
|
|
) -> None:
|
|
|
"""
|
|
|
Processes a failed push request, by calling the error callback of the given
|
|
|
deferred with it.
|
|
|
|
|
|
:param failure: The failure to process.
|
|
|
- :type failure: twisted.python.failure.Failure
|
|
|
:param updateDeferred: The deferred to call the error callback of.
|
|
|
"""
|
|
|
updateDeferred.errback(failure)
|
|
@@ -288,7 +304,11 @@ class NoSignaturesException(Exception):
|
|
|
|
|
|
|
|
|
class NoMatchingSignatureException(Exception):
|
|
|
- def __str__(self):
|
|
|
+ def __init__(self, foundSigs: Sequence[str], requiredServername: str):
|
|
|
+ self.foundSigs = foundSigs
|
|
|
+ self.requiredServername = requiredServername
|
|
|
+
|
|
|
+ def __str__(self) -> str:
|
|
|
return "Found signatures: %s, required server name: %s" % (
|
|
|
self.foundSigs,
|
|
|
self.requiredServername,
|
|
@@ -296,5 +316,8 @@ class NoMatchingSignatureException(Exception):
|
|
|
|
|
|
|
|
|
class RemotePeerError(Exception):
|
|
|
- def __str__(self):
|
|
|
+ def __init__(self, errorDict: JsonDict):
|
|
|
+ self.errorDict = errorDict
|
|
|
+
|
|
|
+ def __str__(self) -> str:
|
|
|
return repr(self.errorDict)
|