12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640 |
- # -*- coding: utf-8 -*-
- # Copyright 2014-2016 OpenMarket Ltd
- # Copyright 2017-2018 New Vector Ltd
- # Copyright 2019 The Matrix.org Foundation C.I.C.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- import time
- from time import monotonic as monotonic_time
- from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Tuple,
- TypeVar,
- )
- from six import iteritems, iterkeys, itervalues
- from six.moves import intern, range
- from prometheus_client import Histogram
- from twisted.enterprise import adbapi
- from twisted.internet import defer
- from synapse.api.errors import StoreError
- from synapse.config.database import DatabaseConnectionConfig
- from synapse.logging.context import (
- LoggingContext,
- LoggingContextOrSentinel,
- current_context,
- make_deferred_yieldable,
- )
- from synapse.metrics.background_process_metrics import run_as_background_process
- from synapse.storage.background_updates import BackgroundUpdater
- from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
- from synapse.storage.types import Connection, Cursor
- from synapse.util.stringutils import exception_to_unicode
- logger = logging.getLogger(__name__)
- # python 3 does not have a maximum int value
- MAX_TXN_ID = 2 ** 63 - 1
- sql_logger = logging.getLogger("synapse.storage.SQL")
- transaction_logger = logging.getLogger("synapse.storage.txn")
- perf_logger = logging.getLogger("synapse.storage.TIME")
- sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
- sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
- sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
- # Unique indexes which have been added in background updates. Maps from table name
- # to the name of the background update which added the unique index to that table.
- #
- # This is used by the upsert logic to figure out which tables are safe to do a proper
- # UPSERT on: until the relevant background update has completed, we
- # have to emulate an upsert by locking the table.
- #
- UNIQUE_INDEX_BACKGROUND_UPDATES = {
- "user_ips": "user_ips_device_unique_index",
- "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
- "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
- "event_search": "event_search_event_id_idx",
- }
- def make_pool(
- reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
- ) -> adbapi.ConnectionPool:
- """Get the connection pool for the database.
- """
- return adbapi.ConnectionPool(
- db_config.config["name"],
- cp_reactor=reactor,
- cp_openfun=engine.on_new_connection,
- **db_config.config.get("args", {})
- )
- def make_conn(
- db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
- ) -> Connection:
- """Make a new connection to the database and return it.
- Returns:
- Connection
- """
- db_params = {
- k: v
- for k, v in db_config.config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = engine.module.connect(**db_params)
- engine.on_new_connection(db_conn)
- return db_conn
- # The type of entry which goes on our after_callbacks and exception_callbacks lists.
- #
- # Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
- # that mypy sees the type but the runtime python doesn't.
- _CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
- class LoggingTransaction:
- """An object that almost-transparently proxies for the 'txn' object
- passed to the constructor. Adds logging and metrics to the .execute()
- method.
- Args:
- txn: The database transcation object to wrap.
- name: The name of this transactions for logging.
- database_engine
- after_callbacks: A list that callbacks will be appended to
- that have been added by `call_after` which should be run on
- successful completion of the transaction. None indicates that no
- callbacks should be allowed to be scheduled to run.
- exception_callbacks: A list that callbacks will be appended
- to that have been added by `call_on_exception` which should be run
- if transaction ends with an error. None indicates that no callbacks
- should be allowed to be scheduled to run.
- """
- __slots__ = [
- "txn",
- "name",
- "database_engine",
- "after_callbacks",
- "exception_callbacks",
- ]
- def __init__(
- self,
- txn: Cursor,
- name: str,
- database_engine: BaseDatabaseEngine,
- after_callbacks: Optional[List[_CallbackListEntry]] = None,
- exception_callbacks: Optional[List[_CallbackListEntry]] = None,
- ):
- self.txn = txn
- self.name = name
- self.database_engine = database_engine
- self.after_callbacks = after_callbacks
- self.exception_callbacks = exception_callbacks
- def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
- """Call the given callback on the main twisted thread after the
- transaction has finished. Used to invalidate the caches on the
- correct thread.
- """
- # if self.after_callbacks is None, that means that whatever constructed the
- # LoggingTransaction isn't expecting there to be any callbacks; assert that
- # is not the case.
- assert self.after_callbacks is not None
- self.after_callbacks.append((callback, args, kwargs))
- def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
- # if self.exception_callbacks is None, that means that whatever constructed the
- # LoggingTransaction isn't expecting there to be any callbacks; assert that
- # is not the case.
- assert self.exception_callbacks is not None
- self.exception_callbacks.append((callback, args, kwargs))
- def fetchall(self) -> List[Tuple]:
- return self.txn.fetchall()
- def fetchone(self) -> Tuple:
- return self.txn.fetchone()
- def __iter__(self) -> Iterator[Tuple]:
- return self.txn.__iter__()
- @property
- def rowcount(self) -> int:
- return self.txn.rowcount
- @property
- def description(self) -> Any:
- return self.txn.description
- def execute_batch(self, sql, args):
- if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch # type: ignore
- self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
- else:
- for val in args:
- self.execute(sql, val)
- def execute(self, sql: str, *args: Any):
- self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql: str, *args: Any):
- self._do_execute(self.txn.executemany, sql, *args)
- def _make_sql_one_line(self, sql):
- "Strip newlines out of SQL so that the loggers in the DB are on one line"
- return " ".join(l.strip() for l in sql.splitlines() if l.strip())
- def _do_execute(self, func, sql, *args):
- sql = self._make_sql_one_line(sql)
- # TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] {%s} %s", self.name, sql)
- sql = self.database_engine.convert_param_style(sql)
- if args:
- try:
- sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
- except Exception:
- # Don't let logging failures stop SQL from working
- pass
- start = time.time()
- try:
- return func(sql, *args)
- except Exception as e:
- logger.debug("[SQL FAIL] {%s} %s", self.name, e)
- raise
- finally:
- secs = time.time() - start
- sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
- sql_query_timer.labels(sql.split()[0]).observe(secs)
- def close(self):
- self.txn.close()
- class PerformanceCounters(object):
- def __init__(self):
- self.current_counters = {}
- self.previous_counters = {}
- def update(self, key, duration_secs):
- count, cum_time = self.current_counters.get(key, (0, 0))
- count += 1
- cum_time += duration_secs
- self.current_counters[key] = (count, cum_time)
- def interval(self, interval_duration_secs, limit=3):
- counters = []
- for name, (count, cum_time) in iteritems(self.current_counters):
- prev_count, prev_time = self.previous_counters.get(name, (0, 0))
- counters.append(
- (
- (cum_time - prev_time) / interval_duration_secs,
- count - prev_count,
- name,
- )
- )
- self.previous_counters = dict(self.current_counters)
- counters.sort(reverse=True)
- top_n_counters = ", ".join(
- "%s(%d): %.3f%%" % (name, count, 100 * ratio)
- for ratio, count, name in counters[:limit]
- )
- return top_n_counters
- class Database(object):
- """Wraps a single physical database and connection pool.
- A single database may be used by multiple data stores.
- """
- _TXN_ID = 0
- def __init__(
- self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
- ):
- self.hs = hs
- self._clock = hs.get_clock()
- self._database_config = database_config
- self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
- self.updates = BackgroundUpdater(hs, self)
- self._previous_txn_total_time = 0.0
- self._current_txn_total_time = 0.0
- self._previous_loop_ts = 0.0
- # TODO(paul): These can eventually be removed once the metrics code
- # is running in mainline, and we have some nice monitoring frontends
- # to watch it
- self._txn_perf_counters = PerformanceCounters()
- self.engine = engine
- # A set of tables that are not safe to use native upserts in.
- self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
- # We add the user_directory_search table to the blacklist on SQLite
- # because the existing search table does not have an index, making it
- # unsafe to use native upserts.
- if isinstance(self.engine, Sqlite3Engine):
- self._unsafe_to_upsert_tables.add("user_directory_search")
- if self.engine.can_native_upsert:
- # Check ASAP (and then later, every 1s) to see if we have finished
- # background updates of tables that aren't safe to update.
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "upsert_safety_check",
- self._check_safe_to_upsert,
- )
- def is_running(self):
- """Is the database pool currently running
- """
- return self._db_pool.running
- @defer.inlineCallbacks
- def _check_safe_to_upsert(self):
- """
- Is it safe to use native UPSERT?
- If there are background updates, we will need to wait, as they may be
- the addition of indexes that set the UNIQUE constraint that we require.
- If the background updates have not completed, wait 15 sec and check again.
- """
- updates = yield self.simple_select_list(
- "background_updates",
- keyvalues=None,
- retcols=["update_name"],
- desc="check_background_updates",
- )
- updates = [x["update_name"] for x in updates]
- for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
- if update_name not in updates:
- logger.debug("Now safe to upsert in %s", table)
- self._unsafe_to_upsert_tables.discard(table)
- # If there's any updates still running, reschedule to run.
- if updates:
- self._clock.call_later(
- 15.0,
- run_as_background_process,
- "upsert_safety_check",
- self._check_safe_to_upsert,
- )
- def start_profiling(self):
- self._previous_loop_ts = monotonic_time()
- def loop():
- curr = self._current_txn_total_time
- prev = self._previous_txn_total_time
- self._previous_txn_total_time = curr
- time_now = monotonic_time()
- time_then = self._previous_loop_ts
- self._previous_loop_ts = time_now
- duration = time_now - time_then
- ratio = (curr - prev) / duration
- top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
- perf_logger.debug(
- "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
- )
- self._clock.looping_call(loop, 10000)
- def new_transaction(
- self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
- ):
- start = monotonic_time()
- txn_id = self._TXN_ID
- # We don't really need these to be unique, so lets stop it from
- # growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
- name = "%s-%x" % (desc, txn_id)
- transaction_logger.debug("[TXN START] {%s}", name)
- try:
- i = 0
- N = 5
- while True:
- cursor = LoggingTransaction(
- conn.cursor(),
- name,
- self.engine,
- after_callbacks,
- exception_callbacks,
- )
- try:
- r = func(cursor, *args, **kwargs)
- conn.commit()
- return r
- except self.engine.module.OperationalError as e:
- # This can happen if the database disappears mid
- # transaction.
- logger.warning(
- "[TXN OPERROR] {%s} %s %d/%d",
- name,
- exception_to_unicode(e),
- i,
- N,
- )
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
- )
- continue
- raise
- except self.engine.module.DatabaseError as e:
- if self.engine.is_deadlock(e):
- logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s",
- name,
- exception_to_unicode(e1),
- )
- continue
- raise
- finally:
- # we're either about to retry with a new cursor, or we're about to
- # release the connection. Once we release the connection, it could
- # get used for another query, which might do a conn.rollback().
- #
- # In the latter case, even though that probably wouldn't affect the
- # results of this transaction, python's sqlite will reset all
- # statements on the connection [1], which will make our cursor
- # invalid [2].
- #
- # In any case, continuing to read rows after commit()ing seems
- # dubious from the PoV of ACID transactional semantics
- # (sqlite explicitly says that once you commit, you may see rows
- # from subsequent updates.)
- #
- # In psycopg2, cursors are essentially a client-side fabrication -
- # all the data is transferred to the client side when the statement
- # finishes executing - so in theory we could go on streaming results
- # from the cursor, but attempting to do so would make us
- # incompatible with sqlite, so let's make sure we're not doing that
- # by closing the cursor.
- #
- # (*named* cursors in psycopg2 are different and are proper server-
- # side things, but (a) we don't use them and (b) they are implicitly
- # closed by ending the transaction anyway.)
- #
- # In short, if we haven't finished with the cursor yet, that's a
- # problem waiting to bite us.
- #
- # TL;DR: we're done with the cursor, so we can close it.
- #
- # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
- # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
- cursor.close()
- except Exception as e:
- logger.debug("[TXN FAIL] {%s} %s", name, e)
- raise
- finally:
- end = monotonic_time()
- duration = end - start
- current_context().add_database_transaction(duration)
- transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
- self._current_txn_total_time += duration
- self._txn_perf_counters.update(desc, duration)
- sql_txn_timer.labels(desc).observe(duration)
- @defer.inlineCallbacks
- def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
- """Starts a transaction on the database and runs a given function
- Arguments:
- desc: description of the transaction, for logging and metrics
- func: callback function, which will be called with a
- database transaction (twisted.enterprise.adbapi.Transaction) as
- its first argument, followed by `args` and `kwargs`.
- args: positional args to pass to `func`
- kwargs: named args to pass to `func`
- Returns:
- Deferred: The result of func
- """
- after_callbacks = [] # type: List[_CallbackListEntry]
- exception_callbacks = [] # type: List[_CallbackListEntry]
- if not current_context():
- logger.warning("Starting db txn '%s' from sentinel context", desc)
- try:
- result = yield self.runWithConnection(
- self.new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- **kwargs
- )
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except: # noqa: E722, as we reraise the exception this is fine.
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
- return result
- @defer.inlineCallbacks
- def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
- """Wraps the .runWithConnection() method on the underlying db_pool.
- Arguments:
- func: callback function, which will be called with a
- database connection (twisted.enterprise.adbapi.Connection) as
- its first argument, followed by `args` and `kwargs`.
- args: positional args to pass to `func`
- kwargs: named args to pass to `func`
- Returns:
- Deferred: The result of func
- """
- parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
- if not parent_context:
- logger.warning(
- "Starting db connection from sentinel context: metrics will be lost"
- )
- parent_context = None
- start_time = monotonic_time()
- def inner_func(conn, *args, **kwargs):
- with LoggingContext("runWithConnection", parent_context) as context:
- sched_duration_sec = monotonic_time() - start_time
- sql_scheduling_timer.observe(sched_duration_sec)
- context.add_database_scheduled(sched_duration_sec)
- if self.engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
- return func(conn, *args, **kwargs)
- result = yield make_deferred_yieldable(
- self._db_pool.runWithConnection(inner_func, *args, **kwargs)
- )
- return result
- @staticmethod
- def cursor_to_dict(cursor):
- """Converts a SQL cursor into an list of dicts.
- Args:
- cursor : The DBAPI cursor which has executed a query.
- Returns:
- A list of dicts where the key is the column header.
- """
- col_headers = [intern(str(column[0])) for column in cursor.description]
- results = [dict(zip(col_headers, row)) for row in cursor]
- return results
- def execute(self, desc, decoder, query, *args):
- """Runs a single query for a result set.
- Args:
- decoder - The function which can resolve the cursor results to
- something meaningful.
- query - The query string to execute
- *args - Query args.
- Returns:
- The result of decoder(results)
- """
- def interaction(txn):
- txn.execute(query, args)
- if decoder:
- return decoder(txn)
- else:
- return txn.fetchall()
- return self.runInteraction(desc, interaction)
- # "Simple" SQL API methods that operate on a single table with no JOINs,
- # no complex WHERE clauses, just a dict of values for columns.
- @defer.inlineCallbacks
- def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
- """Executes an INSERT query on the named table.
- Args:
- table : string giving the table name
- values : dict of new column names and values for them
- or_ignore : bool stating whether an exception should be raised
- when a conflicting row already exists. If True, False will be
- returned by the function instead
- desc : string giving a description of the transaction
- Returns:
- bool: Whether the row was inserted or not. Only useful when
- `or_ignore` is True
- """
- try:
- yield self.runInteraction(desc, self.simple_insert_txn, table, values)
- except self.engine.module.IntegrityError:
- # We have to do or_ignore flag at this layer, since we can't reuse
- # a cursor after we receive an error from the db.
- if not or_ignore:
- raise
- return False
- return True
- @staticmethod
- def simple_insert_txn(txn, table, values):
- keys, vals = zip(*values.items())
- sql = "INSERT INTO %s (%s) VALUES(%s)" % (
- table,
- ", ".join(k for k in keys),
- ", ".join("?" for _ in keys),
- )
- txn.execute(sql, vals)
- def simple_insert_many(self, table, values, desc):
- return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
- @staticmethod
- def simple_insert_many_txn(txn, table, values):
- if not values:
- return
- # This is a *slight* abomination to get a list of tuples of key names
- # and a list of tuples of value names.
- #
- # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
- # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
- #
- # The sort is to ensure that we don't rely on dictionary iteration
- # order.
- keys, vals = zip(
- *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
- )
- for k in keys:
- if k != keys[0]:
- raise RuntimeError("All items must have the same keys")
- sql = "INSERT INTO %s (%s) VALUES(%s)" % (
- table,
- ", ".join(k for k in keys[0]),
- ", ".join("?" for _ in keys[0]),
- )
- txn.executemany(sql, vals)
- @defer.inlineCallbacks
- def simple_upsert(
- self,
- table,
- keyvalues,
- values,
- insertion_values={},
- desc="simple_upsert",
- lock=True,
- ):
- """
- `lock` should generally be set to True (the default), but can be set
- to False if either of the following are true:
- * there is a UNIQUE INDEX on the key columns. In this case a conflict
- will cause an IntegrityError in which case this function will retry
- the update.
- * we somehow know that we are the only thread which will be updating
- this table.
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key columns and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- Deferred(None or bool): Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
- """
- attempts = 0
- while True:
- try:
- result = yield self.runInteraction(
- desc,
- self.simple_upsert_txn,
- table,
- keyvalues,
- values,
- insertion_values,
- lock=lock,
- )
- return result
- except self.engine.module.IntegrityError as e:
- attempts += 1
- if attempts >= 5:
- # don't retry forever, because things other than races
- # can cause IntegrityErrors
- raise
- # presumably we raced with another transaction: let's retry.
- logger.warning(
- "IntegrityError when upserting into %s; retrying: %s", table, e
- )
- def simple_upsert_txn(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
- """
- Pick the UPSERT method which works best on the platform. Either the
- native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
- Args:
- txn: The transaction to use.
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
- """
- if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
- return self.simple_upsert_txn_native_upsert(
- txn, table, keyvalues, values, insertion_values=insertion_values
- )
- else:
- return self.simple_upsert_txn_emulated(
- txn,
- table,
- keyvalues,
- values,
- insertion_values=insertion_values,
- lock=lock,
- )
- def simple_upsert_txn_emulated(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
- """
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- bool: Return True if a new entry was created, False if an existing
- one was updated.
- """
- # We need to lock the table :(, unless we're *really* careful
- if lock:
- self.engine.lock_table(txn, table)
- def _getwhere(key):
- # If the value we're passing in is None (aka NULL), we need to use
- # IS, not =, as NULL = NULL equals NULL (False).
- if keyvalues[key] is None:
- return "%s IS ?" % (key,)
- else:
- return "%s = ?" % (key,)
- if not values:
- # If `values` is empty, then all of the values we care about are in
- # the unique key, so there is nothing to UPDATE. We can just do a
- # SELECT instead to see if it exists.
- sql = "SELECT 1 FROM %s WHERE %s" % (
- table,
- " AND ".join(_getwhere(k) for k in keyvalues),
- )
- sqlargs = list(keyvalues.values())
- txn.execute(sql, sqlargs)
- if txn.fetchall():
- # We have an existing record.
- return False
- else:
- # First try to update.
- sql = "UPDATE %s SET %s WHERE %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in values),
- " AND ".join(_getwhere(k) for k in keyvalues),
- )
- sqlargs = list(values.values()) + list(keyvalues.values())
- txn.execute(sql, sqlargs)
- if txn.rowcount > 0:
- # successfully updated at least one row.
- return False
- # We didn't find any existing rows, so insert a new one
- allvalues = {} # type: Dict[str, Any]
- allvalues.update(keyvalues)
- allvalues.update(values)
- allvalues.update(insertion_values)
- sql = "INSERT INTO %s (%s) VALUES (%s)" % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues),
- )
- txn.execute(sql, list(allvalues.values()))
- # successfully inserted
- return True
- def simple_upsert_txn_native_upsert(
- self, txn, table, keyvalues, values, insertion_values={}
- ):
- """
- Use the native UPSERT functionality in recent PostgreSQL versions.
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- Returns:
- None
- """
- allvalues = {} # type: Dict[str, Any]
- allvalues.update(keyvalues)
- allvalues.update(insertion_values)
- if not values:
- latter = "NOTHING"
- else:
- allvalues.update(values)
- latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues),
- ", ".join(k for k in keyvalues),
- latter,
- )
- txn.execute(sql, list(allvalues.values()))
- def simple_upsert_many_txn(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times.
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
- return self.simple_upsert_many_txn_native_upsert(
- txn, table, key_names, key_values, value_names, value_values
- )
- else:
- return self.simple_upsert_many_txn_emulated(
- txn, table, key_names, key_values, value_names, value_values
- )
- def simple_upsert_many_txn_emulated(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times, but without native UPSERT support or batching.
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- # No value columns, therefore make a blank list so that the following
- # zip() works correctly.
- if not value_names:
- value_values = [() for x in range(len(key_values))]
- for keyv, valv in zip(key_values, value_values):
- _keys = {x: y for x, y in zip(key_names, keyv)}
- _vals = {x: y for x, y in zip(value_names, valv)}
- self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
- def simple_upsert_many_txn_native_upsert(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times, using batching where possible.
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- allnames = [] # type: List[str]
- allnames.extend(key_names)
- allnames.extend(value_names)
- if not value_names:
- # No value columns, therefore make a blank list so that the
- # following zip() works correctly.
- latter = "NOTHING"
- value_values = [() for x in range(len(key_values))]
- else:
- latter = "UPDATE SET " + ", ".join(
- k + "=EXCLUDED." + k for k in value_names
- )
- sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
- table,
- ", ".join(k for k in allnames),
- ", ".join("?" for _ in allnames),
- ", ".join(key_names),
- latter,
- )
- args = []
- for x, y in zip(key_values, value_values):
- args.append(tuple(x) + tuple(y))
- return txn.execute_batch(sql, args)
- def simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
- ):
- """Executes a SELECT query on the named table, which is expected to
- return a single row, returning multiple columns from it.
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcols : list of strings giving the names of the columns to return
- allow_none : If true, return None instead of failing if the SELECT
- statement returns no rows
- """
- return self.runInteraction(
- desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
- )
- def simple_select_one_onecol(
- self,
- table,
- keyvalues,
- retcol,
- allow_none=False,
- desc="simple_select_one_onecol",
- ):
- """Executes a SELECT query on the named table, which is expected to
- return a single row, returning a single column from it.
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcol : string giving the name of the column to return
- """
- return self.runInteraction(
- desc,
- self.simple_select_one_onecol_txn,
- table,
- keyvalues,
- retcol,
- allow_none=allow_none,
- )
- @classmethod
- def simple_select_one_onecol_txn(
- cls, txn, table, keyvalues, retcol, allow_none=False
- ):
- ret = cls.simple_select_onecol_txn(
- txn, table=table, keyvalues=keyvalues, retcol=retcol
- )
- if ret:
- return ret[0]
- else:
- if allow_none:
- return None
- else:
- raise StoreError(404, "No row found")
- @staticmethod
- def simple_select_onecol_txn(txn, table, keyvalues, retcol):
- sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
- if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
- txn.execute(sql, list(keyvalues.values()))
- else:
- txn.execute(sql)
- return [r[0] for r in txn]
- def simple_select_onecol(
- self, table, keyvalues, retcol, desc="simple_select_onecol"
- ):
- """Executes a SELECT query on the named table, which returns a list
- comprising of the values of the named column from the selected rows.
- Args:
- table (str): table name
- keyvalues (dict|None): column names and values to select the rows with
- retcol (str): column whos value we wish to retrieve.
- Returns:
- Deferred: Results in a list
- """
- return self.runInteraction(
- desc, self.simple_select_onecol_txn, table, keyvalues, retcol
- )
- def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
- Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc, self.simple_select_list_txn, table, keyvalues, retcols
- )
- @classmethod
- def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
- Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
- """
- if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
- txn.execute(sql, list(keyvalues.values()))
- else:
- sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
- txn.execute(sql)
- return cls.cursor_to_dict(txn)
- @defer.inlineCallbacks
- def simple_select_many_batch(
- self,
- table,
- column,
- iterable,
- retcols,
- keyvalues={},
- desc="simple_select_many_batch",
- batch_size=100,
- ):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
- Filters rows by if value of `column` is in `iterable`.
- Args:
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
- """
- results = [] # type: List[Dict[str, Any]]
- if not iterable:
- return results
- # iterables can not be sliced, so convert it to a list first
- it_list = list(iterable)
- chunks = [
- it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
- ]
- for chunk in chunks:
- rows = yield self.runInteraction(
- desc,
- self.simple_select_many_txn,
- table,
- column,
- chunk,
- keyvalues,
- retcols,
- )
- results.extend(rows)
- return results
- @classmethod
- def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
- Filters rows by if value of `column` is in `iterable`.
- Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
- """
- if not iterable:
- return []
- clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
- clauses = [clause]
- for key, value in iteritems(keyvalues):
- clauses.append("%s = ?" % (key,))
- values.append(value)
- sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join(clauses),
- )
- txn.execute(sql, values)
- return cls.cursor_to_dict(txn)
- def simple_update(self, table, keyvalues, updatevalues, desc):
- return self.runInteraction(
- desc, self.simple_update_txn, table, keyvalues, updatevalues
- )
- @staticmethod
- def simple_update_txn(txn, table, keyvalues, updatevalues):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
- else:
- where = ""
- update_sql = "UPDATE %s SET %s %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- where,
- )
- txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
- return txn.rowcount
- def simple_update_one(
- self, table, keyvalues, updatevalues, desc="simple_update_one"
- ):
- """Executes an UPDATE query on the named table, setting new values for
- columns in a row matching the key values.
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- updatevalues : dict giving column names and values to update
- retcols : optional list of column names to return
- If present, retcols gives a list of column names on which to perform
- a SELECT statement *before* performing the UPDATE statement. The values
- of these will be returned in a dict.
- These are performed within the same transaction, allowing an atomic
- get-and-set. This can be used to implement compare-and-set by putting
- the update column in the 'keyvalues' dict as well.
- """
- return self.runInteraction(
- desc, self.simple_update_one_txn, table, keyvalues, updatevalues
- )
- @classmethod
- def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
- rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
- if rowcount == 0:
- raise StoreError(404, "No row found (%s)" % (table,))
- if rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
- @staticmethod
- def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
- select_sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
- txn.execute(select_sql, list(keyvalues.values()))
- row = txn.fetchone()
- if not row:
- if allow_none:
- return None
- raise StoreError(404, "No row found (%s)" % (table,))
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
- return dict(zip(retcols, row))
- def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
- """Executes a DELETE query on the named table, expecting to delete a
- single row.
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- """
- return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
- @staticmethod
- def simple_delete_one_txn(txn, table, keyvalues):
- """Executes a DELETE query on the named table, expecting to delete a
- single row.
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- """
- sql = "DELETE FROM %s WHERE %s" % (
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
- txn.execute(sql, list(keyvalues.values()))
- if txn.rowcount == 0:
- raise StoreError(404, "No row found (%s)" % (table,))
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
- def simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
- @staticmethod
- def simple_delete_txn(txn, table, keyvalues):
- sql = "DELETE FROM %s WHERE %s" % (
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
- txn.execute(sql, list(keyvalues.values()))
- return txn.rowcount
- def simple_delete_many(self, table, column, iterable, keyvalues, desc):
- return self.runInteraction(
- desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
- )
- @staticmethod
- def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
- """Executes a DELETE query on the named table.
- Filters rows by if value of `column` is in `iterable`.
- Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- Returns:
- int: Number rows deleted
- """
- if not iterable:
- return 0
- sql = "DELETE FROM %s" % table
- clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
- clauses = [clause]
- for key, value in iteritems(keyvalues):
- clauses.append("%s = ?" % (key,))
- values.append(value)
- if clauses:
- sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
- txn.execute(sql, values)
- return txn.rowcount
- def get_cache_dict(
- self, db_conn, table, entity_column, stream_column, max_value, limit=100000
- ):
- # Fetch a mapping of room_id -> max stream position for "recent" rooms.
- # It doesn't really matter how many we get, the StreamChangeCache will
- # do the right thing to ensure it respects the max size of cache.
- sql = (
- "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
- " WHERE %(stream)s > ? - %(limit)s"
- " GROUP BY %(entity)s"
- ) % {
- "table": table,
- "entity": entity_column,
- "stream": stream_column,
- "limit": limit,
- }
- sql = self.engine.convert_param_style(sql)
- txn = db_conn.cursor()
- txn.execute(sql, (int(max_value),))
- cache = {row[0]: int(row[1]) for row in txn}
- txn.close()
- if cache:
- min_val = min(itervalues(cache))
- else:
- min_val = max_value
- return cache, min_val
- def simple_select_list_paginate(
- self,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- desc="simple_select_list_paginate",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
- Args:
- table (str): the table name
- filters (dict[str, T] | None):
- column names and values to filter the rows with, or None to not
- apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc,
- self.simple_select_list_paginate_txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=filters,
- keyvalues=keyvalues,
- order_direction=order_direction,
- )
- @classmethod
- def simple_select_list_paginate_txn(
- cls,
- txn,
- table,
- orderby,
- start,
- limit,
- retcols,
- filters=None,
- keyvalues=None,
- order_direction="ASC",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
- Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
- select attributes with exact matches. All constraints are joined together
- using 'AND'.
- Args:
- txn : Transaction object
- table (str): the table name
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- filters (dict[str, T] | None):
- column names and values to filter the rows with, or None to not
- apply a WHERE ? LIKE ? clause.
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- if order_direction not in ["ASC", "DESC"]:
- raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
- where_clause = "WHERE " if filters or keyvalues else ""
- arg_list = [] # type: List[Any]
- if filters:
- where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
- arg_list += list(filters.values())
- where_clause += " AND " if filters and keyvalues else ""
- if keyvalues:
- where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
- arg_list += list(keyvalues.values())
- sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
- ", ".join(retcols),
- table,
- where_clause,
- orderby,
- order_direction,
- )
- txn.execute(sql, arg_list + [limit, start])
- return cls.cursor_to_dict(txn)
- def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
- Args:
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
- """
- return self.runInteraction(
- desc, self.simple_search_list_txn, table, term, col, retcols
- )
- @classmethod
- def simple_search_list_txn(cls, txn, table, term, col, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
- Args:
- txn : Transaction object
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
- """
- if term:
- sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
- termvalues = ["%%" + term + "%%"]
- txn.execute(sql, termvalues)
- else:
- return 0
- return cls.cursor_to_dict(txn)
- def make_in_list_sql_clause(
- database_engine, column: str, iterable: Iterable
- ) -> Tuple[str, list]:
- """Returns an SQL clause that checks the given column is in the iterable.
- On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
- it expands to `column = ANY(?)`. While both DBs support the `IN` form,
- using the `ANY` form on postgres means that it views queries with
- different length iterables as the same, helping the query stats.
- Args:
- database_engine
- column: Name of the column
- iterable: The values to check the column against.
- Returns:
- A tuple of SQL query and the args
- """
- if database_engine.supports_using_any_list:
- # This should hopefully be faster, but also makes postgres query
- # stats easier to understand.
- return "%s = ANY(?)" % (column,), [list(iterable)]
- else:
- return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
- KV = TypeVar("KV")
- def make_tuple_comparison_clause(
- database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]]
- ) -> Tuple[str, List[KV]]:
- """Returns a tuple comparison SQL clause
- Depending what the SQL engine supports, builds a SQL clause that looks like either
- "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)".
- Args:
- database_engine
- keys: A set of (column, value) pairs to be compared.
- Returns:
- A tuple of SQL query and the args
- """
- if database_engine.supports_tuple_comparison:
- return (
- "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
- [k[1] for k in keys],
- )
- # we want to build a clause
- # (a > ?) OR
- # (a == ? AND b > ?) OR
- # (a == ? AND b == ? AND c > ?)
- # ...
- # (a == ? AND b == ? AND ... AND z > ?)
- #
- # or, equivalently:
- #
- # (a > ? OR (a == ? AND
- # (b > ? OR (b == ? AND
- # ...
- # (y > ? OR (y == ? AND
- # z > ?
- # ))
- # ...
- # ))
- # ))
- #
- # which itself is equivalent to (and apparently easier for the query optimiser):
- #
- # (a >= ? AND (a > ? OR
- # (b >= ? AND (b > ? OR
- # ...
- # (y >= ? AND (y > ? OR
- # z > ?
- # ))
- # ...
- # ))
- # ))
- #
- #
- clause = ""
- args = [] # type: List[KV]
- for k, v in keys[:-1]:
- clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k)
- args.extend([v, v])
- (k, v) = keys[-1]
- clause += "%s > ?" % (k,)
- args.append(v)
- clause += "))" * (len(keys) - 1)
- return clause, args
|