_base.py 44 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import sys
  17. import threading
  18. import time
  19. from six import PY2, iteritems, iterkeys, itervalues
  20. from six.moves import builtins, intern, range
  21. from canonicaljson import json
  22. from prometheus_client import Histogram
  23. from twisted.internet import defer
  24. from synapse.api.errors import StoreError
  25. from synapse.storage.engines import PostgresEngine
  26. from synapse.util.caches.descriptors import Cache
  27. from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
  28. from synapse.util.stringutils import exception_to_unicode
  29. logger = logging.getLogger(__name__)
  30. try:
  31. MAX_TXN_ID = sys.maxint - 1
  32. except AttributeError:
  33. # python 3 does not have a maximum int value
  34. MAX_TXN_ID = 2**63 - 1
  35. sql_logger = logging.getLogger("synapse.storage.SQL")
  36. transaction_logger = logging.getLogger("synapse.storage.txn")
  37. perf_logger = logging.getLogger("synapse.storage.TIME")
  38. sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
  39. sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
  40. sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
  41. class LoggingTransaction(object):
  42. """An object that almost-transparently proxies for the 'txn' object
  43. passed to the constructor. Adds logging and metrics to the .execute()
  44. method."""
  45. __slots__ = [
  46. "txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
  47. ]
  48. def __init__(self, txn, name, database_engine, after_callbacks,
  49. exception_callbacks):
  50. object.__setattr__(self, "txn", txn)
  51. object.__setattr__(self, "name", name)
  52. object.__setattr__(self, "database_engine", database_engine)
  53. object.__setattr__(self, "after_callbacks", after_callbacks)
  54. object.__setattr__(self, "exception_callbacks", exception_callbacks)
  55. def call_after(self, callback, *args, **kwargs):
  56. """Call the given callback on the main twisted thread after the
  57. transaction has finished. Used to invalidate the caches on the
  58. correct thread.
  59. """
  60. self.after_callbacks.append((callback, args, kwargs))
  61. def call_on_exception(self, callback, *args, **kwargs):
  62. self.exception_callbacks.append((callback, args, kwargs))
  63. def __getattr__(self, name):
  64. return getattr(self.txn, name)
  65. def __setattr__(self, name, value):
  66. setattr(self.txn, name, value)
  67. def __iter__(self):
  68. return self.txn.__iter__()
  69. def execute(self, sql, *args):
  70. self._do_execute(self.txn.execute, sql, *args)
  71. def executemany(self, sql, *args):
  72. self._do_execute(self.txn.executemany, sql, *args)
  73. def _make_sql_one_line(self, sql):
  74. "Strip newlines out of SQL so that the loggers in the DB are on one line"
  75. return " ".join(l.strip() for l in sql.splitlines() if l.strip())
  76. def _do_execute(self, func, sql, *args):
  77. sql = self._make_sql_one_line(sql)
  78. # TODO(paul): Maybe use 'info' and 'debug' for values?
  79. sql_logger.debug("[SQL] {%s} %s", self.name, sql)
  80. sql = self.database_engine.convert_param_style(sql)
  81. if args:
  82. try:
  83. sql_logger.debug(
  84. "[SQL values] {%s} %r",
  85. self.name, args[0]
  86. )
  87. except Exception:
  88. # Don't let logging failures stop SQL from working
  89. pass
  90. start = time.time()
  91. try:
  92. return func(
  93. sql, *args
  94. )
  95. except Exception as e:
  96. logger.debug("[SQL FAIL] {%s} %s", self.name, e)
  97. raise
  98. finally:
  99. secs = time.time() - start
  100. sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
  101. sql_query_timer.labels(sql.split()[0]).observe(secs)
  102. class PerformanceCounters(object):
  103. def __init__(self):
  104. self.current_counters = {}
  105. self.previous_counters = {}
  106. def update(self, key, start_time, end_time=None):
  107. if end_time is None:
  108. end_time = time.time()
  109. duration = end_time - start_time
  110. count, cum_time = self.current_counters.get(key, (0, 0))
  111. count += 1
  112. cum_time += duration
  113. self.current_counters[key] = (count, cum_time)
  114. return end_time
  115. def interval(self, interval_duration, limit=3):
  116. counters = []
  117. for name, (count, cum_time) in iteritems(self.current_counters):
  118. prev_count, prev_time = self.previous_counters.get(name, (0, 0))
  119. counters.append((
  120. (cum_time - prev_time) / interval_duration,
  121. count - prev_count,
  122. name
  123. ))
  124. self.previous_counters = dict(self.current_counters)
  125. counters.sort(reverse=True)
  126. top_n_counters = ", ".join(
  127. "%s(%d): %.3f%%" % (name, count, 100 * ratio)
  128. for ratio, count, name in counters[:limit]
  129. )
  130. return top_n_counters
  131. class SQLBaseStore(object):
  132. _TXN_ID = 0
  133. def __init__(self, db_conn, hs):
  134. self.hs = hs
  135. self._clock = hs.get_clock()
  136. self._db_pool = hs.get_db_pool()
  137. self._previous_txn_total_time = 0
  138. self._current_txn_total_time = 0
  139. self._previous_loop_ts = 0
  140. # TODO(paul): These can eventually be removed once the metrics code
  141. # is running in mainline, and we have some nice monitoring frontends
  142. # to watch it
  143. self._txn_perf_counters = PerformanceCounters()
  144. self._get_event_counters = PerformanceCounters()
  145. self._get_event_cache = Cache("*getEvent*", keylen=3,
  146. max_entries=hs.config.event_cache_size)
  147. self._event_fetch_lock = threading.Condition()
  148. self._event_fetch_list = []
  149. self._event_fetch_ongoing = 0
  150. self._pending_ds = []
  151. self.database_engine = hs.database_engine
  152. def start_profiling(self):
  153. self._previous_loop_ts = self._clock.time_msec()
  154. def loop():
  155. curr = self._current_txn_total_time
  156. prev = self._previous_txn_total_time
  157. self._previous_txn_total_time = curr
  158. time_now = self._clock.time_msec()
  159. time_then = self._previous_loop_ts
  160. self._previous_loop_ts = time_now
  161. ratio = (curr - prev) / (time_now - time_then)
  162. top_three_counters = self._txn_perf_counters.interval(
  163. time_now - time_then, limit=3
  164. )
  165. top_3_event_counters = self._get_event_counters.interval(
  166. time_now - time_then, limit=3
  167. )
  168. perf_logger.info(
  169. "Total database time: %.3f%% {%s} {%s}",
  170. ratio * 100, top_three_counters, top_3_event_counters
  171. )
  172. self._clock.looping_call(loop, 10000)
  173. def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
  174. func, *args, **kwargs):
  175. start = time.time()
  176. txn_id = self._TXN_ID
  177. # We don't really need these to be unique, so lets stop it from
  178. # growing really large.
  179. self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
  180. name = "%s-%x" % (desc, txn_id, )
  181. transaction_logger.debug("[TXN START] {%s}", name)
  182. try:
  183. i = 0
  184. N = 5
  185. while True:
  186. try:
  187. txn = conn.cursor()
  188. txn = LoggingTransaction(
  189. txn, name, self.database_engine, after_callbacks,
  190. exception_callbacks,
  191. )
  192. r = func(txn, *args, **kwargs)
  193. conn.commit()
  194. return r
  195. except self.database_engine.module.OperationalError as e:
  196. # This can happen if the database disappears mid
  197. # transaction.
  198. logger.warning(
  199. "[TXN OPERROR] {%s} %s %d/%d",
  200. name, exception_to_unicode(e), i, N
  201. )
  202. if i < N:
  203. i += 1
  204. try:
  205. conn.rollback()
  206. except self.database_engine.module.Error as e1:
  207. logger.warning(
  208. "[TXN EROLL] {%s} %s",
  209. name, exception_to_unicode(e1),
  210. )
  211. continue
  212. raise
  213. except self.database_engine.module.DatabaseError as e:
  214. if self.database_engine.is_deadlock(e):
  215. logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
  216. if i < N:
  217. i += 1
  218. try:
  219. conn.rollback()
  220. except self.database_engine.module.Error as e1:
  221. logger.warning(
  222. "[TXN EROLL] {%s} %s",
  223. name, exception_to_unicode(e1),
  224. )
  225. continue
  226. raise
  227. except Exception as e:
  228. logger.debug("[TXN FAIL] {%s} %s", name, e)
  229. raise
  230. finally:
  231. end = time.time()
  232. duration = end - start
  233. LoggingContext.current_context().add_database_transaction(duration)
  234. transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
  235. self._current_txn_total_time += duration
  236. self._txn_perf_counters.update(desc, start, end)
  237. sql_txn_timer.labels(desc).observe(duration)
  238. @defer.inlineCallbacks
  239. def runInteraction(self, desc, func, *args, **kwargs):
  240. """Starts a transaction on the database and runs a given function
  241. Arguments:
  242. desc (str): description of the transaction, for logging and metrics
  243. func (func): callback function, which will be called with a
  244. database transaction (twisted.enterprise.adbapi.Transaction) as
  245. its first argument, followed by `args` and `kwargs`.
  246. args (list): positional args to pass to `func`
  247. kwargs (dict): named args to pass to `func`
  248. Returns:
  249. Deferred: The result of func
  250. """
  251. after_callbacks = []
  252. exception_callbacks = []
  253. if LoggingContext.current_context() == LoggingContext.sentinel:
  254. logger.warn(
  255. "Starting db txn '%s' from sentinel context",
  256. desc,
  257. )
  258. try:
  259. result = yield self.runWithConnection(
  260. self._new_transaction,
  261. desc, after_callbacks, exception_callbacks, func,
  262. *args, **kwargs
  263. )
  264. for after_callback, after_args, after_kwargs in after_callbacks:
  265. after_callback(*after_args, **after_kwargs)
  266. except: # noqa: E722, as we reraise the exception this is fine.
  267. for after_callback, after_args, after_kwargs in exception_callbacks:
  268. after_callback(*after_args, **after_kwargs)
  269. raise
  270. defer.returnValue(result)
  271. @defer.inlineCallbacks
  272. def runWithConnection(self, func, *args, **kwargs):
  273. """Wraps the .runWithConnection() method on the underlying db_pool.
  274. Arguments:
  275. func (func): callback function, which will be called with a
  276. database connection (twisted.enterprise.adbapi.Connection) as
  277. its first argument, followed by `args` and `kwargs`.
  278. args (list): positional args to pass to `func`
  279. kwargs (dict): named args to pass to `func`
  280. Returns:
  281. Deferred: The result of func
  282. """
  283. parent_context = LoggingContext.current_context()
  284. if parent_context == LoggingContext.sentinel:
  285. logger.warn(
  286. "Starting db connection from sentinel context: metrics will be lost",
  287. )
  288. parent_context = None
  289. start_time = time.time()
  290. def inner_func(conn, *args, **kwargs):
  291. with LoggingContext("runWithConnection", parent_context) as context:
  292. sched_duration_sec = time.time() - start_time
  293. sql_scheduling_timer.observe(sched_duration_sec)
  294. context.add_database_scheduled(sched_duration_sec)
  295. if self.database_engine.is_connection_closed(conn):
  296. logger.debug("Reconnecting closed database connection")
  297. conn.reconnect()
  298. return func(conn, *args, **kwargs)
  299. with PreserveLoggingContext():
  300. result = yield self._db_pool.runWithConnection(
  301. inner_func, *args, **kwargs
  302. )
  303. defer.returnValue(result)
  304. @staticmethod
  305. def cursor_to_dict(cursor):
  306. """Converts a SQL cursor into an list of dicts.
  307. Args:
  308. cursor : The DBAPI cursor which has executed a query.
  309. Returns:
  310. A list of dicts where the key is the column header.
  311. """
  312. col_headers = list(intern(str(column[0])) for column in cursor.description)
  313. results = list(
  314. dict(zip(col_headers, row)) for row in cursor
  315. )
  316. return results
  317. def _execute(self, desc, decoder, query, *args):
  318. """Runs a single query for a result set.
  319. Args:
  320. decoder - The function which can resolve the cursor results to
  321. something meaningful.
  322. query - The query string to execute
  323. *args - Query args.
  324. Returns:
  325. The result of decoder(results)
  326. """
  327. def interaction(txn):
  328. txn.execute(query, args)
  329. if decoder:
  330. return decoder(txn)
  331. else:
  332. return txn.fetchall()
  333. return self.runInteraction(desc, interaction)
  334. # "Simple" SQL API methods that operate on a single table with no JOINs,
  335. # no complex WHERE clauses, just a dict of values for columns.
  336. @defer.inlineCallbacks
  337. def _simple_insert(self, table, values, or_ignore=False,
  338. desc="_simple_insert"):
  339. """Executes an INSERT query on the named table.
  340. Args:
  341. table : string giving the table name
  342. values : dict of new column names and values for them
  343. Returns:
  344. bool: Whether the row was inserted or not. Only useful when
  345. `or_ignore` is True
  346. """
  347. try:
  348. yield self.runInteraction(
  349. desc,
  350. self._simple_insert_txn, table, values,
  351. )
  352. except self.database_engine.module.IntegrityError:
  353. # We have to do or_ignore flag at this layer, since we can't reuse
  354. # a cursor after we receive an error from the db.
  355. if not or_ignore:
  356. raise
  357. defer.returnValue(False)
  358. defer.returnValue(True)
  359. @staticmethod
  360. def _simple_insert_txn(txn, table, values):
  361. keys, vals = zip(*values.items())
  362. sql = "INSERT INTO %s (%s) VALUES(%s)" % (
  363. table,
  364. ", ".join(k for k in keys),
  365. ", ".join("?" for _ in keys)
  366. )
  367. txn.execute(sql, vals)
  368. def _simple_insert_many(self, table, values, desc):
  369. return self.runInteraction(
  370. desc, self._simple_insert_many_txn, table, values
  371. )
  372. @staticmethod
  373. def _simple_insert_many_txn(txn, table, values):
  374. if not values:
  375. return
  376. # This is a *slight* abomination to get a list of tuples of key names
  377. # and a list of tuples of value names.
  378. #
  379. # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
  380. # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
  381. #
  382. # The sort is to ensure that we don't rely on dictionary iteration
  383. # order.
  384. keys, vals = zip(*[
  385. zip(
  386. *(sorted(i.items(), key=lambda kv: kv[0]))
  387. )
  388. for i in values
  389. if i
  390. ])
  391. for k in keys:
  392. if k != keys[0]:
  393. raise RuntimeError(
  394. "All items must have the same keys"
  395. )
  396. sql = "INSERT INTO %s (%s) VALUES(%s)" % (
  397. table,
  398. ", ".join(k for k in keys[0]),
  399. ", ".join("?" for _ in keys[0])
  400. )
  401. txn.executemany(sql, vals)
  402. @defer.inlineCallbacks
  403. def _simple_upsert(self, table, keyvalues, values,
  404. insertion_values={}, desc="_simple_upsert", lock=True):
  405. """
  406. `lock` should generally be set to True (the default), but can be set
  407. to False if either of the following are true:
  408. * there is a UNIQUE INDEX on the key columns. In this case a conflict
  409. will cause an IntegrityError in which case this function will retry
  410. the update.
  411. * we somehow know that we are the only thread which will be updating
  412. this table.
  413. Args:
  414. table (str): The table to upsert into
  415. keyvalues (dict): The unique key tables and their new values
  416. values (dict): The nonunique columns and their new values
  417. insertion_values (dict): additional key/values to use only when
  418. inserting
  419. lock (bool): True to lock the table when doing the upsert.
  420. Returns:
  421. Deferred(bool): True if a new entry was created, False if an
  422. existing one was updated.
  423. """
  424. attempts = 0
  425. while True:
  426. try:
  427. result = yield self.runInteraction(
  428. desc,
  429. self._simple_upsert_txn, table, keyvalues, values, insertion_values,
  430. lock=lock
  431. )
  432. defer.returnValue(result)
  433. except self.database_engine.module.IntegrityError as e:
  434. attempts += 1
  435. if attempts >= 5:
  436. # don't retry forever, because things other than races
  437. # can cause IntegrityErrors
  438. raise
  439. # presumably we raced with another transaction: let's retry.
  440. logger.warn(
  441. "IntegrityError when upserting into %s; retrying: %s",
  442. table, e
  443. )
  444. def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
  445. lock=True):
  446. # We need to lock the table :(, unless we're *really* careful
  447. if lock:
  448. self.database_engine.lock_table(txn, table)
  449. def _getwhere(key):
  450. # If the value we're passing in is None (aka NULL), we need to use
  451. # IS, not =, as NULL = NULL equals NULL (False).
  452. if keyvalues[key] is None:
  453. return "%s IS ?" % (key,)
  454. else:
  455. return "%s = ?" % (key,)
  456. # First try to update.
  457. sql = "UPDATE %s SET %s WHERE %s" % (
  458. table,
  459. ", ".join("%s = ?" % (k,) for k in values),
  460. " AND ".join(_getwhere(k) for k in keyvalues)
  461. )
  462. sqlargs = list(values.values()) + list(keyvalues.values())
  463. txn.execute(sql, sqlargs)
  464. if txn.rowcount > 0:
  465. # successfully updated at least one row.
  466. return False
  467. # We didn't update any rows so insert a new one
  468. allvalues = {}
  469. allvalues.update(keyvalues)
  470. allvalues.update(values)
  471. allvalues.update(insertion_values)
  472. sql = "INSERT INTO %s (%s) VALUES (%s)" % (
  473. table,
  474. ", ".join(k for k in allvalues),
  475. ", ".join("?" for _ in allvalues)
  476. )
  477. txn.execute(sql, list(allvalues.values()))
  478. # successfully inserted
  479. return True
  480. def _simple_select_one(self, table, keyvalues, retcols,
  481. allow_none=False, desc="_simple_select_one"):
  482. """Executes a SELECT query on the named table, which is expected to
  483. return a single row, returning multiple columns from it.
  484. Args:
  485. table : string giving the table name
  486. keyvalues : dict of column names and values to select the row with
  487. retcols : list of strings giving the names of the columns to return
  488. allow_none : If true, return None instead of failing if the SELECT
  489. statement returns no rows
  490. """
  491. return self.runInteraction(
  492. desc,
  493. self._simple_select_one_txn,
  494. table, keyvalues, retcols, allow_none,
  495. )
  496. def _simple_select_one_onecol(self, table, keyvalues, retcol,
  497. allow_none=False,
  498. desc="_simple_select_one_onecol"):
  499. """Executes a SELECT query on the named table, which is expected to
  500. return a single row, returning a single column from it.
  501. Args:
  502. table : string giving the table name
  503. keyvalues : dict of column names and values to select the row with
  504. retcol : string giving the name of the column to return
  505. """
  506. return self.runInteraction(
  507. desc,
  508. self._simple_select_one_onecol_txn,
  509. table, keyvalues, retcol, allow_none=allow_none,
  510. )
  511. @classmethod
  512. def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
  513. allow_none=False):
  514. ret = cls._simple_select_onecol_txn(
  515. txn,
  516. table=table,
  517. keyvalues=keyvalues,
  518. retcol=retcol,
  519. )
  520. if ret:
  521. return ret[0]
  522. else:
  523. if allow_none:
  524. return None
  525. else:
  526. raise StoreError(404, "No row found")
  527. @staticmethod
  528. def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
  529. sql = (
  530. "SELECT %(retcol)s FROM %(table)s"
  531. ) % {
  532. "retcol": retcol,
  533. "table": table,
  534. }
  535. if keyvalues:
  536. sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
  537. txn.execute(sql, list(keyvalues.values()))
  538. else:
  539. txn.execute(sql)
  540. return [r[0] for r in txn]
  541. def _simple_select_onecol(self, table, keyvalues, retcol,
  542. desc="_simple_select_onecol"):
  543. """Executes a SELECT query on the named table, which returns a list
  544. comprising of the values of the named column from the selected rows.
  545. Args:
  546. table (str): table name
  547. keyvalues (dict|None): column names and values to select the rows with
  548. retcol (str): column whos value we wish to retrieve.
  549. Returns:
  550. Deferred: Results in a list
  551. """
  552. return self.runInteraction(
  553. desc,
  554. self._simple_select_onecol_txn,
  555. table, keyvalues, retcol
  556. )
  557. def _simple_select_list(self, table, keyvalues, retcols,
  558. desc="_simple_select_list"):
  559. """Executes a SELECT query on the named table, which may return zero or
  560. more rows, returning the result as a list of dicts.
  561. Args:
  562. table (str): the table name
  563. keyvalues (dict[str, Any] | None):
  564. column names and values to select the rows with, or None to not
  565. apply a WHERE clause.
  566. retcols (iterable[str]): the names of the columns to return
  567. Returns:
  568. defer.Deferred: resolves to list[dict[str, Any]]
  569. """
  570. return self.runInteraction(
  571. desc,
  572. self._simple_select_list_txn,
  573. table, keyvalues, retcols
  574. )
  575. @classmethod
  576. def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
  577. """Executes a SELECT query on the named table, which may return zero or
  578. more rows, returning the result as a list of dicts.
  579. Args:
  580. txn : Transaction object
  581. table (str): the table name
  582. keyvalues (dict[str, T] | None):
  583. column names and values to select the rows with, or None to not
  584. apply a WHERE clause.
  585. retcols (iterable[str]): the names of the columns to return
  586. """
  587. if keyvalues:
  588. sql = "SELECT %s FROM %s WHERE %s" % (
  589. ", ".join(retcols),
  590. table,
  591. " AND ".join("%s = ?" % (k, ) for k in keyvalues)
  592. )
  593. txn.execute(sql, list(keyvalues.values()))
  594. else:
  595. sql = "SELECT %s FROM %s" % (
  596. ", ".join(retcols),
  597. table
  598. )
  599. txn.execute(sql)
  600. return cls.cursor_to_dict(txn)
  601. @defer.inlineCallbacks
  602. def _simple_select_many_batch(self, table, column, iterable, retcols,
  603. keyvalues={}, desc="_simple_select_many_batch",
  604. batch_size=100):
  605. """Executes a SELECT query on the named table, which may return zero or
  606. more rows, returning the result as a list of dicts.
  607. Filters rows by if value of `column` is in `iterable`.
  608. Args:
  609. table : string giving the table name
  610. column : column name to test for inclusion against `iterable`
  611. iterable : list
  612. keyvalues : dict of column names and values to select the rows with
  613. retcols : list of strings giving the names of the columns to return
  614. """
  615. results = []
  616. if not iterable:
  617. defer.returnValue(results)
  618. # iterables can not be sliced, so convert it to a list first
  619. it_list = list(iterable)
  620. chunks = [
  621. it_list[i:i + batch_size]
  622. for i in range(0, len(it_list), batch_size)
  623. ]
  624. for chunk in chunks:
  625. rows = yield self.runInteraction(
  626. desc,
  627. self._simple_select_many_txn,
  628. table, column, chunk, keyvalues, retcols
  629. )
  630. results.extend(rows)
  631. defer.returnValue(results)
  632. @classmethod
  633. def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
  634. """Executes a SELECT query on the named table, which may return zero or
  635. more rows, returning the result as a list of dicts.
  636. Filters rows by if value of `column` is in `iterable`.
  637. Args:
  638. txn : Transaction object
  639. table : string giving the table name
  640. column : column name to test for inclusion against `iterable`
  641. iterable : list
  642. keyvalues : dict of column names and values to select the rows with
  643. retcols : list of strings giving the names of the columns to return
  644. """
  645. if not iterable:
  646. return []
  647. sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
  648. clauses = []
  649. values = []
  650. clauses.append(
  651. "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
  652. )
  653. values.extend(iterable)
  654. for key, value in iteritems(keyvalues):
  655. clauses.append("%s = ?" % (key,))
  656. values.append(value)
  657. if clauses:
  658. sql = "%s WHERE %s" % (
  659. sql,
  660. " AND ".join(clauses),
  661. )
  662. txn.execute(sql, values)
  663. return cls.cursor_to_dict(txn)
  664. def _simple_update(self, table, keyvalues, updatevalues, desc):
  665. return self.runInteraction(
  666. desc,
  667. self._simple_update_txn,
  668. table, keyvalues, updatevalues,
  669. )
  670. @staticmethod
  671. def _simple_update_txn(txn, table, keyvalues, updatevalues):
  672. if keyvalues:
  673. where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
  674. else:
  675. where = ""
  676. update_sql = "UPDATE %s SET %s %s" % (
  677. table,
  678. ", ".join("%s = ?" % (k,) for k in updatevalues),
  679. where,
  680. )
  681. txn.execute(
  682. update_sql,
  683. list(updatevalues.values()) + list(keyvalues.values())
  684. )
  685. return txn.rowcount
  686. def _simple_update_one(self, table, keyvalues, updatevalues,
  687. desc="_simple_update_one"):
  688. """Executes an UPDATE query on the named table, setting new values for
  689. columns in a row matching the key values.
  690. Args:
  691. table : string giving the table name
  692. keyvalues : dict of column names and values to select the row with
  693. updatevalues : dict giving column names and values to update
  694. retcols : optional list of column names to return
  695. If present, retcols gives a list of column names on which to perform
  696. a SELECT statement *before* performing the UPDATE statement. The values
  697. of these will be returned in a dict.
  698. These are performed within the same transaction, allowing an atomic
  699. get-and-set. This can be used to implement compare-and-set by putting
  700. the update column in the 'keyvalues' dict as well.
  701. """
  702. return self.runInteraction(
  703. desc,
  704. self._simple_update_one_txn,
  705. table, keyvalues, updatevalues,
  706. )
  707. @classmethod
  708. def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
  709. rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
  710. if rowcount == 0:
  711. raise StoreError(404, "No row found (%s)" % (table,))
  712. if rowcount > 1:
  713. raise StoreError(500, "More than one row matched (%s)" % (table,))
  714. @staticmethod
  715. def _simple_select_one_txn(txn, table, keyvalues, retcols,
  716. allow_none=False):
  717. select_sql = "SELECT %s FROM %s WHERE %s" % (
  718. ", ".join(retcols),
  719. table,
  720. " AND ".join("%s = ?" % (k,) for k in keyvalues)
  721. )
  722. txn.execute(select_sql, list(keyvalues.values()))
  723. row = txn.fetchone()
  724. if not row:
  725. if allow_none:
  726. return None
  727. raise StoreError(404, "No row found (%s)" % (table,))
  728. if txn.rowcount > 1:
  729. raise StoreError(500, "More than one row matched (%s)" % (table,))
  730. return dict(zip(retcols, row))
  731. def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
  732. """Executes a DELETE query on the named table, expecting to delete a
  733. single row.
  734. Args:
  735. table : string giving the table name
  736. keyvalues : dict of column names and values to select the row with
  737. """
  738. return self.runInteraction(
  739. desc, self._simple_delete_one_txn, table, keyvalues
  740. )
  741. @staticmethod
  742. def _simple_delete_one_txn(txn, table, keyvalues):
  743. """Executes a DELETE query on the named table, expecting to delete a
  744. single row.
  745. Args:
  746. table : string giving the table name
  747. keyvalues : dict of column names and values to select the row with
  748. """
  749. sql = "DELETE FROM %s WHERE %s" % (
  750. table,
  751. " AND ".join("%s = ?" % (k, ) for k in keyvalues)
  752. )
  753. txn.execute(sql, list(keyvalues.values()))
  754. if txn.rowcount == 0:
  755. raise StoreError(404, "No row found (%s)" % (table,))
  756. if txn.rowcount > 1:
  757. raise StoreError(500, "More than one row matched (%s)" % (table,))
  758. def _simple_delete(self, table, keyvalues, desc):
  759. return self.runInteraction(
  760. desc, self._simple_delete_txn, table, keyvalues
  761. )
  762. @staticmethod
  763. def _simple_delete_txn(txn, table, keyvalues):
  764. sql = "DELETE FROM %s WHERE %s" % (
  765. table,
  766. " AND ".join("%s = ?" % (k, ) for k in keyvalues)
  767. )
  768. return txn.execute(sql, list(keyvalues.values()))
  769. def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
  770. return self.runInteraction(
  771. desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
  772. )
  773. @staticmethod
  774. def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
  775. """Executes a DELETE query on the named table.
  776. Filters rows by if value of `column` is in `iterable`.
  777. Args:
  778. txn : Transaction object
  779. table : string giving the table name
  780. column : column name to test for inclusion against `iterable`
  781. iterable : list
  782. keyvalues : dict of column names and values to select the rows with
  783. """
  784. if not iterable:
  785. return
  786. sql = "DELETE FROM %s" % table
  787. clauses = []
  788. values = []
  789. clauses.append(
  790. "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
  791. )
  792. values.extend(iterable)
  793. for key, value in iteritems(keyvalues):
  794. clauses.append("%s = ?" % (key,))
  795. values.append(value)
  796. if clauses:
  797. sql = "%s WHERE %s" % (
  798. sql,
  799. " AND ".join(clauses),
  800. )
  801. return txn.execute(sql, values)
  802. def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
  803. max_value, limit=100000):
  804. # Fetch a mapping of room_id -> max stream position for "recent" rooms.
  805. # It doesn't really matter how many we get, the StreamChangeCache will
  806. # do the right thing to ensure it respects the max size of cache.
  807. sql = (
  808. "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
  809. " WHERE %(stream)s > ? - %(limit)s"
  810. " GROUP BY %(entity)s"
  811. ) % {
  812. "table": table,
  813. "entity": entity_column,
  814. "stream": stream_column,
  815. "limit": limit,
  816. }
  817. sql = self.database_engine.convert_param_style(sql)
  818. txn = db_conn.cursor()
  819. txn.execute(sql, (int(max_value),))
  820. cache = {
  821. row[0]: int(row[1])
  822. for row in txn
  823. }
  824. txn.close()
  825. if cache:
  826. min_val = min(itervalues(cache))
  827. else:
  828. min_val = max_value
  829. return cache, min_val
  830. def _invalidate_cache_and_stream(self, txn, cache_func, keys):
  831. """Invalidates the cache and adds it to the cache stream so slaves
  832. will know to invalidate their caches.
  833. This should only be used to invalidate caches where slaves won't
  834. otherwise know from other replication streams that the cache should
  835. be invalidated.
  836. """
  837. txn.call_after(cache_func.invalidate, keys)
  838. if isinstance(self.database_engine, PostgresEngine):
  839. # get_next() returns a context manager which is designed to wrap
  840. # the transaction. However, we want to only get an ID when we want
  841. # to use it, here, so we need to call __enter__ manually, and have
  842. # __exit__ called after the transaction finishes.
  843. ctx = self._cache_id_gen.get_next()
  844. stream_id = ctx.__enter__()
  845. txn.call_on_exception(ctx.__exit__, None, None, None)
  846. txn.call_after(ctx.__exit__, None, None, None)
  847. txn.call_after(self.hs.get_notifier().on_new_replication_data)
  848. self._simple_insert_txn(
  849. txn,
  850. table="cache_invalidation_stream",
  851. values={
  852. "stream_id": stream_id,
  853. "cache_func": cache_func.__name__,
  854. "keys": list(keys),
  855. "invalidation_ts": self.clock.time_msec(),
  856. }
  857. )
  858. def get_all_updated_caches(self, last_id, current_id, limit):
  859. if last_id == current_id:
  860. return defer.succeed([])
  861. def get_all_updated_caches_txn(txn):
  862. # We purposefully don't bound by the current token, as we want to
  863. # send across cache invalidations as quickly as possible. Cache
  864. # invalidations are idempotent, so duplicates are fine.
  865. sql = (
  866. "SELECT stream_id, cache_func, keys, invalidation_ts"
  867. " FROM cache_invalidation_stream"
  868. " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
  869. )
  870. txn.execute(sql, (last_id, limit,))
  871. return txn.fetchall()
  872. return self.runInteraction(
  873. "get_all_updated_caches", get_all_updated_caches_txn
  874. )
  875. def get_cache_stream_token(self):
  876. if self._cache_id_gen:
  877. return self._cache_id_gen.get_current_token()
  878. else:
  879. return 0
  880. def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
  881. desc="_simple_select_list_paginate"):
  882. """Executes a SELECT query on the named table with start and limit,
  883. of row numbers, which may return zero or number of rows from start to limit,
  884. returning the result as a list of dicts.
  885. Args:
  886. table (str): the table name
  887. keyvalues (dict[str, Any] | None):
  888. column names and values to select the rows with, or None to not
  889. apply a WHERE clause.
  890. retcols (iterable[str]): the names of the columns to return
  891. order (str): order the select by this column
  892. start (int): start number to begin the query from
  893. limit (int): number of rows to reterive
  894. Returns:
  895. defer.Deferred: resolves to list[dict[str, Any]]
  896. """
  897. return self.runInteraction(
  898. desc,
  899. self._simple_select_list_paginate_txn,
  900. table, keyvalues, pagevalues, retcols
  901. )
  902. @classmethod
  903. def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
  904. """Executes a SELECT query on the named table with start and limit,
  905. of row numbers, which may return zero or number of rows from start to limit,
  906. returning the result as a list of dicts.
  907. Args:
  908. txn : Transaction object
  909. table (str): the table name
  910. keyvalues (dict[str, T] | None):
  911. column names and values to select the rows with, or None to not
  912. apply a WHERE clause.
  913. pagevalues ([]):
  914. order (str): order the select by this column
  915. start (int): start number to begin the query from
  916. limit (int): number of rows to reterive
  917. retcols (iterable[str]): the names of the columns to return
  918. Returns:
  919. defer.Deferred: resolves to list[dict[str, Any]]
  920. """
  921. if keyvalues:
  922. sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
  923. ", ".join(retcols),
  924. table,
  925. " AND ".join("%s = ?" % (k,) for k in keyvalues),
  926. " ? ASC LIMIT ? OFFSET ?"
  927. )
  928. txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
  929. else:
  930. sql = "SELECT %s FROM %s ORDER BY %s" % (
  931. ", ".join(retcols),
  932. table,
  933. " ? ASC LIMIT ? OFFSET ?"
  934. )
  935. txn.execute(sql, pagevalues)
  936. return cls.cursor_to_dict(txn)
  937. @defer.inlineCallbacks
  938. def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
  939. desc="get_user_list_paginate"):
  940. """Get a list of users from start row to a limit number of rows. This will
  941. return a json object with users and total number of users in users list.
  942. Args:
  943. table (str): the table name
  944. keyvalues (dict[str, Any] | None):
  945. column names and values to select the rows with, or None to not
  946. apply a WHERE clause.
  947. pagevalues ([]):
  948. order (str): order the select by this column
  949. start (int): start number to begin the query from
  950. limit (int): number of rows to reterive
  951. retcols (iterable[str]): the names of the columns to return
  952. Returns:
  953. defer.Deferred: resolves to json object {list[dict[str, Any]], count}
  954. """
  955. users = yield self.runInteraction(
  956. desc,
  957. self._simple_select_list_paginate_txn,
  958. table, keyvalues, pagevalues, retcols
  959. )
  960. count = yield self.runInteraction(
  961. desc,
  962. self.get_user_count_txn
  963. )
  964. retval = {
  965. "users": users,
  966. "total": count
  967. }
  968. defer.returnValue(retval)
  969. def get_user_count_txn(self, txn):
  970. """Get a total number of registered users in the users list.
  971. Args:
  972. txn : Transaction object
  973. Returns:
  974. int : number of users
  975. """
  976. sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
  977. txn.execute(sql_count)
  978. return txn.fetchone()[0]
  979. def _simple_search_list(self, table, term, col, retcols,
  980. desc="_simple_search_list"):
  981. """Executes a SELECT query on the named table, which may return zero or
  982. more rows, returning the result as a list of dicts.
  983. Args:
  984. table (str): the table name
  985. term (str | None):
  986. term for searching the table matched to a column.
  987. col (str): column to query term should be matched to
  988. retcols (iterable[str]): the names of the columns to return
  989. Returns:
  990. defer.Deferred: resolves to list[dict[str, Any]] or None
  991. """
  992. return self.runInteraction(
  993. desc,
  994. self._simple_search_list_txn,
  995. table, term, col, retcols
  996. )
  997. @classmethod
  998. def _simple_search_list_txn(cls, txn, table, term, col, retcols):
  999. """Executes a SELECT query on the named table, which may return zero or
  1000. more rows, returning the result as a list of dicts.
  1001. Args:
  1002. txn : Transaction object
  1003. table (str): the table name
  1004. term (str | None):
  1005. term for searching the table matched to a column.
  1006. col (str): column to query term should be matched to
  1007. retcols (iterable[str]): the names of the columns to return
  1008. Returns:
  1009. defer.Deferred: resolves to list[dict[str, Any]] or None
  1010. """
  1011. if term:
  1012. sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
  1013. ", ".join(retcols),
  1014. table,
  1015. col
  1016. )
  1017. termvalues = ["%%" + term + "%%"]
  1018. txn.execute(sql, termvalues)
  1019. else:
  1020. return 0
  1021. return cls.cursor_to_dict(txn)
  1022. class _RollbackButIsFineException(Exception):
  1023. """ This exception is used to rollback a transaction without implying
  1024. something went wrong.
  1025. """
  1026. pass
  1027. def db_to_json(db_content):
  1028. """
  1029. Take some data from a database row and return a JSON-decoded object.
  1030. Args:
  1031. db_content (memoryview|buffer|bytes|bytearray|unicode)
  1032. """
  1033. # psycopg2 on Python 3 returns memoryview objects, which we need to
  1034. # cast to bytes to decode
  1035. if isinstance(db_content, memoryview):
  1036. db_content = db_content.tobytes()
  1037. # psycopg2 on Python 2 returns buffer objects, which we need to cast to
  1038. # bytes to decode
  1039. if PY2 and isinstance(db_content, builtins.buffer):
  1040. db_content = bytes(db_content)
  1041. # Decode it to a Unicode string before feeding it to json.loads, so we
  1042. # consistenty get a Unicode-containing object out.
  1043. if isinstance(db_content, (bytes, bytearray)):
  1044. db_content = db_content.decode('utf8')
  1045. try:
  1046. return json.loads(db_content)
  1047. except Exception:
  1048. logging.warning("Tried to decode '%r' as JSON and failed", db_content)
  1049. raise