_base.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from synapse.api.errors import StoreError
  17. from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
  18. from synapse.util.caches.dictionary_cache import DictionaryCache
  19. from synapse.util.caches.descriptors import Cache
  20. from synapse.util.caches import intern_dict
  21. import synapse.metrics
  22. from twisted.internet import defer
  23. import sys
  24. import time
  25. import threading
  26. import os
  27. CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
  28. logger = logging.getLogger(__name__)
  29. sql_logger = logging.getLogger("synapse.storage.SQL")
  30. transaction_logger = logging.getLogger("synapse.storage.txn")
  31. perf_logger = logging.getLogger("synapse.storage.TIME")
  32. metrics = synapse.metrics.get_metrics_for("synapse.storage")
  33. sql_scheduling_timer = metrics.register_distribution("schedule_time")
  34. sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
  35. sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
  36. class LoggingTransaction(object):
  37. """An object that almost-transparently proxies for the 'txn' object
  38. passed to the constructor. Adds logging and metrics to the .execute()
  39. method."""
  40. __slots__ = ["txn", "name", "database_engine", "after_callbacks"]
  41. def __init__(self, txn, name, database_engine, after_callbacks):
  42. object.__setattr__(self, "txn", txn)
  43. object.__setattr__(self, "name", name)
  44. object.__setattr__(self, "database_engine", database_engine)
  45. object.__setattr__(self, "after_callbacks", after_callbacks)
  46. def call_after(self, callback, *args):
  47. """Call the given callback on the main twisted thread after the
  48. transaction has finished. Used to invalidate the caches on the
  49. correct thread.
  50. """
  51. self.after_callbacks.append((callback, args))
  52. def __getattr__(self, name):
  53. return getattr(self.txn, name)
  54. def __setattr__(self, name, value):
  55. setattr(self.txn, name, value)
  56. def execute(self, sql, *args):
  57. self._do_execute(self.txn.execute, sql, *args)
  58. def executemany(self, sql, *args):
  59. self._do_execute(self.txn.executemany, sql, *args)
  60. def _do_execute(self, func, sql, *args):
  61. # TODO(paul): Maybe use 'info' and 'debug' for values?
  62. sql_logger.debug("[SQL] {%s} %s", self.name, sql)
  63. sql = self.database_engine.convert_param_style(sql)
  64. if args:
  65. try:
  66. sql_logger.debug(
  67. "[SQL values] {%s} %r",
  68. self.name, args[0]
  69. )
  70. except:
  71. # Don't let logging failures stop SQL from working
  72. pass
  73. start = time.time() * 1000
  74. try:
  75. return func(
  76. sql, *args
  77. )
  78. except Exception as e:
  79. logger.debug("[SQL FAIL] {%s} %s", self.name, e)
  80. raise
  81. finally:
  82. msecs = (time.time() * 1000) - start
  83. sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
  84. sql_query_timer.inc_by(msecs, sql.split()[0])
  85. class PerformanceCounters(object):
  86. def __init__(self):
  87. self.current_counters = {}
  88. self.previous_counters = {}
  89. def update(self, key, start_time, end_time=None):
  90. if end_time is None:
  91. end_time = time.time() * 1000
  92. duration = end_time - start_time
  93. count, cum_time = self.current_counters.get(key, (0, 0))
  94. count += 1
  95. cum_time += duration
  96. self.current_counters[key] = (count, cum_time)
  97. return end_time
  98. def interval(self, interval_duration, limit=3):
  99. counters = []
  100. for name, (count, cum_time) in self.current_counters.items():
  101. prev_count, prev_time = self.previous_counters.get(name, (0, 0))
  102. counters.append((
  103. (cum_time - prev_time) / interval_duration,
  104. count - prev_count,
  105. name
  106. ))
  107. self.previous_counters = dict(self.current_counters)
  108. counters.sort(reverse=True)
  109. top_n_counters = ", ".join(
  110. "%s(%d): %.3f%%" % (name, count, 100 * ratio)
  111. for ratio, count, name in counters[:limit]
  112. )
  113. return top_n_counters
  114. class SQLBaseStore(object):
  115. _TXN_ID = 0
  116. def __init__(self, hs):
  117. self.hs = hs
  118. self._clock = hs.get_clock()
  119. self._db_pool = hs.get_db_pool()
  120. self._previous_txn_total_time = 0
  121. self._current_txn_total_time = 0
  122. self._previous_loop_ts = 0
  123. # TODO(paul): These can eventually be removed once the metrics code
  124. # is running in mainline, and we have some nice monitoring frontends
  125. # to watch it
  126. self._txn_perf_counters = PerformanceCounters()
  127. self._get_event_counters = PerformanceCounters()
  128. self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
  129. max_entries=hs.config.event_cache_size)
  130. self._state_group_cache = DictionaryCache(
  131. "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
  132. )
  133. self._event_fetch_lock = threading.Condition()
  134. self._event_fetch_list = []
  135. self._event_fetch_ongoing = 0
  136. self._pending_ds = []
  137. self.database_engine = hs.database_engine
  138. def start_profiling(self):
  139. self._previous_loop_ts = self._clock.time_msec()
  140. def loop():
  141. curr = self._current_txn_total_time
  142. prev = self._previous_txn_total_time
  143. self._previous_txn_total_time = curr
  144. time_now = self._clock.time_msec()
  145. time_then = self._previous_loop_ts
  146. self._previous_loop_ts = time_now
  147. ratio = (curr - prev) / (time_now - time_then)
  148. top_three_counters = self._txn_perf_counters.interval(
  149. time_now - time_then, limit=3
  150. )
  151. top_3_event_counters = self._get_event_counters.interval(
  152. time_now - time_then, limit=3
  153. )
  154. perf_logger.info(
  155. "Total database time: %.3f%% {%s} {%s}",
  156. ratio * 100, top_three_counters, top_3_event_counters
  157. )
  158. self._clock.looping_call(loop, 10000)
  159. def _new_transaction(self, conn, desc, after_callbacks, logging_context,
  160. func, *args, **kwargs):
  161. start = time.time() * 1000
  162. txn_id = self._TXN_ID
  163. # We don't really need these to be unique, so lets stop it from
  164. # growing really large.
  165. self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
  166. name = "%s-%x" % (desc, txn_id, )
  167. transaction_logger.debug("[TXN START] {%s}", name)
  168. try:
  169. i = 0
  170. N = 5
  171. while True:
  172. try:
  173. txn = conn.cursor()
  174. txn = LoggingTransaction(
  175. txn, name, self.database_engine, after_callbacks
  176. )
  177. r = func(txn, *args, **kwargs)
  178. conn.commit()
  179. return r
  180. except self.database_engine.module.OperationalError as e:
  181. # This can happen if the database disappears mid
  182. # transaction.
  183. logger.warn(
  184. "[TXN OPERROR] {%s} %s %d/%d",
  185. name, e, i, N
  186. )
  187. if i < N:
  188. i += 1
  189. try:
  190. conn.rollback()
  191. except self.database_engine.module.Error as e1:
  192. logger.warn(
  193. "[TXN EROLL] {%s} %s",
  194. name, e1,
  195. )
  196. continue
  197. raise
  198. except self.database_engine.module.DatabaseError as e:
  199. if self.database_engine.is_deadlock(e):
  200. logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
  201. if i < N:
  202. i += 1
  203. try:
  204. conn.rollback()
  205. except self.database_engine.module.Error as e1:
  206. logger.warn(
  207. "[TXN EROLL] {%s} %s",
  208. name, e1,
  209. )
  210. continue
  211. raise
  212. except Exception as e:
  213. logger.debug("[TXN FAIL] {%s} %s", name, e)
  214. raise
  215. finally:
  216. end = time.time() * 1000
  217. duration = end - start
  218. if logging_context is not None:
  219. logging_context.add_database_transaction(duration)
  220. transaction_logger.debug("[TXN END] {%s} %f", name, duration)
  221. self._current_txn_total_time += duration
  222. self._txn_perf_counters.update(desc, start, end)
  223. sql_txn_timer.inc_by(duration, desc)
  224. @defer.inlineCallbacks
  225. def runInteraction(self, desc, func, *args, **kwargs):
  226. """Wraps the .runInteraction() method on the underlying db_pool."""
  227. current_context = LoggingContext.current_context()
  228. start_time = time.time() * 1000
  229. after_callbacks = []
  230. def inner_func(conn, *args, **kwargs):
  231. with LoggingContext("runInteraction") as context:
  232. sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
  233. if self.database_engine.is_connection_closed(conn):
  234. logger.debug("Reconnecting closed database connection")
  235. conn.reconnect()
  236. current_context.copy_to(context)
  237. return self._new_transaction(
  238. conn, desc, after_callbacks, current_context,
  239. func, *args, **kwargs
  240. )
  241. with PreserveLoggingContext():
  242. result = yield self._db_pool.runWithConnection(
  243. inner_func, *args, **kwargs
  244. )
  245. for after_callback, after_args in after_callbacks:
  246. after_callback(*after_args)
  247. defer.returnValue(result)
  248. @defer.inlineCallbacks
  249. def runWithConnection(self, func, *args, **kwargs):
  250. """Wraps the .runInteraction() method on the underlying db_pool."""
  251. current_context = LoggingContext.current_context()
  252. start_time = time.time() * 1000
  253. def inner_func(conn, *args, **kwargs):
  254. with LoggingContext("runWithConnection") as context:
  255. sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
  256. if self.database_engine.is_connection_closed(conn):
  257. logger.debug("Reconnecting closed database connection")
  258. conn.reconnect()
  259. current_context.copy_to(context)
  260. return func(conn, *args, **kwargs)
  261. with PreserveLoggingContext():
  262. result = yield self._db_pool.runWithConnection(
  263. inner_func, *args, **kwargs
  264. )
  265. defer.returnValue(result)
  266. @staticmethod
  267. def cursor_to_dict(cursor):
  268. """Converts a SQL cursor into an list of dicts.
  269. Args:
  270. cursor : The DBAPI cursor which has executed a query.
  271. Returns:
  272. A list of dicts where the key is the column header.
  273. """
  274. col_headers = list(column[0] for column in cursor.description)
  275. results = list(
  276. intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall()
  277. )
  278. return results
  279. def _execute(self, desc, decoder, query, *args):
  280. """Runs a single query for a result set.
  281. Args:
  282. decoder - The function which can resolve the cursor results to
  283. something meaningful.
  284. query - The query string to execute
  285. *args - Query args.
  286. Returns:
  287. The result of decoder(results)
  288. """
  289. def interaction(txn):
  290. txn.execute(query, args)
  291. if decoder:
  292. return decoder(txn)
  293. else:
  294. return txn.fetchall()
  295. return self.runInteraction(desc, interaction)
  296. # "Simple" SQL API methods that operate on a single table with no JOINs,
  297. # no complex WHERE clauses, just a dict of values for columns.
  298. @defer.inlineCallbacks
  299. def _simple_insert(self, table, values, or_ignore=False,
  300. desc="_simple_insert"):
  301. """Executes an INSERT query on the named table.
  302. Args:
  303. table : string giving the table name
  304. values : dict of new column names and values for them
  305. """
  306. try:
  307. yield self.runInteraction(
  308. desc,
  309. self._simple_insert_txn, table, values,
  310. )
  311. except self.database_engine.module.IntegrityError:
  312. # We have to do or_ignore flag at this layer, since we can't reuse
  313. # a cursor after we receive an error from the db.
  314. if not or_ignore:
  315. raise
  316. @staticmethod
  317. def _simple_insert_txn(txn, table, values):
  318. keys, vals = zip(*values.items())
  319. sql = "INSERT INTO %s (%s) VALUES(%s)" % (
  320. table,
  321. ", ".join(k for k in keys),
  322. ", ".join("?" for _ in keys)
  323. )
  324. txn.execute(sql, vals)
  325. @staticmethod
  326. def _simple_insert_many_txn(txn, table, values):
  327. if not values:
  328. return
  329. # This is a *slight* abomination to get a list of tuples of key names
  330. # and a list of tuples of value names.
  331. #
  332. # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
  333. # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
  334. #
  335. # The sort is to ensure that we don't rely on dictionary iteration
  336. # order.
  337. keys, vals = zip(*[
  338. zip(
  339. *(sorted(i.items(), key=lambda kv: kv[0]))
  340. )
  341. for i in values
  342. if i
  343. ])
  344. for k in keys:
  345. if k != keys[0]:
  346. raise RuntimeError(
  347. "All items must have the same keys"
  348. )
  349. sql = "INSERT INTO %s (%s) VALUES(%s)" % (
  350. table,
  351. ", ".join(k for k in keys[0]),
  352. ", ".join("?" for _ in keys[0])
  353. )
  354. txn.executemany(sql, vals)
  355. def _simple_upsert(self, table, keyvalues, values,
  356. insertion_values={}, desc="_simple_upsert", lock=True):
  357. """
  358. Args:
  359. table (str): The table to upsert into
  360. keyvalues (dict): The unique key tables and their new values
  361. values (dict): The nonunique columns and their new values
  362. insertion_values (dict): key/values to use when inserting
  363. Returns:
  364. Deferred(bool): True if a new entry was created, False if an
  365. existing one was updated.
  366. """
  367. return self.runInteraction(
  368. desc,
  369. self._simple_upsert_txn, table, keyvalues, values, insertion_values,
  370. lock
  371. )
  372. def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
  373. lock=True):
  374. # We need to lock the table :(, unless we're *really* careful
  375. if lock:
  376. self.database_engine.lock_table(txn, table)
  377. # Try to update
  378. sql = "UPDATE %s SET %s WHERE %s" % (
  379. table,
  380. ", ".join("%s = ?" % (k,) for k in values),
  381. " AND ".join("%s = ?" % (k,) for k in keyvalues)
  382. )
  383. sqlargs = values.values() + keyvalues.values()
  384. logger.debug(
  385. "[SQL] %s Args=%s",
  386. sql, sqlargs,
  387. )
  388. txn.execute(sql, sqlargs)
  389. if txn.rowcount == 0:
  390. # We didn't update and rows so insert a new one
  391. allvalues = {}
  392. allvalues.update(keyvalues)
  393. allvalues.update(values)
  394. allvalues.update(insertion_values)
  395. sql = "INSERT INTO %s (%s) VALUES (%s)" % (
  396. table,
  397. ", ".join(k for k in allvalues),
  398. ", ".join("?" for _ in allvalues)
  399. )
  400. logger.debug(
  401. "[SQL] %s Args=%s",
  402. sql, keyvalues.values(),
  403. )
  404. txn.execute(sql, allvalues.values())
  405. return True
  406. else:
  407. return False
  408. def _simple_select_one(self, table, keyvalues, retcols,
  409. allow_none=False, desc="_simple_select_one"):
  410. """Executes a SELECT query on the named table, which is expected to
  411. return a single row, returning a single column from it.
  412. Args:
  413. table : string giving the table name
  414. keyvalues : dict of column names and values to select the row with
  415. retcols : list of strings giving the names of the columns to return
  416. allow_none : If true, return None instead of failing if the SELECT
  417. statement returns no rows
  418. """
  419. return self.runInteraction(
  420. desc,
  421. self._simple_select_one_txn,
  422. table, keyvalues, retcols, allow_none,
  423. )
  424. def _simple_select_one_onecol(self, table, keyvalues, retcol,
  425. allow_none=False,
  426. desc="_simple_select_one_onecol"):
  427. """Executes a SELECT query on the named table, which is expected to
  428. return a single row, returning a single column from it.
  429. Args:
  430. table : string giving the table name
  431. keyvalues : dict of column names and values to select the row with
  432. retcol : string giving the name of the column to return
  433. """
  434. return self.runInteraction(
  435. desc,
  436. self._simple_select_one_onecol_txn,
  437. table, keyvalues, retcol, allow_none=allow_none,
  438. )
  439. @classmethod
  440. def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
  441. allow_none=False):
  442. ret = cls._simple_select_onecol_txn(
  443. txn,
  444. table=table,
  445. keyvalues=keyvalues,
  446. retcol=retcol,
  447. )
  448. if ret:
  449. return ret[0]
  450. else:
  451. if allow_none:
  452. return None
  453. else:
  454. raise StoreError(404, "No row found")
  455. @staticmethod
  456. def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
  457. sql = (
  458. "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
  459. ) % {
  460. "retcol": retcol,
  461. "table": table,
  462. "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
  463. }
  464. txn.execute(sql, keyvalues.values())
  465. return [r[0] for r in txn.fetchall()]
  466. def _simple_select_onecol(self, table, keyvalues, retcol,
  467. desc="_simple_select_onecol"):
  468. """Executes a SELECT query on the named table, which returns a list
  469. comprising of the values of the named column from the selected rows.
  470. Args:
  471. table (str): table name
  472. keyvalues (dict): column names and values to select the rows with
  473. retcol (str): column whos value we wish to retrieve.
  474. Returns:
  475. Deferred: Results in a list
  476. """
  477. return self.runInteraction(
  478. desc,
  479. self._simple_select_onecol_txn,
  480. table, keyvalues, retcol
  481. )
  482. def _simple_select_list(self, table, keyvalues, retcols,
  483. desc="_simple_select_list"):
  484. """Executes a SELECT query on the named table, which may return zero or
  485. more rows, returning the result as a list of dicts.
  486. Args:
  487. table : string giving the table name
  488. keyvalues : dict of column names and values to select the rows with,
  489. or None to not apply a WHERE clause.
  490. retcols : list of strings giving the names of the columns to return
  491. """
  492. return self.runInteraction(
  493. desc,
  494. self._simple_select_list_txn,
  495. table, keyvalues, retcols
  496. )
  497. @classmethod
  498. def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
  499. """Executes a SELECT query on the named table, which may return zero or
  500. more rows, returning the result as a list of dicts.
  501. Args:
  502. txn : Transaction object
  503. table : string giving the table name
  504. keyvalues : dict of column names and values to select the rows with
  505. retcols : list of strings giving the names of the columns to return
  506. """
  507. if keyvalues:
  508. sql = "SELECT %s FROM %s WHERE %s" % (
  509. ", ".join(retcols),
  510. table,
  511. " AND ".join("%s = ?" % (k, ) for k in keyvalues)
  512. )
  513. txn.execute(sql, keyvalues.values())
  514. else:
  515. sql = "SELECT %s FROM %s" % (
  516. ", ".join(retcols),
  517. table
  518. )
  519. txn.execute(sql)
  520. return cls.cursor_to_dict(txn)
  521. @defer.inlineCallbacks
  522. def _simple_select_many_batch(self, table, column, iterable, retcols,
  523. keyvalues={}, desc="_simple_select_many_batch",
  524. batch_size=100):
  525. """Executes a SELECT query on the named table, which may return zero or
  526. more rows, returning the result as a list of dicts.
  527. Filters rows by if value of `column` is in `iterable`.
  528. Args:
  529. table : string giving the table name
  530. column : column name to test for inclusion against `iterable`
  531. iterable : list
  532. keyvalues : dict of column names and values to select the rows with
  533. retcols : list of strings giving the names of the columns to return
  534. """
  535. results = []
  536. if not iterable:
  537. defer.returnValue(results)
  538. chunks = [
  539. iterable[i:i + batch_size]
  540. for i in xrange(0, len(iterable), batch_size)
  541. ]
  542. for chunk in chunks:
  543. rows = yield self.runInteraction(
  544. desc,
  545. self._simple_select_many_txn,
  546. table, column, chunk, keyvalues, retcols
  547. )
  548. results.extend(rows)
  549. defer.returnValue(results)
  550. @classmethod
  551. def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
  552. """Executes a SELECT query on the named table, which may return zero or
  553. more rows, returning the result as a list of dicts.
  554. Filters rows by if value of `column` is in `iterable`.
  555. Args:
  556. txn : Transaction object
  557. table : string giving the table name
  558. column : column name to test for inclusion against `iterable`
  559. iterable : list
  560. keyvalues : dict of column names and values to select the rows with
  561. retcols : list of strings giving the names of the columns to return
  562. """
  563. if not iterable:
  564. return []
  565. sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
  566. clauses = []
  567. values = []
  568. clauses.append(
  569. "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
  570. )
  571. values.extend(iterable)
  572. for key, value in keyvalues.items():
  573. clauses.append("%s = ?" % (key,))
  574. values.append(value)
  575. if clauses:
  576. sql = "%s WHERE %s" % (
  577. sql,
  578. " AND ".join(clauses),
  579. )
  580. txn.execute(sql, values)
  581. return cls.cursor_to_dict(txn)
  582. def _simple_update_one(self, table, keyvalues, updatevalues,
  583. desc="_simple_update_one"):
  584. """Executes an UPDATE query on the named table, setting new values for
  585. columns in a row matching the key values.
  586. Args:
  587. table : string giving the table name
  588. keyvalues : dict of column names and values to select the row with
  589. updatevalues : dict giving column names and values to update
  590. retcols : optional list of column names to return
  591. If present, retcols gives a list of column names on which to perform
  592. a SELECT statement *before* performing the UPDATE statement. The values
  593. of these will be returned in a dict.
  594. These are performed within the same transaction, allowing an atomic
  595. get-and-set. This can be used to implement compare-and-set by putting
  596. the update column in the 'keyvalues' dict as well.
  597. """
  598. return self.runInteraction(
  599. desc,
  600. self._simple_update_one_txn,
  601. table, keyvalues, updatevalues,
  602. )
  603. @staticmethod
  604. def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
  605. update_sql = "UPDATE %s SET %s WHERE %s" % (
  606. table,
  607. ", ".join("%s = ?" % (k,) for k in updatevalues),
  608. " AND ".join("%s = ?" % (k,) for k in keyvalues)
  609. )
  610. txn.execute(
  611. update_sql,
  612. updatevalues.values() + keyvalues.values()
  613. )
  614. if txn.rowcount == 0:
  615. raise StoreError(404, "No row found")
  616. if txn.rowcount > 1:
  617. raise StoreError(500, "More than one row matched")
  618. @staticmethod
  619. def _simple_select_one_txn(txn, table, keyvalues, retcols,
  620. allow_none=False):
  621. select_sql = "SELECT %s FROM %s WHERE %s" % (
  622. ", ".join(retcols),
  623. table,
  624. " AND ".join("%s = ?" % (k,) for k in keyvalues)
  625. )
  626. txn.execute(select_sql, keyvalues.values())
  627. row = txn.fetchone()
  628. if not row:
  629. if allow_none:
  630. return None
  631. raise StoreError(404, "No row found")
  632. if txn.rowcount > 1:
  633. raise StoreError(500, "More than one row matched")
  634. return dict(zip(retcols, row))
  635. def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
  636. """Executes a DELETE query on the named table, expecting to delete a
  637. single row.
  638. Args:
  639. table : string giving the table name
  640. keyvalues : dict of column names and values to select the row with
  641. """
  642. return self.runInteraction(
  643. desc, self._simple_delete_one_txn, table, keyvalues
  644. )
  645. @staticmethod
  646. def _simple_delete_one_txn(txn, table, keyvalues):
  647. """Executes a DELETE query on the named table, expecting to delete a
  648. single row.
  649. Args:
  650. table : string giving the table name
  651. keyvalues : dict of column names and values to select the row with
  652. """
  653. sql = "DELETE FROM %s WHERE %s" % (
  654. table,
  655. " AND ".join("%s = ?" % (k, ) for k in keyvalues)
  656. )
  657. txn.execute(sql, keyvalues.values())
  658. if txn.rowcount == 0:
  659. raise StoreError(404, "No row found")
  660. if txn.rowcount > 1:
  661. raise StoreError(500, "more than one row matched")
  662. @staticmethod
  663. def _simple_delete_txn(txn, table, keyvalues):
  664. sql = "DELETE FROM %s WHERE %s" % (
  665. table,
  666. " AND ".join("%s = ?" % (k, ) for k in keyvalues)
  667. )
  668. return txn.execute(sql, keyvalues.values())
  669. def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
  670. max_value):
  671. # Fetch a mapping of room_id -> max stream position for "recent" rooms.
  672. # It doesn't really matter how many we get, the StreamChangeCache will
  673. # do the right thing to ensure it respects the max size of cache.
  674. sql = (
  675. "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
  676. " WHERE %(stream)s > ? - 100000"
  677. " GROUP BY %(entity)s"
  678. ) % {
  679. "table": table,
  680. "entity": entity_column,
  681. "stream": stream_column,
  682. }
  683. sql = self.database_engine.convert_param_style(sql)
  684. txn = db_conn.cursor()
  685. txn.execute(sql, (int(max_value),))
  686. rows = txn.fetchall()
  687. txn.close()
  688. cache = {
  689. row[0]: int(row[1])
  690. for row in rows
  691. }
  692. if cache:
  693. min_val = min(cache.values())
  694. else:
  695. min_val = max_value
  696. return cache, min_val
  697. class _RollbackButIsFineException(Exception):
  698. """ This exception is used to rollback a transaction without implying
  699. something went wrong.
  700. """
  701. pass