casefold_db.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. #!/usr/bin/env python
  2. # Copyright 2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import json
  17. import logging
  18. import os
  19. import sqlite3
  20. import sys
  21. import time
  22. from typing import Any, Dict, List, Optional, Tuple
  23. import attr
  24. import signedjson.sign
  25. from sydent.config import SydentConfig
  26. from sydent.sydent import Sydent
  27. from sydent.util import json_decoder
  28. from sydent.util.emailutils import EmailSendException, sendEmail
  29. from sydent.util.hash import sha256_and_url_safe_base64
  30. from tests.utils import ResolvingMemoryReactorClock
  31. logger = logging.getLogger("casefold_db")
  32. # Maximum number of attempts to send an email.
  33. MAX_ATTEMPTS_FOR_EMAIL = 5
  34. @attr.s(auto_attribs=True)
  35. class UpdateDelta:
  36. """A row to update in the local_threepid_associations table."""
  37. address: str
  38. mxid: str
  39. lookup_hash: str
  40. @attr.s(auto_attribs=True)
  41. class DeleteDelta:
  42. """A row to delete from the local_threepid_associations table."""
  43. address: str
  44. mxid: str
  45. @attr.s(auto_attribs=True)
  46. class Delta:
  47. """Delta to apply to the local_threepid_associations table for a single
  48. case-insensitive email address.
  49. """
  50. to_update: UpdateDelta
  51. to_delete: Optional[List[DeleteDelta]] = None
  52. class CantSendEmailException(Exception):
  53. """Raised when we didn't succeed to send an email after MAX_ATTEMPTS_FOR_EMAIL
  54. attempts.
  55. """
  56. pass
  57. def calculate_lookup_hash(sydent, address):
  58. cur = sydent.db.cursor()
  59. pepper_result = cur.execute("SELECT lookup_pepper from hashing_metadata")
  60. pepper = pepper_result.fetchone()[0]
  61. combo = "%s %s %s" % (address, "email", pepper)
  62. lookup_hash = sha256_and_url_safe_base64(combo)
  63. return lookup_hash
  64. def sendEmailWithBackoff(
  65. sydent: Sydent,
  66. address: str,
  67. mxid: str,
  68. test: bool = False,
  69. ) -> None:
  70. """Send an email with exponential backoff - that way we don't stop sending halfway
  71. through if the SMTP server rejects our email (e.g. because of rate limiting).
  72. Setting test to True disables the backoff.
  73. Raises a CantSendEmailException if no email could be sent after MAX_ATTEMPTS_FOR_EMAIL
  74. attempts.
  75. """
  76. # Disable backoff if we're running tests.
  77. backoff = 1 if not test else 0
  78. for i in range(MAX_ATTEMPTS_FOR_EMAIL):
  79. try:
  80. template_file = sydent.get_branded_template(
  81. None,
  82. "migration_template.eml",
  83. )
  84. sendEmail(
  85. sydent,
  86. template_file,
  87. address,
  88. {"mxid": mxid},
  89. log_send_errors=False,
  90. )
  91. logger.info("Sent email to %s" % address)
  92. return
  93. except EmailSendException:
  94. logger.info(
  95. "Failed to send email to %s (attempt %d/%d)"
  96. % (address, i + 1, MAX_ATTEMPTS_FOR_EMAIL)
  97. )
  98. time.sleep(backoff)
  99. backoff *= 2
  100. raise CantSendEmailException()
  101. def update_local_associations(
  102. sydent: Sydent,
  103. db: sqlite3.Connection,
  104. send_email: bool,
  105. dry_run: bool,
  106. test: bool = False,
  107. ) -> None:
  108. """Update the DB table local_threepid_associations so that all stored
  109. emails are casefolded, and any duplicate mxid's associated with the
  110. given email are deleted.
  111. Setting dry_run to True means that the script is being run in dry-run mode
  112. by the user, i.e. it will run but will not send any email nor update the database.
  113. Setting test to True means that the function is being called as part of an automated
  114. test, and therefore we shouldn't backoff when sending emails.
  115. :return: None
  116. """
  117. logger.info("Processing rows in local_threepid_associations")
  118. res = db.execute(
  119. "SELECT address, mxid FROM local_threepid_associations WHERE medium = 'email'"
  120. "ORDER BY ts DESC"
  121. )
  122. # a dict that associates an email address with correspoinding mxids and lookup hashes
  123. associations: Dict[str, List[Tuple[str, str, str]]] = {}
  124. logger.info("Computing new hashes and signatures for local_threepid_associations")
  125. # iterate through selected associations, casefold email, rehash it, and add to
  126. # associations dict
  127. for address, mxid in res.fetchall():
  128. casefold_address = address.casefold()
  129. # rehash email since hashes are case-sensitive
  130. lookup_hash = calculate_lookup_hash(sydent, casefold_address)
  131. if casefold_address in associations:
  132. associations[casefold_address].append((address, mxid, lookup_hash))
  133. else:
  134. associations[casefold_address] = [(address, mxid, lookup_hash)]
  135. # Deltas to apply to the database, associated with the casefolded address they're for.
  136. deltas: Dict[str, Delta] = {}
  137. # Iterate through the results, to build the deltas.
  138. for casefold_address, assoc_tuples in associations.items():
  139. # If the row is already in the right state and there's no duplicate, don't compute
  140. # a delta for it.
  141. if len(assoc_tuples) == 1 and assoc_tuples[0][0] == casefold_address:
  142. continue
  143. deltas[casefold_address] = Delta(
  144. to_update=UpdateDelta(
  145. address=assoc_tuples[0][0],
  146. mxid=assoc_tuples[0][1],
  147. lookup_hash=assoc_tuples[0][2],
  148. )
  149. )
  150. if len(assoc_tuples) > 1:
  151. # Iterate over all associations except for the first one, since we've already
  152. # processed it.
  153. deltas[casefold_address].to_delete = []
  154. for address, mxid, _ in assoc_tuples[1:]:
  155. deltas[casefold_address].to_delete.append(
  156. DeleteDelta(
  157. address=address,
  158. mxid=mxid,
  159. )
  160. )
  161. logger.info(f"{len(deltas)} rows to update in local_threepid_associations")
  162. # Apply the deltas
  163. for casefolded_address, delta in deltas.items():
  164. if not test:
  165. log_msg = f"Updating {casefolded_address}"
  166. if delta.to_delete is not None:
  167. log_msg += (
  168. f" and deleting {len(delta.to_delete)} rows associated with it"
  169. )
  170. logger.info(log_msg)
  171. try:
  172. # Delete each association, and send an email mentioning the affected MXID.
  173. if delta.to_delete is not None:
  174. for to_delete in delta.to_delete:
  175. if send_email and not dry_run:
  176. # If the MXID is one that will still be associated with this
  177. # email address after this run, don't send an email for it.
  178. if to_delete.mxid == delta.to_update.mxid:
  179. continue
  180. sendEmailWithBackoff(
  181. sydent,
  182. to_delete.address,
  183. to_delete.mxid,
  184. test=test,
  185. )
  186. if not dry_run:
  187. cur = db.cursor()
  188. cur.execute(
  189. "DELETE FROM local_threepid_associations WHERE medium = 'email' AND address = ?",
  190. (to_delete.address,),
  191. )
  192. db.commit()
  193. # Update the row now that there's no duplicate.
  194. if not dry_run:
  195. cur = db.cursor()
  196. cur.execute(
  197. "UPDATE local_threepid_associations SET address = ?, lookup_hash = ? WHERE medium = 'email' AND address = ? AND mxid = ?",
  198. (
  199. casefolded_address,
  200. delta.to_update.lookup_hash,
  201. delta.to_update.address,
  202. delta.to_update.mxid,
  203. ),
  204. )
  205. db.commit()
  206. except CantSendEmailException:
  207. # If we failed because we couldn't send an email move on to the next address
  208. # to de-duplicate.
  209. # We catch this error here rather than when sending the email because we want
  210. # to avoid deleting rows we can't warn users about, and we don't want to
  211. # proceed with the subsequent update because there might still be duplicates
  212. # in the database (since we haven't deleted everything we wanted to delete).
  213. continue
  214. def update_global_associations(
  215. sydent: Sydent,
  216. db: sqlite3.Connection,
  217. dry_run: bool,
  218. ) -> None:
  219. """Update the DB table global_threepid_associations so that all stored
  220. emails are casefolded, the signed association is re-signed and any duplicate
  221. mxid's associated with the given email are deleted.
  222. Setting dry_run to True means that the script is being run in dry-run mode
  223. by the user, i.e. it will run but will not send any email nor update the database.
  224. :return: None
  225. """
  226. logger.info("Processing rows in global_threepid_associations")
  227. # get every row where the local server is origin server and medium is email
  228. origin_server = sydent.config.general.server_name
  229. medium = "email"
  230. res = db.execute(
  231. "SELECT address, mxid, sgAssoc FROM global_threepid_associations WHERE medium = ?"
  232. "AND originServer = ? ORDER BY ts DESC",
  233. (medium, origin_server),
  234. )
  235. # dict that stores email address with mxid, email address, lookup hash, and
  236. # signed association
  237. associations: Dict[str, List[Tuple[str, str, str, str]]] = {}
  238. logger.info("Computing new hashes and signatures for global_threepid_associations")
  239. # iterate through selected associations, casefold email, rehash it, re-sign the
  240. # associations and add to associations dict
  241. for address, mxid, sg_assoc in res.fetchall():
  242. casefold_address = address.casefold()
  243. # rehash the email since hash functions are case-sensitive
  244. lookup_hash = calculate_lookup_hash(sydent, casefold_address)
  245. # update signed associations with new casefolded address and re-sign
  246. sg_assoc = json_decoder.decode(sg_assoc)
  247. sg_assoc["address"] = address.casefold()
  248. sg_assoc = json.dumps(
  249. signedjson.sign.sign_json(
  250. sg_assoc, sydent.config.general.server_name, sydent.keyring.ed25519
  251. )
  252. )
  253. if casefold_address in associations:
  254. associations[casefold_address].append(
  255. (address, mxid, lookup_hash, sg_assoc)
  256. )
  257. else:
  258. associations[casefold_address] = [(address, mxid, lookup_hash, sg_assoc)]
  259. # list of arguments to update db with
  260. db_update_args: List[Tuple[Any, str, str, str, str]] = []
  261. # list of mxids to delete
  262. to_delete: List[Tuple[str]] = []
  263. for casefold_address, assoc_tuples in associations.items():
  264. # If the row is already in the right state and there's no duplicate, don't compute
  265. # a delta for it.
  266. if len(assoc_tuples) == 1 and assoc_tuples[0][0] == casefold_address:
  267. continue
  268. db_update_args.append(
  269. (
  270. casefold_address,
  271. assoc_tuples[0][2],
  272. assoc_tuples[0][3],
  273. assoc_tuples[0][0],
  274. assoc_tuples[0][1],
  275. )
  276. )
  277. if len(assoc_tuples) > 1:
  278. # Iterate over all associations except for the first one, since we've already
  279. # processed it.
  280. for address, mxid, _, _ in assoc_tuples[1:]:
  281. to_delete.append((address,))
  282. logger.info(
  283. f"{len(to_delete)} rows to delete, {len(db_update_args)} rows to update in global_threepid_associations"
  284. )
  285. if not dry_run:
  286. cur = db.cursor()
  287. if len(to_delete) > 0:
  288. cur.executemany(
  289. "DELETE FROM global_threepid_associations WHERE medium = 'email' AND address = ?",
  290. to_delete,
  291. )
  292. logger.info(
  293. f"{len(to_delete)} rows deleted from global_threepid_associations"
  294. )
  295. if len(db_update_args) > 0:
  296. cur.executemany(
  297. "UPDATE global_threepid_associations SET address = ?, lookup_hash = ?, sgAssoc = ? WHERE medium = 'email' AND address = ? AND mxid = ?",
  298. db_update_args,
  299. )
  300. logger.info(
  301. f"{len(db_update_args)} rows updated in global_threepid_associations"
  302. )
  303. db.commit()
  304. if __name__ == "__main__":
  305. parser = argparse.ArgumentParser(description="Casefold email addresses in database")
  306. parser.add_argument(
  307. "--no-email", action="store_true", help="run script but do not send emails"
  308. )
  309. parser.add_argument(
  310. "--dry-run",
  311. action="store_true",
  312. help="run script but do not send emails or alter database",
  313. )
  314. parser.add_argument("config_path", help="path to the sydent configuration file")
  315. args = parser.parse_args()
  316. # Set up logging.
  317. log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"
  318. formatter = logging.Formatter(log_format)
  319. handler = logging.StreamHandler()
  320. handler.setFormatter(formatter)
  321. logger.setLevel(logging.INFO)
  322. logger.addHandler(handler)
  323. # if the path the user gives us doesn't work, find it for them
  324. if not os.path.exists(args.config_path):
  325. logger.error(f"The config file '{args.config_path}' does not exist.")
  326. sys.exit(1)
  327. sydent_config = SydentConfig()
  328. sydent_config.parse_config_file(args.config_path, skip_logging_setup=True)
  329. reactor = ResolvingMemoryReactorClock()
  330. sydent = Sydent(sydent_config, reactor, False)
  331. update_global_associations(sydent, sydent.db, dry_run=args.dry_run)
  332. update_local_associations(
  333. sydent,
  334. sydent.db,
  335. send_email=not args.no_email,
  336. dry_run=args.dry_run,
  337. )