Browse Source

[DINSIC] Replicate 3PID invites (#108)

* WIP: Replicate 3pid invites

* allow lastSent to be null

* Import copy

* last sent prefix

* SQL demands spaces

* Only commit relevant ID information

* Don't check token IDs for incoming eph. keys. Add 400s

* Don't only check the first origin ID

* Fix string formatting

* Fix missing var

* Return 200 on having received duplicated data

* public keys

* fix typo

* Prevent last sent IDs being reset on not sending anything

* Prevent more than 100 items being replicated at a time

* Properly verify replicated data [should go on master too]

* Change logging back to INFO

* Ensure replication doesn't happen one way, and then the other

* Doc and method name updates

* Add param/return info

* Correct docstring

* More docstring corrections

* Remove conflicts

* Remove erroneous dependency

* And another

* Add types to docstrings

* Add more types to docstrings

* Address comments

* Fix assoc signing

* Fix replication servlet's signature handling, add comments

* Fix example in generate-key

* Make it clearer that signedAssocs is separated

* Clarify signing method for peer

* Clarify signing method for peer

(cherry picked from commit 616bc050a4551a61ec89c32cd1af06fef45d6e7f)

* Don't assume we have the private/signing key of a peer

* Don't assume we have the private/signing key of a peer

(cherry picked from commit 87be14f5f6e342e00d940468016fbc4a65247b4d)

* Revert "Don't assume we have the private/signing key of a peer"

This reverts commit 87be14f5f6e342e00d940468016fbc4a65247b4d.

* Revert "Clarify signing method for peer"

This reverts commit 616bc050a4551a61ec89c32cd1af06fef45d6e7f.

* Change signing key algorithm to global var

* Clean up remaining signing nonsense

* Cleaner way of pushing associations

* Remove unnecessary methods

* Rename replicationpushservlet back to replication

* Reduce indent, fix comments, move token/key storage logic

* Add logging for push connections sending more than the max allowed entries

* Fix verify_key sourcing
Andrew Morgan 5 years ago
parent
commit
4b5e36e6b3

+ 3 - 0
scripts/show-verify-key

@@ -1,5 +1,8 @@
 #!/usr/bin/env python
 
+# Run example
+# ./scripts/show-verify-key ed25519 0 xjKlejLcLyTQg7Fxy/XGopUeF3W/l3/cRgpFe+edi0E
+
 # Use this to generate a signing key and verify key for use in sydent
 # configurations.
 

+ 1 - 0
setup.py

@@ -48,6 +48,7 @@ setup(
         "phonenumbers",
         "pyopenssl",
 
+        "pynacl",
         "pyyaml",
         "netaddr",
     ],

+ 215 - 10
sydent/db/invite_tokens.py

@@ -20,16 +20,55 @@ class JoinTokenStore(object):
     def __init__(self, sydent):
         self.sydent = sydent
 
-    def storeToken(self, medium, address, roomId, sender, token):
+    def storeToken(self, medium, address, roomId, sender, token, originServer=None, originId=None, commit=True):
+        """Stores an invite token.
+        
+        :param medium: The medium of the token.
+        :type medium: str
+        :param address: The address of the token.
+        :type address: str
+        :param roomId: The room ID this token is tied to.
+        :type roomId: str
+        :param sender: The sender of the invite.
+        :type sender: str
+        :param token: The token itself.
+        :type token: str
+        :param originServer: The server this invite originated from (if
+            coming from replication).
+        :type originServer: str, None
+        :param originId: The id of the token in the DB of originServer. Used
+        for determining if we've already received a token or not.
+        :type originId: int, None
+        :param commit: Whether DB changes should be committed by this
+            function (or an external one).
+        :type commit: bool
+        """
+        if originId and originServer:
+            # Check if we've already seen this association from this server
+            last_processed_id = tokensStore.getLastTokenIdFromServer(originServer)
+            if int(originId) <= int(last_processed_id):
+                logger.info("We have already seen token ID %s from %s. Ignoring.", originId, originServer)
+                return
+
         cur = self.sydent.db.cursor()
 
         cur.execute("INSERT INTO invite_tokens"
-                    " ('medium', 'address', 'room_id', 'sender', 'token', 'received_ts')"
-                    " VALUES (?, ?, ?, ?, ?, ?)",
-                    (medium, address, roomId, sender, token, int(time.time())))
-        self.sydent.db.commit()
+                    " ('medium', 'address', 'room_id', 'sender', 'token', 'received_ts', 'origin_server', 'origin_id')"
+                    " VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
+                    (medium, address, roomId, sender, token, int(time.time()), originServer, originId))
+        if commit:
+            self.sydent.db.commit()
 
     def getTokens(self, medium, address):
+        """Retrieve the invite token(s) for a given 3PID medium and address.
+        
+        :param medium: The medium of the 3PID.
+        :type medium: str
+        :param address: The address of the 3PID.
+        :type address: str
+        :returns a list of invite tokens, or an empty list if no tokens found.
+        :rtype: list[Dict[str, str]]
+        """
         cur = self.sydent.db.cursor()
 
         res = cur.execute(
@@ -53,7 +92,75 @@ class JoinTokenStore(object):
 
         return ret
 
+    def getInviteTokensAfterId(self, afterId, limit):
+        """Retrieves max `limit` invite tokens after a given DB id.
+        
+        :param afterId: A database id to act as an offset. Tokens after this
+            id are returned.
+        :type afterId: int
+        :param limit: Max amount of database rows to return.
+        :type limit: int, None
+        :returns a tuple consisting of a dict of invite tokens (with key
+            being the token's DB id) and the maximum DB id that was extracted.
+            Otherwise returns ({}, None) if no tokens are found.
+        :rtype: Tuple[Dict[int, Dict], int|None]
+        """
+        cur = self.sydent.db.cursor()
+        res = cur.execute(
+            "SELECT id, medium, address, room_id, sender, token FROM invite_tokens"
+            " WHERE id > ? AND origin_id = NULL LIMIT ?",
+            (afterId, limit,)
+        )
+        rows = res.fetchall()
+
+        # Dict of "id": {content}
+        invite_tokens = {}
+
+        maxId = None
+
+        for row in rows:
+            maxId, medium, address, room_id, sender, token = row
+            invite_tokens[maxId] = {
+                "origin_id": maxId,
+                "medium": medium,
+                "address": address,
+                "room_id": room_id,
+                "sender": sender,
+                "token": token,
+            }
+
+        return (invite_tokens, maxId)
+
+    def getLastTokenIdFromServer(self, server):
+        """Returns the last known invite token that was received from the
+        given server.
+
+        :param server: The name of the origin server.
+        :type server: str
+        :returns a database id marking the last known invite token received
+            from the given server. Returns 0 if no tokens have been received from
+            this server.
+        :rtype: int
+        """
+        cur = self.sydent.db.cursor()
+        res = cur.execute("select max(origin_id), count(origin_id) from invite_tokens"
+                          " where origin_server = ?", (server,))
+        row = res.fetchone()
+
+        if row[1] == 0:
+            return 0
+
+        return row[0]
+
+
     def markTokensAsSent(self, medium, address):
+        """Mark invite tokens as sent.
+        
+        :param medium: The medium of the token.
+        :type medium: str
+        :param address: The address of the token.
+        :type address: str
+        """
         cur = self.sydent.db.cursor()
 
         cur.execute(
@@ -62,17 +169,51 @@ class JoinTokenStore(object):
         )
         self.sydent.db.commit()
 
-    def storeEphemeralPublicKey(self, publicKey):
+    def storeEphemeralPublicKey(self, publicKey, persistenceTs=None, originServer=None, originId=None, commit=True):
+        """Stores an ephemeral public key in the database.
+        
+        :param publicKey: the ephemeral public key to store.
+        :type publicKey: str
+        :param persistenceTs: 
+        :type persistenceTs: int
+        :param originServer: the server this key was received from (if
+            retrieved through replication).
+        :type originServer: str
+        :param originId: The id of the key in the DB of originServer. Used
+            for determining if we've already received a key or not.
+        :type originId: int
+        :param commit: Whether DB changes should be committed by this
+            function (or an external one).
+        :type commit: bool
+        """
+        if originId and originServer:
+            # Check if we've already seen this association from this server
+            last_processed_id = tokensStore.getLastEphemeralPublicKeyIdFromServer(originServer)
+            if int(originId) <= int(last_processed_id):
+                logger.info("We have already seen key ID %s from %s. Ignoring.", originId, originServer)
+                return
+
+        if not persistenceTs:
+            persistenceTs = int(time.time())
         cur = self.sydent.db.cursor()
         cur.execute(
             "INSERT INTO ephemeral_public_keys"
-            " (public_key, persistence_ts)"
-            " VALUES (?, ?)",
-            (publicKey, int(time.time()))
+            " (public_key, persistence_ts, origin_server, origin_id)"
+            " VALUES (?, ?, ?, ?)",
+            (publicKey, persistenceTs, originServer, originId)
         )
-        self.sydent.db.commit()
+        if commit:
+            self.sydent.db.commit()
 
     def validateEphemeralPublicKey(self, publicKey):
+        """Mark an ephemeral public key as validated.
+        
+        :param publicKey: An ephemeral public key.
+        :type publicKey: str
+        :returns true or false depending on whether validation was
+            successful.
+        :rtype: bool
+        """
         cur = self.sydent.db.cursor()
         cur.execute(
             "UPDATE ephemeral_public_keys"
@@ -83,7 +224,71 @@ class JoinTokenStore(object):
         self.sydent.db.commit()
         return cur.rowcount > 0
 
+    def getEphemeralPublicKeysAfterId(self, afterId, limit):
+        """Retrieves max `limit` ephemeral public keys after a given DB id.
+        
+        :param afterId: A database id to act as an offset. Keys after this id
+            are returned.
+        :type afterId: int
+        :param limit: Max amount of database rows to return.
+        :type limit: int
+        :returns a tuple consisting of a list of ephemeral public keys (with
+            key being the token's DB id) and the maximum table id that was
+            extracted. Otherwise returns ({}, None) if no keys are found.
+        :rtype: Tuple[Dict[int, Dict], int|None]
+        """
+        cur = self.sydent.db.cursor()
+        res = cur.execute(
+            "SELECT id, public_key, verify_count, persistence_ts FROM ephemeral_public_keys"
+            " WHERE id > ? AND origin_id = NULL LIMIT ?",
+            (afterId, limit,)
+        )
+        rows = res.fetchall()
+
+        # Dict of "id": {content}
+        ephemeral_keys = {}
+
+        maxId = None
+
+        for row in rows:
+            maxId, public_key, verify_count, persistence_ts = row
+            ephemeral_keys[maxId] = {
+                "public_key": public_key,
+                "verify_count": verify_count,
+                "persistence_ts": persistence_ts,
+            }
+
+        return (ephemeral_keys, maxId)
+
+    def getLastEphemeralPublicKeyIdFromServer(self, server):
+        """Returns the last known ephemeral public key that was received from
+        the given server.
+
+        :param server: The name of the origin server.
+        :type server: str
+        :returns the last known DB id received from the given server, or 0 if
+            none have been received.
+        :rtype: int
+        """
+        cur = self.sydent.db.cursor()
+        res = cur.execute("select max(origin_id),count(origin_id) from ephemeral_public_keys"
+                          " where origin_server = ?", (server,))
+        row = res.fetchone()
+
+        if not row or row[1] == 0:
+            return 0
+
+        return row[0]
+
     def getSenderForToken(self, token):
+        """Returns the sender for a given invite token.
+        
+        :param token: The invite token.
+        :type token: str
+        :returns the sender of a given invite token or None if there isn't
+            one.
+        :rtype: str, None
+        """
         cur = self.sydent.db.cursor()
         res = cur.execute(
             "SELECT sender FROM invite_tokens WHERE token = ?",

+ 6 - 2
sydent/db/invite_tokens.sql

@@ -22,7 +22,9 @@ CREATE TABLE IF NOT EXISTS invite_tokens (
     sender varchar(256) not null,
     token varchar(256) not null,
     received_ts bigint, -- When the invite was received by us from the homeserver
-    sent_ts bigint -- When the token was sent by us to the user
+    sent_ts bigint, -- When the token was sent by us to the user
+    origin_id integer, -- original id in homeserver's DB that this was replicated from (if applicable)
+    origin_server text -- homeserver this was replicated from (if applicable)
 );
 CREATE INDEX IF NOT EXISTS invite_token_medium_address on invite_tokens(medium, address);
 CREATE INDEX IF NOT EXISTS invite_token_token on invite_tokens(token);
@@ -31,7 +33,9 @@ CREATE TABLE IF NOT EXISTS ephemeral_public_keys(
     id integer primary key,
     public_key varchar(256) not null,
     verify_count bigint default 0,
-    persistence_ts bigint
+    persistence_ts bigint,
+    origin_server text, -- homeserver this was replicated from (if applicable)
+    origin_id integer -- original id in homeserver's DB that this was replicated from (if applicable)
 );
 
 CREATE UNIQUE INDEX IF NOT EXISTS ephemeral_public_keys_index on ephemeral_public_keys(public_key);

+ 57 - 16
sydent/db/peers.py

@@ -23,26 +23,34 @@ class PeerStore:
 
     def getPeerByName(self, name):
         cur = self.sydent.db.cursor()
-        res = cur.execute("select p.name, p.port, p.lastSentVersion, p.shadow, pk.alg, pk.key from peers p, peer_pubkeys pk "
+        res = cur.execute("select p.name, p.port, "
+                          "p.lastSentAssocsId, p.lastSentInviteTokensId, p.lastSentEphemeralKeysId, "
+                          "p.shadow, pk.alg, pk.key from peers p, peer_pubkeys pk "
                           "where p.name = ? and pk.peername = p.name and p.active = 1", (name,))
 
         serverName = None
         port = None
-        lastSentVer = None
+        lastSentAssocsId = None
+        lastSentInviteTokensId = None
+        lastSentEphemeralKeysId = None
         pubkeys = {}
 
         for row in res.fetchall():
             serverName = row[0]
             port = row[1]
-            lastSentVer = row[2]
-            shadow = row[3]
-            pubkeys[row[4]] = row[5]
+            lastSentAssocsId = row[2]
+            lastSentInviteTokensId = row[3]
+            lastSentEphemeralKeysId = row[4]
+            shadow = row[5]
+            pubkeys[row[6]] = row[7]
 
         if len(pubkeys) == 0:
             return None
 
         p = RemotePeer(self.sydent, serverName, pubkeys)
-        p.lastSentVersion = lastSentVer
+        p.lastSentAssocsId = lastSentAssocsId
+        p.lastSentInviteTokensId = lastSentInviteTokensId
+        p.lastSentEphemeralKeysId = lastSentEphemeralKeysId
         p.shadow = True if shadow else False
         if port:
             p.port = port
@@ -51,34 +59,44 @@ class PeerStore:
 
     def getAllPeers(self):
         cur = self.sydent.db.cursor()
-        res = cur.execute("select p.name, p.port, p.lastSentVersion, p.shadow, pk.alg, pk.key from peers p, peer_pubkeys pk "
+        res = cur.execute("select p.name, p.port, "
+                          "p.lastSentAssocsId, p.lastSentInviteTokensId, p.lastSentEphemeralKeysId, "
+                          "p.shadow, pk.alg, pk.key from peers p, peer_pubkeys pk "
                           "where pk.peername = p.name and p.active = 1")
 
         peers = []
 
         peername = None
         port = None
-        lastSentVer = None
+        lastSentAssocsId = 0
+        lastSentInviteTokensId = 0
+        lastSentEphemeralKeysId = 0
         pubkeys = {}
 
         for row in res.fetchall():
             if row[0] != peername:
                 if len(pubkeys) > 0:
                     p = RemotePeer(self.sydent, peername, pubkeys)
-                    p.lastSentVersion = lastSentVer
+                    p.lastSentAssocsId = lastSentAssocsId
+                    p.lastSentInviteTokensId = lastSentInviteTokensId
+                    p.lastSentEphemeralKeysId = lastSentEphemeralKeysId
                     if port:
                         p.port = port
                     peers.append(p)
                     pubkeys = {}
                 peername = row[0]
                 port = row[1]
-                lastSentVer = row[2]
-                shadow = row[3]
-            pubkeys[row[4]] = row[5]
+                lastSentAssocsId = row[2]
+                lastSentInviteTokensId = row[3]
+                lastSentEphemeralKeysId = row[4]
+                shadow = row[5]
+            pubkeys[row[6]] = row[7]
 
         if len(pubkeys) > 0:
             p = RemotePeer(self.sydent, peername, pubkeys)
-            p.lastSentVersion = lastSentVer
+            p.lastSentAssocsId = lastSentAssocsId
+            p.lastSentInviteTokensId = lastSentInviteTokensId
+            p.lastSentEphemeralKeysId = lastSentEphemeralKeysId
             p.shadow = True if shadow else False
             if port:
                 p.port = port
@@ -87,8 +105,31 @@ class PeerStore:
 
         return peers
 
-    def setLastSentVersionAndPokeSucceeded(self, peerName, lastSentVersion, lastPokeSucceeded):
+    def setLastSentIdAndPokeSucceeded(self, peerName, ids, lastPokeSucceeded):
+        """Set last successful replication of data to this peer.
+
+        If an id for a replicated database table is None, the last sent value
+        will not be updated.
+
+        :param peerName: The name of the peer.
+        :type peerName: str
+        :param ids: A Dictionary of ids that represent the last database
+            table ids that were replicated to this peer.
+        :type ids: Dict[str, int]
+        :param lastPokeSucceeded: The time of when the last successful
+            replication succeeded (even if no actual replication of data was
+            necessary).
+        :type lastPokeSucceeded: int
+        """
+
         cur = self.sydent.db.cursor()
-        res = cur.execute("update peers set lastSentVersion = ?, lastPokeSucceededAt = ? "
-                          "where name = ?", (lastSentVersion, lastPokeSucceeded, peerName))
+        if ids["sg_assocs"]:
+            cur.execute("update peers set lastSentAssocsId = ?, lastPokeSucceededAt = ? "
+                        "where name = ?", (ids["sg_assocs"], lastPokeSucceeded, peerName))
+        if ids["invite_tokens"]:
+            cur.execute("update peers set lastSentInviteTokensId = ?, lastPokeSucceededAt = ? "
+                        "where name = ?", (ids["invite_tokens"], lastPokeSucceeded, peerName))
+        if ids["ephemeral_public_keys"]:
+            cur.execute("update peers set lastSentEphemeralKeysId = ?, lastPokeSucceededAt = ? "
+                        "where name = ?", (ids["ephemeral_public_keys"], lastPokeSucceeded, peerName))
         self.sydent.db.commit()

+ 3 - 1
sydent/db/peers.sql

@@ -18,7 +18,9 @@ CREATE TABLE IF NOT EXISTS peers (
 	id integer primary key,
 	name varchar(255) not null,
 	port integer default null,
-	lastSentVersion integer,
+	lastSentAssocsId integer,
+	lastSentInviteTokensId integer,
+	lastSentEphemeralKeysId integer,
 	lastPokeSucceededAt integer,
 	active integer not null default 0,
 	shadow integer not null default 0

+ 11 - 3
sydent/db/threepid_associations.py

@@ -178,9 +178,17 @@ class GlobalAssociationStore:
 
     def addAssociation(self, assoc, rawSgAssoc, originServer, originId, commit=True):
         """
-        :param assoc: (sydent.threepid.GlobalThreepidAssociation) The association to add as a high level object
-        :param sgAssoc The original raw bytes of the signed association
-        :return:
+        :param assoc: The association to add as a high level object.
+        :type assoc: sydent.threepid.GlobalThreepidAssociation
+        :param rawSgAssoc: The original raw string of the signed association (in JSON format).
+        :type rawSgAssoc: str
+        :param originServer: The name of the server this association originated from.
+        :type originServer: str
+        :param originId: The DB table id of the association the origin server.
+        :type originId: int
+        :param commit: Whether this function should commit to the DB after
+            completing insertion.
+        :type commit: bool
         """
         cur = self.sydent.db.cursor()
         res = cur.execute("insert or ignore into global_threepid_associations "

+ 3 - 3
sydent/hs_federation/verifier.py

@@ -88,11 +88,11 @@ class Verifier(object):
         to do perspectives checks.
 
         :param acceptable_server_names: If provided and not None,
-        only signatures from servers in this list will be accepted.
-        :type acceptable_server_names: list of strings
+            only signatures from servers in this list will be accepted.
+        :type acceptable_server_names: list[str]
 
         :return a tuple of the server name and key name that was
-        successfully verified. If the json cannot be verified,
+            successfully verified. If the json cannot be verified,
         raises SignatureVerifyException.
         """
         if 'signatures' not in signed_json:

+ 3 - 2
sydent/http/httpclient.py

@@ -130,9 +130,10 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
     :param reactor: Twisted reactor.
     :param destination: The name of the server to connect to.
     :type destination: bytes
-    :param ssl_context_factory: Factory which generates SSL contexts to use for TLS.
+    :param ssl_context_factory: Factory which generates SSL contexts to use
+        for TLS.
     :type ssl_context_factory: twisted.internet.ssl.ContextFactory
-    :param timeout (int): connection timeout in seconds
+    :param timeout: connection timeout in seconds
     :type timeout: int
     """
 

+ 126 - 40
sydent/http/servlets/replication.py

@@ -16,22 +16,51 @@
 
 import twisted.python.log
 from twisted.web.resource import Resource
+from twisted.web import server
+from twisted.internet import defer
 from sydent.http.servlets import jsonwrap
 from sydent.threepid import threePidAssocFromDict
 from sydent.db.peers import PeerStore
 from sydent.db.threepid_associations import GlobalAssociationStore
+from sydent.db.invite_tokens import JoinTokenStore
+from sydent.replication.peer import NoMatchingSignatureException, NoSignaturesException, RemotePeerError
+from signedjson.sign import SignatureVerifyException
 
 import logging
 import json
 
 logger = logging.getLogger(__name__)
 
+MAX_SG_ASSOCS_LIMIT = 100
+MAX_INVITE_TOKENS_LIMIT = 100
+MAX_EPHEMERAL_PUBLIC_KEYS_LIMIT = 100
+
 class ReplicationPushServlet(Resource):
     def __init__(self, sydent):
         self.sydent = sydent
 
-    @jsonwrap
     def render_POST(self, request):
+        self._async_render_POST(request)
+        return server.NOT_DONE_YET
+
+    @defer.inlineCallbacks
+    def _async_render_POST(self, request):
+        """Verify and store replicated information from trusted peer identity servers.
+
+        To prevent data sent from erroneous servers from being stored, we
+        initially verify that the sender's certificate contains a commonName
+        that we trust. This is checked against the peers stored in the local
+        DB. Data is then ingested.
+
+        Replicated associations must each be individually signed by the
+        signing key of the remote peer, which we verify using the verifykey
+        stored in the local DB.
+
+        Other data does not need to be signed.
+
+        :params request: The HTTPS request.
+        """
+
         peerCert = request.transport.getPeerCertificate()
         peerCertCn = peerCert.get_subject().commonName
 
@@ -42,7 +71,9 @@ class ReplicationPushServlet(Resource):
         if not peer:
             logger.warn("Got connection from %s but no peer found by that name", peerCertCn)
             request.setResponseCode(403)
-            return {'errcode': 'M_UNKNOWN_PEER', 'error': 'This peer is not known to this server'}
+            request.write(json.dumps({'errcode': 'M_UNKNOWN_PEER', 'error': 'This peer is not known to this server'}))
+            request.finish()
+            return
 
         logger.info("Push connection made from peer %s", peer.servername)
 
@@ -50,57 +81,112 @@ class ReplicationPushServlet(Resource):
                 request.requestHeaders.getRawHeaders('Content-Type')[0] != 'application/json':
             logger.warn("Peer %s made push connection with non-JSON content (type: %s)",
                         peer.servername, request.requestHeaders.getRawHeaders('Content-Type')[0])
-            return {'errcode': 'M_NOT_JSON', 'error': 'This endpoint expects JSON'}
+            request.setResponseCode(400)
+            request.write(json.dumps({'errcode': 'M_NOT_JSON', 'error': 'This endpoint expects JSON'}))
+            request.finish()
+            return
 
         try:
             inJson = json.load(request.content)
         except ValueError:
             logger.warn("Peer %s made push connection with malformed JSON", peer.servername)
-            return {'errcode': 'M_BAD_JSON', 'error': 'Malformed JSON'}
-
-        if 'sgAssocs' not in inJson:
-            logger.warn("Peer %s made push connection with no 'sgAssocs' key in JSON", peer.servername)
-            return {'errcode': 'M_BAD_JSON', 'error': 'No "sgAssocs" key in JSON'}
+            request.setResponseCode(400)
+            request.write(json.dumps({'errcode': 'M_BAD_JSON', 'error': 'Malformed JSON'}))
+            request.finish()
+            return
 
-        failedIds = []
+        # Ensure there is data we are able to process
+        if 'sg_assocs' not in inJson and 'invite_tokens' not in inJson and 'ephemeral_public_keys' not in inJson:
+            logger.warn("Peer %s made push connection with no 'sg_assocs', 'invite_tokens' or 'ephemeral_public_keys' keys in JSON", peer.servername)
+            request.setResponseCode(400)
+            request.write(json.dumps({'errcode': 'M_BAD_JSON', 'error': 'No "sg_assocs", "invite_tokens" or "ephemeral_public_keys" key in JSON'}))
+            request.finish()
+            return
+
+        # Process signed associations
+        sg_assocs = inJson.get('sg_assocs', {})
+        if len(sg_assocs) > MAX_SG_ASSOCS_LIMIT:
+            logger.warn("Peer %s made push with 'sg_assocs' field containing %d entries, which is greater than the maximum %d", peer.servername, len(sg_assocs), MAX_SG_ASSOCS_LIMIT)
+            request.setResponseCode(400)
+            request.write(json.dumps({'errcode': 'M_BAD_JSON', 'error': '"sg_assocs" has more than %d keys' % MAX_SG_ASSOCS_LIMIT}))
+            request.finish()
+            return
 
         globalAssocsStore = GlobalAssociationStore(self.sydent)
 
-        for originId,sgAssoc in inJson['sgAssocs'].items():
+        # Check that this message is signed by one of our trusted associated peers
+        for originId, sgAssoc in sg_assocs.items():
             try:
-                peer.verifySignedAssociation(sgAssoc)
+                yield peer.verifySignedAssociation(sgAssoc)
                 logger.debug("Signed association from %s with origin ID %s verified", peer.servername, originId)
+            except (NoSignaturesException, NoMatchingSignatureException, RemotePeerError, SignatureVerifyException):
+                self.sydent.db.rollback()
+                logger.warn("Failed to verify signed association from %s with origin ID %s", peer.servername, originId)
+                request.setResponseCode(400)
+                request.write(json.dumps({'errcode': 'M_VERIFICATION_FAILED', 'error': 'Signature verification failed'}))
+                request.finish()
+                return
+            except Exception:
+                self.sydent.db.rollback()
+                logger.error("Failed to verify signed association from %s with origin ID %s", peer.servername, originId)
+                request.setResponseCode(500)
+                request.write(json.dumps({'errcode': 'M_INTERNAL_SERVER_ERROR', 'error': 'Signature verification failed'}))
+                request.finish()
+                return
+
+            assocObj = threePidAssocFromDict(sgAssoc)
+
+            if assocObj.mxid is not None:
+                # Add the association components and the original signed
+                # object (as assocs must be signed when requested by clients)
+                globalAssocsStore.addAssociation(assocObj, json.dumps(sgAssoc), peer.servername, originId, commit=False)
+            else:
+                logger.info("Incoming deletion: removing associations for %s / %s", assocObj.medium, assocObj.address)
+                globalAssocsStore.removeAssociation(assocObj.medium, assocObj.address)
+
+            logger.info("Stored association with origin ID %s from %s", originId, peer.servername)
+
+            # if this is an association that matches one of our invite_tokens then we should call the onBind callback
+            # at this point, in order to tell the inviting HS that someone out there has just bound the 3PID.
+            self.sydent.threepidBinder.notifyPendingInvites(assocObj)
+
+        tokensStore = JoinTokenStore(self.sydent)
+
+        # Process any invite tokens
+
+        invite_tokens = inJson.get('invite_tokens', {})
+        if len(invite_tokens) > MAX_INVITE_TOKENS_LIMIT:
+            self.sydent.db.rollback()
+            logger.warn("Peer %s made push with 'sg_assocs' field containing %d entries, which is greater than the maximum %d", peer.servername, len(invite_tokens), MAX_INVITE_TOKENS_LIMIT)
+            request.setResponseCode(400)
+            request.write(json.dumps({'errcode': 'M_BAD_JSON', 'error': '"invite_tokens" has more than %d keys' % MAX_INVITE_TOKENS_LIMIT}))
+            request.finish()
+            return
 
-                # Don't bother adding if one has already failed: we add all of them or none so we're only going to
-                # roll back the transaction anyway (but we continue to try & verify the rest so we can give a
-                # complete list of the ones that don't verify)
-                if len(failedIds) > 0:
-                    continue
-
-                assocObj = threePidAssocFromDict(sgAssoc)
-
-                if assocObj.mxid is not None:
-                    globalAssocsStore.addAssociation(assocObj, json.dumps(sgAssoc), peer.servername, originId, commit=False)
-                else:
-                    logger.info("Incoming deletion: removing associations for %s / %s", assocObj.medium, assocObj.address)
-                    globalAssocsStore.removeAssociation(assocObj.medium, assocObj.address)
-                logger.info("Stored association origin ID %s from %s", originId, peer.servername)
-
-                # if this is an association that matches one of our invite_tokens then we should call the onBind callback
-                # at this point, in order to tell the inviting HS that someone out there has just bound the 3PID.
-                self.sydent.threepidBinder.notifyPendingInvites(assocObj)
+        for originId, inviteToken in invite_tokens.items():
+            tokensStore.storeToken(inviteToken['medium'], inviteToken['address'], inviteToken['room_id'],
+                                inviteToken['sender'], inviteToken['token'],
+                                originServer=peer.servername, originId=originId, commit=False)
+            logger.info("Stored invite token with origin ID %s from %s", originId, peer.servername)
 
-            except:
-                failedIds.append(originId)
-                logger.warn("Failed to verify signed association from %s with origin ID %s",
-                            peer.servername, originId)
-                twisted.python.log.err()
+        # Process any ephemeral public keys
 
-        if len(failedIds) > 0:
+        ephemeral_public_keys = inJson.get('ephemeral_public_keys', {})
+        if len(ephemeral_public_keys) > MAX_EPHEMERAL_PUBLIC_KEYS_LIMIT:
             self.sydent.db.rollback()
+            logger.warn("Peer %s made push with 'sg_assocs' field containing %d entries, which is greater than the maximum %d", peer.servername, len(ephemeral_public_keys), MAX_EPHEMERAL_PUBLIC_KEYS_LIMIT)
             request.setResponseCode(400)
-            return {'errcode': 'M_VERIFICATION_FAILED', 'error': 'Verification failed for one or more associations',
-                    'failed_ids':failedIds}
-        else:
-            self.sydent.db.commit()
-            return {'success':True}
+            request.write(json.dumps({'errcode': 'M_BAD_JSON', 'error': '"ephemeral_public_keys" has more than %d keys' % MAX_EPHEMERAL_PUBLIC_KEYS_LIMIT}))
+            request.finish()
+            return
+
+        for originId, ephemeralKey in ephemeral_public_keys.items():
+            tokensStore.storeEphemeralPublicKey(
+                ephemeralKey['public_key'], persistenceTs=ephemeralKey['persistence_ts'],
+                originServer=peer.servername, originId=originId, commit=False)
+            logger.info("Stored ephemeral key with origin ID %s from %s", originId, peer.servername)
+
+        self.sydent.db.commit()
+        request.write(json.dumps({'success':True}))
+        request.finish()
+        return

+ 16 - 15
sydent/replication/peer.py

@@ -16,6 +16,7 @@
 
 from sydent.db.threepid_associations import GlobalAssociationStore
 from sydent.threepid import threePidAssocFromDict
+from unpaddedbase64 import decode_base64
 
 import signedjson.sign
 import signedjson.key
@@ -23,6 +24,8 @@ import signedjson.key
 import logging
 import json
 
+import nacl
+
 import twisted.internet.reactor
 from twisted.internet import defer
 from twisted.web.client import readBody
@@ -38,14 +41,6 @@ class Peer(object):
         self.pubkeys = pubkeys
         self.shadow = False
 
-    def pushUpdates(self, sgAssocs):
-        """
-        :param sgAssocs: Sequence of (originId, (sgAssoc, shadowSgAssoc)) tuples where originId
-            is the id on the creating server and sgAssoc is the json object of the signed association
-        :return a deferred
-        """
-        pass
-
 
 class LocalPeer(Peer):
     """
@@ -61,6 +56,7 @@ class LocalPeer(Peer):
             self.lastId = -1
 
     def pushUpdates(self, sgAssocs):
+        """Push updates from local associations table to the global one."""
         globalAssocStore = GlobalAssociationStore(self.sydent)
         for localId in sgAssocs:
             if localId > self.lastId:
@@ -97,7 +93,8 @@ class RemotePeer(Peer):
         self.port = 1001
 
         # Get verify key for this peer
-        self.verify_key = self.pubkeys[SIGNING_KEY_ALGORITHM]
+        key_bytes = decode_base64(self.pubkeys[SIGNING_KEY_ALGORITHM])
+        self.verify_key = signedjson.key.decode_verify_key_bytes(SIGNING_KEY_ALGORITHM + ":", key_bytes)
 
         # Attach metadata
         self.verify_key.alg = SIGNING_KEY_ALGORITHM
@@ -122,16 +119,20 @@ class RemotePeer(Peer):
         # Verify the JSON
         signedjson.sign.verify_signed_json(assoc, self.servername, self.verify_key)
 
-    def pushUpdates(self, sgAssocs):
-        if self.shadow:
-            body = {'sgAssocs': { k: v[1] for k, v in sgAssocs.items()}}
-        else:
-            body = {'sgAssocs': { k: v[0] for k, v in sgAssocs.items()}}
+    def pushUpdates(self, data):
+        """Push updates to a remote peer.
+
+        :param data: A dictionary of possible `sg_assocs`, `invite_tokens`
+            and `ephemeral_public_keys` keys.
+        :type data: Dict
+        :returns a deferred.
+        :rtype: Deferred
+        """
 
         reqDeferred = self.sydent.replicationHttpsClient.postJson(self.servername,
                                                                   self.port,
                                                                   '/_matrix/identity/replicate/v1/push',
-                                                                  body)
+                                                                  data)
 
         # XXX: We'll also need to prune the deleted associations out of the
         # local associations table once they've been replicated to all peers

+ 59 - 27
sydent/replication/pusher.py

@@ -23,11 +23,15 @@ import twisted.internet.task
 from sydent.util import time_msec
 from sydent.replication.peer import LocalPeer
 from sydent.db.threepid_associations import LocalAssociationStore
+from sydent.db.invite_tokens import JoinTokenStore
 from sydent.db.peers import PeerStore
 from sydent.threepid.signer import Signer
 
 logger = logging.getLogger(__name__)
 
+EPHEMERAL_PUBLIC_KEYS_PUSH_LIMIT = 100
+INVITE_TOKENS_PUSH_LIMIT = 100
+ASSOCIATIONS_PUSH_LIMIT = 100
 
 class Pusher:
     def __init__(self, sydent):
@@ -39,7 +43,23 @@ class Pusher:
         cb = twisted.internet.task.LoopingCall(Pusher.scheduledPush, self)
         cb.start(10.0)
 
-    def getSignedAssociationsAfterId(self, afterId, limit):
+    def getSignedAssociationsAfterId(self, afterId, limit, shadow=False):
+        """Return max `limit` associations from the database after a given
+        DB table id.
+
+        :param afterId: A database id to act as an offset. Rows after this id
+            are returned.
+        :type afterId: int
+        :param limit: Max amount of database rows to return.
+        :type limit: int
+        :param shadow: Whether these associations are intended for a shadow
+            server.
+        :type shadow: bool
+        :returns a tuple with the first item being a dict of associations,
+            and the second being the maximum table id of the returned
+            associations.
+        :rtype: Tuple[Dict[Dict, Dict], int|None]
+        """
         assocs = {}
 
         localAssocStore = LocalAssociationStore(self.sydent)
@@ -47,22 +67,16 @@ class Pusher:
 
         signer = Signer(self.sydent)
 
-        for localId in localAssocs:
-            sgAssoc = signer.signedThreePidAssociation(localAssocs[localId])
-            shadowSgAssoc = None
-
-            if self.sydent.shadow_hs_master and self.sydent.shadow_hs_slave:
-                shadowAssoc = copy.deepcopy(localAssocs[localId])
-
+        for localId, assoc in localAssocs.items():
+            if shadow and self.sydent.shadow_hs_master and self.sydent.shadow_hs_slave:
                 # mxid is null if 3pid has been unbound
-                if shadowAssoc.mxid:
-                    shadowAssoc.mxid = shadowAssoc.mxid.replace(
+                if assoc.mxid:
+                    assoc.mxid = assoc.mxid.replace(
                         ":" + self.sydent.shadow_hs_master,
                         ":" + self.sydent.shadow_hs_slave
                     )
-                shadowSgAssoc = signer.signedThreePidAssociation(shadowAssoc)
 
-            assocs[localId] = (sgAssoc, shadowSgAssoc)
+            assocs[localId] = signer.signedThreePidAssociation(assoc)
 
         return (assocs, maxId)
 
@@ -80,42 +94,60 @@ class Pusher:
         localPeer.pushUpdates(signedAssocs)
 
     def scheduledPush(self):
+        """Push pending updates to a remote peer. To be called regularly."""
         if self.pushing:
             return
         self.pushing = True
 
         updateDeferred = None
 
+        join_token_store = JoinTokenStore(self.sydent)
+
         try:
             peers = self.peerStore.getAllPeers()
 
             for p in peers:
-                if p.lastSentVersion:
-                    logger.debug("Looking for update after %d to push to %s", p.lastSentVersion, p.servername)
-                else:
-                    logger.debug("Looking for update to push to %s", p.servername)
-                (signedAssocTuples, maxId) = self.getSignedAssociationsAfterId(p.lastSentVersion, 100)
-                logger.debug("%d updates to push to %s", len(signedAssocTuples), p.servername)
-                if len(signedAssocTuples) > 0:
-                    logger.info("Pushing %d updates to %s", len(signedAssocTuples), p.servername)
-                    updateDeferred = p.pushUpdates(signedAssocTuples)
-                    updateDeferred.addCallback(self._pushSucceeded, peer=p, maxId=maxId)
+                logger.debug("Looking for updates to push to %s", p.servername)
+
+                # Dictionary for holding all data to push
+                push_data = {}
+
+                # Dictionary for holding all the ids of db tables we've successfully replicated up to
+                ids = {}
+                total_updates = 0
+
+                # Push associations
+                (push_data["sg_assocs"], ids["sg_assocs"]) = self.getSignedAssociationsAfterId(p.lastSentAssocsId, ASSOCIATIONS_PUSH_LIMIT, p.shadow)
+                total_updates += len(push_data["sg_assocs"])
+
+                # Push invite tokens and ephemeral public keys
+                (push_data["invite_tokens"], ids["invite_tokens"]) = join_token_store.getInviteTokensAfterId(p.lastSentInviteTokensId, INVITE_TOKENS_PUSH_LIMIT)
+                (push_data["ephemeral_public_keys"], ids["ephemeral_public_keys"]) = join_token_store.getEphemeralPublicKeysAfterId(p.lastSentEphemeralKeysId, EPHEMERAL_PUBLIC_KEYS_PUSH_LIMIT)
+                total_updates += len(push_data["invite_tokens"]) + len(push_data["ephemeral_public_keys"])
+
+                logger.debug("%d updates to push to %s", total_updates, p.servername)
+                if total_updates:
+                    logger.info("Pushing %d updates to %s:%d", total_updates, p.servername, p.port)
+                    updateDeferred = p.pushUpdates(push_data)
+                    updateDeferred.addCallback(self._pushSucceeded, peer=p, ids=ids)
                     updateDeferred.addErrback(self._pushFailed, peer=p)
                     break
         finally:
             if not updateDeferred:
                 self.pushing = False
 
-    def _pushSucceeded(self, result, peer, maxId):
-        logger.info("Pushed updates up to %d to %s with result %d %s",
-                    maxId, peer.servername, result.code, result.phrase)
+    def _pushSucceeded(self, result, peer, ids):
+        """To be called after a successful push to a remote peer."""
+        logger.info("Pushed updates to %s with result %d %s",
+                    peer.servername, result.code, result.phrase)
 
-        self.peerStore.setLastSentVersionAndPokeSucceeded(peer.servername, maxId, time_msec())
+        self.peerStore.setLastSentIdAndPokeSucceeded(peer.servername, ids, time_msec())
 
         self.pushing = False
         self.scheduledPush()
 
     def _pushFailed(self, failure, peer):
-        logger.info("Failed to push updates to %s: %s", peer.servername, failure)
+        """To be called after an unsuccessful push to a remote peer."""
+        logger.info("Failed to push updates to %s:%s: %s", peer.servername, peer.port, failure)
         self.pushing = False
         return None

+ 14 - 6
sydent/threepid/__init__.py

@@ -21,12 +21,20 @@ def threePidAssocFromDict(d):
 class ThreepidAssociation:
     def __init__(self, medium, address, mxid, ts, not_before, not_after):
         """
-        :param medium: The medium of the 3pid (eg. email)
-        :param address: The identifier (eg. email address)
-        :param mxid: The matrix ID the 3pid is associated with
-        :param ts: The creation timestamp of this association, ms
-        :param not_before: The timestamp, in ms, at which this association becomes valid
-        :param not_after: The timestamp, in ms, at which this association ceases to be valid
+        :param medium: The medium of the 3pid (eg. email).
+        :type medium: str
+        :param address: The identifier (eg. email address).
+        :type address: str
+        :param mxid: The matrix ID the 3pid is associated with.
+        :type mxid: str
+        :param ts: The creation timestamp of this association, ms.
+        :type ts: int
+        :param not_before: The timestamp, in ms, at which this association
+            becomes valid.
+        :type not_before: int
+        :param not_after: The timestamp, in ms, at which this association
+            ceases to be valid.
+        :type not_after: int
         """
         self.medium = medium
         self.address = address