test_casefold_migration.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345
  1. # Copyright 2021 Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. from unittest.mock import patch
  16. from twisted.trial import unittest
  17. from scripts.casefold_db import (
  18. calculate_lookup_hash,
  19. update_global_associations,
  20. update_local_associations,
  21. )
  22. from sydent.util import json_decoder
  23. from sydent.util.emailutils import sendEmail
  24. from tests.utils import make_sydent
  25. class MigrationTestCase(unittest.TestCase):
  26. def create_signedassoc(self, medium, address, mxid, ts, not_before, not_after):
  27. return {
  28. "medium": medium,
  29. "address": address,
  30. "mxid": mxid,
  31. "ts": ts,
  32. "not_before": not_before,
  33. "not_after": not_after,
  34. }
  35. def setUp(self):
  36. # Create a new sydent
  37. self.sydent = make_sydent()
  38. # create some local associations
  39. associations = []
  40. for i in range(10):
  41. address = "bob%d@example.com" % i
  42. associations.append(
  43. {
  44. "medium": "email",
  45. "address": address,
  46. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  47. "mxid": "@bob%d:example.com" % i,
  48. "ts": (i * 10000),
  49. "not_before": 0,
  50. "not_after": 99999999999,
  51. }
  52. )
  53. # create some casefold-conflicting associations
  54. for i in range(5):
  55. address = "BOB%d@example.com" % i
  56. associations.append(
  57. {
  58. "medium": "email",
  59. "address": address,
  60. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  61. "mxid": "@otherbob%d:example.com" % i,
  62. "ts": (i * 10000),
  63. "not_before": 0,
  64. "not_after": 99999999999,
  65. }
  66. )
  67. associations.append(
  68. {
  69. "medium": "email",
  70. "address": "BoB4@example.com",
  71. "lookup_hash": calculate_lookup_hash(self.sydent, "BoB4@example.com"),
  72. "mxid": "@otherbob4:example.com",
  73. "ts": 42000,
  74. "not_before": 0,
  75. "not_after": 99999999999,
  76. }
  77. )
  78. # add all associations to db
  79. cur = self.sydent.db.cursor()
  80. cur.executemany(
  81. "INSERT INTO local_threepid_associations "
  82. "(medium, address, lookup_hash, mxid, ts, notBefore, notAfter) "
  83. "VALUES (?, ?, ?, ?, ?, ?, ?)",
  84. [
  85. (
  86. assoc["medium"],
  87. assoc["address"],
  88. assoc["lookup_hash"],
  89. assoc["mxid"],
  90. assoc["ts"],
  91. assoc["not_before"],
  92. assoc["not_after"],
  93. )
  94. for assoc in associations
  95. ],
  96. )
  97. self.sydent.db.commit()
  98. # create some global associations
  99. associations = []
  100. originServer = self.sydent.config.general.server_name
  101. for i in range(10):
  102. address = "bob%d@example.com" % i
  103. mxid = "@bob%d:example.com" % i
  104. ts = 10000 * i
  105. associations.append(
  106. {
  107. "medium": "email",
  108. "address": address,
  109. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  110. "mxid": mxid,
  111. "ts": ts,
  112. "not_before": 0,
  113. "not_after": 99999999999,
  114. "originServer": originServer,
  115. "originId": i,
  116. "sgAssoc": json.dumps(
  117. self.create_signedassoc(
  118. "email", address, mxid, ts, 0, 99999999999
  119. )
  120. ),
  121. }
  122. )
  123. # create some casefold-conflicting associations
  124. for i in range(5):
  125. address = "BOB%d@example.com" % i
  126. mxid = "@BOB%d:example.com" % i
  127. ts = 10000 * i
  128. associations.append(
  129. {
  130. "medium": "email",
  131. "address": address,
  132. "lookup_hash": calculate_lookup_hash(self.sydent, address),
  133. "mxid": mxid,
  134. "ts": ts + 1,
  135. "not_before": 0,
  136. "not_after": 99999999999,
  137. "originServer": originServer,
  138. "originId": i + 10,
  139. "sgAssoc": json.dumps(
  140. self.create_signedassoc(
  141. "email", address, mxid, ts, 0, 99999999999
  142. )
  143. ),
  144. }
  145. )
  146. # add all associations to db
  147. cur = self.sydent.db.cursor()
  148. cur.executemany(
  149. "INSERT INTO global_threepid_associations "
  150. "(medium, address, lookup_hash, mxid, ts, notBefore, notAfter, originServer, originId, sgAssoc) "
  151. "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
  152. [
  153. (
  154. assoc["medium"],
  155. assoc["address"],
  156. assoc["lookup_hash"],
  157. assoc["mxid"],
  158. assoc["ts"],
  159. assoc["not_before"],
  160. assoc["not_after"],
  161. assoc["originServer"],
  162. assoc["originId"],
  163. assoc["sgAssoc"],
  164. )
  165. for assoc in associations
  166. ],
  167. )
  168. self.sydent.db.commit()
  169. def test_migration_email(self):
  170. with patch("sydent.util.emailutils.smtplib") as smtplib:
  171. # self.sydent.config.email.template is deprecated
  172. if self.sydent.config.email.template is None:
  173. templateFile = self.sydent.get_branded_template(
  174. None,
  175. "migration_template.eml",
  176. )
  177. else:
  178. templateFile = self.sydent.config.email.template
  179. sendEmail(
  180. self.sydent,
  181. templateFile,
  182. "bob@example.com",
  183. {
  184. "mxid": "@bob:example.com",
  185. "subject_header_value": "MatrixID Deletion",
  186. },
  187. )
  188. smtp = smtplib.SMTP.return_value
  189. email_contents = smtp.sendmail.call_args[0][2].decode("utf-8")
  190. self.assertIn("In the past", email_contents)
  191. # test email was sent
  192. smtp.sendmail.assert_called()
  193. def test_local_db_migration(self):
  194. with patch("sydent.util.emailutils.smtplib") as smtplib:
  195. update_local_associations(
  196. self.sydent,
  197. self.sydent.db,
  198. send_email=True,
  199. dry_run=False,
  200. test=True,
  201. )
  202. # test 5 emails were sent
  203. smtp = smtplib.SMTP.return_value
  204. self.assertEqual(smtp.sendmail.call_count, 5)
  205. # don't send emails to people who weren't affected
  206. self.assertNotIn(
  207. smtp.sendmail.call_args_list,
  208. [
  209. "bob5@example.com",
  210. "bob6@example.com",
  211. "bob7@example.com",
  212. "bob8@example.com",
  213. "bob9@example.com",
  214. ],
  215. )
  216. # make sure someone who is affected gets email
  217. self.assertIn("bob4@example.com", smtp.sendmail.call_args_list[0][0])
  218. cur = self.sydent.db.cursor()
  219. res = cur.execute("SELECT * FROM local_threepid_associations")
  220. db_state = res.fetchall()
  221. # five addresses should have been deleted
  222. self.assertEqual(len(db_state), 10)
  223. # iterate through db and make sure all addresses are casefolded and hash matches casefolded address
  224. for row in db_state:
  225. casefolded = row[2].casefold()
  226. self.assertEqual(row[2], casefolded)
  227. self.assertEqual(
  228. calculate_lookup_hash(self.sydent, row[2]),
  229. calculate_lookup_hash(self.sydent, casefolded),
  230. )
  231. def test_global_db_migration(self):
  232. update_global_associations(
  233. self.sydent,
  234. self.sydent.db,
  235. dry_run=False,
  236. )
  237. cur = self.sydent.db.cursor()
  238. res = cur.execute("SELECT * FROM global_threepid_associations")
  239. db_state = res.fetchall()
  240. # five addresses should have been deleted
  241. self.assertEqual(len(db_state), 10)
  242. # iterate through db and make sure all addresses are casefolded and hash matches casefolded address
  243. # and make sure the casefolded address matches the address in sgAssoc
  244. for row in db_state:
  245. casefolded = row[2].casefold()
  246. self.assertEqual(row[2], casefolded)
  247. self.assertEqual(
  248. calculate_lookup_hash(self.sydent, row[2]),
  249. calculate_lookup_hash(self.sydent, casefolded),
  250. )
  251. sgassoc = json_decoder.decode(row[9])
  252. self.assertEqual(row[2], sgassoc["address"])
  253. def test_local_no_email_does_not_send_email(self):
  254. with patch("sydent.util.emailutils.smtplib") as smtplib:
  255. update_local_associations(
  256. self.sydent,
  257. self.sydent.db,
  258. send_email=False,
  259. dry_run=False,
  260. test=True,
  261. )
  262. smtp = smtplib.SMTP.return_value
  263. # test no emails were sent
  264. self.assertEqual(smtp.sendmail.call_count, 0)
  265. def test_dry_run_does_nothing(self):
  266. # reset DB
  267. self.setUp()
  268. cur = self.sydent.db.cursor()
  269. # grab a snapshot of global table before running script
  270. res1 = cur.execute("SELECT mxid FROM global_threepid_associations")
  271. list1 = res1.fetchall()
  272. with patch("sydent.util.emailutils.smtplib") as smtplib:
  273. update_global_associations(
  274. self.sydent,
  275. self.sydent.db,
  276. dry_run=True,
  277. )
  278. # test no emails were sent
  279. smtp = smtplib.SMTP.return_value
  280. self.assertEqual(smtp.sendmail.call_count, 0)
  281. res2 = cur.execute("SELECT mxid FROM global_threepid_associations")
  282. list2 = res2.fetchall()
  283. self.assertEqual(list1, list2)
  284. # grab a snapshot of local table db before running script
  285. res3 = cur.execute("SELECT mxid FROM local_threepid_associations")
  286. list3 = res3.fetchall()
  287. with patch("sydent.util.emailutils.smtplib") as smtplib:
  288. update_local_associations(
  289. self.sydent,
  290. self.sydent.db,
  291. send_email=True,
  292. dry_run=True,
  293. test=True,
  294. )
  295. # test no emails were sent
  296. smtp = smtplib.SMTP.return_value
  297. self.assertEqual(smtp.sendmail.call_count, 0)
  298. res4 = cur.execute("SELECT mxid FROM local_threepid_associations")
  299. list4 = res4.fetchall()
  300. self.assertEqual(list3, list4)