Browse Source

Refuse to start if DB has an unsafe locale (#12262)

Shay 2 years ago
parent
commit
e78d4f61fc

+ 1 - 0
changelog.d/12262.misc

@@ -0,0 +1 @@
+Refuse to start if DB has non-`C` locale, unless config flag `allow_unsafe_db_locale` is set to true.

+ 4 - 3
docs/postgres.md

@@ -234,12 +234,13 @@ host    all         all             ::1/128     ident
 ### Fixing incorrect `COLLATE` or `CTYPE`
 
 Synapse will refuse to set up a new database if it has the wrong values of
-`COLLATE` and `CTYPE` set, and will log warnings on existing databases. Using
-different locales can cause issues if the locale library is updated from
+`COLLATE` and `CTYPE` set. Synapse will also refuse to start an existing database with incorrect values
+of `COLLATE` and `CTYPE` unless the config flag `allow_unsafe_locale`, found in the 
+`database` section of the config, is set to true. Using different locales can cause issues if the locale library is updated from
 underneath the database, or if a different version of the locale is used on any
 replicas.
 
-The safest way to fix the issue is to dump the database and recreate it with
+If you have a databse with an unsafe locale, the safest way to fix the issue is to dump the database and recreate it with
 the correct locale parameter (as shown above). It is also possible to change the
 parameters on a live database and run a `REINDEX` on the entire database,
 however extreme care must be taken to avoid database corruption.

+ 6 - 0
docs/sample_config.yaml

@@ -783,6 +783,12 @@ caches:
 # 'txn_limit' gives the maximum number of transactions to run per connection
 # before reconnecting. Defaults to 0, which means no limit.
 #
+# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
+# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
+# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
+# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
+# https://wiki.postgresql.org/wiki/Locale_data_changes
+#
 # 'args' gives options which are passed through to the database engine,
 # except for options starting 'cp_', which are used to configure the Twisted
 # connection pool. For a reference to valid arguments, see:

+ 6 - 0
synapse/config/database.py

@@ -37,6 +37,12 @@ DEFAULT_CONFIG = """\
 # 'txn_limit' gives the maximum number of transactions to run per connection
 # before reconnecting. Defaults to 0, which means no limit.
 #
+# 'allow_unsafe_locale' is an option specific to Postgres. Under the default behavior, Synapse will refuse to
+# start if the postgres db is set to a non-C locale. You can override this behavior (which is *not* recommended)
+# by setting 'allow_unsafe_locale' to true. Note that doing so may corrupt your database. You can find more information
+# here: https://matrix-org.github.io/synapse/latest/postgres.html#fixing-incorrect-collate-or-ctype and here:
+# https://wiki.postgresql.org/wiki/Locale_data_changes
+#
 # 'args' gives options which are passed through to the database engine,
 # except for options starting 'cp_', which are used to configure the Twisted
 # connection pool. For a reference to valid arguments, see:

+ 30 - 15
synapse/storage/engines/postgres.py

@@ -47,17 +47,26 @@ class PostgresEngine(BaseDatabaseEngine):
         self.default_isolation_level = (
             self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
         )
+        self.config = database_config
 
     @property
     def single_threaded(self) -> bool:
         return False
 
+    def get_db_locale(self, txn):
+        txn.execute(
+            "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+        )
+        collation, ctype = txn.fetchone()
+        return collation, ctype
+
     def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
         # revision numbers into two-decimal-digit numbers and appending them
         # together. For example, version 8.1.5 will be returned as 80105
         self._version = db_conn.server_version
+        allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
 
         # Are we on a supported PostgreSQL version?
         if not allow_outdated_version and self._version < 100000:
@@ -72,33 +81,39 @@ class PostgresEngine(BaseDatabaseEngine):
                     "See docs/postgres.md for more information." % (rows[0][0],)
                 )
 
-            txn.execute(
-                "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
-            )
-            collation, ctype = txn.fetchone()
+            collation, ctype = self.get_db_locale(txn)
             if collation != "C":
                 logger.warning(
-                    "Database has incorrect collation of %r. Should be 'C'\n"
-                    "See docs/postgres.md for more information.",
+                    "Database has incorrect collation of %r. Should be 'C'",
                     collation,
                 )
+                if not allow_unsafe_locale:
+                    raise IncorrectDatabaseSetup(
+                        "Database has incorrect collation of %r. Should be 'C'\n"
+                        "See docs/postgres.md for more information. You can override this check by"
+                        "setting 'allow_unsafe_locale' to true in the database config.",
+                        collation,
+                    )
 
             if ctype != "C":
-                logger.warning(
-                    "Database has incorrect ctype of %r. Should be 'C'\n"
-                    "See docs/postgres.md for more information.",
-                    ctype,
-                )
+                if not allow_unsafe_locale:
+                    logger.warning(
+                        "Database has incorrect ctype of %r. Should be 'C'",
+                        ctype,
+                    )
+                    raise IncorrectDatabaseSetup(
+                        "Database has incorrect ctype of %r. Should be 'C'\n"
+                        "See docs/postgres.md for more information. You can override this check by"
+                        "setting 'allow_unsafe_locale' to true in the database config.",
+                        ctype,
+                    )
 
     def check_new_database(self, txn):
         """Gets called when setting up a brand new database. This allows us to
         apply stricter checks on new databases versus existing database.
         """
 
-        txn.execute(
-            "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
-        )
-        collation, ctype = txn.fetchone()
+        collation, ctype = self.get_db_locale(txn)
 
         errors = []
 

+ 46 - 0
tests/storage/test_unsafe_locale.py

@@ -0,0 +1,46 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from unittest.mock import MagicMock, patch
+
+from synapse.storage.database import make_conn
+from synapse.storage.engines._base import IncorrectDatabaseSetup
+
+from tests.unittest import HomeserverTestCase
+from tests.utils import USE_POSTGRES_FOR_TESTS
+
+
+class UnsafeLocaleTest(HomeserverTestCase):
+    if not USE_POSTGRES_FOR_TESTS:
+        skip = "Requires Postgres"
+
+    @patch("synapse.storage.engines.postgres.PostgresEngine.get_db_locale")
+    def test_unsafe_locale(self, mock_db_locale: MagicMock) -> None:
+        mock_db_locale.return_value = ("B", "B")
+        database = self.hs.get_datastores().databases[0]
+
+        db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
+        with self.assertRaises(IncorrectDatabaseSetup):
+            database.engine.check_database(db_conn)
+        with self.assertRaises(IncorrectDatabaseSetup):
+            database.engine.check_new_database(db_conn)
+        db_conn.close()
+
+    def test_safe_locale(self) -> None:
+        database = self.hs.get_datastores().databases[0]
+
+        db_conn = make_conn(database._database_config, database.engine, "test_unsafe")
+        with db_conn.cursor() as txn:
+            res = database.engine.get_db_locale(txn)
+        self.assertEqual(res, ("C", "C"))
+        db_conn.close()