test_casefold_migration.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. import json
  2. import os.path
  3. from unittest.mock import patch
  4. from twisted.trial import unittest
  5. from scripts.casefold_db import (
  6. calculate_lookup_hash,
  7. update_global_associations,
  8. update_local_associations,
  9. )
  10. from sydent.util import json_decoder
  11. from sydent.util.emailutils import sendEmail
  12. from tests.utils import make_sydent
  13. class MigrationTestCase(unittest.TestCase):
  14. def create_signedassoc(self, medium, address, mxid, ts, not_before, not_after):
  15. return {
  16. "medium": medium,
  17. "address": address,
  18. "mxid": mxid,
  19. "ts": ts,
  20. "not_before": not_before,
  21. "not_after": not_after,
  22. }
  23. def setUp(self):
  24. # Create a new sydent
  25. config = {
  26. "general": {
  27. "templates.path": os.path.join(
  28. os.path.dirname(os.path.dirname(__file__)), "res"
  29. ),
  30. },
  31. "crypto": {
  32. "ed25519.signingkey": "ed25519 0 FJi1Rnpj3/otydngacrwddFvwz/dTDsBv62uZDN2fZM"
  33. },
  34. }
  35. self.sydent = make_sydent(test_config=config)
  36. # create some local associations
  37. associations = []
  38. for i in range(10):
  39. address = "bob%d@example.com" % i
  40. associations.append(
  41. {
  42. "medium": "email",
  43. "address": address,
  44. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  45. "mxid": "@bob%d:example.com" % i,
  46. "ts": (i * 10000),
  47. "not_before": 0,
  48. "not_after": 99999999999,
  49. }
  50. )
  51. # create some casefold-conflicting associations
  52. for i in range(5):
  53. address = "BOB%d@example.com" % i
  54. associations.append(
  55. {
  56. "medium": "email",
  57. "address": address,
  58. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  59. "mxid": "@otherbob%d:example.com" % i,
  60. "ts": (i * 10000),
  61. "not_before": 0,
  62. "not_after": 99999999999,
  63. }
  64. )
  65. associations.append(
  66. {
  67. "medium": "email",
  68. "address": "BoB4@example.com",
  69. "lookup_hash": calculate_lookup_hash(self.sydent, "BoB4@example.com"),
  70. "mxid": "@otherbob4:example.com",
  71. "ts": 42000,
  72. "not_before": 0,
  73. "not_after": 99999999999,
  74. }
  75. )
  76. # add all associations to db
  77. cur = self.sydent.db.cursor()
  78. cur.executemany(
  79. "INSERT INTO local_threepid_associations "
  80. "(medium, address, lookup_hash, mxid, ts, notBefore, notAfter) "
  81. "VALUES (?, ?, ?, ?, ?, ?, ?)",
  82. [
  83. (
  84. assoc["medium"],
  85. assoc["address"],
  86. assoc["lookup_hash"],
  87. assoc["mxid"],
  88. assoc["ts"],
  89. assoc["not_before"],
  90. assoc["not_after"],
  91. )
  92. for assoc in associations
  93. ],
  94. )
  95. self.sydent.db.commit()
  96. # create some global associations
  97. associations = []
  98. originServer = self.sydent.config.general.server_name
  99. for i in range(10):
  100. address = "bob%d@example.com" % i
  101. mxid = "@bob%d:example.com" % i
  102. ts = 10000 * i
  103. associations.append(
  104. {
  105. "medium": "email",
  106. "address": address,
  107. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  108. "mxid": mxid,
  109. "ts": ts,
  110. "not_before": 0,
  111. "not_after": 99999999999,
  112. "originServer": originServer,
  113. "originId": i,
  114. "sgAssoc": json.dumps(
  115. self.create_signedassoc(
  116. "email", address, mxid, ts, 0, 99999999999
  117. )
  118. ),
  119. }
  120. )
  121. # create some casefold-conflicting associations
  122. for i in range(5):
  123. address = "BOB%d@example.com" % i
  124. mxid = "@BOB%d:example.com" % i
  125. ts = 10000 * i
  126. associations.append(
  127. {
  128. "medium": "email",
  129. "address": address,
  130. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  131. "mxid": mxid,
  132. "ts": ts + 1,
  133. "not_before": 0,
  134. "not_after": 99999999999,
  135. "originServer": originServer,
  136. "originId": i + 10,
  137. "sgAssoc": json.dumps(
  138. self.create_signedassoc(
  139. "email", address, mxid, ts, 0, 99999999999
  140. )
  141. ),
  142. }
  143. )
  144. # add all associations to db
  145. cur = self.sydent.db.cursor()
  146. cur.executemany(
  147. "INSERT INTO global_threepid_associations "
  148. "(medium, address, lookup_hash, mxid, ts, notBefore, notAfter, originServer, originId, sgAssoc) "
  149. "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
  150. [
  151. (
  152. assoc["medium"],
  153. assoc["address"],
  154. assoc["lookup_hash"],
  155. assoc["mxid"],
  156. assoc["ts"],
  157. assoc["not_before"],
  158. assoc["not_after"],
  159. assoc["originServer"],
  160. assoc["originId"],
  161. assoc["sgAssoc"],
  162. )
  163. for assoc in associations
  164. ],
  165. )
  166. self.sydent.db.commit()
  167. def test_migration_email(self):
  168. with patch("sydent.util.emailutils.smtplib") as smtplib:
  169. # self.sydent.config.email.template is deprecated
  170. if self.sydent.config.email.template is None:
  171. templateFile = self.sydent.get_branded_template(
  172. None,
  173. "migration_template.eml",
  174. )
  175. else:
  176. templateFile = self.sydent.config.email.template
  177. sendEmail(
  178. self.sydent,
  179. templateFile,
  180. "bob@example.com",
  181. {
  182. "mxid": "@bob:example.com",
  183. "subject_header_value": "MatrixID Deletion",
  184. },
  185. )
  186. smtp = smtplib.SMTP.return_value
  187. email_contents = smtp.sendmail.call_args[0][2].decode("utf-8")
  188. self.assertIn("In the past", email_contents)
  189. # test email was sent
  190. smtp.sendmail.assert_called()
  191. def test_local_db_migration(self):
  192. with patch("sydent.util.emailutils.smtplib") as smtplib:
  193. update_local_associations(
  194. self.sydent,
  195. self.sydent.db,
  196. send_email=True,
  197. dry_run=False,
  198. test=True,
  199. )
  200. # test 5 emails were sent
  201. smtp = smtplib.SMTP.return_value
  202. self.assertEqual(smtp.sendmail.call_count, 5)
  203. # don't send emails to people who weren't affected
  204. self.assertNotIn(
  205. smtp.sendmail.call_args_list,
  206. [
  207. "bob5@example.com",
  208. "bob6@example.com",
  209. "bob7@example.com",
  210. "bob8@example.com",
  211. "bob9@example.com",
  212. ],
  213. )
  214. # make sure someone who is affected gets email
  215. self.assertIn("bob4@example.com", smtp.sendmail.call_args_list[0][0])
  216. cur = self.sydent.db.cursor()
  217. res = cur.execute("SELECT * FROM local_threepid_associations")
  218. db_state = res.fetchall()
  219. # five addresses should have been deleted
  220. self.assertEqual(len(db_state), 10)
  221. # iterate through db and make sure all addresses are casefolded and hash matches casefolded address
  222. for row in db_state:
  223. casefolded = row[2].casefold()
  224. self.assertEqual(row[2], casefolded)
  225. self.assertEqual(
  226. calculate_lookup_hash(self.sydent, row[2]),
  227. calculate_lookup_hash(self.sydent, casefolded),
  228. )
  229. def test_global_db_migration(self):
  230. update_global_associations(
  231. self.sydent,
  232. self.sydent.db,
  233. dry_run=False,
  234. )
  235. cur = self.sydent.db.cursor()
  236. res = cur.execute("SELECT * FROM global_threepid_associations")
  237. db_state = res.fetchall()
  238. # five addresses should have been deleted
  239. self.assertEqual(len(db_state), 10)
  240. # iterate through db and make sure all addresses are casefolded and hash matches casefolded address
  241. # and make sure the casefolded address matches the address in sgAssoc
  242. for row in db_state:
  243. casefolded = row[2].casefold()
  244. self.assertEqual(row[2], casefolded)
  245. self.assertEqual(
  246. calculate_lookup_hash(self.sydent, row[2]),
  247. calculate_lookup_hash(self.sydent, casefolded),
  248. )
  249. sgassoc = json_decoder.decode(row[9])
  250. self.assertEqual(row[2], sgassoc["address"])
  251. def test_local_no_email_does_not_send_email(self):
  252. with patch("sydent.util.emailutils.smtplib") as smtplib:
  253. update_local_associations(
  254. self.sydent,
  255. self.sydent.db,
  256. send_email=False,
  257. dry_run=False,
  258. test=True,
  259. )
  260. smtp = smtplib.SMTP.return_value
  261. # test no emails were sent
  262. self.assertEqual(smtp.sendmail.call_count, 0)
  263. def test_dry_run_does_nothing(self):
  264. # reset DB
  265. self.setUp()
  266. cur = self.sydent.db.cursor()
  267. # grab a snapshot of global table before running script
  268. res1 = cur.execute("SELECT mxid FROM global_threepid_associations")
  269. list1 = res1.fetchall()
  270. with patch("sydent.util.emailutils.smtplib") as smtplib:
  271. update_global_associations(
  272. self.sydent,
  273. self.sydent.db,
  274. dry_run=True,
  275. )
  276. # test no emails were sent
  277. smtp = smtplib.SMTP.return_value
  278. self.assertEqual(smtp.sendmail.call_count, 0)
  279. res2 = cur.execute("SELECT mxid FROM global_threepid_associations")
  280. list2 = res2.fetchall()
  281. self.assertEqual(list1, list2)
  282. # grab a snapshot of local table db before running script
  283. res3 = cur.execute("SELECT mxid FROM local_threepid_associations")
  284. list3 = res3.fetchall()
  285. with patch("sydent.util.emailutils.smtplib") as smtplib:
  286. update_local_associations(
  287. self.sydent,
  288. self.sydent.db,
  289. send_email=True,
  290. dry_run=True,
  291. test=True,
  292. )
  293. # test no emails were sent
  294. smtp = smtplib.SMTP.return_value
  295. self.assertEqual(smtp.sendmail.call_count, 0)
  296. res4 = cur.execute("SELECT mxid FROM local_threepid_associations")
  297. list4 = res4.fetchall()
  298. self.assertEqual(list3, list4)