database.py 83 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2017-2018 New Vector Ltd
  3. # Copyright 2019 The Matrix.org Foundation C.I.C.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import inspect
  17. import logging
  18. import time
  19. import types
  20. from collections import defaultdict
  21. from sys import intern
  22. from time import monotonic as monotonic_time
  23. from typing import (
  24. TYPE_CHECKING,
  25. Any,
  26. Callable,
  27. Collection,
  28. Dict,
  29. Iterable,
  30. Iterator,
  31. List,
  32. Optional,
  33. Tuple,
  34. Type,
  35. TypeVar,
  36. cast,
  37. overload,
  38. )
  39. import attr
  40. from prometheus_client import Histogram
  41. from typing_extensions import Concatenate, Literal, ParamSpec
  42. from twisted.enterprise import adbapi
  43. from twisted.internet.interfaces import IReactorCore
  44. from synapse.api.errors import StoreError
  45. from synapse.config.database import DatabaseConnectionConfig
  46. from synapse.logging import opentracing
  47. from synapse.logging.context import (
  48. LoggingContext,
  49. current_context,
  50. make_deferred_yieldable,
  51. )
  52. from synapse.metrics import register_threadpool
  53. from synapse.metrics.background_process_metrics import run_as_background_process
  54. from synapse.storage.background_updates import BackgroundUpdater
  55. from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
  56. from synapse.storage.types import Connection, Cursor
  57. from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
  58. from synapse.util.iterutils import batch_iter
  59. if TYPE_CHECKING:
  60. from synapse.server import HomeServer
  61. # python 3 does not have a maximum int value
  62. MAX_TXN_ID = 2**63 - 1
  63. logger = logging.getLogger(__name__)
  64. sql_logger = logging.getLogger("synapse.storage.SQL")
  65. transaction_logger = logging.getLogger("synapse.storage.txn")
  66. perf_logger = logging.getLogger("synapse.storage.TIME")
  67. sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
  68. sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
  69. sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
  70. # Unique indexes which have been added in background updates. Maps from table name
  71. # to the name of the background update which added the unique index to that table.
  72. #
  73. # This is used by the upsert logic to figure out which tables are safe to do a proper
  74. # UPSERT on: until the relevant background update has completed, we
  75. # have to emulate an upsert by locking the table.
  76. #
  77. UNIQUE_INDEX_BACKGROUND_UPDATES = {
  78. "user_ips": "user_ips_device_unique_index",
  79. "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
  80. "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
  81. "event_search": "event_search_event_id_idx",
  82. "local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
  83. "remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
  84. "event_push_summary": "event_push_summary_unique_index",
  85. }
  86. def make_pool(
  87. reactor: IReactorCore,
  88. db_config: DatabaseConnectionConfig,
  89. engine: BaseDatabaseEngine,
  90. ) -> adbapi.ConnectionPool:
  91. """Get the connection pool for the database."""
  92. # By default enable `cp_reconnect`. We need to fiddle with db_args in case
  93. # someone has explicitly set `cp_reconnect`.
  94. db_args = dict(db_config.config.get("args", {}))
  95. db_args.setdefault("cp_reconnect", True)
  96. def _on_new_connection(conn: Connection) -> None:
  97. # Ensure we have a logging context so we can correctly track queries,
  98. # etc.
  99. with LoggingContext("db.on_new_connection"):
  100. engine.on_new_connection(
  101. LoggingDatabaseConnection(conn, engine, "on_new_connection")
  102. )
  103. connection_pool = adbapi.ConnectionPool(
  104. db_config.config["name"],
  105. cp_reactor=reactor,
  106. cp_openfun=_on_new_connection,
  107. **db_args,
  108. )
  109. register_threadpool(f"database-{db_config.name}", connection_pool.threadpool)
  110. return connection_pool
  111. def make_conn(
  112. db_config: DatabaseConnectionConfig,
  113. engine: BaseDatabaseEngine,
  114. default_txn_name: str,
  115. ) -> "LoggingDatabaseConnection":
  116. """Make a new connection to the database and return it.
  117. Returns:
  118. Connection
  119. """
  120. db_params = {
  121. k: v
  122. for k, v in db_config.config.get("args", {}).items()
  123. if not k.startswith("cp_")
  124. }
  125. native_db_conn = engine.module.connect(**db_params)
  126. db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name)
  127. engine.on_new_connection(db_conn)
  128. return db_conn
  129. @attr.s(slots=True, auto_attribs=True)
  130. class LoggingDatabaseConnection:
  131. """A wrapper around a database connection that returns `LoggingTransaction`
  132. as its cursor class.
  133. This is mainly used on startup to ensure that queries get logged correctly
  134. """
  135. conn: Connection
  136. engine: BaseDatabaseEngine
  137. default_txn_name: str
  138. def cursor(
  139. self,
  140. *,
  141. txn_name: Optional[str] = None,
  142. after_callbacks: Optional[List["_CallbackListEntry"]] = None,
  143. exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
  144. ) -> "LoggingTransaction":
  145. if not txn_name:
  146. txn_name = self.default_txn_name
  147. return LoggingTransaction(
  148. self.conn.cursor(),
  149. name=txn_name,
  150. database_engine=self.engine,
  151. after_callbacks=after_callbacks,
  152. exception_callbacks=exception_callbacks,
  153. )
  154. def close(self) -> None:
  155. self.conn.close()
  156. def commit(self) -> None:
  157. self.conn.commit()
  158. def rollback(self) -> None:
  159. self.conn.rollback()
  160. def __enter__(self) -> "LoggingDatabaseConnection":
  161. self.conn.__enter__()
  162. return self
  163. def __exit__(
  164. self,
  165. exc_type: Optional[Type[BaseException]],
  166. exc_value: Optional[BaseException],
  167. traceback: Optional[types.TracebackType],
  168. ) -> Optional[bool]:
  169. return self.conn.__exit__(exc_type, exc_value, traceback)
  170. # Proxy through any unknown lookups to the DB conn class.
  171. def __getattr__(self, name: str) -> Any:
  172. return getattr(self.conn, name)
  173. # The type of entry which goes on our after_callbacks and exception_callbacks lists.
  174. _CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
  175. P = ParamSpec("P")
  176. R = TypeVar("R")
  177. class LoggingTransaction:
  178. """An object that almost-transparently proxies for the 'txn' object
  179. passed to the constructor. Adds logging and metrics to the .execute()
  180. method.
  181. Args:
  182. txn: The database transaction object to wrap.
  183. name: The name of this transactions for logging.
  184. database_engine
  185. after_callbacks: A list that callbacks will be appended to
  186. that have been added by `call_after` which should be run on
  187. successful completion of the transaction. None indicates that no
  188. callbacks should be allowed to be scheduled to run.
  189. exception_callbacks: A list that callbacks will be appended
  190. to that have been added by `call_on_exception` which should be run
  191. if transaction ends with an error. None indicates that no callbacks
  192. should be allowed to be scheduled to run.
  193. """
  194. __slots__ = [
  195. "txn",
  196. "name",
  197. "database_engine",
  198. "after_callbacks",
  199. "exception_callbacks",
  200. ]
  201. def __init__(
  202. self,
  203. txn: Cursor,
  204. name: str,
  205. database_engine: BaseDatabaseEngine,
  206. after_callbacks: Optional[List[_CallbackListEntry]] = None,
  207. exception_callbacks: Optional[List[_CallbackListEntry]] = None,
  208. ):
  209. self.txn = txn
  210. self.name = name
  211. self.database_engine = database_engine
  212. self.after_callbacks = after_callbacks
  213. self.exception_callbacks = exception_callbacks
  214. def call_after(
  215. self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
  216. ) -> None:
  217. """Call the given callback on the main twisted thread after the transaction has
  218. finished.
  219. Mostly used to invalidate the caches on the correct thread.
  220. Note that transactions may be retried a few times if they encounter database
  221. errors such as serialization failures. Callbacks given to `call_after`
  222. will accumulate across transaction attempts and will _all_ be called once a
  223. transaction attempt succeeds, regardless of whether previous transaction
  224. attempts failed. Otherwise, if all transaction attempts fail, all
  225. `call_on_exception` callbacks will be run instead.
  226. """
  227. # if self.after_callbacks is None, that means that whatever constructed the
  228. # LoggingTransaction isn't expecting there to be any callbacks; assert that
  229. # is not the case.
  230. assert self.after_callbacks is not None
  231. # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
  232. self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
  233. def call_on_exception(
  234. self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
  235. ) -> None:
  236. """Call the given callback on the main twisted thread after the transaction has
  237. failed.
  238. Note that transactions may be retried a few times if they encounter database
  239. errors such as serialization failures. Callbacks given to `call_on_exception`
  240. will accumulate across transaction attempts and will _all_ be called once the
  241. final transaction attempt fails. No `call_on_exception` callbacks will be run
  242. if any transaction attempt succeeds.
  243. """
  244. # if self.exception_callbacks is None, that means that whatever constructed the
  245. # LoggingTransaction isn't expecting there to be any callbacks; assert that
  246. # is not the case.
  247. assert self.exception_callbacks is not None
  248. # type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
  249. self.exception_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
  250. def fetchone(self) -> Optional[Tuple]:
  251. return self.txn.fetchone()
  252. def fetchmany(self, size: Optional[int] = None) -> List[Tuple]:
  253. return self.txn.fetchmany(size=size)
  254. def fetchall(self) -> List[Tuple]:
  255. return self.txn.fetchall()
  256. def __iter__(self) -> Iterator[Tuple]:
  257. return self.txn.__iter__()
  258. @property
  259. def rowcount(self) -> int:
  260. return self.txn.rowcount
  261. @property
  262. def description(self) -> Any:
  263. return self.txn.description
  264. def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
  265. """Similar to `executemany`, except `txn.rowcount` will not be correct
  266. afterwards.
  267. More efficient than `executemany` on PostgreSQL
  268. """
  269. if isinstance(self.database_engine, PostgresEngine):
  270. from psycopg2.extras import execute_batch
  271. self._do_execute(
  272. lambda the_sql: execute_batch(self.txn, the_sql, args), sql
  273. )
  274. else:
  275. self.executemany(sql, args)
  276. def execute_values(
  277. self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
  278. ) -> List[Tuple]:
  279. """Corresponds to psycopg2.extras.execute_values. Only available when
  280. using postgres.
  281. The `fetch` parameter must be set to False if the query does not return
  282. rows (e.g. INSERTs).
  283. """
  284. assert isinstance(self.database_engine, PostgresEngine)
  285. from psycopg2.extras import execute_values
  286. return self._do_execute(
  287. lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
  288. sql,
  289. )
  290. def execute(self, sql: str, *args: Any) -> None:
  291. self._do_execute(self.txn.execute, sql, *args)
  292. def executemany(self, sql: str, *args: Any) -> None:
  293. self._do_execute(self.txn.executemany, sql, *args)
  294. def _make_sql_one_line(self, sql: str) -> str:
  295. "Strip newlines out of SQL so that the loggers in the DB are on one line"
  296. return " ".join(line.strip() for line in sql.splitlines() if line.strip())
  297. def _do_execute(
  298. self,
  299. func: Callable[Concatenate[str, P], R],
  300. sql: str,
  301. *args: P.args,
  302. **kwargs: P.kwargs,
  303. ) -> R:
  304. # Generate a one-line version of the SQL to better log it.
  305. one_line_sql = self._make_sql_one_line(sql)
  306. # TODO(paul): Maybe use 'info' and 'debug' for values?
  307. sql_logger.debug("[SQL] {%s} %s", self.name, one_line_sql)
  308. sql = self.database_engine.convert_param_style(sql)
  309. if args:
  310. try:
  311. # The type-ignore should be redundant once mypy releases a version with
  312. # https://github.com/python/mypy/pull/12668. (`args` might be empty,
  313. # (but we'll catch the index error if so.)
  314. sql_logger.debug("[SQL values] {%s} %r", self.name, args[0]) # type: ignore[index]
  315. except Exception:
  316. # Don't let logging failures stop SQL from working
  317. pass
  318. start = time.time()
  319. try:
  320. with opentracing.start_active_span(
  321. "db.query",
  322. tags={
  323. opentracing.tags.DATABASE_TYPE: "sql",
  324. opentracing.tags.DATABASE_STATEMENT: one_line_sql,
  325. },
  326. ):
  327. return func(sql, *args, **kwargs)
  328. except Exception as e:
  329. sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
  330. raise
  331. finally:
  332. secs = time.time() - start
  333. sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
  334. sql_query_timer.labels(sql.split()[0]).observe(secs)
  335. def close(self) -> None:
  336. self.txn.close()
  337. def __enter__(self) -> "LoggingTransaction":
  338. return self
  339. def __exit__(
  340. self,
  341. exc_type: Optional[Type[BaseException]],
  342. exc_value: Optional[BaseException],
  343. traceback: Optional[types.TracebackType],
  344. ) -> None:
  345. self.close()
  346. class PerformanceCounters:
  347. def __init__(self) -> None:
  348. self.current_counters: Dict[str, Tuple[int, float]] = {}
  349. self.previous_counters: Dict[str, Tuple[int, float]] = {}
  350. def update(self, key: str, duration_secs: float) -> None:
  351. count, cum_time = self.current_counters.get(key, (0, 0.0))
  352. count += 1
  353. cum_time += duration_secs
  354. self.current_counters[key] = (count, cum_time)
  355. def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
  356. counters = []
  357. for name, (count, cum_time) in self.current_counters.items():
  358. prev_count, prev_time = self.previous_counters.get(name, (0, 0))
  359. counters.append(
  360. (
  361. (cum_time - prev_time) / interval_duration_secs,
  362. count - prev_count,
  363. name,
  364. )
  365. )
  366. self.previous_counters = dict(self.current_counters)
  367. counters.sort(reverse=True)
  368. top_n_counters = ", ".join(
  369. "%s(%d): %.3f%%" % (name, count, 100 * ratio)
  370. for ratio, count, name in counters[:limit]
  371. )
  372. return top_n_counters
  373. class DatabasePool:
  374. """Wraps a single physical database and connection pool.
  375. A single database may be used by multiple data stores.
  376. """
  377. _TXN_ID = 0
  378. def __init__(
  379. self,
  380. hs: "HomeServer",
  381. database_config: DatabaseConnectionConfig,
  382. engine: BaseDatabaseEngine,
  383. ):
  384. self.hs = hs
  385. self._clock = hs.get_clock()
  386. self._txn_limit = database_config.config.get("txn_limit", 0)
  387. self._database_config = database_config
  388. self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
  389. self.updates = BackgroundUpdater(hs, self)
  390. self._previous_txn_total_time = 0.0
  391. self._current_txn_total_time = 0.0
  392. self._previous_loop_ts = 0.0
  393. # Transaction counter: key is the twisted thread id, value is the current count
  394. self._txn_counters: Dict[int, int] = defaultdict(int)
  395. # TODO(paul): These can eventually be removed once the metrics code
  396. # is running in mainline, and we have some nice monitoring frontends
  397. # to watch it
  398. self._txn_perf_counters = PerformanceCounters()
  399. self.engine = engine
  400. # A set of tables that are not safe to use native upserts in.
  401. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
  402. # We add the user_directory_search table to the blacklist on SQLite
  403. # because the existing search table does not have an index, making it
  404. # unsafe to use native upserts.
  405. if isinstance(self.engine, Sqlite3Engine):
  406. self._unsafe_to_upsert_tables.add("user_directory_search")
  407. if self.engine.can_native_upsert:
  408. # Check ASAP (and then later, every 1s) to see if we have finished
  409. # background updates of tables that aren't safe to update.
  410. self._clock.call_later(
  411. 0.0,
  412. run_as_background_process,
  413. "upsert_safety_check",
  414. self._check_safe_to_upsert,
  415. )
  416. def name(self) -> str:
  417. "Return the name of this database"
  418. return self._database_config.name
  419. def is_running(self) -> bool:
  420. """Is the database pool currently running"""
  421. return self._db_pool.running
  422. async def _check_safe_to_upsert(self) -> None:
  423. """
  424. Is it safe to use native UPSERT?
  425. If there are background updates, we will need to wait, as they may be
  426. the addition of indexes that set the UNIQUE constraint that we require.
  427. If the background updates have not completed, wait 15 sec and check again.
  428. """
  429. updates = await self.simple_select_list(
  430. "background_updates",
  431. keyvalues=None,
  432. retcols=["update_name"],
  433. desc="check_background_updates",
  434. )
  435. updates = [x["update_name"] for x in updates]
  436. for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
  437. if update_name not in updates:
  438. logger.debug("Now safe to upsert in %s", table)
  439. self._unsafe_to_upsert_tables.discard(table)
  440. # If there's any updates still running, reschedule to run.
  441. if updates:
  442. self._clock.call_later(
  443. 15.0,
  444. run_as_background_process,
  445. "upsert_safety_check",
  446. self._check_safe_to_upsert,
  447. )
  448. def start_profiling(self) -> None:
  449. self._previous_loop_ts = monotonic_time()
  450. def loop() -> None:
  451. curr = self._current_txn_total_time
  452. prev = self._previous_txn_total_time
  453. self._previous_txn_total_time = curr
  454. time_now = monotonic_time()
  455. time_then = self._previous_loop_ts
  456. self._previous_loop_ts = time_now
  457. duration = time_now - time_then
  458. ratio = (curr - prev) / duration
  459. top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
  460. perf_logger.debug(
  461. "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
  462. )
  463. self._clock.looping_call(loop, 10000)
  464. def new_transaction(
  465. self,
  466. conn: LoggingDatabaseConnection,
  467. desc: str,
  468. after_callbacks: List[_CallbackListEntry],
  469. exception_callbacks: List[_CallbackListEntry],
  470. func: Callable[Concatenate[LoggingTransaction, P], R],
  471. *args: P.args,
  472. **kwargs: P.kwargs,
  473. ) -> R:
  474. """Start a new database transaction with the given connection.
  475. Note: The given func may be called multiple times under certain
  476. failure modes. This is normally fine when in a standard transaction,
  477. but care must be taken if the connection is in `autocommit` mode that
  478. the function will correctly handle being aborted and retried half way
  479. through its execution.
  480. Similarly, the arguments to `func` (`args`, `kwargs`) should not be generators,
  481. since they could be evaluated multiple times (which would produce an empty
  482. result on the second or subsequent evaluation). Likewise, the closure of `func`
  483. must not reference any generators. This method attempts to detect such usage
  484. and will log an error.
  485. Args:
  486. conn
  487. desc
  488. after_callbacks
  489. exception_callbacks
  490. func
  491. *args
  492. **kwargs
  493. """
  494. # Robustness check: ensure that none of the arguments are generators, since that
  495. # will fail if we have to repeat the transaction.
  496. # For now, we just log an error, and hope that it works on the first attempt.
  497. # TODO: raise an exception.
  498. # Type-ignore Mypy doesn't yet consider ParamSpec.args to be iterable; see
  499. # https://github.com/python/mypy/pull/12668
  500. for i, arg in enumerate(args): # type: ignore[arg-type, var-annotated]
  501. if inspect.isgenerator(arg):
  502. logger.error(
  503. "Programming error: generator passed to new_transaction as "
  504. "argument %i to function %s",
  505. i,
  506. func,
  507. )
  508. # Type-ignore Mypy doesn't yet consider ParamSpec.args to be a mapping; see
  509. # https://github.com/python/mypy/pull/12668
  510. for name, val in kwargs.items(): # type: ignore[attr-defined]
  511. if inspect.isgenerator(val):
  512. logger.error(
  513. "Programming error: generator passed to new_transaction as "
  514. "argument %s to function %s",
  515. name,
  516. func,
  517. )
  518. # also check variables referenced in func's closure
  519. if inspect.isfunction(func):
  520. f = cast(types.FunctionType, func)
  521. if f.__closure__:
  522. for i, cell in enumerate(f.__closure__):
  523. if inspect.isgenerator(cell.cell_contents):
  524. logger.error(
  525. "Programming error: function %s references generator %s "
  526. "via its closure",
  527. f,
  528. f.__code__.co_freevars[i],
  529. )
  530. start = monotonic_time()
  531. txn_id = self._TXN_ID
  532. # We don't really need these to be unique, so lets stop it from
  533. # growing really large.
  534. self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
  535. name = "%s-%x" % (desc, txn_id)
  536. transaction_logger.debug("[TXN START] {%s}", name)
  537. try:
  538. i = 0
  539. N = 5
  540. while True:
  541. cursor = conn.cursor(
  542. txn_name=name,
  543. after_callbacks=after_callbacks,
  544. exception_callbacks=exception_callbacks,
  545. )
  546. try:
  547. with opentracing.start_active_span(
  548. "db.txn",
  549. tags={
  550. opentracing.SynapseTags.DB_TXN_DESC: desc,
  551. opentracing.SynapseTags.DB_TXN_ID: name,
  552. },
  553. ):
  554. r = func(cursor, *args, **kwargs)
  555. opentracing.log_kv({"message": "commit"})
  556. conn.commit()
  557. return r
  558. except self.engine.module.OperationalError as e:
  559. # This can happen if the database disappears mid
  560. # transaction.
  561. transaction_logger.warning(
  562. "[TXN OPERROR] {%s} %s %d/%d",
  563. name,
  564. e,
  565. i,
  566. N,
  567. )
  568. if i < N:
  569. i += 1
  570. try:
  571. with opentracing.start_active_span("db.rollback"):
  572. conn.rollback()
  573. except self.engine.module.Error as e1:
  574. transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
  575. continue
  576. raise
  577. except self.engine.module.DatabaseError as e:
  578. if self.engine.is_deadlock(e):
  579. transaction_logger.warning(
  580. "[TXN DEADLOCK] {%s} %d/%d", name, i, N
  581. )
  582. if i < N:
  583. i += 1
  584. try:
  585. with opentracing.start_active_span("db.rollback"):
  586. conn.rollback()
  587. except self.engine.module.Error as e1:
  588. transaction_logger.warning(
  589. "[TXN EROLL] {%s} %s",
  590. name,
  591. e1,
  592. )
  593. continue
  594. raise
  595. finally:
  596. # we're either about to retry with a new cursor, or we're about to
  597. # release the connection. Once we release the connection, it could
  598. # get used for another query, which might do a conn.rollback().
  599. #
  600. # In the latter case, even though that probably wouldn't affect the
  601. # results of this transaction, python's sqlite will reset all
  602. # statements on the connection [1], which will make our cursor
  603. # invalid [2].
  604. #
  605. # In any case, continuing to read rows after commit()ing seems
  606. # dubious from the PoV of ACID transactional semantics
  607. # (sqlite explicitly says that once you commit, you may see rows
  608. # from subsequent updates.)
  609. #
  610. # In psycopg2, cursors are essentially a client-side fabrication -
  611. # all the data is transferred to the client side when the statement
  612. # finishes executing - so in theory we could go on streaming results
  613. # from the cursor, but attempting to do so would make us
  614. # incompatible with sqlite, so let's make sure we're not doing that
  615. # by closing the cursor.
  616. #
  617. # (*named* cursors in psycopg2 are different and are proper server-
  618. # side things, but (a) we don't use them and (b) they are implicitly
  619. # closed by ending the transaction anyway.)
  620. #
  621. # In short, if we haven't finished with the cursor yet, that's a
  622. # problem waiting to bite us.
  623. #
  624. # TL;DR: we're done with the cursor, so we can close it.
  625. #
  626. # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
  627. # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
  628. cursor.close()
  629. except Exception as e:
  630. transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
  631. raise
  632. finally:
  633. end = monotonic_time()
  634. duration = end - start
  635. current_context().add_database_transaction(duration)
  636. transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
  637. self._current_txn_total_time += duration
  638. self._txn_perf_counters.update(desc, duration)
  639. sql_txn_timer.labels(desc).observe(duration)
  640. async def runInteraction(
  641. self,
  642. desc: str,
  643. func: Callable[..., R],
  644. *args: Any,
  645. db_autocommit: bool = False,
  646. isolation_level: Optional[int] = None,
  647. **kwargs: Any,
  648. ) -> R:
  649. """Starts a transaction on the database and runs a given function
  650. Arguments:
  651. desc: description of the transaction, for logging and metrics
  652. func: callback function, which will be called with a
  653. database transaction (twisted.enterprise.adbapi.Transaction) as
  654. its first argument, followed by `args` and `kwargs`.
  655. db_autocommit: Whether to run the function in "autocommit" mode,
  656. i.e. outside of a transaction. This is useful for transactions
  657. that are only a single query.
  658. Currently, this is only implemented for Postgres. SQLite will still
  659. run the function inside a transaction.
  660. WARNING: This means that if func fails half way through then
  661. the changes will *not* be rolled back. `func` may also get
  662. called multiple times if the transaction is retried, so must
  663. correctly handle that case.
  664. isolation_level: Set the server isolation level for this transaction.
  665. args: positional args to pass to `func`
  666. kwargs: named args to pass to `func`
  667. Returns:
  668. The result of func
  669. """
  670. async def _runInteraction() -> R:
  671. after_callbacks: List[_CallbackListEntry] = []
  672. exception_callbacks: List[_CallbackListEntry] = []
  673. if not current_context():
  674. logger.warning("Starting db txn '%s' from sentinel context", desc)
  675. try:
  676. with opentracing.start_active_span(f"db.{desc}"):
  677. result = await self.runWithConnection(
  678. self.new_transaction,
  679. desc,
  680. after_callbacks,
  681. exception_callbacks,
  682. func,
  683. *args,
  684. db_autocommit=db_autocommit,
  685. isolation_level=isolation_level,
  686. **kwargs,
  687. )
  688. for after_callback, after_args, after_kwargs in after_callbacks:
  689. await maybe_awaitable(after_callback(*after_args, **after_kwargs))
  690. return cast(R, result)
  691. except Exception:
  692. for exception_callback, after_args, after_kwargs in exception_callbacks:
  693. await maybe_awaitable(
  694. exception_callback(*after_args, **after_kwargs)
  695. )
  696. raise
  697. # To handle cancellation, we ensure that `after_callback`s and
  698. # `exception_callback`s are always run, since the transaction will complete
  699. # on another thread regardless of cancellation.
  700. #
  701. # We also wait until everything above is done before releasing the
  702. # `CancelledError`, so that logging contexts won't get used after they have been
  703. # finished.
  704. return await delay_cancellation(_runInteraction())
  705. async def runWithConnection(
  706. self,
  707. func: Callable[..., R],
  708. *args: Any,
  709. db_autocommit: bool = False,
  710. isolation_level: Optional[int] = None,
  711. **kwargs: Any,
  712. ) -> R:
  713. """Wraps the .runWithConnection() method on the underlying db_pool.
  714. Arguments:
  715. func: callback function, which will be called with a
  716. database connection (twisted.enterprise.adbapi.Connection) as
  717. its first argument, followed by `args` and `kwargs`.
  718. args: positional args to pass to `func`
  719. db_autocommit: Whether to run the function in "autocommit" mode,
  720. i.e. outside of a transaction. This is useful for transaction
  721. that are only a single query. Currently only affects postgres.
  722. isolation_level: Set the server isolation level for this transaction.
  723. kwargs: named args to pass to `func`
  724. Returns:
  725. The result of func
  726. """
  727. curr_context = current_context()
  728. if not curr_context:
  729. logger.warning(
  730. "Starting db connection from sentinel context: metrics will be lost"
  731. )
  732. parent_context = None
  733. else:
  734. assert isinstance(curr_context, LoggingContext)
  735. parent_context = curr_context
  736. start_time = monotonic_time()
  737. def inner_func(conn, *args, **kwargs):
  738. # We shouldn't be in a transaction. If we are then something
  739. # somewhere hasn't committed after doing work. (This is likely only
  740. # possible during startup, as `run*` will ensure changes are
  741. # committed/rolled back before putting the connection back in the
  742. # pool).
  743. assert not self.engine.in_transaction(conn)
  744. with LoggingContext(
  745. str(curr_context), parent_context=parent_context
  746. ) as context:
  747. with opentracing.start_active_span(
  748. operation_name="db.connection",
  749. ):
  750. sched_duration_sec = monotonic_time() - start_time
  751. sql_scheduling_timer.observe(sched_duration_sec)
  752. context.add_database_scheduled(sched_duration_sec)
  753. if self._txn_limit > 0:
  754. tid = self._db_pool.threadID()
  755. self._txn_counters[tid] += 1
  756. if self._txn_counters[tid] > self._txn_limit:
  757. logger.debug(
  758. "Reconnecting database connection over transaction limit"
  759. )
  760. conn.reconnect()
  761. opentracing.log_kv(
  762. {"message": "reconnected due to txn limit"}
  763. )
  764. self._txn_counters[tid] = 1
  765. if self.engine.is_connection_closed(conn):
  766. logger.debug("Reconnecting closed database connection")
  767. conn.reconnect()
  768. opentracing.log_kv({"message": "reconnected"})
  769. if self._txn_limit > 0:
  770. self._txn_counters[tid] = 1
  771. try:
  772. if db_autocommit:
  773. self.engine.attempt_to_set_autocommit(conn, True)
  774. if isolation_level is not None:
  775. self.engine.attempt_to_set_isolation_level(
  776. conn, isolation_level
  777. )
  778. db_conn = LoggingDatabaseConnection(
  779. conn, self.engine, "runWithConnection"
  780. )
  781. return func(db_conn, *args, **kwargs)
  782. finally:
  783. if db_autocommit:
  784. self.engine.attempt_to_set_autocommit(conn, False)
  785. if isolation_level:
  786. self.engine.attempt_to_set_isolation_level(conn, None)
  787. return await make_deferred_yieldable(
  788. self._db_pool.runWithConnection(inner_func, *args, **kwargs)
  789. )
  790. @staticmethod
  791. def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
  792. """Converts a SQL cursor into an list of dicts.
  793. Args:
  794. cursor: The DBAPI cursor which has executed a query.
  795. Returns:
  796. A list of dicts where the key is the column header.
  797. """
  798. assert cursor.description is not None, "cursor.description was None"
  799. col_headers = [intern(str(column[0])) for column in cursor.description]
  800. results = [dict(zip(col_headers, row)) for row in cursor]
  801. return results
  802. @overload
  803. async def execute(
  804. self, desc: str, decoder: Literal[None], query: str, *args: Any
  805. ) -> List[Tuple[Any, ...]]:
  806. ...
  807. @overload
  808. async def execute(
  809. self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
  810. ) -> R:
  811. ...
  812. async def execute(
  813. self,
  814. desc: str,
  815. decoder: Optional[Callable[[Cursor], R]],
  816. query: str,
  817. *args: Any,
  818. ) -> R:
  819. """Runs a single query for a result set.
  820. Args:
  821. desc: description of the transaction, for logging and metrics
  822. decoder - The function which can resolve the cursor results to
  823. something meaningful.
  824. query - The query string to execute
  825. *args - Query args.
  826. Returns:
  827. The result of decoder(results)
  828. """
  829. def interaction(txn):
  830. txn.execute(query, args)
  831. if decoder:
  832. return decoder(txn)
  833. else:
  834. return txn.fetchall()
  835. return await self.runInteraction(desc, interaction)
  836. # "Simple" SQL API methods that operate on a single table with no JOINs,
  837. # no complex WHERE clauses, just a dict of values for columns.
  838. async def simple_insert(
  839. self,
  840. table: str,
  841. values: Dict[str, Any],
  842. desc: str = "simple_insert",
  843. ) -> None:
  844. """Executes an INSERT query on the named table.
  845. Args:
  846. table: string giving the table name
  847. values: dict of new column names and values for them
  848. desc: description of the transaction, for logging and metrics
  849. """
  850. await self.runInteraction(desc, self.simple_insert_txn, table, values)
  851. @staticmethod
  852. def simple_insert_txn(
  853. txn: LoggingTransaction, table: str, values: Dict[str, Any]
  854. ) -> None:
  855. keys, vals = zip(*values.items())
  856. sql = "INSERT INTO %s (%s) VALUES(%s)" % (
  857. table,
  858. ", ".join(k for k in keys),
  859. ", ".join("?" for _ in keys),
  860. )
  861. txn.execute(sql, vals)
  862. async def simple_insert_many(
  863. self,
  864. table: str,
  865. keys: Collection[str],
  866. values: Collection[Collection[Any]],
  867. desc: str,
  868. ) -> None:
  869. """Executes an INSERT query on the named table.
  870. The input is given as a list of rows, where each row is a list of values.
  871. (Actually any iterable is fine.)
  872. Args:
  873. table: string giving the table name
  874. keys: list of column names
  875. values: for each row, a list of values in the same order as `keys`
  876. desc: description of the transaction, for logging and metrics
  877. """
  878. await self.runInteraction(
  879. desc, self.simple_insert_many_txn, table, keys, values
  880. )
  881. @staticmethod
  882. def simple_insert_many_txn(
  883. txn: LoggingTransaction,
  884. table: str,
  885. keys: Collection[str],
  886. values: Iterable[Iterable[Any]],
  887. ) -> None:
  888. """Executes an INSERT query on the named table.
  889. The input is given as a list of rows, where each row is a list of values.
  890. (Actually any iterable is fine.)
  891. Args:
  892. txn: The transaction to use.
  893. table: string giving the table name
  894. keys: list of column names
  895. values: for each row, a list of values in the same order as `keys`
  896. """
  897. if isinstance(txn.database_engine, PostgresEngine):
  898. # We use `execute_values` as it can be a lot faster than `execute_batch`,
  899. # but it's only available on postgres.
  900. sql = "INSERT INTO %s (%s) VALUES ?" % (
  901. table,
  902. ", ".join(k for k in keys),
  903. )
  904. txn.execute_values(sql, values, fetch=False)
  905. else:
  906. sql = "INSERT INTO %s (%s) VALUES(%s)" % (
  907. table,
  908. ", ".join(k for k in keys),
  909. ", ".join("?" for _ in keys),
  910. )
  911. txn.execute_batch(sql, values)
  912. async def simple_upsert(
  913. self,
  914. table: str,
  915. keyvalues: Dict[str, Any],
  916. values: Dict[str, Any],
  917. insertion_values: Optional[Dict[str, Any]] = None,
  918. desc: str = "simple_upsert",
  919. lock: bool = True,
  920. ) -> bool:
  921. """
  922. `lock` should generally be set to True (the default), but can be set
  923. to False if either of the following are true:
  924. 1. there is a UNIQUE INDEX on the key columns. In this case a conflict
  925. will cause an IntegrityError in which case this function will retry
  926. the update.
  927. 2. we somehow know that we are the only thread which will be updating
  928. this table.
  929. As an additional note, this parameter only matters for old SQLite versions
  930. because we will use native upserts otherwise.
  931. Args:
  932. table: The table to upsert into
  933. keyvalues: The unique key columns and their new values
  934. values: The nonunique columns and their new values
  935. insertion_values: additional key/values to use only when inserting
  936. desc: description of the transaction, for logging and metrics
  937. lock: True to lock the table when doing the upsert.
  938. Returns:
  939. Returns True if a row was inserted or updated (i.e. if `values` is
  940. not empty then this always returns True)
  941. """
  942. insertion_values = insertion_values or {}
  943. attempts = 0
  944. while True:
  945. try:
  946. # We can autocommit if we are going to use native upserts
  947. autocommit = (
  948. self.engine.can_native_upsert
  949. and table not in self._unsafe_to_upsert_tables
  950. )
  951. return await self.runInteraction(
  952. desc,
  953. self.simple_upsert_txn,
  954. table,
  955. keyvalues,
  956. values,
  957. insertion_values,
  958. lock=lock,
  959. db_autocommit=autocommit,
  960. )
  961. except self.engine.module.IntegrityError as e:
  962. attempts += 1
  963. if attempts >= 5:
  964. # don't retry forever, because things other than races
  965. # can cause IntegrityErrors
  966. raise
  967. # presumably we raced with another transaction: let's retry.
  968. logger.warning(
  969. "IntegrityError when upserting into %s; retrying: %s", table, e
  970. )
  971. def simple_upsert_txn(
  972. self,
  973. txn: LoggingTransaction,
  974. table: str,
  975. keyvalues: Dict[str, Any],
  976. values: Dict[str, Any],
  977. insertion_values: Optional[Dict[str, Any]] = None,
  978. lock: bool = True,
  979. ) -> bool:
  980. """
  981. Pick the UPSERT method which works best on the platform. Either the
  982. native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
  983. Args:
  984. txn: The transaction to use.
  985. table: The table to upsert into
  986. keyvalues: The unique key tables and their new values
  987. values: The nonunique columns and their new values
  988. insertion_values: additional key/values to use only when inserting
  989. lock: True to lock the table when doing the upsert.
  990. Returns:
  991. Returns True if a row was inserted or updated (i.e. if `values` is
  992. not empty then this always returns True)
  993. """
  994. insertion_values = insertion_values or {}
  995. if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
  996. return self.simple_upsert_txn_native_upsert(
  997. txn, table, keyvalues, values, insertion_values=insertion_values
  998. )
  999. else:
  1000. return self.simple_upsert_txn_emulated(
  1001. txn,
  1002. table,
  1003. keyvalues,
  1004. values,
  1005. insertion_values=insertion_values,
  1006. lock=lock,
  1007. )
  1008. def simple_upsert_txn_emulated(
  1009. self,
  1010. txn: LoggingTransaction,
  1011. table: str,
  1012. keyvalues: Dict[str, Any],
  1013. values: Dict[str, Any],
  1014. insertion_values: Optional[Dict[str, Any]] = None,
  1015. lock: bool = True,
  1016. ) -> bool:
  1017. """
  1018. Args:
  1019. table: The table to upsert into
  1020. keyvalues: The unique key tables and their new values
  1021. values: The nonunique columns and their new values
  1022. insertion_values: additional key/values to use only when inserting
  1023. lock: True to lock the table when doing the upsert.
  1024. Returns:
  1025. Returns True if a row was inserted or updated (i.e. if `values` is
  1026. not empty then this always returns True)
  1027. """
  1028. insertion_values = insertion_values or {}
  1029. # We need to lock the table :(, unless we're *really* careful
  1030. if lock:
  1031. self.engine.lock_table(txn, table)
  1032. def _getwhere(key: str) -> str:
  1033. # If the value we're passing in is None (aka NULL), we need to use
  1034. # IS, not =, as NULL = NULL equals NULL (False).
  1035. if keyvalues[key] is None:
  1036. return "%s IS ?" % (key,)
  1037. else:
  1038. return "%s = ?" % (key,)
  1039. if not values:
  1040. # If `values` is empty, then all of the values we care about are in
  1041. # the unique key, so there is nothing to UPDATE. We can just do a
  1042. # SELECT instead to see if it exists.
  1043. sql = "SELECT 1 FROM %s WHERE %s" % (
  1044. table,
  1045. " AND ".join(_getwhere(k) for k in keyvalues),
  1046. )
  1047. sqlargs = list(keyvalues.values())
  1048. txn.execute(sql, sqlargs)
  1049. if txn.fetchall():
  1050. # We have an existing record.
  1051. return False
  1052. else:
  1053. # First try to update.
  1054. sql = "UPDATE %s SET %s WHERE %s" % (
  1055. table,
  1056. ", ".join("%s = ?" % (k,) for k in values),
  1057. " AND ".join(_getwhere(k) for k in keyvalues),
  1058. )
  1059. sqlargs = list(values.values()) + list(keyvalues.values())
  1060. txn.execute(sql, sqlargs)
  1061. if txn.rowcount > 0:
  1062. return True
  1063. # We didn't find any existing rows, so insert a new one
  1064. allvalues: Dict[str, Any] = {}
  1065. allvalues.update(keyvalues)
  1066. allvalues.update(values)
  1067. allvalues.update(insertion_values)
  1068. sql = "INSERT INTO %s (%s) VALUES (%s)" % (
  1069. table,
  1070. ", ".join(k for k in allvalues),
  1071. ", ".join("?" for _ in allvalues),
  1072. )
  1073. txn.execute(sql, list(allvalues.values()))
  1074. # successfully inserted
  1075. return True
  1076. def simple_upsert_txn_native_upsert(
  1077. self,
  1078. txn: LoggingTransaction,
  1079. table: str,
  1080. keyvalues: Dict[str, Any],
  1081. values: Dict[str, Any],
  1082. insertion_values: Optional[Dict[str, Any]] = None,
  1083. ) -> bool:
  1084. """
  1085. Use the native UPSERT functionality in PostgreSQL.
  1086. Args:
  1087. table: The table to upsert into
  1088. keyvalues: The unique key tables and their new values
  1089. values: The nonunique columns and their new values
  1090. insertion_values: additional key/values to use only when inserting
  1091. Returns:
  1092. Returns True if a row was inserted or updated (i.e. if `values` is
  1093. not empty then this always returns True)
  1094. """
  1095. allvalues: Dict[str, Any] = {}
  1096. allvalues.update(keyvalues)
  1097. allvalues.update(insertion_values or {})
  1098. if not values:
  1099. latter = "NOTHING"
  1100. else:
  1101. allvalues.update(values)
  1102. latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
  1103. sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
  1104. table,
  1105. ", ".join(k for k in allvalues),
  1106. ", ".join("?" for _ in allvalues),
  1107. ", ".join(k for k in keyvalues),
  1108. latter,
  1109. )
  1110. txn.execute(sql, list(allvalues.values()))
  1111. return bool(txn.rowcount)
  1112. async def simple_upsert_many(
  1113. self,
  1114. table: str,
  1115. key_names: Collection[str],
  1116. key_values: Collection[Collection[Any]],
  1117. value_names: Collection[str],
  1118. value_values: Collection[Collection[Any]],
  1119. desc: str,
  1120. lock: bool = True,
  1121. ) -> None:
  1122. """
  1123. Upsert, many times.
  1124. Args:
  1125. table: The table to upsert into
  1126. key_names: The key column names.
  1127. key_values: A list of each row's key column values.
  1128. value_names: The value column names
  1129. value_values: A list of each row's value column values.
  1130. Ignored if value_names is empty.
  1131. lock: True to lock the table when doing the upsert. Unused if the database engine
  1132. supports native upserts.
  1133. """
  1134. # We can autocommit if we are going to use native upserts
  1135. autocommit = (
  1136. self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables
  1137. )
  1138. await self.runInteraction(
  1139. desc,
  1140. self.simple_upsert_many_txn,
  1141. table,
  1142. key_names,
  1143. key_values,
  1144. value_names,
  1145. value_values,
  1146. lock=lock,
  1147. db_autocommit=autocommit,
  1148. )
  1149. def simple_upsert_many_txn(
  1150. self,
  1151. txn: LoggingTransaction,
  1152. table: str,
  1153. key_names: Collection[str],
  1154. key_values: Collection[Iterable[Any]],
  1155. value_names: Collection[str],
  1156. value_values: Iterable[Iterable[Any]],
  1157. lock: bool = True,
  1158. ) -> None:
  1159. """
  1160. Upsert, many times.
  1161. Args:
  1162. table: The table to upsert into
  1163. key_names: The key column names.
  1164. key_values: A list of each row's key column values.
  1165. value_names: The value column names
  1166. value_values: A list of each row's value column values.
  1167. Ignored if value_names is empty.
  1168. lock: True to lock the table when doing the upsert. Unused if the database engine
  1169. supports native upserts.
  1170. """
  1171. if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
  1172. return self.simple_upsert_many_txn_native_upsert(
  1173. txn, table, key_names, key_values, value_names, value_values
  1174. )
  1175. else:
  1176. return self.simple_upsert_many_txn_emulated(
  1177. txn, table, key_names, key_values, value_names, value_values, lock=lock
  1178. )
  1179. def simple_upsert_many_txn_emulated(
  1180. self,
  1181. txn: LoggingTransaction,
  1182. table: str,
  1183. key_names: Iterable[str],
  1184. key_values: Collection[Iterable[Any]],
  1185. value_names: Collection[str],
  1186. value_values: Iterable[Iterable[Any]],
  1187. lock: bool = True,
  1188. ) -> None:
  1189. """
  1190. Upsert, many times, but without native UPSERT support or batching.
  1191. Args:
  1192. table: The table to upsert into
  1193. key_names: The key column names.
  1194. key_values: A list of each row's key column values.
  1195. value_names: The value column names
  1196. value_values: A list of each row's value column values.
  1197. Ignored if value_names is empty.
  1198. lock: True to lock the table when doing the upsert.
  1199. """
  1200. # No value columns, therefore make a blank list so that the following
  1201. # zip() works correctly.
  1202. if not value_names:
  1203. value_values = [() for x in range(len(key_values))]
  1204. if lock:
  1205. # Lock the table just once, to prevent it being done once per row.
  1206. # Note that, according to Postgres' documentation, once obtained,
  1207. # the lock is held for the remainder of the current transaction.
  1208. self.engine.lock_table(txn, "user_ips")
  1209. for keyv, valv in zip(key_values, value_values):
  1210. _keys = {x: y for x, y in zip(key_names, keyv)}
  1211. _vals = {x: y for x, y in zip(value_names, valv)}
  1212. self.simple_upsert_txn_emulated(txn, table, _keys, _vals, lock=False)
  1213. def simple_upsert_many_txn_native_upsert(
  1214. self,
  1215. txn: LoggingTransaction,
  1216. table: str,
  1217. key_names: Collection[str],
  1218. key_values: Collection[Iterable[Any]],
  1219. value_names: Collection[str],
  1220. value_values: Iterable[Iterable[Any]],
  1221. ) -> None:
  1222. """
  1223. Upsert, many times, using batching where possible.
  1224. Args:
  1225. table: The table to upsert into
  1226. key_names: The key column names.
  1227. key_values: A list of each row's key column values.
  1228. value_names: The value column names
  1229. value_values: A list of each row's value column values.
  1230. Ignored if value_names is empty.
  1231. """
  1232. allnames: List[str] = []
  1233. allnames.extend(key_names)
  1234. allnames.extend(value_names)
  1235. if not value_names:
  1236. # No value columns, therefore make a blank list so that the
  1237. # following zip() works correctly.
  1238. latter = "NOTHING"
  1239. value_values = [() for x in range(len(key_values))]
  1240. else:
  1241. latter = "UPDATE SET " + ", ".join(
  1242. k + "=EXCLUDED." + k for k in value_names
  1243. )
  1244. args = []
  1245. for x, y in zip(key_values, value_values):
  1246. args.append(tuple(x) + tuple(y))
  1247. if isinstance(txn.database_engine, PostgresEngine):
  1248. # We use `execute_values` as it can be a lot faster than `execute_batch`,
  1249. # but it's only available on postgres.
  1250. sql = "INSERT INTO %s (%s) VALUES ? ON CONFLICT (%s) DO %s" % (
  1251. table,
  1252. ", ".join(k for k in allnames),
  1253. ", ".join(key_names),
  1254. latter,
  1255. )
  1256. txn.execute_values(sql, args, fetch=False)
  1257. else:
  1258. sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
  1259. table,
  1260. ", ".join(k for k in allnames),
  1261. ", ".join("?" for _ in allnames),
  1262. ", ".join(key_names),
  1263. latter,
  1264. )
  1265. return txn.execute_batch(sql, args)
  1266. @overload
  1267. async def simple_select_one(
  1268. self,
  1269. table: str,
  1270. keyvalues: Dict[str, Any],
  1271. retcols: Collection[str],
  1272. allow_none: Literal[False] = False,
  1273. desc: str = "simple_select_one",
  1274. ) -> Dict[str, Any]:
  1275. ...
  1276. @overload
  1277. async def simple_select_one(
  1278. self,
  1279. table: str,
  1280. keyvalues: Dict[str, Any],
  1281. retcols: Collection[str],
  1282. allow_none: Literal[True] = True,
  1283. desc: str = "simple_select_one",
  1284. ) -> Optional[Dict[str, Any]]:
  1285. ...
  1286. async def simple_select_one(
  1287. self,
  1288. table: str,
  1289. keyvalues: Dict[str, Any],
  1290. retcols: Collection[str],
  1291. allow_none: bool = False,
  1292. desc: str = "simple_select_one",
  1293. ) -> Optional[Dict[str, Any]]:
  1294. """Executes a SELECT query on the named table, which is expected to
  1295. return a single row, returning multiple columns from it.
  1296. Args:
  1297. table: string giving the table name
  1298. keyvalues: dict of column names and values to select the row with
  1299. retcols: list of strings giving the names of the columns to return
  1300. allow_none: If true, return None instead of failing if the SELECT
  1301. statement returns no rows
  1302. desc: description of the transaction, for logging and metrics
  1303. """
  1304. return await self.runInteraction(
  1305. desc,
  1306. self.simple_select_one_txn,
  1307. table,
  1308. keyvalues,
  1309. retcols,
  1310. allow_none,
  1311. db_autocommit=True,
  1312. )
  1313. @overload
  1314. async def simple_select_one_onecol(
  1315. self,
  1316. table: str,
  1317. keyvalues: Dict[str, Any],
  1318. retcol: str,
  1319. allow_none: Literal[False] = False,
  1320. desc: str = "simple_select_one_onecol",
  1321. ) -> Any:
  1322. ...
  1323. @overload
  1324. async def simple_select_one_onecol(
  1325. self,
  1326. table: str,
  1327. keyvalues: Dict[str, Any],
  1328. retcol: str,
  1329. allow_none: Literal[True] = True,
  1330. desc: str = "simple_select_one_onecol",
  1331. ) -> Optional[Any]:
  1332. ...
  1333. async def simple_select_one_onecol(
  1334. self,
  1335. table: str,
  1336. keyvalues: Dict[str, Any],
  1337. retcol: str,
  1338. allow_none: bool = False,
  1339. desc: str = "simple_select_one_onecol",
  1340. ) -> Optional[Any]:
  1341. """Executes a SELECT query on the named table, which is expected to
  1342. return a single row, returning a single column from it.
  1343. Args:
  1344. table: string giving the table name
  1345. keyvalues: dict of column names and values to select the row with
  1346. retcol: string giving the name of the column to return
  1347. allow_none: If true, return None instead of failing if the SELECT
  1348. statement returns no rows
  1349. desc: description of the transaction, for logging and metrics
  1350. """
  1351. return await self.runInteraction(
  1352. desc,
  1353. self.simple_select_one_onecol_txn,
  1354. table,
  1355. keyvalues,
  1356. retcol,
  1357. allow_none=allow_none,
  1358. db_autocommit=True,
  1359. )
  1360. @overload
  1361. @classmethod
  1362. def simple_select_one_onecol_txn(
  1363. cls,
  1364. txn: LoggingTransaction,
  1365. table: str,
  1366. keyvalues: Dict[str, Any],
  1367. retcol: str,
  1368. allow_none: Literal[False] = False,
  1369. ) -> Any:
  1370. ...
  1371. @overload
  1372. @classmethod
  1373. def simple_select_one_onecol_txn(
  1374. cls,
  1375. txn: LoggingTransaction,
  1376. table: str,
  1377. keyvalues: Dict[str, Any],
  1378. retcol: str,
  1379. allow_none: Literal[True] = True,
  1380. ) -> Optional[Any]:
  1381. ...
  1382. @classmethod
  1383. def simple_select_one_onecol_txn(
  1384. cls,
  1385. txn: LoggingTransaction,
  1386. table: str,
  1387. keyvalues: Dict[str, Any],
  1388. retcol: str,
  1389. allow_none: bool = False,
  1390. ) -> Optional[Any]:
  1391. ret = cls.simple_select_onecol_txn(
  1392. txn, table=table, keyvalues=keyvalues, retcol=retcol
  1393. )
  1394. if ret:
  1395. return ret[0]
  1396. else:
  1397. if allow_none:
  1398. return None
  1399. else:
  1400. raise StoreError(404, "No row found")
  1401. @staticmethod
  1402. def simple_select_onecol_txn(
  1403. txn: LoggingTransaction,
  1404. table: str,
  1405. keyvalues: Dict[str, Any],
  1406. retcol: str,
  1407. ) -> List[Any]:
  1408. sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
  1409. if keyvalues:
  1410. sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
  1411. txn.execute(sql, list(keyvalues.values()))
  1412. else:
  1413. txn.execute(sql)
  1414. return [r[0] for r in txn]
  1415. async def simple_select_onecol(
  1416. self,
  1417. table: str,
  1418. keyvalues: Optional[Dict[str, Any]],
  1419. retcol: str,
  1420. desc: str = "simple_select_onecol",
  1421. ) -> List[Any]:
  1422. """Executes a SELECT query on the named table, which returns a list
  1423. comprising of the values of the named column from the selected rows.
  1424. Args:
  1425. table: table name
  1426. keyvalues: column names and values to select the rows with
  1427. retcol: column whos value we wish to retrieve.
  1428. desc: description of the transaction, for logging and metrics
  1429. Returns:
  1430. Results in a list
  1431. """
  1432. return await self.runInteraction(
  1433. desc,
  1434. self.simple_select_onecol_txn,
  1435. table,
  1436. keyvalues,
  1437. retcol,
  1438. db_autocommit=True,
  1439. )
  1440. async def simple_select_list(
  1441. self,
  1442. table: str,
  1443. keyvalues: Optional[Dict[str, Any]],
  1444. retcols: Collection[str],
  1445. desc: str = "simple_select_list",
  1446. ) -> List[Dict[str, Any]]:
  1447. """Executes a SELECT query on the named table, which may return zero or
  1448. more rows, returning the result as a list of dicts.
  1449. Args:
  1450. table: the table name
  1451. keyvalues:
  1452. column names and values to select the rows with, or None to not
  1453. apply a WHERE clause.
  1454. retcols: the names of the columns to return
  1455. desc: description of the transaction, for logging and metrics
  1456. Returns:
  1457. A list of dictionaries.
  1458. """
  1459. return await self.runInteraction(
  1460. desc,
  1461. self.simple_select_list_txn,
  1462. table,
  1463. keyvalues,
  1464. retcols,
  1465. db_autocommit=True,
  1466. )
  1467. @classmethod
  1468. def simple_select_list_txn(
  1469. cls,
  1470. txn: LoggingTransaction,
  1471. table: str,
  1472. keyvalues: Optional[Dict[str, Any]],
  1473. retcols: Iterable[str],
  1474. ) -> List[Dict[str, Any]]:
  1475. """Executes a SELECT query on the named table, which may return zero or
  1476. more rows, returning the result as a list of dicts.
  1477. Args:
  1478. txn: Transaction object
  1479. table: the table name
  1480. keyvalues:
  1481. column names and values to select the rows with, or None to not
  1482. apply a WHERE clause.
  1483. retcols: the names of the columns to return
  1484. """
  1485. if keyvalues:
  1486. sql = "SELECT %s FROM %s WHERE %s" % (
  1487. ", ".join(retcols),
  1488. table,
  1489. " AND ".join("%s = ?" % (k,) for k in keyvalues),
  1490. )
  1491. txn.execute(sql, list(keyvalues.values()))
  1492. else:
  1493. sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
  1494. txn.execute(sql)
  1495. return cls.cursor_to_dict(txn)
  1496. async def simple_select_many_batch(
  1497. self,
  1498. table: str,
  1499. column: str,
  1500. iterable: Iterable[Any],
  1501. retcols: Collection[str],
  1502. keyvalues: Optional[Dict[str, Any]] = None,
  1503. desc: str = "simple_select_many_batch",
  1504. batch_size: int = 100,
  1505. ) -> List[Any]:
  1506. """Executes a SELECT query on the named table, which may return zero or
  1507. more rows, returning the result as a list of dicts.
  1508. Filters rows by whether the value of `column` is in `iterable`.
  1509. Args:
  1510. table: string giving the table name
  1511. column: column name to test for inclusion against `iterable`
  1512. iterable: list
  1513. retcols: list of strings giving the names of the columns to return
  1514. keyvalues: dict of column names and values to select the rows with
  1515. desc: description of the transaction, for logging and metrics
  1516. batch_size: the number of rows for each select query
  1517. """
  1518. keyvalues = keyvalues or {}
  1519. results: List[Dict[str, Any]] = []
  1520. for chunk in batch_iter(iterable, batch_size):
  1521. rows = await self.runInteraction(
  1522. desc,
  1523. self.simple_select_many_txn,
  1524. table,
  1525. column,
  1526. chunk,
  1527. keyvalues,
  1528. retcols,
  1529. db_autocommit=True,
  1530. )
  1531. results.extend(rows)
  1532. return results
  1533. @classmethod
  1534. def simple_select_many_txn(
  1535. cls,
  1536. txn: LoggingTransaction,
  1537. table: str,
  1538. column: str,
  1539. iterable: Collection[Any],
  1540. keyvalues: Dict[str, Any],
  1541. retcols: Iterable[str],
  1542. ) -> List[Dict[str, Any]]:
  1543. """Executes a SELECT query on the named table, which may return zero or
  1544. more rows, returning the result as a list of dicts.
  1545. Filters rows by whether the value of `column` is in `iterable`.
  1546. Args:
  1547. txn: Transaction object
  1548. table: string giving the table name
  1549. column: column name to test for inclusion against `iterable`
  1550. iterable: list
  1551. keyvalues: dict of column names and values to select the rows with
  1552. retcols: list of strings giving the names of the columns to return
  1553. """
  1554. if not iterable:
  1555. return []
  1556. clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
  1557. clauses = [clause]
  1558. for key, value in keyvalues.items():
  1559. clauses.append("%s = ?" % (key,))
  1560. values.append(value)
  1561. sql = "SELECT %s FROM %s WHERE %s" % (
  1562. ", ".join(retcols),
  1563. table,
  1564. " AND ".join(clauses),
  1565. )
  1566. txn.execute(sql, values)
  1567. return cls.cursor_to_dict(txn)
  1568. async def simple_update(
  1569. self,
  1570. table: str,
  1571. keyvalues: Dict[str, Any],
  1572. updatevalues: Dict[str, Any],
  1573. desc: str,
  1574. ) -> int:
  1575. return await self.runInteraction(
  1576. desc, self.simple_update_txn, table, keyvalues, updatevalues
  1577. )
  1578. @staticmethod
  1579. def simple_update_txn(
  1580. txn: LoggingTransaction,
  1581. table: str,
  1582. keyvalues: Dict[str, Any],
  1583. updatevalues: Dict[str, Any],
  1584. ) -> int:
  1585. if keyvalues:
  1586. where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
  1587. else:
  1588. where = ""
  1589. update_sql = "UPDATE %s SET %s %s" % (
  1590. table,
  1591. ", ".join("%s = ?" % (k,) for k in updatevalues),
  1592. where,
  1593. )
  1594. txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
  1595. return txn.rowcount
  1596. async def simple_update_many(
  1597. self,
  1598. table: str,
  1599. key_names: Collection[str],
  1600. key_values: Collection[Iterable[Any]],
  1601. value_names: Collection[str],
  1602. value_values: Iterable[Iterable[Any]],
  1603. desc: str,
  1604. ) -> None:
  1605. """
  1606. Update, many times, using batching where possible.
  1607. If the keys don't match anything, nothing will be updated.
  1608. Args:
  1609. table: The table to update
  1610. key_names: The key column names.
  1611. key_values: A list of each row's key column values.
  1612. value_names: The names of value columns to update.
  1613. value_values: A list of each row's value column values.
  1614. """
  1615. await self.runInteraction(
  1616. desc,
  1617. self.simple_update_many_txn,
  1618. table,
  1619. key_names,
  1620. key_values,
  1621. value_names,
  1622. value_values,
  1623. )
  1624. @staticmethod
  1625. def simple_update_many_txn(
  1626. txn: LoggingTransaction,
  1627. table: str,
  1628. key_names: Collection[str],
  1629. key_values: Collection[Iterable[Any]],
  1630. value_names: Collection[str],
  1631. value_values: Collection[Iterable[Any]],
  1632. ) -> None:
  1633. """
  1634. Update, many times, using batching where possible.
  1635. If the keys don't match anything, nothing will be updated.
  1636. Args:
  1637. table: The table to update
  1638. key_names: The key column names.
  1639. key_values: A list of each row's key column values.
  1640. value_names: The names of value columns to update.
  1641. value_values: A list of each row's value column values.
  1642. """
  1643. if len(value_values) != len(key_values):
  1644. raise ValueError(
  1645. f"{len(key_values)} key rows and {len(value_values)} value rows: should be the same number."
  1646. )
  1647. # List of tuples of (value values, then key values)
  1648. # (This matches the order needed for the query)
  1649. args = [tuple(x) + tuple(y) for x, y in zip(value_values, key_values)]
  1650. for ks, vs in zip(key_values, value_values):
  1651. args.append(tuple(vs) + tuple(ks))
  1652. # 'col1 = ?, col2 = ?, ...'
  1653. set_clause = ", ".join(f"{n} = ?" for n in value_names)
  1654. if key_names:
  1655. # 'WHERE col3 = ? AND col4 = ? AND col5 = ?'
  1656. where_clause = "WHERE " + (" AND ".join(f"{n} = ?" for n in key_names))
  1657. else:
  1658. where_clause = ""
  1659. # UPDATE mytable SET col1 = ?, col2 = ? WHERE col3 = ? AND col4 = ?
  1660. sql = f"""
  1661. UPDATE {table} SET {set_clause} {where_clause}
  1662. """
  1663. txn.execute_batch(sql, args)
  1664. async def simple_update_one(
  1665. self,
  1666. table: str,
  1667. keyvalues: Dict[str, Any],
  1668. updatevalues: Dict[str, Any],
  1669. desc: str = "simple_update_one",
  1670. ) -> None:
  1671. """Executes an UPDATE query on the named table, setting new values for
  1672. columns in a row matching the key values.
  1673. Args:
  1674. table: string giving the table name
  1675. keyvalues: dict of column names and values to select the row with
  1676. updatevalues: dict giving column names and values to update
  1677. desc: description of the transaction, for logging and metrics
  1678. """
  1679. await self.runInteraction(
  1680. desc,
  1681. self.simple_update_one_txn,
  1682. table,
  1683. keyvalues,
  1684. updatevalues,
  1685. db_autocommit=True,
  1686. )
  1687. @classmethod
  1688. def simple_update_one_txn(
  1689. cls,
  1690. txn: LoggingTransaction,
  1691. table: str,
  1692. keyvalues: Dict[str, Any],
  1693. updatevalues: Dict[str, Any],
  1694. ) -> None:
  1695. rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
  1696. if rowcount == 0:
  1697. raise StoreError(404, "No row found (%s)" % (table,))
  1698. if rowcount > 1:
  1699. raise StoreError(500, "More than one row matched (%s)" % (table,))
  1700. # Ideally we could use the overload decorator here to specify that the
  1701. # return type is only optional if allow_none is True, but this does not work
  1702. # when you call a static method from an instance.
  1703. # See https://github.com/python/mypy/issues/7781
  1704. @staticmethod
  1705. def simple_select_one_txn(
  1706. txn: LoggingTransaction,
  1707. table: str,
  1708. keyvalues: Dict[str, Any],
  1709. retcols: Collection[str],
  1710. allow_none: bool = False,
  1711. ) -> Optional[Dict[str, Any]]:
  1712. select_sql = "SELECT %s FROM %s WHERE %s" % (
  1713. ", ".join(retcols),
  1714. table,
  1715. " AND ".join("%s = ?" % (k,) for k in keyvalues),
  1716. )
  1717. txn.execute(select_sql, list(keyvalues.values()))
  1718. row = txn.fetchone()
  1719. if not row:
  1720. if allow_none:
  1721. return None
  1722. raise StoreError(404, "No row found (%s)" % (table,))
  1723. if txn.rowcount > 1:
  1724. raise StoreError(500, "More than one row matched (%s)" % (table,))
  1725. return dict(zip(retcols, row))
  1726. async def simple_delete_one(
  1727. self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
  1728. ) -> None:
  1729. """Executes a DELETE query on the named table, expecting to delete a
  1730. single row.
  1731. Args:
  1732. table: string giving the table name
  1733. keyvalues: dict of column names and values to select the row with
  1734. desc: description of the transaction, for logging and metrics
  1735. """
  1736. await self.runInteraction(
  1737. desc,
  1738. self.simple_delete_one_txn,
  1739. table,
  1740. keyvalues,
  1741. db_autocommit=True,
  1742. )
  1743. @staticmethod
  1744. def simple_delete_one_txn(
  1745. txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
  1746. ) -> None:
  1747. """Executes a DELETE query on the named table, expecting to delete a
  1748. single row.
  1749. Args:
  1750. table: string giving the table name
  1751. keyvalues: dict of column names and values to select the row with
  1752. """
  1753. sql = "DELETE FROM %s WHERE %s" % (
  1754. table,
  1755. " AND ".join("%s = ?" % (k,) for k in keyvalues),
  1756. )
  1757. txn.execute(sql, list(keyvalues.values()))
  1758. if txn.rowcount == 0:
  1759. raise StoreError(404, "No row found (%s)" % (table,))
  1760. if txn.rowcount > 1:
  1761. raise StoreError(500, "More than one row matched (%s)" % (table,))
  1762. async def simple_delete(
  1763. self, table: str, keyvalues: Dict[str, Any], desc: str
  1764. ) -> int:
  1765. """Executes a DELETE query on the named table.
  1766. Filters rows by the key-value pairs.
  1767. Args:
  1768. table: string giving the table name
  1769. keyvalues: dict of column names and values to select the row with
  1770. desc: description of the transaction, for logging and metrics
  1771. Returns:
  1772. The number of deleted rows.
  1773. """
  1774. return await self.runInteraction(
  1775. desc, self.simple_delete_txn, table, keyvalues, db_autocommit=True
  1776. )
  1777. @staticmethod
  1778. def simple_delete_txn(
  1779. txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
  1780. ) -> int:
  1781. """Executes a DELETE query on the named table.
  1782. Filters rows by the key-value pairs.
  1783. Args:
  1784. table: string giving the table name
  1785. keyvalues: dict of column names and values to select the row with
  1786. Returns:
  1787. The number of deleted rows.
  1788. """
  1789. sql = "DELETE FROM %s WHERE %s" % (
  1790. table,
  1791. " AND ".join("%s = ?" % (k,) for k in keyvalues),
  1792. )
  1793. txn.execute(sql, list(keyvalues.values()))
  1794. return txn.rowcount
  1795. async def simple_delete_many(
  1796. self,
  1797. table: str,
  1798. column: str,
  1799. iterable: Collection[Any],
  1800. keyvalues: Dict[str, Any],
  1801. desc: str,
  1802. ) -> int:
  1803. """Executes a DELETE query on the named table.
  1804. Filters rows by if value of `column` is in `iterable`.
  1805. Args:
  1806. table: string giving the table name
  1807. column: column name to test for inclusion against `iterable`
  1808. iterable: list of values to match against `column`. NB cannot be a generator
  1809. as it may be evaluated multiple times.
  1810. keyvalues: dict of column names and values to select the rows with
  1811. desc: description of the transaction, for logging and metrics
  1812. Returns:
  1813. Number rows deleted
  1814. """
  1815. return await self.runInteraction(
  1816. desc,
  1817. self.simple_delete_many_txn,
  1818. table,
  1819. column,
  1820. iterable,
  1821. keyvalues,
  1822. db_autocommit=True,
  1823. )
  1824. @staticmethod
  1825. def simple_delete_many_txn(
  1826. txn: LoggingTransaction,
  1827. table: str,
  1828. column: str,
  1829. values: Collection[Any],
  1830. keyvalues: Dict[str, Any],
  1831. ) -> int:
  1832. """Executes a DELETE query on the named table.
  1833. Deletes the rows:
  1834. - whose value of `column` is in `values`; AND
  1835. - that match extra column-value pairs specified in `keyvalues`.
  1836. Args:
  1837. txn: Transaction object
  1838. table: string giving the table name
  1839. column: column name to test for inclusion against `values`
  1840. values: values of `column` which choose rows to delete
  1841. keyvalues: dict of extra column names and values to select the rows
  1842. with. They will be ANDed together with the main predicate.
  1843. Returns:
  1844. Number rows deleted
  1845. """
  1846. if not values:
  1847. return 0
  1848. sql = "DELETE FROM %s" % table
  1849. clause, values = make_in_list_sql_clause(txn.database_engine, column, values)
  1850. clauses = [clause]
  1851. for key, value in keyvalues.items():
  1852. clauses.append("%s = ?" % (key,))
  1853. values.append(value)
  1854. if clauses:
  1855. sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
  1856. txn.execute(sql, values)
  1857. return txn.rowcount
  1858. def get_cache_dict(
  1859. self,
  1860. db_conn: LoggingDatabaseConnection,
  1861. table: str,
  1862. entity_column: str,
  1863. stream_column: str,
  1864. max_value: int,
  1865. limit: int = 100000,
  1866. ) -> Tuple[Dict[Any, int], int]:
  1867. """Gets roughly the last N changes in the given stream table as a
  1868. map from entity to the stream ID of the most recent change.
  1869. Also returns the minimum stream ID.
  1870. """
  1871. # This may return many rows for the same entity, but the `limit` is only
  1872. # a suggestion so we don't care that much.
  1873. #
  1874. # Note: Some stream tables can have multiple rows with the same stream
  1875. # ID. Instead of handling this with complicated SQL, we instead simply
  1876. # add one to the returned minimum stream ID to ensure correctness.
  1877. sql = f"""
  1878. SELECT {entity_column}, {stream_column}
  1879. FROM {table}
  1880. ORDER BY {stream_column} DESC
  1881. LIMIT ?
  1882. """
  1883. txn = db_conn.cursor(txn_name="get_cache_dict")
  1884. txn.execute(sql, (limit,))
  1885. # The rows come out in reverse stream ID order, so we want to keep the
  1886. # stream ID of the first row for each entity.
  1887. cache: Dict[Any, int] = {}
  1888. for row in txn:
  1889. cache.setdefault(row[0], int(row[1]))
  1890. txn.close()
  1891. if cache:
  1892. # We add one here as we don't know if we have all rows for the
  1893. # minimum stream ID.
  1894. min_val = min(cache.values()) + 1
  1895. else:
  1896. min_val = max_value
  1897. return cache, min_val
  1898. @classmethod
  1899. def simple_select_list_paginate_txn(
  1900. cls,
  1901. txn: LoggingTransaction,
  1902. table: str,
  1903. orderby: str,
  1904. start: int,
  1905. limit: int,
  1906. retcols: Iterable[str],
  1907. filters: Optional[Dict[str, Any]] = None,
  1908. keyvalues: Optional[Dict[str, Any]] = None,
  1909. exclude_keyvalues: Optional[Dict[str, Any]] = None,
  1910. order_direction: str = "ASC",
  1911. ) -> List[Dict[str, Any]]:
  1912. """
  1913. Executes a SELECT query on the named table with start and limit,
  1914. of row numbers, which may return zero or number of rows from start to limit,
  1915. returning the result as a list of dicts.
  1916. Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
  1917. select attributes with exact matches. All constraints are joined together
  1918. using 'AND'.
  1919. Args:
  1920. txn: Transaction object
  1921. table: the table name
  1922. orderby: Column to order the results by.
  1923. start: Index to begin the query at.
  1924. limit: Number of results to return.
  1925. retcols: the names of the columns to return
  1926. filters:
  1927. column names and values to filter the rows with, or None to not
  1928. apply a WHERE ? LIKE ? clause.
  1929. keyvalues:
  1930. column names and values to select the rows with, or None to not
  1931. apply a WHERE key = value clause.
  1932. exclude_keyvalues:
  1933. column names and values to exclude rows with, or None to not
  1934. apply a WHERE key != value clause.
  1935. order_direction: Whether the results should be ordered "ASC" or "DESC".
  1936. Returns:
  1937. The result as a list of dictionaries.
  1938. """
  1939. if order_direction not in ["ASC", "DESC"]:
  1940. raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
  1941. where_clause = "WHERE " if filters or keyvalues or exclude_keyvalues else ""
  1942. arg_list: List[Any] = []
  1943. if filters:
  1944. where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
  1945. arg_list += list(filters.values())
  1946. where_clause += " AND " if filters and keyvalues else ""
  1947. if keyvalues:
  1948. where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
  1949. arg_list += list(keyvalues.values())
  1950. if exclude_keyvalues:
  1951. where_clause += " AND ".join("%s != ?" % (k,) for k in exclude_keyvalues)
  1952. arg_list += list(exclude_keyvalues.values())
  1953. sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
  1954. ", ".join(retcols),
  1955. table,
  1956. where_clause,
  1957. orderby,
  1958. order_direction,
  1959. )
  1960. txn.execute(sql, arg_list + [limit, start])
  1961. return cls.cursor_to_dict(txn)
  1962. async def simple_search_list(
  1963. self,
  1964. table: str,
  1965. term: Optional[str],
  1966. col: str,
  1967. retcols: Collection[str],
  1968. desc: str = "simple_search_list",
  1969. ) -> Optional[List[Dict[str, Any]]]:
  1970. """Executes a SELECT query on the named table, which may return zero or
  1971. more rows, returning the result as a list of dicts.
  1972. Args:
  1973. table: the table name
  1974. term: term for searching the table matched to a column.
  1975. col: column to query term should be matched to
  1976. retcols: the names of the columns to return
  1977. Returns:
  1978. A list of dictionaries or None.
  1979. """
  1980. return await self.runInteraction(
  1981. desc,
  1982. self.simple_search_list_txn,
  1983. table,
  1984. term,
  1985. col,
  1986. retcols,
  1987. db_autocommit=True,
  1988. )
  1989. @classmethod
  1990. def simple_search_list_txn(
  1991. cls,
  1992. txn: LoggingTransaction,
  1993. table: str,
  1994. term: Optional[str],
  1995. col: str,
  1996. retcols: Iterable[str],
  1997. ) -> Optional[List[Dict[str, Any]]]:
  1998. """Executes a SELECT query on the named table, which may return zero or
  1999. more rows, returning the result as a list of dicts.
  2000. Args:
  2001. txn: Transaction object
  2002. table: the table name
  2003. term: term for searching the table matched to a column.
  2004. col: column to query term should be matched to
  2005. retcols: the names of the columns to return
  2006. Returns:
  2007. None if no term is given, otherwise a list of dictionaries.
  2008. """
  2009. if term:
  2010. sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
  2011. termvalues = ["%%" + term + "%%"]
  2012. txn.execute(sql, termvalues)
  2013. else:
  2014. return None
  2015. return cls.cursor_to_dict(txn)
  2016. def make_in_list_sql_clause(
  2017. database_engine: BaseDatabaseEngine, column: str, iterable: Collection[Any]
  2018. ) -> Tuple[str, list]:
  2019. """Returns an SQL clause that checks the given column is in the iterable.
  2020. On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
  2021. it expands to `column = ANY(?)`. While both DBs support the `IN` form,
  2022. using the `ANY` form on postgres means that it views queries with
  2023. different length iterables as the same, helping the query stats.
  2024. Args:
  2025. database_engine
  2026. column: Name of the column
  2027. iterable: The values to check the column against.
  2028. Returns:
  2029. A tuple of SQL query and the args
  2030. """
  2031. if database_engine.supports_using_any_list:
  2032. # This should hopefully be faster, but also makes postgres query
  2033. # stats easier to understand.
  2034. return "%s = ANY(?)" % (column,), [list(iterable)]
  2035. else:
  2036. return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
  2037. KV = TypeVar("KV")
  2038. def make_tuple_comparison_clause(keys: List[Tuple[str, KV]]) -> Tuple[str, List[KV]]:
  2039. """Returns a tuple comparison SQL clause
  2040. Builds a SQL clause that looks like "(a, b) > (?, ?)"
  2041. Args:
  2042. keys: A set of (column, value) pairs to be compared.
  2043. Returns:
  2044. A tuple of SQL query and the args
  2045. """
  2046. return (
  2047. "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
  2048. [k[1] for k in keys],
  2049. )