Browse Source

Make the casefolding migration script more reliable (#379)

If the same MXID is associated with `Alice@example.com` and `alice@example.com`, we shouldn't send them an email.
While I was there, I also realised that the `mxid` variable for the template was set to the string `"mxid"` and the email subject was a bit obscure and not easy to change, so I've also fixed that.

I've also added exponential backoff when sending emails, so an error (e.g. a network blip, the SMTP server rate-limiting, etc) would not mean we need to restart the script from scratch.
Brendan Abolivier 2 years ago
parent
commit
baba66f8b9

+ 1 - 0
changelog.d/379.misc

@@ -0,0 +1 @@
+Case-fold email addresses when binding to MXIDs or performing look-ups. Contributed by H. Shay.

+ 1 - 1
res/matrix-org/migration_template.eml.j2

@@ -2,7 +2,7 @@ Date: {{ date|safe }}
 From: {{ from|safe }}
 To: {{ to|safe }}
 Message-ID: {{ messageid|safe }}
-Subject: {{ subject_header_value|safe }}
+Subject: We have changed the way your Matrix account and email address are associated
 MIME-Version: 1.0
 Content-Type: multipart/alternative;
 	boundary="{{ multipart_boundary|safe }}"

+ 1 - 1
res/vector-im/migration_template.eml.j2

@@ -2,7 +2,7 @@ Date: {{ date|safe }}
 From: {{ from|safe }}
 To: {{ to|safe }}
 Message-ID: {{ messageid|safe }}
-Subject: {{ subject_header_value|safe }}
+Subject: We have changed the way your Matrix account and email address are associated
 MIME-Version: 1.0
 Content-Type: multipart/alternative;
 	boundary="{{ multipart_boundary|safe }}"

+ 207 - 67
scripts/casefold_db.py

@@ -18,16 +18,56 @@ import json
 import os
 import sqlite3
 import sys
-from typing import Any, Dict, List, Tuple
+import time
+from typing import Any, Dict, List, Optional, Tuple
 
+import attr
 import signedjson.sign
 
 from sydent.sydent import Sydent, parse_config_file
 from sydent.util import json_decoder
-from sydent.util.emailutils import sendEmail
+from sydent.util.emailutils import EmailSendException, sendEmail
 from sydent.util.hash import sha256_and_url_safe_base64
 from tests.utils import ResolvingMemoryReactorClock
 
+# Maximum number of attempts to send an email.
+MAX_ATTEMPTS_FOR_EMAIL = 5
+
+
+@attr.s(auto_attribs=True)
+class UpdateDelta:
+    """A row to update in the local_threepid_associations table."""
+
+    address: str
+    mxid: str
+    lookup_hash: str
+
+
+@attr.s(auto_attribs=True)
+class DeleteDelta:
+    """A row to delete from the local_threepid_associations table."""
+
+    address: str
+    mxid: str
+
+
+@attr.s(auto_attribs=True)
+class Delta:
+    """Delta to apply to the local_threepid_associations table for a single
+    case-insensitive email address.
+    """
+
+    to_update: UpdateDelta
+    to_delete: Optional[List[DeleteDelta]] = None
+
+
+class CantSendEmailException(Exception):
+    """Raised when we didn't succeed to send an email after MAX_ATTEMPTS_FOR_EMAIL
+    attempts.
+    """
+
+    pass
+
 
 def calculate_lookup_hash(sydent, address):
     cur = sydent.db.cursor()
@@ -38,18 +78,75 @@ def calculate_lookup_hash(sydent, address):
     return lookup_hash
 
 
+def sendEmailWithBackoff(
+    sydent: Sydent,
+    address: str,
+    mxid: str,
+    test: bool = False,
+) -> None:
+    """Send an email with exponential backoff - that way we don't stop sending halfway
+    through if the SMTP server rejects our email (e.g. because of rate limiting).
+
+    Setting test to True disables the logging.
+
+    Raises a CantSendEmailException if no email could be sent after MAX_ATTEMPTS_FOR_EMAIL
+    attempts.
+    """
+
+    # Disable backoff if we're running tests.
+    backoff = 1 if not test else 0
+
+    for i in range(MAX_ATTEMPTS_FOR_EMAIL):
+        try:
+            template_file = sydent.get_branded_template(
+                None,
+                "migration_template.eml",
+                ("email", "email.template"),
+            )
+
+            sendEmail(
+                sydent,
+                template_file,
+                address,
+                {"mxid": mxid},
+                log_send_errors=False,
+            )
+            if not test:
+                print("Sent email to %s" % address)
+
+            return
+        except EmailSendException:
+            if not test:
+                print(
+                    "Failed to send email to %s (attempt %d/%d)"
+                    % (address, i + 1, MAX_ATTEMPTS_FOR_EMAIL)
+                )
+
+            time.sleep(backoff)
+            backoff *= 2
+
+    raise CantSendEmailException()
+
+
 def update_local_associations(
-    sydent, db: sqlite3.Connection, send_email: bool, dry_run: bool
-):
+    sydent,
+    db: sqlite3.Connection,
+    send_email: bool,
+    dry_run: bool,
+    test: bool = False,
+) -> None:
     """Update the DB table local_threepid_associations so that all stored
     emails are casefolded, and any duplicate mxid's associated with the
     given email are deleted.
 
+    Setting dry_run to True means that the script is being run in dry-run mode
+    by the user, i.e. it will run but will not send any email nor update the database.
+    Setting test to True means that the function is being called as part of an automated
+    test, and therefore we should neither backoff when sending emails or log.
+
     :return: None
     """
-    cur = db.cursor()
-
-    res = cur.execute(
+    res = db.execute(
         "SELECT address, mxid FROM local_threepid_associations WHERE medium = 'email'"
         "ORDER BY ts DESC"
     )
@@ -70,81 +167,113 @@ def update_local_associations(
         else:
             associations[casefold_address] = [(address, mxid, lookup_hash)]
 
-    # list of arguments to update db with
-    db_update_args: List[Tuple[str, str, str, str]] = []
-
-    # list of mxids to delete
-    to_delete: List[Tuple[str]] = []
-
-    # list of mxids to send emails to letting them know the mxid has been deleted
-    mxids: List[Tuple[str, str]] = []
+    # Deltas to apply to the database, associated with the casefolded address they're for.
+    deltas: Dict[str, Delta] = {}
 
+    # Iterate through the results, to build the deltas.
     for casefold_address, assoc_tuples in associations.items():
-        db_update_args.append(
-            (
-                casefold_address,
-                assoc_tuples[0][2],
-                assoc_tuples[0][0],
-                assoc_tuples[0][1],
+        # If the row is already in the right state and there's no duplicate, don't compute
+        # a delta for it.
+        if len(assoc_tuples) == 1 and assoc_tuples[0][0] == casefold_address:
+            continue
+
+        deltas[casefold_address] = Delta(
+            to_update=UpdateDelta(
+                address=assoc_tuples[0][0],
+                mxid=assoc_tuples[0][1],
+                lookup_hash=assoc_tuples[0][2],
             )
         )
 
         if len(assoc_tuples) > 1:
             # Iterate over all associations except for the first one, since we've already
             # processed it.
+            deltas[casefold_address].to_delete = []
             for address, mxid, _ in assoc_tuples[1:]:
-                to_delete.append((address,))
-                mxids.append((mxid, address))
-
-    # iterate through the mxids and send email, let's only send one email per mxid
-    if send_email and not dry_run:
-        for mxid, address in mxids:
-            processed_mxids = []
-
-            if mxid in processed_mxids:
-                continue
-            else:
-                templateFile = sydent.get_branded_template(
-                    None,
-                    "migration_template.eml",
-                    ("email", "email.template"),
+                deltas[casefold_address].to_delete.append(
+                    DeleteDelta(
+                        address=address,
+                        mxid=mxid,
+                    )
                 )
 
-                sendEmail(
-                    sydent,
-                    templateFile,
-                    address,
-                    {"mxid": "mxid", "subject_header_value": "MatrixID Update"},
-                )
-                processed_mxids.append(mxid)
-
-    print(
-        f"{len(to_delete)} rows to delete, {len(db_update_args)} rows to update in local_threepid_associations"
-    )
-
-    if not dry_run:
-        if len(to_delete) > 0:
-            cur.executemany(
-                "DELETE FROM local_threepid_associations WHERE address = ?", to_delete
-            )
+    if not test:
+        print(f"{len(deltas)} rows to update in local_threepid_associations")
 
-        if len(db_update_args) > 0:
-            cur.executemany(
-                "UPDATE local_threepid_associations SET address = ?, lookup_hash = ? WHERE address = ? AND mxid = ?",
-                db_update_args,
-            )
+    # Apply the deltas
+    for casefolded_address, delta in deltas.items():
+        if not test:
+            log_msg = f"Updating {casefolded_address}"
+            if delta.to_delete is not None:
+                log_msg += (
+                    f" and deleting {len(delta.to_delete)} rows associated with it"
+                )
+            print(log_msg)
+
+        try:
+            # Delete each association, and send an email mentioning the affected MXID.
+            if delta.to_delete is not None:
+                for to_delete in delta.to_delete:
+                    if send_email and not dry_run:
+                        # If the MXID is one that will still be associated with this
+                        # email address after this run, don't send an email for it.
+                        if to_delete.mxid == delta.to_update.mxid:
+                            continue
+
+                        sendEmailWithBackoff(
+                            sydent,
+                            to_delete.address,
+                            to_delete.mxid,
+                            test=test,
+                        )
+
+                    if not dry_run:
+                        cur = db.cursor()
+                        cur.execute(
+                            "DELETE FROM local_threepid_associations WHERE address = ?",
+                            (to_delete.address,),
+                        )
+                        db.commit()
+
+            # Update the row now that there's no duplicate.
+            if not dry_run:
+                cur = db.cursor()
+                cur.execute(
+                    "UPDATE local_threepid_associations SET address = ?, lookup_hash = ? WHERE address = ? AND mxid = ?",
+                    (
+                        casefolded_address,
+                        delta.to_update.lookup_hash,
+                        delta.to_update.address,
+                        delta.to_update.mxid,
+                    ),
+                )
+                db.commit()
 
-        # We've finished updating the database, committing the transaction.
-        db.commit()
+        except CantSendEmailException:
+            # If we failed because we couldn't send an email move on to the next address
+            # to de-duplicate.
+            # We catch this error here rather than when sending the email because we want
+            # to avoid deleting rows we can't warn users about, and we don't want to
+            # proceed with the subsequent update because there might still be duplicates
+            # in the database (since we haven't deleted everything we wanted to delete).
+            continue
 
 
 def update_global_associations(
-    sydent, db: sqlite3.Connection, send_email: bool, dry_run: bool
-):
+    sydent,
+    db: sqlite3.Connection,
+    dry_run: bool,
+    test: bool = False,
+) -> None:
     """Update the DB table global_threepid_associations so that all stored
     emails are casefolded, the signed association is re-signed and any duplicate
     mxid's associated with the given email are deleted.
 
+    Setting dry_run to True means that the script is being run in dry-run mode
+    by the user, i.e. it will run but will not send any email nor update the database.
+    Setting test to True means that the function is being called as part of an automated
+    test, and therefore we should suppress logs.
+
     :return: None
     """
 
@@ -194,6 +323,11 @@ def update_global_associations(
     to_delete: List[Tuple[str]] = []
 
     for casefold_address, assoc_tuples in associations.items():
+        # If the row is already in the right state and there's no duplicate, don't compute
+        # a delta for it.
+        if len(assoc_tuples) == 1 and assoc_tuples[0][0] == casefold_address:
+            continue
+
         db_update_args.append(
             (
                 casefold_address,
@@ -210,9 +344,10 @@ def update_global_associations(
             for address, mxid, _, _ in assoc_tuples[1:]:
                 to_delete.append((address,))
 
-    print(
-        f"{len(to_delete)} rows to delete, {len(db_update_args)} rows to update in global_threepid_associations"
-    )
+    if not test:
+        print(
+            f"{len(to_delete)} rows to delete, {len(db_update_args)} rows to update in global_threepid_associations"
+        )
 
     if not dry_run:
         if len(to_delete) > 0:
@@ -254,5 +389,10 @@ if __name__ == "__main__":
     reactor = ResolvingMemoryReactorClock()
     sydent = Sydent(config, reactor, False)
 
-    update_global_associations(sydent, sydent.db, not args.no_email, args.dry_run)
-    update_local_associations(sydent, sydent.db, not args.no_email, args.dry_run)
+    update_global_associations(sydent, sydent.db, dry_run=args.dry_run)
+    update_local_associations(
+        sydent,
+        sydent.db,
+        send_email=not args.no_email,
+        dry_run=args.dry_run,
+    )

+ 8 - 2
sydent/util/emailutils.py

@@ -34,7 +34,11 @@ logger = logging.getLogger(__name__)
 
 
 def sendEmail(
-    sydent: "Sydent", templateFile: str, mailTo: str, substitutions: Dict[str, str]
+    sydent: "Sydent",
+    templateFile: str,
+    mailTo: str,
+    substitutions: Dict[str, str],
+    log_send_errors: bool = True,
 ) -> None:
     """
     Sends an email with the given parameters.
@@ -45,6 +49,7 @@ def sendEmail(
         email.
     :param mailTo: The email address to send the email to.
     :param substitutions: The substitutions to use with the template.
+    :param log_send_errors: Whether to log errors happening when sending an email.
     """
     mailFrom = sydent.cfg.get("email", "email.from")
 
@@ -121,7 +126,8 @@ def sendEmail(
         smtp.sendmail(mailFrom, mailTo, mailString.encode("utf-8"))
         smtp.quit()
     except Exception as origException:
-        twisted.python.log.err()
+        if log_send_errors:
+            twisted.python.log.err()
         raise EmailSendException() from origException
 
 

+ 24 - 6
tests/test_casefold_migration.py

@@ -63,7 +63,7 @@ class MigrationTestCase(unittest.TestCase):
                     "medium": "email",
                     "address": address,
                     "lookup_hash": calculate_lookup_hash(self.sydent, address),
-                    "mxid": "@BOB%d:example.com" % i,
+                    "mxid": "@otherbob%d:example.com" % i,
                     "ts": (i * 10000),
                     "not_before": 0,
                     "not_after": 99999999999,
@@ -195,7 +195,11 @@ class MigrationTestCase(unittest.TestCase):
     def test_local_db_migration(self):
         with patch("sydent.util.emailutils.smtplib") as smtplib:
             update_local_associations(
-                self.sydent, self.sydent.db, send_email=True, dry_run=False
+                self.sydent,
+                self.sydent.db,
+                send_email=True,
+                dry_run=False,
+                test=True,
             )
 
         # test 5 emails were sent
@@ -236,7 +240,10 @@ class MigrationTestCase(unittest.TestCase):
 
     def test_global_db_migration(self):
         update_global_associations(
-            self.sydent, self.sydent.db, send_email=True, dry_run=False
+            self.sydent,
+            self.sydent.db,
+            dry_run=False,
+            test=True,
         )
 
         cur = self.sydent.db.cursor()
@@ -262,7 +269,11 @@ class MigrationTestCase(unittest.TestCase):
     def test_local_no_email_does_not_send_email(self):
         with patch("sydent.util.emailutils.smtplib") as smtplib:
             update_local_associations(
-                self.sydent, self.sydent.db, send_email=False, dry_run=False
+                self.sydent,
+                self.sydent.db,
+                send_email=False,
+                dry_run=False,
+                test=True,
             )
             smtp = smtplib.SMTP.return_value
 
@@ -281,7 +292,10 @@ class MigrationTestCase(unittest.TestCase):
 
         with patch("sydent.util.emailutils.smtplib") as smtplib:
             update_global_associations(
-                self.sydent, self.sydent.db, send_email=True, dry_run=True
+                self.sydent,
+                self.sydent.db,
+                dry_run=True,
+                test=True,
             )
 
         # test no emails were sent
@@ -299,7 +313,11 @@ class MigrationTestCase(unittest.TestCase):
 
         with patch("sydent.util.emailutils.smtplib") as smtplib:
             update_local_associations(
-                self.sydent, self.sydent.db, send_email=True, dry_run=True
+                self.sydent,
+                self.sydent.db,
+                send_email=True,
+                dry_run=True,
+                test=True,
             )
 
         # test no emails were sent