port_from_sqlite_to_postgres.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2015 OpenMarket Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from twisted.internet import defer, reactor
  17. from twisted.enterprise import adbapi
  18. from synapse.storage._base import LoggingTransaction, SQLBaseStore
  19. from synapse.storage.engines import create_engine
  20. import argparse
  21. import curses
  22. import logging
  23. import sys
  24. import time
  25. import traceback
  26. import yaml
  27. logger = logging.getLogger("port_from_sqlite_to_postgres")
  28. BOOLEAN_COLUMNS = {
  29. "events": ["processed", "outlier"],
  30. "rooms": ["is_public"],
  31. "event_edges": ["is_state"],
  32. "presence_list": ["accepted"],
  33. }
  34. APPEND_ONLY_TABLES = [
  35. "event_content_hashes",
  36. "event_reference_hashes",
  37. "event_signatures",
  38. "event_edge_hashes",
  39. "events",
  40. "event_json",
  41. "state_events",
  42. "room_memberships",
  43. "feedback",
  44. "topics",
  45. "room_names",
  46. "rooms",
  47. "local_media_repository",
  48. "local_media_repository_thumbnails",
  49. "remote_media_cache",
  50. "remote_media_cache_thumbnails",
  51. "redactions",
  52. "event_edges",
  53. "event_auth",
  54. "received_transactions",
  55. "sent_transactions",
  56. "transaction_id_to_pdu",
  57. "users",
  58. "state_groups",
  59. "state_groups_state",
  60. "event_to_state_groups",
  61. "rejections",
  62. ]
  63. end_error_exec_info = None
  64. class Store(object):
  65. """This object is used to pull out some of the convenience API from the
  66. Storage layer.
  67. *All* database interactions should go through this object.
  68. """
  69. def __init__(self, db_pool, engine):
  70. self.db_pool = db_pool
  71. self.database_engine = engine
  72. _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
  73. _simple_insert = SQLBaseStore.__dict__["_simple_insert"]
  74. _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
  75. _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
  76. _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
  77. _simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
  78. _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
  79. _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
  80. _execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"]
  81. def runInteraction(self, desc, func, *args, **kwargs):
  82. def r(conn):
  83. try:
  84. i = 0
  85. N = 5
  86. while True:
  87. try:
  88. txn = conn.cursor()
  89. return func(
  90. LoggingTransaction(txn, desc, self.database_engine, []),
  91. *args, **kwargs
  92. )
  93. except self.database_engine.module.DatabaseError as e:
  94. if self.database_engine.is_deadlock(e):
  95. logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
  96. if i < N:
  97. i += 1
  98. conn.rollback()
  99. continue
  100. raise
  101. except Exception as e:
  102. logger.debug("[TXN FAIL] {%s} %s", desc, e)
  103. raise
  104. return self.db_pool.runWithConnection(r)
  105. def execute(self, f, *args, **kwargs):
  106. return self.runInteraction(f.__name__, f, *args, **kwargs)
  107. def execute_sql(self, sql, *args):
  108. def r(txn):
  109. txn.execute(sql, args)
  110. return txn.fetchall()
  111. return self.runInteraction("execute_sql", r)
  112. def insert_many_txn(self, txn, table, headers, rows):
  113. sql = "INSERT INTO %s (%s) VALUES (%s)" % (
  114. table,
  115. ", ".join(k for k in headers),
  116. ", ".join("%s" for _ in headers)
  117. )
  118. try:
  119. txn.executemany(sql, rows)
  120. except:
  121. logger.exception(
  122. "Failed to insert: %s",
  123. table,
  124. )
  125. raise
  126. class Porter(object):
  127. def __init__(self, **kwargs):
  128. self.__dict__.update(kwargs)
  129. @defer.inlineCallbacks
  130. def setup_table(self, table):
  131. if table in APPEND_ONLY_TABLES:
  132. # It's safe to just carry on inserting.
  133. next_chunk = yield self.postgres_store._simple_select_one_onecol(
  134. table="port_from_sqlite3",
  135. keyvalues={"table_name": table},
  136. retcol="rowid",
  137. allow_none=True,
  138. )
  139. total_to_port = None
  140. if next_chunk is None:
  141. if table == "sent_transactions":
  142. next_chunk, already_ported, total_to_port = (
  143. yield self._setup_sent_transactions()
  144. )
  145. else:
  146. yield self.postgres_store._simple_insert(
  147. table="port_from_sqlite3",
  148. values={"table_name": table, "rowid": 1}
  149. )
  150. next_chunk = 1
  151. already_ported = 0
  152. if total_to_port is None:
  153. already_ported, total_to_port = yield self._get_total_count_to_port(
  154. table, next_chunk
  155. )
  156. else:
  157. def delete_all(txn):
  158. txn.execute(
  159. "DELETE FROM port_from_sqlite3 WHERE table_name = %s",
  160. (table,)
  161. )
  162. txn.execute("TRUNCATE %s CASCADE" % (table,))
  163. yield self.postgres_store.execute(delete_all)
  164. yield self.postgres_store._simple_insert(
  165. table="port_from_sqlite3",
  166. values={"table_name": table, "rowid": 0}
  167. )
  168. next_chunk = 1
  169. already_ported, total_to_port = yield self._get_total_count_to_port(
  170. table, next_chunk
  171. )
  172. defer.returnValue((table, already_ported, total_to_port, next_chunk))
  173. @defer.inlineCallbacks
  174. def handle_table(self, table, postgres_size, table_size, next_chunk):
  175. if not table_size:
  176. return
  177. self.progress.add_table(table, postgres_size, table_size)
  178. select = (
  179. "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
  180. % (table,)
  181. )
  182. while True:
  183. def r(txn):
  184. txn.execute(select, (next_chunk, self.batch_size,))
  185. rows = txn.fetchall()
  186. headers = [column[0] for column in txn.description]
  187. return headers, rows
  188. headers, rows = yield self.sqlite_store.runInteraction("select", r)
  189. if rows:
  190. next_chunk = rows[-1][0] + 1
  191. self._convert_rows(table, headers, rows)
  192. def insert(txn):
  193. self.postgres_store.insert_many_txn(
  194. txn, table, headers[1:], rows
  195. )
  196. self.postgres_store._simple_update_one_txn(
  197. txn,
  198. table="port_from_sqlite3",
  199. keyvalues={"table_name": table},
  200. updatevalues={"rowid": next_chunk},
  201. )
  202. yield self.postgres_store.execute(insert)
  203. postgres_size += len(rows)
  204. self.progress.update(table, postgres_size)
  205. else:
  206. return
  207. def setup_db(self, db_config, database_engine):
  208. db_conn = database_engine.module.connect(
  209. **{
  210. k: v for k, v in db_config.get("args", {}).items()
  211. if not k.startswith("cp_")
  212. }
  213. )
  214. database_engine.prepare_database(db_conn)
  215. db_conn.commit()
  216. @defer.inlineCallbacks
  217. def run(self):
  218. try:
  219. sqlite_db_pool = adbapi.ConnectionPool(
  220. self.sqlite_config["name"],
  221. **self.sqlite_config["args"]
  222. )
  223. postgres_db_pool = adbapi.ConnectionPool(
  224. self.postgres_config["name"],
  225. **self.postgres_config["args"]
  226. )
  227. sqlite_engine = create_engine("sqlite3")
  228. postgres_engine = create_engine("psycopg2")
  229. self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
  230. self.postgres_store = Store(postgres_db_pool, postgres_engine)
  231. yield self.postgres_store.execute(
  232. postgres_engine.check_database
  233. )
  234. # Step 1. Set up databases.
  235. self.progress.set_state("Preparing SQLite3")
  236. self.setup_db(sqlite_config, sqlite_engine)
  237. self.progress.set_state("Preparing PostgreSQL")
  238. self.setup_db(postgres_config, postgres_engine)
  239. # Step 2. Get tables.
  240. self.progress.set_state("Fetching tables")
  241. sqlite_tables = yield self.sqlite_store._simple_select_onecol(
  242. table="sqlite_master",
  243. keyvalues={
  244. "type": "table",
  245. },
  246. retcol="name",
  247. )
  248. postgres_tables = yield self.postgres_store._simple_select_onecol(
  249. table="information_schema.tables",
  250. keyvalues={
  251. "table_schema": "public",
  252. },
  253. retcol="distinct table_name",
  254. )
  255. tables = set(sqlite_tables) & set(postgres_tables)
  256. self.progress.set_state("Creating tables")
  257. logger.info("Found %d tables", len(tables))
  258. def create_port_table(txn):
  259. txn.execute(
  260. "CREATE TABLE port_from_sqlite3 ("
  261. " table_name varchar(100) NOT NULL UNIQUE,"
  262. " rowid bigint NOT NULL"
  263. ")"
  264. )
  265. try:
  266. yield self.postgres_store.runInteraction(
  267. "create_port_table", create_port_table
  268. )
  269. except Exception as e:
  270. logger.info("Failed to create port table: %s", e)
  271. self.progress.set_state("Setting up")
  272. # Set up tables.
  273. setup_res = yield defer.gatherResults(
  274. [
  275. self.setup_table(table)
  276. for table in tables
  277. if table not in ["schema_version", "applied_schema_deltas"]
  278. and not table.startswith("sqlite_")
  279. ],
  280. consumeErrors=True,
  281. )
  282. # Process tables.
  283. yield defer.gatherResults(
  284. [
  285. self.handle_table(*res)
  286. for res in setup_res
  287. ],
  288. consumeErrors=True,
  289. )
  290. self.progress.done()
  291. except:
  292. global end_error_exec_info
  293. end_error_exec_info = sys.exc_info()
  294. logger.exception("")
  295. finally:
  296. reactor.stop()
  297. def _convert_rows(self, table, headers, rows):
  298. bool_col_names = BOOLEAN_COLUMNS.get(table, [])
  299. bool_cols = [
  300. i for i, h in enumerate(headers) if h in bool_col_names
  301. ]
  302. def conv(j, col):
  303. if j in bool_cols:
  304. return bool(col)
  305. return col
  306. for i, row in enumerate(rows):
  307. rows[i] = tuple(
  308. conv(j, col)
  309. for j, col in enumerate(row)
  310. if j > 0
  311. )
  312. @defer.inlineCallbacks
  313. def _setup_sent_transactions(self):
  314. # Only save things from the last day
  315. yesterday = int(time.time()*1000) - 86400000
  316. # And save the max transaction id from each destination
  317. select = (
  318. "SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
  319. "SELECT max(rowid) FROM sent_transactions"
  320. " GROUP BY destination"
  321. ")"
  322. )
  323. def r(txn):
  324. txn.execute(select)
  325. rows = txn.fetchall()
  326. headers = [column[0] for column in txn.description]
  327. ts_ind = headers.index('ts')
  328. return headers, [r for r in rows if r[ts_ind] < yesterday]
  329. headers, rows = yield self.sqlite_store.runInteraction(
  330. "select", r,
  331. )
  332. self._convert_rows("sent_transactions", headers, rows)
  333. inserted_rows = len(rows)
  334. max_inserted_rowid = max(r[0] for r in rows)
  335. def insert(txn):
  336. self.postgres_store.insert_many_txn(
  337. txn, "sent_transactions", headers[1:], rows
  338. )
  339. yield self.postgres_store.execute(insert)
  340. def get_start_id(txn):
  341. txn.execute(
  342. "SELECT rowid FROM sent_transactions WHERE ts >= ?"
  343. " ORDER BY rowid ASC LIMIT 1",
  344. (yesterday,)
  345. )
  346. rows = txn.fetchall()
  347. if rows:
  348. return rows[0][0]
  349. else:
  350. return 1
  351. next_chunk = yield self.sqlite_store.execute(get_start_id)
  352. next_chunk = max(max_inserted_rowid + 1, next_chunk)
  353. yield self.postgres_store._simple_insert(
  354. table="port_from_sqlite3",
  355. values={"table_name": "sent_transactions", "rowid": next_chunk}
  356. )
  357. def get_sent_table_size(txn):
  358. txn.execute(
  359. "SELECT count(*) FROM sent_transactions"
  360. " WHERE ts >= ?",
  361. (yesterday,)
  362. )
  363. size, = txn.fetchone()
  364. return int(size)
  365. remaining_count = yield self.sqlite_store.execute(
  366. get_sent_table_size
  367. )
  368. total_count = remaining_count + inserted_rows
  369. defer.returnValue((next_chunk, inserted_rows, total_count))
  370. @defer.inlineCallbacks
  371. def _get_remaining_count_to_port(self, table, next_chunk):
  372. rows = yield self.sqlite_store.execute_sql(
  373. "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
  374. next_chunk,
  375. )
  376. defer.returnValue(rows[0][0])
  377. @defer.inlineCallbacks
  378. def _get_already_ported_count(self, table):
  379. rows = yield self.postgres_store.execute_sql(
  380. "SELECT count(*) FROM %s" % (table,),
  381. )
  382. defer.returnValue(rows[0][0])
  383. @defer.inlineCallbacks
  384. def _get_total_count_to_port(self, table, next_chunk):
  385. remaining, done = yield defer.gatherResults(
  386. [
  387. self._get_remaining_count_to_port(table, next_chunk),
  388. self._get_already_ported_count(table),
  389. ],
  390. consumeErrors=True,
  391. )
  392. remaining = int(remaining) if remaining else 0
  393. done = int(done) if done else 0
  394. defer.returnValue((done, remaining + done))
  395. ##############################################
  396. ###### The following is simply UI stuff ######
  397. ##############################################
  398. class Progress(object):
  399. """Used to report progress of the port
  400. """
  401. def __init__(self):
  402. self.tables = {}
  403. self.start_time = int(time.time())
  404. def add_table(self, table, cur, size):
  405. self.tables[table] = {
  406. "start": cur,
  407. "num_done": cur,
  408. "total": size,
  409. "perc": int(cur * 100 / size),
  410. }
  411. def update(self, table, num_done):
  412. data = self.tables[table]
  413. data["num_done"] = num_done
  414. data["perc"] = int(num_done * 100 / data["total"])
  415. def done(self):
  416. pass
  417. class CursesProgress(Progress):
  418. """Reports progress to a curses window
  419. """
  420. def __init__(self, stdscr):
  421. self.stdscr = stdscr
  422. curses.use_default_colors()
  423. curses.curs_set(0)
  424. curses.init_pair(1, curses.COLOR_RED, -1)
  425. curses.init_pair(2, curses.COLOR_GREEN, -1)
  426. self.last_update = 0
  427. self.finished = False
  428. self.total_processed = 0
  429. self.total_remaining = 0
  430. super(CursesProgress, self).__init__()
  431. def update(self, table, num_done):
  432. super(CursesProgress, self).update(table, num_done)
  433. self.total_processed = 0
  434. self.total_remaining = 0
  435. for table, data in self.tables.items():
  436. self.total_processed += data["num_done"] - data["start"]
  437. self.total_remaining += data["total"] - data["num_done"]
  438. self.render()
  439. def render(self, force=False):
  440. now = time.time()
  441. if not force and now - self.last_update < 0.2:
  442. # reactor.callLater(1, self.render)
  443. return
  444. self.stdscr.clear()
  445. rows, cols = self.stdscr.getmaxyx()
  446. duration = int(now) - int(self.start_time)
  447. minutes, seconds = divmod(duration, 60)
  448. duration_str = '%02dm %02ds' % (minutes, seconds,)
  449. if self.finished:
  450. status = "Time spent: %s (Done!)" % (duration_str,)
  451. else:
  452. if self.total_processed > 0:
  453. left = float(self.total_remaining) / self.total_processed
  454. est_remaining = (int(now) - self.start_time) * left
  455. est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
  456. else:
  457. est_remaining_str = "Unknown"
  458. status = (
  459. "Time spent: %s (est. remaining: %s)"
  460. % (duration_str, est_remaining_str,)
  461. )
  462. self.stdscr.addstr(
  463. 0, 0,
  464. status,
  465. curses.A_BOLD,
  466. )
  467. max_len = max([len(t) for t in self.tables.keys()])
  468. left_margin = 5
  469. middle_space = 1
  470. items = self.tables.items()
  471. items.sort(
  472. key=lambda i: (i[1]["perc"], i[0]),
  473. )
  474. for i, (table, data) in enumerate(items):
  475. if i + 2 >= rows:
  476. break
  477. perc = data["perc"]
  478. color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
  479. self.stdscr.addstr(
  480. i+2, left_margin + max_len - len(table),
  481. table,
  482. curses.A_BOLD | color,
  483. )
  484. size = 20
  485. progress = "[%s%s]" % (
  486. "#" * int(perc*size/100),
  487. " " * (size - int(perc*size/100)),
  488. )
  489. self.stdscr.addstr(
  490. i+2, left_margin + max_len + middle_space,
  491. "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
  492. )
  493. if self.finished:
  494. self.stdscr.addstr(
  495. rows-1, 0,
  496. "Press any key to exit...",
  497. )
  498. self.stdscr.refresh()
  499. self.last_update = time.time()
  500. def done(self):
  501. self.finished = True
  502. self.render(True)
  503. self.stdscr.getch()
  504. def set_state(self, state):
  505. self.stdscr.clear()
  506. self.stdscr.addstr(
  507. 0, 0,
  508. state + "...",
  509. curses.A_BOLD,
  510. )
  511. self.stdscr.refresh()
  512. class TerminalProgress(Progress):
  513. """Just prints progress to the terminal
  514. """
  515. def update(self, table, num_done):
  516. super(TerminalProgress, self).update(table, num_done)
  517. data = self.tables[table]
  518. print "%s: %d%% (%d/%d)" % (
  519. table, data["perc"],
  520. data["num_done"], data["total"],
  521. )
  522. def set_state(self, state):
  523. print state + "..."
  524. ##############################################
  525. ##############################################
  526. if __name__ == "__main__":
  527. parser = argparse.ArgumentParser(
  528. description="A script to port an existing synapse SQLite database to"
  529. " a new PostgreSQL database."
  530. )
  531. parser.add_argument("-v", action='store_true')
  532. parser.add_argument(
  533. "--sqlite-database", required=True,
  534. help="The snapshot of the SQLite database file. This must not be"
  535. " currently used by a running synapse server"
  536. )
  537. parser.add_argument(
  538. "--postgres-config", type=argparse.FileType('r'), required=True,
  539. help="The database config file for the PostgreSQL database"
  540. )
  541. parser.add_argument(
  542. "--curses", action='store_true',
  543. help="display a curses based progress UI"
  544. )
  545. parser.add_argument(
  546. "--batch-size", type=int, default=1000,
  547. help="The number of rows to select from the SQLite table each"
  548. " iteration [default=1000]",
  549. )
  550. args = parser.parse_args()
  551. logging_config = {
  552. "level": logging.DEBUG if args.v else logging.INFO,
  553. "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
  554. }
  555. if args.curses:
  556. logging_config["filename"] = "port-synapse.log"
  557. logging.basicConfig(**logging_config)
  558. sqlite_config = {
  559. "name": "sqlite3",
  560. "args": {
  561. "database": args.sqlite_database,
  562. "cp_min": 1,
  563. "cp_max": 1,
  564. "check_same_thread": False,
  565. },
  566. }
  567. postgres_config = yaml.safe_load(args.postgres_config)
  568. if "database" in postgres_config:
  569. postgres_config = postgres_config["database"]
  570. if "name" not in postgres_config:
  571. sys.stderr.write("Malformed database config: no 'name'")
  572. sys.exit(2)
  573. if postgres_config["name"] != "psycopg2":
  574. sys.stderr.write("Database must use 'psycopg2' connector.")
  575. sys.exit(3)
  576. def start(stdscr=None):
  577. if stdscr:
  578. progress = CursesProgress(stdscr)
  579. else:
  580. progress = TerminalProgress()
  581. porter = Porter(
  582. sqlite_config=sqlite_config,
  583. postgres_config=postgres_config,
  584. progress=progress,
  585. batch_size=args.batch_size,
  586. )
  587. reactor.callWhenRunning(porter.run)
  588. reactor.run()
  589. if args.curses:
  590. curses.wrapper(start)
  591. else:
  592. start()
  593. if end_error_exec_info:
  594. exc_type, exc_value, exc_traceback = end_error_exec_info
  595. traceback.print_exception(exc_type, exc_value, exc_traceback)