1
0

casefold_db.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. #!/usr/bin/env python
  2. # Copyright 2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import json
  17. import logging
  18. import os
  19. import sqlite3
  20. import sys
  21. import time
  22. from typing import Any, Dict, List, Optional, Tuple
  23. import attr
  24. import signedjson.sign
  25. from sydent.config import SydentConfig
  26. from sydent.sydent import Sydent
  27. from sydent.util import json_decoder
  28. from sydent.util.emailutils import EmailSendException, sendEmail
  29. from sydent.util.hash import sha256_and_url_safe_base64
  30. from tests.utils import ResolvingMemoryReactorClock
  31. logger = logging.getLogger("casefold_db")
  32. # Maximum number of attempts to send an email.
  33. MAX_ATTEMPTS_FOR_EMAIL = 5
  34. @attr.s(auto_attribs=True)
  35. class UpdateDelta:
  36. """A row to update in the local_threepid_associations table."""
  37. address: str
  38. mxid: str
  39. lookup_hash: str
  40. @attr.s(auto_attribs=True)
  41. class DeleteDelta:
  42. """A row to delete from the local_threepid_associations table."""
  43. address: str
  44. mxid: str
  45. @attr.s(auto_attribs=True)
  46. class Delta:
  47. """Delta to apply to the local_threepid_associations table for a single
  48. case-insensitive email address.
  49. """
  50. to_update: UpdateDelta
  51. to_delete: Optional[List[DeleteDelta]] = None
  52. class CantSendEmailException(Exception):
  53. """Raised when we didn't succeed to send an email after MAX_ATTEMPTS_FOR_EMAIL
  54. attempts.
  55. """
  56. pass
  57. def calculate_lookup_hash(sydent: Sydent, address: str) -> str:
  58. pepper = sydent.threepidBinder.hashing_store.get_lookup_pepper()
  59. if pepper is None:
  60. raise RuntimeError(
  61. "No lookup pepper found; Sydent should have generated one on startup."
  62. )
  63. combo = "%s %s %s" % (address, "email", pepper)
  64. lookup_hash = sha256_and_url_safe_base64(combo)
  65. return lookup_hash
  66. def sendEmailWithBackoff(
  67. sydent: Sydent,
  68. address: str,
  69. mxid: str,
  70. test: bool = False,
  71. ) -> None:
  72. """Send an email with exponential backoff - that way we don't stop sending halfway
  73. through if the SMTP server rejects our email (e.g. because of rate limiting).
  74. Setting test to True disables the backoff.
  75. Raises a CantSendEmailException if no email could be sent after MAX_ATTEMPTS_FOR_EMAIL
  76. attempts.
  77. """
  78. # Disable backoff if we're running tests.
  79. backoff = 1 if not test else 0
  80. for i in range(MAX_ATTEMPTS_FOR_EMAIL):
  81. try:
  82. template_file = sydent.get_branded_template(
  83. None,
  84. "migration_template.eml",
  85. )
  86. sendEmail(
  87. sydent,
  88. template_file,
  89. address,
  90. {"mxid": mxid},
  91. log_send_errors=False,
  92. )
  93. logger.info("Sent email to %s" % address)
  94. return
  95. except EmailSendException:
  96. logger.info(
  97. "Failed to send email to %s (attempt %d/%d)"
  98. % (address, i + 1, MAX_ATTEMPTS_FOR_EMAIL)
  99. )
  100. time.sleep(backoff)
  101. backoff *= 2
  102. raise CantSendEmailException()
  103. def update_local_associations(
  104. sydent: Sydent,
  105. db: sqlite3.Connection,
  106. send_email: bool,
  107. dry_run: bool,
  108. test: bool = False,
  109. ) -> None:
  110. """Update the DB table local_threepid_associations so that all stored
  111. emails are casefolded, and any duplicate mxid's associated with the
  112. given email are deleted.
  113. Setting dry_run to True means that the script is being run in dry-run mode
  114. by the user, i.e. it will run but will not send any email nor update the database.
  115. Setting test to True means that the function is being called as part of an automated
  116. test, and therefore we shouldn't backoff when sending emails.
  117. :return: None
  118. """
  119. logger.info("Processing rows in local_threepid_associations")
  120. res = db.execute(
  121. "SELECT address, mxid FROM local_threepid_associations WHERE medium = 'email'"
  122. "ORDER BY ts DESC"
  123. )
  124. # a dict that associates an email address with correspoinding mxids and lookup hashes
  125. associations: Dict[str, List[Tuple[str, str, str]]] = {}
  126. logger.info("Computing new hashes and signatures for local_threepid_associations")
  127. # iterate through selected associations, casefold email, rehash it, and add to
  128. # associations dict
  129. for address, mxid in res.fetchall():
  130. casefold_address = address.casefold()
  131. # rehash email since hashes are case-sensitive
  132. lookup_hash = calculate_lookup_hash(sydent, casefold_address)
  133. if casefold_address in associations:
  134. associations[casefold_address].append((address, mxid, lookup_hash))
  135. else:
  136. associations[casefold_address] = [(address, mxid, lookup_hash)]
  137. # Deltas to apply to the database, associated with the casefolded address they're for.
  138. deltas: Dict[str, Delta] = {}
  139. # Iterate through the results, to build the deltas.
  140. for casefold_address, assoc_tuples in associations.items():
  141. # If the row is already in the right state and there's no duplicate, don't compute
  142. # a delta for it.
  143. if len(assoc_tuples) == 1 and assoc_tuples[0][0] == casefold_address:
  144. continue
  145. deltas[casefold_address] = Delta(
  146. to_update=UpdateDelta(
  147. address=assoc_tuples[0][0],
  148. mxid=assoc_tuples[0][1],
  149. lookup_hash=assoc_tuples[0][2],
  150. )
  151. )
  152. if len(assoc_tuples) > 1:
  153. # Iterate over all associations except for the first one, since we've already
  154. # processed it.
  155. deltas[casefold_address].to_delete = []
  156. for address, mxid, _ in assoc_tuples[1:]:
  157. deltas[casefold_address].to_delete.append(
  158. DeleteDelta(
  159. address=address,
  160. mxid=mxid,
  161. )
  162. )
  163. logger.info(f"{len(deltas)} rows to update in local_threepid_associations")
  164. # Apply the deltas
  165. for casefolded_address, delta in deltas.items():
  166. if not test:
  167. log_msg = f"Updating {casefolded_address}"
  168. if delta.to_delete is not None:
  169. log_msg += (
  170. f" and deleting {len(delta.to_delete)} rows associated with it"
  171. )
  172. logger.info(log_msg)
  173. try:
  174. # Delete each association, and send an email mentioning the affected MXID.
  175. if delta.to_delete is not None and not dry_run:
  176. for to_delete in delta.to_delete:
  177. if send_email and to_delete.mxid != delta.to_update.mxid:
  178. # If the MXID is one that will still be associated with this
  179. # email address after this run, don't send an email for it.
  180. sendEmailWithBackoff(
  181. sydent,
  182. to_delete.address,
  183. to_delete.mxid,
  184. test=test,
  185. )
  186. logger.debug(
  187. "Deleting %s from table local_threepid_associations",
  188. to_delete.address,
  189. )
  190. cur = db.cursor()
  191. cur.execute(
  192. "DELETE FROM local_threepid_associations WHERE medium = 'email' AND address = ?",
  193. (to_delete.address,),
  194. )
  195. db.commit()
  196. # Update the row now that there's no duplicate.
  197. if not dry_run:
  198. logger.debug(
  199. "Updating table local threepid associations setting address to %s, "
  200. "lookup_hash to %s, where medium = email and address = %s and mxid = %s",
  201. casefolded_address,
  202. delta.to_update.lookup_hash,
  203. delta.to_update.address,
  204. delta.to_update.mxid,
  205. )
  206. cur = db.cursor()
  207. cur.execute(
  208. "UPDATE local_threepid_associations SET address = ?, lookup_hash = ? WHERE medium = 'email' AND address = ? AND mxid = ?",
  209. (
  210. casefolded_address,
  211. delta.to_update.lookup_hash,
  212. delta.to_update.address,
  213. delta.to_update.mxid,
  214. ),
  215. )
  216. db.commit()
  217. except CantSendEmailException:
  218. # If we failed because we couldn't send an email move on to the next address
  219. # to de-duplicate.
  220. # We catch this error here rather than when sending the email because we want
  221. # to avoid deleting rows we can't warn users about, and we don't want to
  222. # proceed with the subsequent update because there might still be duplicates
  223. # in the database (since we haven't deleted everything we wanted to delete).
  224. logger.warn("Failed to send email to %s; skipping!", to_delete.address)
  225. continue
  226. def update_global_associations(
  227. sydent: Sydent,
  228. db: sqlite3.Connection,
  229. dry_run: bool,
  230. ) -> None:
  231. """Update the DB table global_threepid_associations so that all stored
  232. emails are casefolded, the signed association is re-signed and any duplicate
  233. mxid's associated with the given email are deleted.
  234. Setting dry_run to True means that the script is being run in dry-run mode
  235. by the user, i.e. it will run but will not send any email nor update the database.
  236. :return: None
  237. """
  238. logger.info("Processing rows in global_threepid_associations")
  239. # get every row where the local server is origin server and medium is email
  240. origin_server = sydent.config.general.server_name
  241. medium = "email"
  242. res = db.execute(
  243. "SELECT address, mxid, sgAssoc FROM global_threepid_associations WHERE medium = ?"
  244. "AND originServer = ? ORDER BY ts DESC",
  245. (medium, origin_server),
  246. )
  247. # dict that stores email address with mxid, email address, lookup hash, and
  248. # signed association
  249. associations: Dict[str, List[Tuple[str, str, str, str]]] = {}
  250. logger.info("Computing new hashes and signatures for global_threepid_associations")
  251. # iterate through selected associations, casefold email, rehash it, re-sign the
  252. # associations and add to associations dict
  253. for address, mxid, sg_assoc in res.fetchall():
  254. casefold_address = address.casefold()
  255. # rehash the email since hash functions are case-sensitive
  256. lookup_hash = calculate_lookup_hash(sydent, casefold_address)
  257. # update signed associations with new casefolded address and re-sign
  258. sg_assoc = json_decoder.decode(sg_assoc)
  259. sg_assoc["address"] = address.casefold()
  260. sg_assoc = json.dumps(
  261. signedjson.sign.sign_json(
  262. sg_assoc, sydent.config.general.server_name, sydent.keyring.ed25519
  263. )
  264. )
  265. if casefold_address in associations:
  266. associations[casefold_address].append(
  267. (address, mxid, lookup_hash, sg_assoc)
  268. )
  269. else:
  270. associations[casefold_address] = [(address, mxid, lookup_hash, sg_assoc)]
  271. # list of arguments to update db with
  272. db_update_args: List[Tuple[Any, str, str, str, str]] = []
  273. # list of mxids to delete
  274. to_delete: List[Tuple[str]] = []
  275. for casefold_address, assoc_tuples in associations.items():
  276. # If the row is already in the right state and there's no duplicate, don't compute
  277. # a delta for it.
  278. if len(assoc_tuples) == 1 and assoc_tuples[0][0] == casefold_address:
  279. continue
  280. db_update_args.append(
  281. (
  282. casefold_address,
  283. assoc_tuples[0][2],
  284. assoc_tuples[0][3],
  285. assoc_tuples[0][0],
  286. assoc_tuples[0][1],
  287. )
  288. )
  289. if len(assoc_tuples) > 1:
  290. # Iterate over all associations except for the first one, since we've already
  291. # processed it.
  292. for address, mxid, _, _ in assoc_tuples[1:]:
  293. to_delete.append((address,))
  294. logger.info(
  295. f"{len(to_delete)} rows to delete, {len(db_update_args)} rows to update in global_threepid_associations"
  296. )
  297. if not dry_run:
  298. cur = db.cursor()
  299. if len(to_delete) > 0:
  300. cur.executemany(
  301. "DELETE FROM global_threepid_associations WHERE medium = 'email' AND address = ?",
  302. to_delete,
  303. )
  304. logger.info(
  305. f"{len(to_delete)} rows deleted from global_threepid_associations"
  306. )
  307. if len(db_update_args) > 0:
  308. cur.executemany(
  309. "UPDATE global_threepid_associations SET address = ?, lookup_hash = ?, sgAssoc = ? WHERE medium = 'email' AND address = ? AND mxid = ?",
  310. db_update_args,
  311. )
  312. logger.info(
  313. f"{len(db_update_args)} rows updated in global_threepid_associations"
  314. )
  315. db.commit()
  316. if __name__ == "__main__":
  317. parser = argparse.ArgumentParser(description="Casefold email addresses in database")
  318. parser.add_argument(
  319. "--no-email", action="store_true", help="run script but do not send emails"
  320. )
  321. parser.add_argument(
  322. "--dry-run",
  323. action="store_true",
  324. help="run script but do not send emails or alter database",
  325. )
  326. parser.add_argument("config_path", help="path to the sydent configuration file")
  327. args = parser.parse_args()
  328. # Set up logging.
  329. log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"
  330. formatter = logging.Formatter(log_format)
  331. handler = logging.StreamHandler()
  332. handler.setFormatter(formatter)
  333. logger.setLevel(logging.INFO)
  334. logger.addHandler(handler)
  335. # if the path the user gives us doesn't work, find it for them
  336. if not os.path.exists(args.config_path):
  337. logger.error(f"The config file '{args.config_path}' does not exist.")
  338. sys.exit(1)
  339. sydent_config = SydentConfig()
  340. sydent_config.parse_config_file(args.config_path)
  341. reactor = ResolvingMemoryReactorClock()
  342. sydent = Sydent(sydent_config, reactor, False)
  343. update_global_associations(sydent, sydent.db, dry_run=args.dry_run)
  344. update_local_associations(
  345. sydent,
  346. sydent.db,
  347. send_email=not args.no_email,
  348. dry_run=args.dry_run,
  349. )