Browse Source

Use native UPSERTs where possible (#4306)

Amber Brown 5 years ago
parent
commit
58f6c48183

+ 1 - 5
.coveragerc

@@ -1,11 +1,7 @@
 [run]
 branch = True
 parallel = True
-source = synapse
-
-[paths]
-source=
-   coverage
+include = synapse/*
 
 [report]
 precision = 2

+ 3 - 3
.gitignore

@@ -25,9 +25,9 @@ homeserver*.pid
 *.tls.dh
 *.tls.key
 
-.coverage
-.coverage.*
-!.coverage.rc
+.coverage*
+coverage.*
+!.coveragerc
 htmlcov
 
 demo/*/*.db

+ 1 - 0
changelog.d/4306.misc

@@ -0,0 +1 @@
+Synapse will now take advantage of native UPSERT functionality in PostgreSQL 9.5+ and SQLite 3.24+.

+ 137 - 11
synapse/storage/_base.py

@@ -192,6 +192,41 @@ class SQLBaseStore(object):
 
         self.database_engine = hs.database_engine
 
+        # A set of tables that are not safe to use native upserts in.
+        self._unsafe_to_upsert_tables = {"user_ips"}
+
+        if self.database_engine.can_native_upsert:
+            # Check ASAP (and then later, every 1s) to see if we have finished
+            # background updates of tables that aren't safe to update.
+            self._clock.call_later(0.0, self._check_safe_to_upsert)
+
+    @defer.inlineCallbacks
+    def _check_safe_to_upsert(self):
+        """
+        Is it safe to use native UPSERT?
+
+        If there are background updates, we will need to wait, as they may be
+        the addition of indexes that set the UNIQUE constraint that we require.
+
+        If the background updates have not completed, wait a second and check again.
+        """
+        updates = yield self._simple_select_list(
+            "background_updates",
+            keyvalues=None,
+            retcols=["update_name"],
+            desc="check_background_updates",
+        )
+        updates = [x["update_name"] for x in updates]
+
+        # The User IPs table in schema #53 was missing a unique index, which we
+        # run as a background update.
+        if "user_ips_device_unique_index" not in updates:
+            self._unsafe_to_upsert_tables.discard("user_id")
+
+        # If there's any tables left to check, reschedule to run.
+        if self._unsafe_to_upsert_tables:
+            self._clock.call_later(1.0, self._check_safe_to_upsert)
+
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
 
@@ -494,8 +529,15 @@ class SQLBaseStore(object):
         txn.executemany(sql, vals)
 
     @defer.inlineCallbacks
-    def _simple_upsert(self, table, keyvalues, values,
-                       insertion_values={}, desc="_simple_upsert", lock=True):
+    def _simple_upsert(
+        self,
+        table,
+        keyvalues,
+        values,
+        insertion_values={},
+        desc="_simple_upsert",
+        lock=True
+    ):
         """
 
         `lock` should generally be set to True (the default), but can be set
@@ -516,16 +558,21 @@ class SQLBaseStore(object):
                 inserting
             lock (bool): True to lock the table when doing the upsert.
         Returns:
-            Deferred(bool): True if a new entry was created, False if an
-                existing one was updated.
+            Deferred(None or bool): Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
         """
         attempts = 0
         while True:
             try:
                 result = yield self.runInteraction(
                     desc,
-                    self._simple_upsert_txn, table, keyvalues, values, insertion_values,
-                    lock=lock
+                    self._simple_upsert_txn,
+                    table,
+                    keyvalues,
+                    values,
+                    insertion_values,
+                    lock=lock,
                 )
                 defer.returnValue(result)
             except self.database_engine.module.IntegrityError as e:
@@ -537,12 +584,59 @@ class SQLBaseStore(object):
 
                 # presumably we raced with another transaction: let's retry.
                 logger.warn(
-                    "IntegrityError when upserting into %s; retrying: %s",
-                    table, e
+                    "%s when upserting into %s; retrying: %s", e.__name__, table, e
                 )
 
-    def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
-                           lock=True):
+    def _simple_upsert_txn(
+        self,
+        txn,
+        table,
+        keyvalues,
+        values,
+        insertion_values={},
+        lock=True,
+    ):
+        """
+        Pick the UPSERT method which works best on the platform. Either the
+        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+        Args:
+            txn: The transaction to use.
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+            lock (bool): True to lock the table when doing the upsert.
+        Returns:
+            Deferred(None or bool): Native upserts always return None. Emulated
+            upserts return True if a new entry was created, False if an existing
+            one was updated.
+        """
+        if (
+            self.database_engine.can_native_upsert
+            and table not in self._unsafe_to_upsert_tables
+        ):
+            return self._simple_upsert_txn_native_upsert(
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+            )
+        else:
+            return self._simple_upsert_txn_emulated(
+                txn,
+                table,
+                keyvalues,
+                values,
+                insertion_values=insertion_values,
+                lock=lock,
+            )
+
+    def _simple_upsert_txn_emulated(
+        self, txn, table, keyvalues, values, insertion_values={}, lock=True
+    ):
         # We need to lock the table :(, unless we're *really* careful
         if lock:
             self.database_engine.lock_table(txn, table)
@@ -577,12 +671,44 @@ class SQLBaseStore(object):
         sql = "INSERT INTO %s (%s) VALUES (%s)" % (
             table,
             ", ".join(k for k in allvalues),
-            ", ".join("?" for _ in allvalues)
+            ", ".join("?" for _ in allvalues),
         )
         txn.execute(sql, list(allvalues.values()))
         # successfully inserted
         return True
 
+    def _simple_upsert_txn_native_upsert(
+        self, txn, table, keyvalues, values, insertion_values={}
+    ):
+        """
+        Use the native UPSERT functionality in recent PostgreSQL versions.
+
+        Args:
+            table (str): The table to upsert into
+            keyvalues (dict): The unique key tables and their new values
+            values (dict): The nonunique columns and their new values
+            insertion_values (dict): additional key/values to use only when
+                inserting
+        Returns:
+            None
+        """
+        allvalues = {}
+        allvalues.update(keyvalues)
+        allvalues.update(values)
+        allvalues.update(insertion_values)
+
+        sql = (
+            "INSERT INTO %s (%s) VALUES (%s) "
+            "ON CONFLICT (%s) DO UPDATE SET %s"
+        ) % (
+            table,
+            ", ".join(k for k in allvalues),
+            ", ".join("?" for _ in allvalues),
+            ", ".join(k for k in keyvalues),
+            ", ".join(k + "=EXCLUDED." + k for k in values),
+        )
+        txn.execute(sql, list(allvalues.values()))
+
     def _simple_select_one(self, table, keyvalues, retcols,
                            allow_none=False, desc="_simple_select_one"):
         """Executes a SELECT query on the named table, which is expected to

+ 4 - 1
synapse/storage/client_ips.py

@@ -257,7 +257,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
         )
 
     def _update_client_ips_batch_txn(self, txn, to_update):
-        self.database_engine.lock_table(txn, "user_ips")
+        if "user_ips" in self._unsafe_to_upsert_tables or (
+            not self.database_engine.can_native_upsert
+        ):
+            self.database_engine.lock_table(txn, "user_ips")
 
         for entry in iteritems(to_update):
             (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry

+ 1 - 1
synapse/storage/engines/__init__.py

@@ -18,7 +18,7 @@ import platform
 
 from ._base import IncorrectDatabaseSetup
 from .postgres import PostgresEngine
-from .sqlite3 import Sqlite3Engine
+from .sqlite import Sqlite3Engine
 
 SUPPORTED_MODULE = {
     "sqlite3": Sqlite3Engine,

+ 14 - 0
synapse/storage/engines/postgres.py

@@ -38,6 +38,13 @@ class PostgresEngine(object):
         return sql.replace("?", "%s")
 
     def on_new_connection(self, db_conn):
+
+        # Get the version of PostgreSQL that we're using. As per the psycopg2
+        # docs: The number is formed by converting the major, minor, and
+        # revision numbers into two-decimal-digit numbers and appending them
+        # together. For example, version 8.1.5 will be returned as 80105
+        self._version = db_conn.server_version
+
         db_conn.set_isolation_level(
             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
@@ -54,6 +61,13 @@ class PostgresEngine(object):
 
         cursor.close()
 
+    @property
+    def can_native_upsert(self):
+        """
+        Can we use native UPSERTs? This requires PostgreSQL 9.5+.
+        """
+        return self._version >= 90500
+
     def is_deadlock(self, error):
         if isinstance(error, self.module.DatabaseError):
             # https://www.postgresql.org/docs/current/static/errcodes-appendix.html

+ 9 - 0
synapse/storage/engines/sqlite3.py → synapse/storage/engines/sqlite.py

@@ -15,6 +15,7 @@
 
 import struct
 import threading
+from sqlite3 import sqlite_version_info
 
 from synapse.storage.prepare_database import prepare_database
 
@@ -30,6 +31,14 @@ class Sqlite3Engine(object):
         self._current_state_group_id = None
         self._current_state_group_id_lock = threading.Lock()
 
+    @property
+    def can_native_upsert(self):
+        """
+        Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
+        more work we haven't done yet to tell what was inserted vs updated.
+        """
+        return sqlite_version_info >= (3, 24, 0)
+
     def check_database(self, txn):
         pass
 

+ 7 - 2
synapse/storage/pusher.py

@@ -215,7 +215,7 @@ class PusherStore(PusherWorkerStore):
         with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
             # (app_id, pushkey, user_name) so _simple_upsert will retry
-            newly_inserted = yield self._simple_upsert(
+            yield self._simple_upsert(
                 table="pushers",
                 keyvalues={
                     "app_id": app_id,
@@ -238,7 +238,12 @@ class PusherStore(PusherWorkerStore):
                 lock=False,
             )
 
-            if newly_inserted:
+            user_has_pusher = self.get_if_user_has_pusher.cache.get(
+                (user_id,), None, update_metrics=False
+            )
+
+            if user_has_pusher is not True:
+                # invalidate, since we the user might not have had a pusher before
                 yield self.runInteraction(
                     "add_pusher",
                     self._invalidate_cache_and_stream,

+ 40 - 15
synapse/storage/user_directory.py

@@ -168,14 +168,14 @@ class UserDirectoryStore(SQLBaseStore):
             if isinstance(self.database_engine, PostgresEngine):
                 # We weight the localpart most highly, then display name and finally
                 # server name
-                if new_entry:
+                if self.database_engine.can_native_upsert:
                     sql = """
                         INSERT INTO user_directory_search(user_id, vector)
                         VALUES (?,
                             setweight(to_tsvector('english', ?), 'A')
                             || setweight(to_tsvector('english', ?), 'D')
                             || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
-                        )
+                        ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
                     """
                     txn.execute(
                         sql,
@@ -185,20 +185,45 @@ class UserDirectoryStore(SQLBaseStore):
                         )
                     )
                 else:
-                    sql = """
-                        UPDATE user_directory_search
-                        SET vector = setweight(to_tsvector('english', ?), 'A')
-                            || setweight(to_tsvector('english', ?), 'D')
-                            || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
-                        WHERE user_id = ?
-                    """
-                    txn.execute(
-                        sql,
-                        (
-                            get_localpart_from_id(user_id), get_domain_from_id(user_id),
-                            display_name, user_id,
+                    # TODO: Remove this code after we've bumped the minimum version
+                    # of postgres to always support upserts, so we can get rid of
+                    # `new_entry` usage
+                    if new_entry is True:
+                        sql = """
+                            INSERT INTO user_directory_search(user_id, vector)
+                            VALUES (?,
+                                setweight(to_tsvector('english', ?), 'A')
+                                || setweight(to_tsvector('english', ?), 'D')
+                                || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                            )
+                        """
+                        txn.execute(
+                            sql,
+                            (
+                                user_id, get_localpart_from_id(user_id),
+                                get_domain_from_id(user_id), display_name,
+                            )
+                        )
+                    elif new_entry is False:
+                        sql = """
+                            UPDATE user_directory_search
+                            SET vector = setweight(to_tsvector('english', ?), 'A')
+                                || setweight(to_tsvector('english', ?), 'D')
+                                || setweight(to_tsvector('english', COALESCE(?, '')), 'B')
+                            WHERE user_id = ?
+                        """
+                        txn.execute(
+                            sql,
+                            (
+                                get_localpart_from_id(user_id),
+                                get_domain_from_id(user_id),
+                                display_name, user_id,
+                            )
+                        )
+                    else:
+                        raise RuntimeError(
+                            "upsert returned None when 'can_native_upsert' is False"
                         )
-                    )
             elif isinstance(self.database_engine, Sqlite3Engine):
                 value = "%s %s" % (user_id, display_name,) if display_name else user_id
                 self._simple_upsert_txn(

+ 1 - 0
tests/storage/test_base.py

@@ -49,6 +49,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
         self.db_pool.runWithConnection = runWithConnection
 
         config = Mock()
+        config._enable_native_upserts = False
         config.event_cache_size = 1
         config.database_config = {"name": "sqlite3"}
         hs = TestHomeServer(

+ 9 - 3
tests/test_server.py

@@ -19,7 +19,7 @@ from six import StringIO
 
 from twisted.internet.defer import Deferred
 from twisted.python.failure import Failure
-from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
+from twisted.test.proto_helpers import AccumulatingProtocol
 from twisted.web.resource import Resource
 from twisted.web.server import NOT_DONE_YET
 
@@ -30,12 +30,18 @@ from synapse.util import Clock
 from synapse.util.logcontext import make_deferred_yieldable
 
 from tests import unittest
-from tests.server import FakeTransport, make_request, render, setup_test_homeserver
+from tests.server import (
+    FakeTransport,
+    ThreadedMemoryReactorClock,
+    make_request,
+    render,
+    setup_test_homeserver,
+)
 
 
 class JsonResourceTests(unittest.TestCase):
     def setUp(self):
-        self.reactor = MemoryReactorClock()
+        self.reactor = ThreadedMemoryReactorClock()
         self.hs_clock = Clock(self.reactor)
         self.homeserver = setup_test_homeserver(
             self.addCleanup, http_client=None, clock=self.hs_clock, reactor=self.reactor

+ 10 - 2
tests/unittest.py

@@ -96,7 +96,7 @@ class TestCase(unittest.TestCase):
 
         method = getattr(self, methodName)
 
-        level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
+        level = getattr(method, "loglevel", getattr(self, "loglevel", logging.WARNING))
 
         @around(self)
         def setUp(orig):
@@ -333,7 +333,15 @@ class HomeserverTestCase(TestCase):
         """
         kwargs = dict(kwargs)
         kwargs.update(self._hs_args)
-        return setup_test_homeserver(self.addCleanup, *args, **kwargs)
+        hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
+        stor = hs.get_datastore()
+
+        # Run the database background updates.
+        if hasattr(stor, "do_next_background_update"):
+            while not self.get_success(stor.has_completed_background_updates()):
+                self.get_success(stor.do_next_background_update(1))
+
+        return hs
 
     def pump(self, by=0.0):
         """

+ 1 - 0
tox.ini

@@ -149,4 +149,5 @@ deps =
     codecov
 commands =
     coverage combine
+    coverage xml
     codecov -X gcov