test_casefold_migration.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import json
  2. import os.path
  3. from unittest.mock import patch
  4. from twisted.trial import unittest
  5. from scripts.casefold_db import (
  6. calculate_lookup_hash,
  7. update_global_associations,
  8. update_local_associations,
  9. )
  10. from sydent.util import json_decoder
  11. from sydent.util.emailutils import sendEmail
  12. from tests.utils import make_sydent
  13. class MigrationTestCase(unittest.TestCase):
  14. def create_signedassoc(self, medium, address, mxid, ts, not_before, not_after):
  15. return {
  16. "medium": medium,
  17. "address": address,
  18. "mxid": mxid,
  19. "ts": ts,
  20. "not_before": not_before,
  21. "not_after": not_after,
  22. }
  23. def setUp(self):
  24. # Create a new sydent
  25. config = {
  26. "general": {
  27. "templates.path": os.path.join(
  28. os.path.dirname(os.path.dirname(__file__)), "res"
  29. ),
  30. },
  31. "crypto": {
  32. "ed25519.signingkey": "ed25519 0 FJi1Rnpj3/otydngacrwddFvwz/dTDsBv62uZDN2fZM"
  33. },
  34. }
  35. self.sydent = make_sydent(test_config=config)
  36. # create some local associations
  37. associations = []
  38. for i in range(10):
  39. address = "bob%d@example.com" % i
  40. associations.append(
  41. {
  42. "medium": "email",
  43. "address": address,
  44. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  45. "mxid": "@bob%d:example.com" % i,
  46. "ts": (i * 10000),
  47. "not_before": 0,
  48. "not_after": 99999999999,
  49. }
  50. )
  51. # create some casefold-conflicting associations
  52. for i in range(5):
  53. address = "BOB%d@example.com" % i
  54. associations.append(
  55. {
  56. "medium": "email",
  57. "address": address,
  58. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  59. "mxid": "@otherbob%d:example.com" % i,
  60. "ts": (i * 10000),
  61. "not_before": 0,
  62. "not_after": 99999999999,
  63. }
  64. )
  65. # add all associations to db
  66. cur = self.sydent.db.cursor()
  67. cur.executemany(
  68. "INSERT INTO local_threepid_associations "
  69. "(medium, address, lookup_hash, mxid, ts, notBefore, notAfter) "
  70. "VALUES (?, ?, ?, ?, ?, ?, ?)",
  71. [
  72. (
  73. assoc["medium"],
  74. assoc["address"],
  75. assoc["lookup_hash"],
  76. assoc["mxid"],
  77. assoc["ts"],
  78. assoc["not_before"],
  79. assoc["not_after"],
  80. )
  81. for assoc in associations
  82. ],
  83. )
  84. self.sydent.db.commit()
  85. # create some global associations
  86. associations = []
  87. originServer = self.sydent.config.general.server_name
  88. for i in range(10):
  89. address = "bob%d@example.com" % i
  90. mxid = "@bob%d:example.com" % i
  91. ts = 10000 * i
  92. associations.append(
  93. {
  94. "medium": "email",
  95. "address": address,
  96. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  97. "mxid": mxid,
  98. "ts": ts,
  99. "not_before": 0,
  100. "not_after": 99999999999,
  101. "originServer": originServer,
  102. "originId": i,
  103. "sgAssoc": json.dumps(
  104. self.create_signedassoc(
  105. "email", address, mxid, ts, 0, 99999999999
  106. )
  107. ),
  108. }
  109. )
  110. # create some casefold-conflicting associations
  111. for i in range(5):
  112. address = "BOB%d@example.com" % i
  113. mxid = "@BOB%d:example.com" % i
  114. ts = 10000 * i
  115. associations.append(
  116. {
  117. "medium": "email",
  118. "address": address,
  119. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  120. "mxid": mxid,
  121. "ts": ts + 1,
  122. "not_before": 0,
  123. "not_after": 99999999999,
  124. "originServer": originServer,
  125. "originId": i + 10,
  126. "sgAssoc": json.dumps(
  127. self.create_signedassoc(
  128. "email", address, mxid, ts, 0, 99999999999
  129. )
  130. ),
  131. }
  132. )
  133. # add all associations to db
  134. cur = self.sydent.db.cursor()
  135. cur.executemany(
  136. "INSERT INTO global_threepid_associations "
  137. "(medium, address, lookup_hash, mxid, ts, notBefore, notAfter, originServer, originId, sgAssoc) "
  138. "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
  139. [
  140. (
  141. assoc["medium"],
  142. assoc["address"],
  143. assoc["lookup_hash"],
  144. assoc["mxid"],
  145. assoc["ts"],
  146. assoc["not_before"],
  147. assoc["not_after"],
  148. assoc["originServer"],
  149. assoc["originId"],
  150. assoc["sgAssoc"],
  151. )
  152. for assoc in associations
  153. ],
  154. )
  155. self.sydent.db.commit()
  156. def test_migration_email(self):
  157. with patch("sydent.util.emailutils.smtplib") as smtplib:
  158. # self.sydent.config.email.template is deprecated
  159. if self.sydent.config.email.template is None:
  160. templateFile = self.sydent.get_branded_template(
  161. None,
  162. "migration_template.eml",
  163. )
  164. else:
  165. templateFile = self.sydent.config.email.template
  166. sendEmail(
  167. self.sydent,
  168. templateFile,
  169. "bob@example.com",
  170. {
  171. "mxid": "@bob:example.com",
  172. "subject_header_value": "MatrixID Deletion",
  173. },
  174. )
  175. smtp = smtplib.SMTP.return_value
  176. email_contents = smtp.sendmail.call_args[0][2].decode("utf-8")
  177. self.assertIn("In the past", email_contents)
  178. # test email was sent
  179. smtp.sendmail.assert_called()
  180. def test_local_db_migration(self):
  181. with patch("sydent.util.emailutils.smtplib") as smtplib:
  182. update_local_associations(
  183. self.sydent,
  184. self.sydent.db,
  185. send_email=True,
  186. dry_run=False,
  187. test=True,
  188. )
  189. # test 5 emails were sent
  190. smtp = smtplib.SMTP.return_value
  191. self.assertEqual(smtp.sendmail.call_count, 5)
  192. # don't send emails to people who weren't affected
  193. self.assertNotIn(
  194. smtp.sendmail.call_args_list,
  195. [
  196. "bob5@example.com",
  197. "bob6@example.com",
  198. "bob7@example.com",
  199. "bob8@example.com",
  200. "bob9@example.com",
  201. ],
  202. )
  203. # make sure someone who is affected gets email
  204. self.assertIn("bob4@example.com", smtp.sendmail.call_args_list[0][0])
  205. cur = self.sydent.db.cursor()
  206. res = cur.execute("SELECT * FROM local_threepid_associations")
  207. db_state = res.fetchall()
  208. # five addresses should have been deleted
  209. self.assertEqual(len(db_state), 10)
  210. # iterate through db and make sure all addresses are casefolded and hash matches casefolded address
  211. for row in db_state:
  212. casefolded = row[2].casefold()
  213. self.assertEqual(row[2], casefolded)
  214. self.assertEqual(
  215. calculate_lookup_hash(self.sydent, row[2]),
  216. calculate_lookup_hash(self.sydent, casefolded),
  217. )
  218. def test_global_db_migration(self):
  219. update_global_associations(
  220. self.sydent,
  221. self.sydent.db,
  222. dry_run=False,
  223. )
  224. cur = self.sydent.db.cursor()
  225. res = cur.execute("SELECT * FROM global_threepid_associations")
  226. db_state = res.fetchall()
  227. # five addresses should have been deleted
  228. self.assertEqual(len(db_state), 10)
  229. # iterate through db and make sure all addresses are casefolded and hash matches casefolded address
  230. # and make sure the casefolded address matches the address in sgAssoc
  231. for row in db_state:
  232. casefolded = row[2].casefold()
  233. self.assertEqual(row[2], casefolded)
  234. self.assertEqual(
  235. calculate_lookup_hash(self.sydent, row[2]),
  236. calculate_lookup_hash(self.sydent, casefolded),
  237. )
  238. sgassoc = json_decoder.decode(row[9])
  239. self.assertEqual(row[2], sgassoc["address"])
  240. def test_local_no_email_does_not_send_email(self):
  241. with patch("sydent.util.emailutils.smtplib") as smtplib:
  242. update_local_associations(
  243. self.sydent,
  244. self.sydent.db,
  245. send_email=False,
  246. dry_run=False,
  247. test=True,
  248. )
  249. smtp = smtplib.SMTP.return_value
  250. # test no emails were sent
  251. self.assertEqual(smtp.sendmail.call_count, 0)
  252. def test_dry_run_does_nothing(self):
  253. # reset DB
  254. self.setUp()
  255. cur = self.sydent.db.cursor()
  256. # grab a snapshot of global table before running script
  257. res1 = cur.execute("SELECT mxid FROM global_threepid_associations")
  258. list1 = res1.fetchall()
  259. with patch("sydent.util.emailutils.smtplib") as smtplib:
  260. update_global_associations(
  261. self.sydent,
  262. self.sydent.db,
  263. dry_run=True,
  264. )
  265. # test no emails were sent
  266. smtp = smtplib.SMTP.return_value
  267. self.assertEqual(smtp.sendmail.call_count, 0)
  268. res2 = cur.execute("SELECT mxid FROM global_threepid_associations")
  269. list2 = res2.fetchall()
  270. self.assertEqual(list1, list2)
  271. # grab a snapshot of local table db before running script
  272. res3 = cur.execute("SELECT mxid FROM local_threepid_associations")
  273. list3 = res3.fetchall()
  274. with patch("sydent.util.emailutils.smtplib") as smtplib:
  275. update_local_associations(
  276. self.sydent,
  277. self.sydent.db,
  278. send_email=True,
  279. dry_run=True,
  280. test=True,
  281. )
  282. # test no emails were sent
  283. smtp = smtplib.SMTP.return_value
  284. self.assertEqual(smtp.sendmail.call_count, 0)
  285. res4 = cur.execute("SELECT mxid FROM local_threepid_associations")
  286. list4 = res4.fetchall()
  287. self.assertEqual(list3, list4)