test_casefold_migration.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. import json
  2. import os.path
  3. from unittest.mock import patch
  4. from twisted.trial import unittest
  5. from scripts.casefold_db import (
  6. calculate_lookup_hash,
  7. update_global_associations,
  8. update_local_associations,
  9. )
  10. from sydent.util import json_decoder
  11. from sydent.util.emailutils import sendEmail
  12. from tests.utils import make_sydent
  13. class MigrationTestCase(unittest.TestCase):
  14. def create_signedassoc(self, medium, address, mxid, ts, not_before, not_after):
  15. return {
  16. "medium": medium,
  17. "address": address,
  18. "mxid": mxid,
  19. "ts": ts,
  20. "not_before": not_before,
  21. "not_after": not_after,
  22. }
  23. def setUp(self):
  24. # Create a new sydent
  25. config = {
  26. "general": {
  27. "templates.path": os.path.join(
  28. os.path.dirname(os.path.dirname(__file__)), "res"
  29. ),
  30. },
  31. "crypto": {
  32. "ed25519.signingkey": "ed25519 0 FJi1Rnpj3/otydngacrwddFvwz/dTDsBv62uZDN2fZM"
  33. },
  34. }
  35. self.sydent = make_sydent(test_config=config)
  36. # create some local associations
  37. associations = []
  38. for i in range(10):
  39. address = "bob%d@example.com" % i
  40. associations.append(
  41. {
  42. "medium": "email",
  43. "address": address,
  44. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  45. "mxid": "@bob%d:example.com" % i,
  46. "ts": (i * 10000),
  47. "not_before": 0,
  48. "not_after": 99999999999,
  49. }
  50. )
  51. # create some casefold-conflicting associations
  52. for i in range(5):
  53. address = "BOB%d@example.com" % i
  54. associations.append(
  55. {
  56. "medium": "email",
  57. "address": address,
  58. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  59. "mxid": "@BOB%d:example.com" % i,
  60. "ts": (i * 10000),
  61. "not_before": 0,
  62. "not_after": 99999999999,
  63. }
  64. )
  65. # add all associations to db
  66. cur = self.sydent.db.cursor()
  67. cur.executemany(
  68. "INSERT INTO local_threepid_associations "
  69. "(medium, address, lookup_hash, mxid, ts, notBefore, notAfter) "
  70. "VALUES (?, ?, ?, ?, ?, ?, ?)",
  71. [
  72. (
  73. assoc["medium"],
  74. assoc["address"],
  75. assoc["lookup_hash"],
  76. assoc["mxid"],
  77. assoc["ts"],
  78. assoc["not_before"],
  79. assoc["not_after"],
  80. )
  81. for assoc in associations
  82. ],
  83. )
  84. self.sydent.db.commit()
  85. # create some global associations
  86. associations = []
  87. originServer = self.sydent.config.general.server_name
  88. for i in range(10):
  89. address = "bob%d@example.com" % i
  90. mxid = "@bob%d:example.com" % i
  91. ts = 10000 * i
  92. associations.append(
  93. {
  94. "medium": "email",
  95. "address": address,
  96. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  97. "mxid": mxid,
  98. "ts": ts,
  99. "not_before": 0,
  100. "not_after": 99999999999,
  101. "originServer": originServer,
  102. "originId": i,
  103. "sgAssoc": json.dumps(
  104. self.create_signedassoc(
  105. "email", address, mxid, ts, 0, 99999999999
  106. )
  107. ),
  108. }
  109. )
  110. # create some casefold-conflicting associations
  111. for i in range(5):
  112. address = "BOB%d@example.com" % i
  113. mxid = "@BOB%d:example.com" % i
  114. ts = 10000 * i
  115. associations.append(
  116. {
  117. "medium": "email",
  118. "address": address,
  119. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  120. "mxid": mxid,
  121. "ts": ts + 1,
  122. "not_before": 0,
  123. "not_after": 99999999999,
  124. "originServer": originServer,
  125. "originId": i + 10,
  126. "sgAssoc": json.dumps(
  127. self.create_signedassoc(
  128. "email", address, mxid, ts, 0, 99999999999
  129. )
  130. ),
  131. }
  132. )
  133. # add all associations to db
  134. cur = self.sydent.db.cursor()
  135. cur.executemany(
  136. "INSERT INTO global_threepid_associations "
  137. "(medium, address, lookup_hash, mxid, ts, notBefore, notAfter, originServer, originId, sgAssoc) "
  138. "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
  139. [
  140. (
  141. assoc["medium"],
  142. assoc["address"],
  143. assoc["lookup_hash"],
  144. assoc["mxid"],
  145. assoc["ts"],
  146. assoc["not_before"],
  147. assoc["not_after"],
  148. assoc["originServer"],
  149. assoc["originId"],
  150. assoc["sgAssoc"],
  151. )
  152. for assoc in associations
  153. ],
  154. )
  155. self.sydent.db.commit()
  156. def test_migration_email(self):
  157. with patch("sydent.util.emailutils.smtplib") as smtplib:
  158. if self.sydent.config.email.template is None:
  159. templateFile = self.sydent.get_branded_template(
  160. None,
  161. "migration_template.eml",
  162. ("email", "email.template"),
  163. )
  164. else:
  165. templateFile = self.sydent.config.email.template
  166. sendEmail(
  167. self.sydent,
  168. templateFile,
  169. "bob@example.com",
  170. {
  171. "mxid": "@bob:example.com",
  172. "subject_header_value": "MatrixID Deletion",
  173. },
  174. )
  175. smtp = smtplib.SMTP.return_value
  176. email_contents = smtp.sendmail.call_args[0][2].decode("utf-8")
  177. self.assertIn("In the past", email_contents)
  178. # test email was sent
  179. smtp.sendmail.assert_called()
  180. def test_local_db_migration(self):
  181. with patch("sydent.util.emailutils.smtplib") as smtplib:
  182. update_local_associations(
  183. self.sydent, self.sydent.db, send_email=True, dry_run=False
  184. )
  185. # test 5 emails were sent
  186. smtp = smtplib.SMTP.return_value
  187. self.assertEqual(smtp.sendmail.call_count, 5)
  188. # don't send emails to people who weren't affected
  189. self.assertNotIn(
  190. smtp.sendmail.call_args_list,
  191. [
  192. "bob5@example.com",
  193. "bob6@example.com",
  194. "bob7@example.com",
  195. "bob8@example.com",
  196. "bob9@example.com",
  197. ],
  198. )
  199. # make sure someone who is affected gets email
  200. self.assertIn("bob4@example.com", smtp.sendmail.call_args_list[0][0])
  201. cur = self.sydent.db.cursor()
  202. res = cur.execute("SELECT * FROM local_threepid_associations")
  203. db_state = res.fetchall()
  204. # five addresses should have been deleted
  205. self.assertEqual(len(db_state), 10)
  206. # iterate through db and make sure all addresses are casefolded and hash matches casefolded address
  207. for row in db_state:
  208. casefolded = row[2].casefold()
  209. self.assertEqual(row[2], casefolded)
  210. self.assertEqual(
  211. calculate_lookup_hash(self.sydent, row[2]),
  212. calculate_lookup_hash(self.sydent, casefolded),
  213. )
  214. def test_global_db_migration(self):
  215. update_global_associations(
  216. self.sydent, self.sydent.db, send_email=True, dry_run=False
  217. )
  218. cur = self.sydent.db.cursor()
  219. res = cur.execute("SELECT * FROM global_threepid_associations")
  220. db_state = res.fetchall()
  221. # five addresses should have been deleted
  222. self.assertEqual(len(db_state), 10)
  223. # iterate through db and make sure all addresses are casefolded and hash matches casefolded address
  224. # and make sure the casefolded address matches the address in sgAssoc
  225. for row in db_state:
  226. casefolded = row[2].casefold()
  227. self.assertEqual(row[2], casefolded)
  228. self.assertEqual(
  229. calculate_lookup_hash(self.sydent, row[2]),
  230. calculate_lookup_hash(self.sydent, casefolded),
  231. )
  232. sgassoc = json_decoder.decode(row[9])
  233. self.assertEqual(row[2], sgassoc["address"])
  234. def test_local_no_email_does_not_send_email(self):
  235. with patch("sydent.util.emailutils.smtplib") as smtplib:
  236. update_local_associations(
  237. self.sydent, self.sydent.db, send_email=False, dry_run=False
  238. )
  239. smtp = smtplib.SMTP.return_value
  240. # test no emails were sent
  241. self.assertEqual(smtp.sendmail.call_count, 0)
  242. def test_dry_run_does_nothing(self):
  243. # reset DB
  244. self.setUp()
  245. cur = self.sydent.db.cursor()
  246. # grab a snapshot of global table before running script
  247. res1 = cur.execute("SELECT mxid FROM global_threepid_associations")
  248. list1 = res1.fetchall()
  249. with patch("sydent.util.emailutils.smtplib") as smtplib:
  250. update_global_associations(
  251. self.sydent, self.sydent.db, send_email=True, dry_run=True
  252. )
  253. # test no emails were sent
  254. smtp = smtplib.SMTP.return_value
  255. self.assertEqual(smtp.sendmail.call_count, 0)
  256. res2 = cur.execute("SELECT mxid FROM global_threepid_associations")
  257. list2 = res2.fetchall()
  258. self.assertEqual(list1, list2)
  259. # grab a snapshot of local table db before running script
  260. res3 = cur.execute("SELECT mxid FROM local_threepid_associations")
  261. list3 = res3.fetchall()
  262. with patch("sydent.util.emailutils.smtplib") as smtplib:
  263. update_local_associations(
  264. self.sydent, self.sydent.db, send_email=True, dry_run=True
  265. )
  266. # test no emails were sent
  267. smtp = smtplib.SMTP.return_value
  268. self.assertEqual(smtp.sendmail.call_count, 0)
  269. res4 = cur.execute("SELECT mxid FROM local_threepid_associations")
  270. list4 = res4.fetchall()
  271. self.assertEqual(list3, list4)