Browse Source

Get `sydent.replication` to pass `mypy --strict` (#430)

Co-authored-by: reivilibre <oliverw@matrix.org>
David Robertson 2 years ago
parent
commit
d96b25e096

+ 1 - 0
changelog.d/430.misc

@@ -0,0 +1 @@
+Make `sydent.replication` pass `mypy --strict`.

+ 1 - 0
pyproject.toml

@@ -52,6 +52,7 @@ files = [
     "sydent/config",
     "sydent/db",
     "sydent/hs_federation",
+    "sydent/replication",
     "sydent/sms",
     "sydent/terms",
     "sydent/threepid",

+ 0 - 0
stubs/twisted/web/__init__.pyi


+ 4 - 0
stubs/twisted/web/client.pyi

@@ -0,0 +1,4 @@
+from twisted.internet.defer import Deferred
+from twisted.web.iweb import IResponse
+
+def readBody(response: IResponse) -> Deferred[bytes]: ...

+ 4 - 1
sydent/db/peers.py

@@ -111,7 +111,10 @@ class PeerStore:
         return peers
 
     def setLastSentVersionAndPokeSucceeded(
-        self, peerName: str, lastSentVersion: int, lastPokeSucceeded: int
+        self,
+        peerName: str,
+        lastSentVersion: Optional[int],
+        lastPokeSucceeded: Optional[int],
     ) -> None:
         """
         Sets the ID of the last association sent to a given peer and the time of the

+ 8 - 3
sydent/db/threepid_associations.py

@@ -22,6 +22,11 @@ from sydent.util import time_msec
 if TYPE_CHECKING:
     from sydent.sydent import Sydent
 
+# Key: id from associations db table
+# Value: an association dict. Roughly speaking, a signed
+# version of sydent.db.TheepidAssociation.
+SignedAssociations = Dict[int, Dict[str, Any]]
+
 
 logger = logging.getLogger(__name__)
 
@@ -111,7 +116,7 @@ class LocalAssociationStore:
 
     def getSignedAssociationsAfterId(
         self, afterId: Optional[int], limit: Optional[int] = None
-    ) -> Tuple[Dict[int, Dict[str, Any]], Optional[int]]:
+    ) -> Tuple[SignedAssociations, Optional[int]]:
         """Get associations after a given ID, and sign them before returning
 
         :param afterId: The ID to return results after (not inclusive)
@@ -308,7 +313,7 @@ class GlobalAssociationStore:
     def addAssociation(
         self,
         assoc: ThreepidAssociation,
-        rawSgAssoc: Dict[str, Any],
+        rawSgAssoc: str,
         originServer: str,
         originId: int,
         commit: bool = True,
@@ -319,7 +324,7 @@ class GlobalAssociationStore:
         this function.
 
         :param assoc: The association to add as a high level object.
-        :param rawSgAssoc: The original raw bytes of the signed association.
+        :param rawSgAssoc: The original raw text of the signed association.
         :param originServer: The name of the server the association was created on.
         :param originId: The ID of the association on the server the association was
             created on.

+ 4 - 2
sydent/http/httpsclient.py

@@ -21,7 +21,7 @@ from twisted.internet.defer import Deferred
 from twisted.internet.ssl import optionsForClientTLS
 from twisted.web.client import Agent, FileBodyProducer
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IPolicyForHTTPS
+from twisted.web.iweb import IPolicyForHTTPS, IResponse
 from zope.interface import implementer
 
 from sydent.types import JsonDict
@@ -51,7 +51,9 @@ class ReplicationHttpsClient:
             #                                                      trustRoot=self.sydent.sslComponents.trustRoot)
             self.agent = Agent(self.sydent.reactor, SydentPolicyForHTTPS(self.sydent))
 
-    def postJson(self, uri: str, jsonObject: JsonDict) -> Optional[Deferred]:
+    def postJson(
+        self, uri: str, jsonObject: JsonDict
+    ) -> Optional["Deferred[IResponse]"]:
         """
         Sends an POST request over HTTPS.
 

+ 50 - 27
sydent/replication/peer.py

@@ -16,25 +16,29 @@
 import binascii
 import json
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from abc import abstractmethod
+from typing import TYPE_CHECKING, Dict, Generic, Optional, Sequence, TypeVar
 
 import signedjson.key
 import signedjson.sign
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
 from twisted.web.client import readBody
 from twisted.web.iweb import IResponse
 from unpaddedbase64 import decode_base64
 
 from sydent.config import ConfigError
 from sydent.db.hashing_metadata import HashingMetadataStore
-from sydent.db.threepid_associations import GlobalAssociationStore
+from sydent.db.threepid_associations import GlobalAssociationStore, SignedAssociations
 from sydent.threepid import threePidAssocFromDict
 from sydent.types import JsonDict
 from sydent.util import json_decoder
 from sydent.util.hash import sha256_and_url_safe_base64
 from sydent.util.stringutils import normalise_address
 
+PushUpdateReturn = TypeVar("PushUpdateReturn")
+
 if TYPE_CHECKING:
     from sydent.sydent import Sydent
 
@@ -43,21 +47,27 @@ logger = logging.getLogger(__name__)
 SIGNING_KEY_ALGORITHM = "ed25519"
 
 
-class Peer:
-    def __init__(self, servername, pubkeys):
+class Peer(Generic[PushUpdateReturn]):
+    def __init__(self, servername: str, pubkeys: Dict[str, str]):
+        """
+        :param server_name: The peer's server name.
+        :param pubkeys: The peer's public keys in a Dict[key_id, key_b64]
+        """
         self.servername = servername
         self.pubkeys = pubkeys
         self.is_being_pushed_to = False
 
-    def pushUpdates(self, sgAssocs) -> "Deferred":
+    @abstractmethod
+    def pushUpdates(self, sgAssocs: SignedAssociations) -> "Deferred[PushUpdateReturn]":
         """
-        :param sgAssocs: Sequence of (originId, sgAssoc) tuples where originId is the id on the creating server and
-                        sgAssoc is the json object of the signed association
+        :param sgAssocs: Map from originId to sgAssoc,  where originId is the id
+                         on the creating server and sgAssoc is the json object
+                         of the signed association
         """
-        pass
+        ...
 
 
-class LocalPeer(Peer):
+class LocalPeer(Peer[bool]):
     """
     The local peer (ourselves: essentially copying from the local associations table to the global one)
     """
@@ -71,14 +81,14 @@ class LocalPeer(Peer):
         lastId = globalAssocStore.lastIdFromServer(self.servername)
         self.lastId = lastId if lastId is not None else -1
 
-    def pushUpdates(self, sgAssocs: Dict[int, Dict[str, Any]]) -> "Deferred":
+    def pushUpdates(self, sgAssocs: SignedAssociations) -> "Deferred[bool]":
         """
         Saves the given associations in the global associations store. Only stores an
         association if its ID is greater than the last seen ID.
 
         :param sgAssocs: The associations to save.
 
-        :return: True
+        :return: A deferred that succeeds with the value `True`.
         """
         globalAssocStore = GlobalAssociationStore(self.sydent)
         for localId in sgAssocs:
@@ -90,11 +100,14 @@ class LocalPeer(Peer):
 
                 if assocObj.mxid is not None:
                     # Assign a lookup_hash to this association
+                    pepper = self.hashing_store.get_lookup_pepper()
+                    if not pepper:
+                        raise RuntimeError("No lookup_pepper in the database.")
                     str_to_hash = " ".join(
                         [
                             assocObj.address,
                             assocObj.medium,
-                            self.hashing_store.get_lookup_pepper(),
+                            pepper,
                         ],
                     )
                     assocObj.lookup_hash = sha256_and_url_safe_base64(str_to_hash)
@@ -116,7 +129,7 @@ class LocalPeer(Peer):
         return d
 
 
-class RemotePeer(Peer):
+class RemotePeer(Peer[IResponse]):
     def __init__(
         self,
         sydent: "Sydent",
@@ -199,15 +212,16 @@ class RemotePeer(Peer):
             or len(key_ids) == 0
             or not key_ids[0].startswith(SIGNING_KEY_ALGORITHM + ":")
         ):
-            e = NoMatchingSignatureException()
-            e.foundSigs = assoc["signatures"].keys()
-            e.requiredServername = self.servername
+            e = NoMatchingSignatureException(
+                foundSigs=assoc["signatures"].keys(),
+                requiredServername=self.servername,
+            )
             raise e
 
         # Verify the JSON
         signedjson.sign.verify_signed_json(assoc, self.servername, self.verify_key)
 
-    def pushUpdates(self, sgAssocs: Dict[int, Dict[str, Any]]) -> "Deferred":
+    def pushUpdates(self, sgAssocs: SignedAssociations) -> "Deferred[IResponse]":
         """
         Pushes the given associations to the peer.
 
@@ -220,13 +234,15 @@ class RemotePeer(Peer):
         reqDeferred = self.sydent.replicationHttpsClient.postJson(
             self.replication_url, body
         )
+        if reqDeferred is None:
+            raise RuntimeError(f"Unable to push sgAssocs to {self.replication_url}")
 
         # XXX: We'll also need to prune the deleted associations out of the
         # local associations table once they've been replicated to all peers
         # (ie. remove the record we kept in order to propagate the deletion to
         # other peers).
 
-        updateDeferred = defer.Deferred()
+        updateDeferred: "Deferred[IResponse]" = defer.Deferred()
 
         reqDeferred.addCallback(self._pushSuccess, updateDeferred=updateDeferred)
         reqDeferred.addErrback(self._pushFailed, updateDeferred=updateDeferred)
@@ -236,7 +252,7 @@ class RemotePeer(Peer):
     def _pushSuccess(
         self,
         result: "IResponse",
-        updateDeferred: "Deferred",
+        updateDeferred: "Deferred[IResponse]",
     ) -> None:
         """
         Processes a successful push request. If the request resulted in a status code
@@ -253,7 +269,9 @@ class RemotePeer(Peer):
             d.addCallback(self._failedPushBodyRead, updateDeferred=updateDeferred)
             d.addErrback(self._pushFailed, updateDeferred=updateDeferred)
 
-    def _failedPushBodyRead(self, body: bytes, updateDeferred: "Deferred") -> None:
+    def _failedPushBodyRead(
+        self, body: bytes, updateDeferred: "Deferred[IResponse]"
+    ) -> None:
         """
         Processes a response body from a failed push request, then calls the error
         callback of the provided deferred.
@@ -262,21 +280,19 @@ class RemotePeer(Peer):
         :param updateDeferred: The deferred to call the error callback of.
         """
         errObj = json_decoder.decode(body.decode("utf8"))
-        e = RemotePeerError()
-        e.errorDict = errObj
+        e = RemotePeerError(errObj)
         updateDeferred.errback(e)
 
     def _pushFailed(
         self,
-        failure,
-        updateDeferred: "Deferred",
+        failure: Failure,
+        updateDeferred: "Deferred[object]",
     ) -> None:
         """
         Processes a failed push request, by calling the error callback of the given
         deferred with it.
 
         :param failure: The failure to process.
-        :type failure: twisted.python.failure.Failure
         :param updateDeferred: The deferred to call the error callback of.
         """
         updateDeferred.errback(failure)
@@ -288,7 +304,11 @@ class NoSignaturesException(Exception):
 
 
 class NoMatchingSignatureException(Exception):
-    def __str__(self):
+    def __init__(self, foundSigs: Sequence[str], requiredServername: str):
+        self.foundSigs = foundSigs
+        self.requiredServername = requiredServername
+
+    def __str__(self) -> str:
         return "Found signatures: %s, required server name: %s" % (
             self.foundSigs,
             self.requiredServername,
@@ -296,5 +316,8 @@ class NoMatchingSignatureException(Exception):
 
 
 class RemotePeerError(Exception):
-    def __str__(self):
+    def __init__(self, errorDict: JsonDict):
+        self.errorDict = errorDict
+
+    def __str__(self) -> str:
         return repr(self.errorDict)

+ 3 - 4
sydent/replication/pusher.py

@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, List, Tuple
 
 import twisted.internet.reactor
 import twisted.internet.task
@@ -41,7 +41,7 @@ class Pusher:
         self.peerStore = PeerStore(self.sydent)
         self.local_assoc_store = LocalAssociationStore(self.sydent)
 
-    def setup(self):
+    def setup(self) -> None:
         cb = twisted.internet.task.LoopingCall(Pusher.scheduledPush, self)
         cb.clock = self.sydent.reactor
         cb.start(10.0)
@@ -62,12 +62,11 @@ class Pusher:
 
         localPeer.pushUpdates(signedAssocs)
 
-    def scheduledPush(self):
+    def scheduledPush(self) -> "defer.Deferred[List[Tuple[bool, None]]]":
         """Push pending updates to all known remote peers. To be called regularly.
 
         :returns a deferred.DeferredList of defers, one per peer we're pushing to that will
         resolve when pushing to that peer has completed, successfully or otherwise
-        :rtype deferred.DeferredList
         """
         peers = self.peerStore.getAllPeers()
 

+ 3 - 1
sydent/sydent.py

@@ -182,7 +182,9 @@ class Sydent:
 
         self.clientApiHttpServer = ClientApiHttpServer(self)
         self.replicationHttpsServer = ReplicationHttpsServer(self)
-        self.replicationHttpsClient = ReplicationHttpsClient(self)
+        self.replicationHttpsClient: ReplicationHttpsClient = ReplicationHttpsClient(
+            self
+        )
 
         self.pusher: Pusher = Pusher(self)