synapse_port_db 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2015, 2016 OpenMarket Ltd
  4. # Copyright 2018 New Vector Ltd
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import argparse
  18. import curses
  19. import logging
  20. import sys
  21. import time
  22. import traceback
  23. from six import string_types
  24. import yaml
  25. from twisted.enterprise import adbapi
  26. from twisted.internet import defer, reactor
  27. from synapse.storage._base import LoggingTransaction, SQLBaseStore
  28. from synapse.storage.engines import create_engine
  29. from synapse.storage.prepare_database import prepare_database
  30. logger = logging.getLogger("synapse_port_db")
  31. BOOLEAN_COLUMNS = {
  32. "events": ["processed", "outlier", "contains_url"],
  33. "rooms": ["is_public"],
  34. "event_edges": ["is_state"],
  35. "presence_list": ["accepted"],
  36. "presence_stream": ["currently_active"],
  37. "public_room_list_stream": ["visibility"],
  38. "device_lists_outbound_pokes": ["sent"],
  39. "users_who_share_rooms": ["share_private"],
  40. "groups": ["is_public"],
  41. "group_rooms": ["is_public"],
  42. "group_users": ["is_public", "is_admin"],
  43. "group_summary_rooms": ["is_public"],
  44. "group_room_categories": ["is_public"],
  45. "group_summary_users": ["is_public"],
  46. "group_roles": ["is_public"],
  47. "local_group_membership": ["is_publicised", "is_admin"],
  48. "e2e_room_keys": ["is_verified"],
  49. }
  50. APPEND_ONLY_TABLES = [
  51. "event_reference_hashes",
  52. "events",
  53. "event_json",
  54. "state_events",
  55. "room_memberships",
  56. "topics",
  57. "room_names",
  58. "rooms",
  59. "local_media_repository",
  60. "local_media_repository_thumbnails",
  61. "remote_media_cache",
  62. "remote_media_cache_thumbnails",
  63. "redactions",
  64. "event_edges",
  65. "event_auth",
  66. "received_transactions",
  67. "sent_transactions",
  68. "transaction_id_to_pdu",
  69. "users",
  70. "state_groups",
  71. "state_groups_state",
  72. "event_to_state_groups",
  73. "rejections",
  74. "event_search",
  75. "presence_stream",
  76. "push_rules_stream",
  77. "ex_outlier_stream",
  78. "cache_invalidation_stream",
  79. "public_room_list_stream",
  80. "state_group_edges",
  81. "stream_ordering_to_exterm",
  82. ]
  83. end_error_exec_info = None
  84. class Store(object):
  85. """This object is used to pull out some of the convenience API from the
  86. Storage layer.
  87. *All* database interactions should go through this object.
  88. """
  89. def __init__(self, db_pool, engine):
  90. self.db_pool = db_pool
  91. self.database_engine = engine
  92. _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
  93. _simple_insert = SQLBaseStore.__dict__["_simple_insert"]
  94. _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
  95. _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
  96. _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
  97. _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
  98. _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
  99. _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
  100. "_simple_select_one_onecol_txn"
  101. ]
  102. _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
  103. _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
  104. _simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
  105. def runInteraction(self, desc, func, *args, **kwargs):
  106. def r(conn):
  107. try:
  108. i = 0
  109. N = 5
  110. while True:
  111. try:
  112. txn = conn.cursor()
  113. return func(
  114. LoggingTransaction(txn, desc, self.database_engine, [], []),
  115. *args,
  116. **kwargs
  117. )
  118. except self.database_engine.module.DatabaseError as e:
  119. if self.database_engine.is_deadlock(e):
  120. logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
  121. if i < N:
  122. i += 1
  123. conn.rollback()
  124. continue
  125. raise
  126. except Exception as e:
  127. logger.debug("[TXN FAIL] {%s} %s", desc, e)
  128. raise
  129. return self.db_pool.runWithConnection(r)
  130. def execute(self, f, *args, **kwargs):
  131. return self.runInteraction(f.__name__, f, *args, **kwargs)
  132. def execute_sql(self, sql, *args):
  133. def r(txn):
  134. txn.execute(sql, args)
  135. return txn.fetchall()
  136. return self.runInteraction("execute_sql", r)
  137. def insert_many_txn(self, txn, table, headers, rows):
  138. sql = "INSERT INTO %s (%s) VALUES (%s)" % (
  139. table,
  140. ", ".join(k for k in headers),
  141. ", ".join("%s" for _ in headers),
  142. )
  143. try:
  144. txn.executemany(sql, rows)
  145. except Exception:
  146. logger.exception("Failed to insert: %s", table)
  147. raise
  148. class Porter(object):
  149. def __init__(self, **kwargs):
  150. self.__dict__.update(kwargs)
  151. @defer.inlineCallbacks
  152. def setup_table(self, table):
  153. if table in APPEND_ONLY_TABLES:
  154. # It's safe to just carry on inserting.
  155. row = yield self.postgres_store._simple_select_one(
  156. table="port_from_sqlite3",
  157. keyvalues={"table_name": table},
  158. retcols=("forward_rowid", "backward_rowid"),
  159. allow_none=True,
  160. )
  161. total_to_port = None
  162. if row is None:
  163. if table == "sent_transactions":
  164. forward_chunk, already_ported, total_to_port = (
  165. yield self._setup_sent_transactions()
  166. )
  167. backward_chunk = 0
  168. else:
  169. yield self.postgres_store._simple_insert(
  170. table="port_from_sqlite3",
  171. values={
  172. "table_name": table,
  173. "forward_rowid": 1,
  174. "backward_rowid": 0,
  175. },
  176. )
  177. forward_chunk = 1
  178. backward_chunk = 0
  179. already_ported = 0
  180. else:
  181. forward_chunk = row["forward_rowid"]
  182. backward_chunk = row["backward_rowid"]
  183. if total_to_port is None:
  184. already_ported, total_to_port = yield self._get_total_count_to_port(
  185. table, forward_chunk, backward_chunk
  186. )
  187. else:
  188. def delete_all(txn):
  189. txn.execute(
  190. "DELETE FROM port_from_sqlite3 WHERE table_name = %s", (table,)
  191. )
  192. txn.execute("TRUNCATE %s CASCADE" % (table,))
  193. yield self.postgres_store.execute(delete_all)
  194. yield self.postgres_store._simple_insert(
  195. table="port_from_sqlite3",
  196. values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
  197. )
  198. forward_chunk = 1
  199. backward_chunk = 0
  200. already_ported, total_to_port = yield self._get_total_count_to_port(
  201. table, forward_chunk, backward_chunk
  202. )
  203. defer.returnValue(
  204. (table, already_ported, total_to_port, forward_chunk, backward_chunk)
  205. )
  206. @defer.inlineCallbacks
  207. def handle_table(
  208. self, table, postgres_size, table_size, forward_chunk, backward_chunk
  209. ):
  210. logger.info(
  211. "Table %s: %i/%i (rows %i-%i) already ported",
  212. table,
  213. postgres_size,
  214. table_size,
  215. backward_chunk + 1,
  216. forward_chunk - 1,
  217. )
  218. if not table_size:
  219. return
  220. self.progress.add_table(table, postgres_size, table_size)
  221. if table == "event_search":
  222. yield self.handle_search_table(
  223. postgres_size, table_size, forward_chunk, backward_chunk
  224. )
  225. return
  226. if table in (
  227. "user_directory",
  228. "user_directory_search",
  229. "users_who_share_rooms",
  230. "users_in_pubic_room",
  231. ):
  232. # We don't port these tables, as they're a faff and we can regenreate
  233. # them anyway.
  234. self.progress.update(table, table_size) # Mark table as done
  235. return
  236. if table == "user_directory_stream_pos":
  237. # We need to make sure there is a single row, `(X, null), as that is
  238. # what synapse expects to be there.
  239. yield self.postgres_store._simple_insert(
  240. table=table, values={"stream_id": None}
  241. )
  242. self.progress.update(table, table_size) # Mark table as done
  243. return
  244. forward_select = (
  245. "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)
  246. )
  247. backward_select = (
  248. "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" % (table,)
  249. )
  250. do_forward = [True]
  251. do_backward = [True]
  252. while True:
  253. def r(txn):
  254. forward_rows = []
  255. backward_rows = []
  256. if do_forward[0]:
  257. txn.execute(forward_select, (forward_chunk, self.batch_size))
  258. forward_rows = txn.fetchall()
  259. if not forward_rows:
  260. do_forward[0] = False
  261. if do_backward[0]:
  262. txn.execute(backward_select, (backward_chunk, self.batch_size))
  263. backward_rows = txn.fetchall()
  264. if not backward_rows:
  265. do_backward[0] = False
  266. if forward_rows or backward_rows:
  267. headers = [column[0] for column in txn.description]
  268. else:
  269. headers = None
  270. return headers, forward_rows, backward_rows
  271. headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
  272. if frows or brows:
  273. if frows:
  274. forward_chunk = max(row[0] for row in frows) + 1
  275. if brows:
  276. backward_chunk = min(row[0] for row in brows) - 1
  277. rows = frows + brows
  278. rows = self._convert_rows(table, headers, rows)
  279. def insert(txn):
  280. self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
  281. self.postgres_store._simple_update_one_txn(
  282. txn,
  283. table="port_from_sqlite3",
  284. keyvalues={"table_name": table},
  285. updatevalues={
  286. "forward_rowid": forward_chunk,
  287. "backward_rowid": backward_chunk,
  288. },
  289. )
  290. yield self.postgres_store.execute(insert)
  291. postgres_size += len(rows)
  292. self.progress.update(table, postgres_size)
  293. else:
  294. return
  295. @defer.inlineCallbacks
  296. def handle_search_table(
  297. self, postgres_size, table_size, forward_chunk, backward_chunk
  298. ):
  299. select = (
  300. "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
  301. " FROM event_search as es"
  302. " INNER JOIN events AS e USING (event_id, room_id)"
  303. " WHERE es.rowid >= ?"
  304. " ORDER BY es.rowid LIMIT ?"
  305. )
  306. while True:
  307. def r(txn):
  308. txn.execute(select, (forward_chunk, self.batch_size))
  309. rows = txn.fetchall()
  310. headers = [column[0] for column in txn.description]
  311. return headers, rows
  312. headers, rows = yield self.sqlite_store.runInteraction("select", r)
  313. if rows:
  314. forward_chunk = rows[-1][0] + 1
  315. # We have to treat event_search differently since it has a
  316. # different structure in the two different databases.
  317. def insert(txn):
  318. sql = (
  319. "INSERT INTO event_search (event_id, room_id, key,"
  320. " sender, vector, origin_server_ts, stream_ordering)"
  321. " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
  322. )
  323. rows_dict = []
  324. for row in rows:
  325. d = dict(zip(headers, row))
  326. if "\0" in d['value']:
  327. logger.warn('dropping search row %s', d)
  328. else:
  329. rows_dict.append(d)
  330. txn.executemany(
  331. sql,
  332. [
  333. (
  334. row["event_id"],
  335. row["room_id"],
  336. row["key"],
  337. row["sender"],
  338. row["value"],
  339. row["origin_server_ts"],
  340. row["stream_ordering"],
  341. )
  342. for row in rows_dict
  343. ],
  344. )
  345. self.postgres_store._simple_update_one_txn(
  346. txn,
  347. table="port_from_sqlite3",
  348. keyvalues={"table_name": "event_search"},
  349. updatevalues={
  350. "forward_rowid": forward_chunk,
  351. "backward_rowid": backward_chunk,
  352. },
  353. )
  354. yield self.postgres_store.execute(insert)
  355. postgres_size += len(rows)
  356. self.progress.update("event_search", postgres_size)
  357. else:
  358. return
  359. def setup_db(self, db_config, database_engine):
  360. db_conn = database_engine.module.connect(
  361. **{
  362. k: v
  363. for k, v in db_config.get("args", {}).items()
  364. if not k.startswith("cp_")
  365. }
  366. )
  367. prepare_database(db_conn, database_engine, config=None)
  368. db_conn.commit()
  369. @defer.inlineCallbacks
  370. def run(self):
  371. try:
  372. sqlite_db_pool = adbapi.ConnectionPool(
  373. self.sqlite_config["name"], **self.sqlite_config["args"]
  374. )
  375. postgres_db_pool = adbapi.ConnectionPool(
  376. self.postgres_config["name"], **self.postgres_config["args"]
  377. )
  378. sqlite_engine = create_engine(sqlite_config)
  379. postgres_engine = create_engine(postgres_config)
  380. self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
  381. self.postgres_store = Store(postgres_db_pool, postgres_engine)
  382. yield self.postgres_store.execute(postgres_engine.check_database)
  383. # Step 1. Set up databases.
  384. self.progress.set_state("Preparing SQLite3")
  385. self.setup_db(sqlite_config, sqlite_engine)
  386. self.progress.set_state("Preparing PostgreSQL")
  387. self.setup_db(postgres_config, postgres_engine)
  388. self.progress.set_state("Creating port tables")
  389. def create_port_table(txn):
  390. txn.execute(
  391. "CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
  392. " table_name varchar(100) NOT NULL UNIQUE,"
  393. " forward_rowid bigint NOT NULL,"
  394. " backward_rowid bigint NOT NULL"
  395. ")"
  396. )
  397. # The old port script created a table with just a "rowid" column.
  398. # We want people to be able to rerun this script from an old port
  399. # so that they can pick up any missing events that were not
  400. # ported across.
  401. def alter_table(txn):
  402. txn.execute(
  403. "ALTER TABLE IF EXISTS port_from_sqlite3"
  404. " RENAME rowid TO forward_rowid"
  405. )
  406. txn.execute(
  407. "ALTER TABLE IF EXISTS port_from_sqlite3"
  408. " ADD backward_rowid bigint NOT NULL DEFAULT 0"
  409. )
  410. try:
  411. yield self.postgres_store.runInteraction("alter_table", alter_table)
  412. except Exception:
  413. # On Error Resume Next
  414. pass
  415. yield self.postgres_store.runInteraction(
  416. "create_port_table", create_port_table
  417. )
  418. # Step 2. Get tables.
  419. self.progress.set_state("Fetching tables")
  420. sqlite_tables = yield self.sqlite_store._simple_select_onecol(
  421. table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
  422. )
  423. postgres_tables = yield self.postgres_store._simple_select_onecol(
  424. table="information_schema.tables",
  425. keyvalues={},
  426. retcol="distinct table_name",
  427. )
  428. tables = set(sqlite_tables) & set(postgres_tables)
  429. logger.info("Found %d tables", len(tables))
  430. # Step 3. Figure out what still needs copying
  431. self.progress.set_state("Checking on port progress")
  432. setup_res = yield defer.gatherResults(
  433. [
  434. self.setup_table(table)
  435. for table in tables
  436. if table not in ["schema_version", "applied_schema_deltas"]
  437. and not table.startswith("sqlite_")
  438. ],
  439. consumeErrors=True,
  440. )
  441. # Step 4. Do the copying.
  442. self.progress.set_state("Copying to postgres")
  443. yield defer.gatherResults(
  444. [self.handle_table(*res) for res in setup_res], consumeErrors=True
  445. )
  446. # Step 5. Do final post-processing
  447. yield self._setup_state_group_id_seq()
  448. self.progress.done()
  449. except Exception:
  450. global end_error_exec_info
  451. end_error_exec_info = sys.exc_info()
  452. logger.exception("")
  453. finally:
  454. reactor.stop()
  455. def _convert_rows(self, table, headers, rows):
  456. bool_col_names = BOOLEAN_COLUMNS.get(table, [])
  457. bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
  458. class BadValueException(Exception):
  459. pass
  460. def conv(j, col):
  461. if j in bool_cols:
  462. return bool(col)
  463. elif isinstance(col, string_types) and "\0" in col:
  464. logger.warn(
  465. "DROPPING ROW: NUL value in table %s col %s: %r",
  466. table,
  467. headers[j],
  468. col,
  469. )
  470. raise BadValueException()
  471. return col
  472. outrows = []
  473. for i, row in enumerate(rows):
  474. try:
  475. outrows.append(
  476. tuple(conv(j, col) for j, col in enumerate(row) if j > 0)
  477. )
  478. except BadValueException:
  479. pass
  480. return outrows
  481. @defer.inlineCallbacks
  482. def _setup_sent_transactions(self):
  483. # Only save things from the last day
  484. yesterday = int(time.time() * 1000) - 86400000
  485. # And save the max transaction id from each destination
  486. select = (
  487. "SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
  488. "SELECT max(rowid) FROM sent_transactions"
  489. " GROUP BY destination"
  490. ")"
  491. )
  492. def r(txn):
  493. txn.execute(select)
  494. rows = txn.fetchall()
  495. headers = [column[0] for column in txn.description]
  496. ts_ind = headers.index('ts')
  497. return headers, [r for r in rows if r[ts_ind] < yesterday]
  498. headers, rows = yield self.sqlite_store.runInteraction("select", r)
  499. rows = self._convert_rows("sent_transactions", headers, rows)
  500. inserted_rows = len(rows)
  501. if inserted_rows:
  502. max_inserted_rowid = max(r[0] for r in rows)
  503. def insert(txn):
  504. self.postgres_store.insert_many_txn(
  505. txn, "sent_transactions", headers[1:], rows
  506. )
  507. yield self.postgres_store.execute(insert)
  508. else:
  509. max_inserted_rowid = 0
  510. def get_start_id(txn):
  511. txn.execute(
  512. "SELECT rowid FROM sent_transactions WHERE ts >= ?"
  513. " ORDER BY rowid ASC LIMIT 1",
  514. (yesterday,),
  515. )
  516. rows = txn.fetchall()
  517. if rows:
  518. return rows[0][0]
  519. else:
  520. return 1
  521. next_chunk = yield self.sqlite_store.execute(get_start_id)
  522. next_chunk = max(max_inserted_rowid + 1, next_chunk)
  523. yield self.postgres_store._simple_insert(
  524. table="port_from_sqlite3",
  525. values={
  526. "table_name": "sent_transactions",
  527. "forward_rowid": next_chunk,
  528. "backward_rowid": 0,
  529. },
  530. )
  531. def get_sent_table_size(txn):
  532. txn.execute(
  533. "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
  534. )
  535. size, = txn.fetchone()
  536. return int(size)
  537. remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
  538. total_count = remaining_count + inserted_rows
  539. defer.returnValue((next_chunk, inserted_rows, total_count))
  540. @defer.inlineCallbacks
  541. def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
  542. frows = yield self.sqlite_store.execute_sql(
  543. "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
  544. )
  545. brows = yield self.sqlite_store.execute_sql(
  546. "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
  547. )
  548. defer.returnValue(frows[0][0] + brows[0][0])
  549. @defer.inlineCallbacks
  550. def _get_already_ported_count(self, table):
  551. rows = yield self.postgres_store.execute_sql(
  552. "SELECT count(*) FROM %s" % (table,)
  553. )
  554. defer.returnValue(rows[0][0])
  555. @defer.inlineCallbacks
  556. def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
  557. remaining, done = yield defer.gatherResults(
  558. [
  559. self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
  560. self._get_already_ported_count(table),
  561. ],
  562. consumeErrors=True,
  563. )
  564. remaining = int(remaining) if remaining else 0
  565. done = int(done) if done else 0
  566. defer.returnValue((done, remaining + done))
  567. def _setup_state_group_id_seq(self):
  568. def r(txn):
  569. txn.execute("SELECT MAX(id) FROM state_groups")
  570. next_id = txn.fetchone()[0] + 1
  571. txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
  572. return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
  573. ##############################################
  574. # The following is simply UI stuff
  575. ##############################################
  576. class Progress(object):
  577. """Used to report progress of the port
  578. """
  579. def __init__(self):
  580. self.tables = {}
  581. self.start_time = int(time.time())
  582. def add_table(self, table, cur, size):
  583. self.tables[table] = {
  584. "start": cur,
  585. "num_done": cur,
  586. "total": size,
  587. "perc": int(cur * 100 / size),
  588. }
  589. def update(self, table, num_done):
  590. data = self.tables[table]
  591. data["num_done"] = num_done
  592. data["perc"] = int(num_done * 100 / data["total"])
  593. def done(self):
  594. pass
  595. class CursesProgress(Progress):
  596. """Reports progress to a curses window
  597. """
  598. def __init__(self, stdscr):
  599. self.stdscr = stdscr
  600. curses.use_default_colors()
  601. curses.curs_set(0)
  602. curses.init_pair(1, curses.COLOR_RED, -1)
  603. curses.init_pair(2, curses.COLOR_GREEN, -1)
  604. self.last_update = 0
  605. self.finished = False
  606. self.total_processed = 0
  607. self.total_remaining = 0
  608. super(CursesProgress, self).__init__()
  609. def update(self, table, num_done):
  610. super(CursesProgress, self).update(table, num_done)
  611. self.total_processed = 0
  612. self.total_remaining = 0
  613. for table, data in self.tables.items():
  614. self.total_processed += data["num_done"] - data["start"]
  615. self.total_remaining += data["total"] - data["num_done"]
  616. self.render()
  617. def render(self, force=False):
  618. now = time.time()
  619. if not force and now - self.last_update < 0.2:
  620. # reactor.callLater(1, self.render)
  621. return
  622. self.stdscr.clear()
  623. rows, cols = self.stdscr.getmaxyx()
  624. duration = int(now) - int(self.start_time)
  625. minutes, seconds = divmod(duration, 60)
  626. duration_str = '%02dm %02ds' % (minutes, seconds)
  627. if self.finished:
  628. status = "Time spent: %s (Done!)" % (duration_str,)
  629. else:
  630. if self.total_processed > 0:
  631. left = float(self.total_remaining) / self.total_processed
  632. est_remaining = (int(now) - self.start_time) * left
  633. est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
  634. else:
  635. est_remaining_str = "Unknown"
  636. status = "Time spent: %s (est. remaining: %s)" % (
  637. duration_str,
  638. est_remaining_str,
  639. )
  640. self.stdscr.addstr(0, 0, status, curses.A_BOLD)
  641. max_len = max([len(t) for t in self.tables.keys()])
  642. left_margin = 5
  643. middle_space = 1
  644. items = self.tables.items()
  645. items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
  646. for i, (table, data) in enumerate(items):
  647. if i + 2 >= rows:
  648. break
  649. perc = data["perc"]
  650. color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
  651. self.stdscr.addstr(
  652. i + 2, left_margin + max_len - len(table), table, curses.A_BOLD | color
  653. )
  654. size = 20
  655. progress = "[%s%s]" % (
  656. "#" * int(perc * size / 100),
  657. " " * (size - int(perc * size / 100)),
  658. )
  659. self.stdscr.addstr(
  660. i + 2,
  661. left_margin + max_len + middle_space,
  662. "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
  663. )
  664. if self.finished:
  665. self.stdscr.addstr(rows - 1, 0, "Press any key to exit...")
  666. self.stdscr.refresh()
  667. self.last_update = time.time()
  668. def done(self):
  669. self.finished = True
  670. self.render(True)
  671. self.stdscr.getch()
  672. def set_state(self, state):
  673. self.stdscr.clear()
  674. self.stdscr.addstr(0, 0, state + "...", curses.A_BOLD)
  675. self.stdscr.refresh()
  676. class TerminalProgress(Progress):
  677. """Just prints progress to the terminal
  678. """
  679. def update(self, table, num_done):
  680. super(TerminalProgress, self).update(table, num_done)
  681. data = self.tables[table]
  682. print(
  683. "%s: %d%% (%d/%d)" % (table, data["perc"], data["num_done"], data["total"])
  684. )
  685. def set_state(self, state):
  686. print(state + "...")
  687. ##############################################
  688. ##############################################
  689. if __name__ == "__main__":
  690. parser = argparse.ArgumentParser(
  691. description="A script to port an existing synapse SQLite database to"
  692. " a new PostgreSQL database."
  693. )
  694. parser.add_argument("-v", action='store_true')
  695. parser.add_argument(
  696. "--sqlite-database",
  697. required=True,
  698. help="The snapshot of the SQLite database file. This must not be"
  699. " currently used by a running synapse server",
  700. )
  701. parser.add_argument(
  702. "--postgres-config",
  703. type=argparse.FileType('r'),
  704. required=True,
  705. help="The database config file for the PostgreSQL database",
  706. )
  707. parser.add_argument(
  708. "--curses", action='store_true', help="display a curses based progress UI"
  709. )
  710. parser.add_argument(
  711. "--batch-size",
  712. type=int,
  713. default=1000,
  714. help="The number of rows to select from the SQLite table each"
  715. " iteration [default=1000]",
  716. )
  717. args = parser.parse_args()
  718. logging_config = {
  719. "level": logging.DEBUG if args.v else logging.INFO,
  720. "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s",
  721. }
  722. if args.curses:
  723. logging_config["filename"] = "port-synapse.log"
  724. logging.basicConfig(**logging_config)
  725. sqlite_config = {
  726. "name": "sqlite3",
  727. "args": {
  728. "database": args.sqlite_database,
  729. "cp_min": 1,
  730. "cp_max": 1,
  731. "check_same_thread": False,
  732. },
  733. }
  734. postgres_config = yaml.safe_load(args.postgres_config)
  735. if "database" in postgres_config:
  736. postgres_config = postgres_config["database"]
  737. if "name" not in postgres_config:
  738. sys.stderr.write("Malformed database config: no 'name'")
  739. sys.exit(2)
  740. if postgres_config["name"] != "psycopg2":
  741. sys.stderr.write("Database must use 'psycopg2' connector.")
  742. sys.exit(3)
  743. def start(stdscr=None):
  744. if stdscr:
  745. progress = CursesProgress(stdscr)
  746. else:
  747. progress = TerminalProgress()
  748. porter = Porter(
  749. sqlite_config=sqlite_config,
  750. postgres_config=postgres_config,
  751. progress=progress,
  752. batch_size=args.batch_size,
  753. )
  754. reactor.callWhenRunning(porter.run)
  755. reactor.run()
  756. if args.curses:
  757. curses.wrapper(start)
  758. else:
  759. start()
  760. if end_error_exec_info:
  761. exc_type, exc_value, exc_traceback = end_error_exec_info
  762. traceback.print_exception(exc_type, exc_value, exc_traceback)