123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280 |
- # Copyright 2020 The Matrix.org Foundation C.I.C.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Callable, Tuple
- from unittest.mock import Mock, call
- from twisted.internet import defer
- from twisted.internet.defer import CancelledError, Deferred
- from twisted.test.proto_helpers import MemoryReactor
- from synapse.server import HomeServer
- from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
- make_tuple_comparison_clause,
- )
- from synapse.util import Clock
- from tests import unittest
- class TupleComparisonClauseTestCase(unittest.TestCase):
- def test_native_tuple_comparison(self) -> None:
- clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
- self.assertEqual(clause, "(a,b) > (?,?)")
- self.assertEqual(args, [1, 2])
- class ExecuteScriptTestCase(unittest.HomeserverTestCase):
- """Tests for `BaseDatabaseEngine.executescript` implementations."""
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.db_pool: DatabasePool = self.store.db_pool
- self.get_success(
- self.db_pool.runInteraction(
- "create",
- lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"),
- )
- )
- def test_transaction(self) -> None:
- """Test that all statements are run in a single transaction."""
- def run(conn: LoggingDatabaseConnection) -> None:
- cur = conn.cursor(txn_name="test_transaction")
- self.db_pool.engine.executescript(
- cur,
- ";".join(
- [
- "INSERT INTO foo (name) VALUES ('transaction test')",
- # This next statement will fail. When `executescript` is not
- # transactional, the previous row will be observed later.
- "INSERT INTO foo (name) VALUES ('transaction test')",
- ]
- ),
- )
- self.get_failure(
- self.db_pool.runWithConnection(run),
- self.db_pool.engine.module.IntegrityError,
- )
- self.assertIsNone(
- self.get_success(
- self.db_pool.simple_select_one_onecol(
- "foo",
- keyvalues={"name": "transaction test"},
- retcol="name",
- allow_none=True,
- )
- ),
- "executescript is not running statements inside a transaction",
- )
- def test_commit(self) -> None:
- """Test that the script transaction remains open and can be committed."""
- def run(conn: LoggingDatabaseConnection) -> None:
- cur = conn.cursor(txn_name="test_commit")
- self.db_pool.engine.executescript(
- cur, "INSERT INTO foo (name) VALUES ('commit test')"
- )
- cur.execute("COMMIT")
- self.get_success(self.db_pool.runWithConnection(run))
- self.assertIsNotNone(
- self.get_success(
- self.db_pool.simple_select_one_onecol(
- "foo",
- keyvalues={"name": "commit test"},
- retcol="name",
- allow_none=True,
- )
- ),
- )
- def test_rollback(self) -> None:
- """Test that the script transaction remains open and can be rolled back."""
- def run(conn: LoggingDatabaseConnection) -> None:
- cur = conn.cursor(txn_name="test_rollback")
- self.db_pool.engine.executescript(
- cur, "INSERT INTO foo (name) VALUES ('rollback test')"
- )
- cur.execute("ROLLBACK")
- self.get_success(self.db_pool.runWithConnection(run))
- self.assertIsNone(
- self.get_success(
- self.db_pool.simple_select_one_onecol(
- "foo",
- keyvalues={"name": "rollback test"},
- retcol="name",
- allow_none=True,
- )
- ),
- "executescript is not leaving the script transaction open",
- )
- class CallbacksTestCase(unittest.HomeserverTestCase):
- """Tests for transaction callbacks."""
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.db_pool: DatabasePool = self.store.db_pool
- def _run_interaction(
- self, func: Callable[[LoggingTransaction], object]
- ) -> Tuple[Mock, Mock]:
- """Run the given function in a database transaction, with callbacks registered.
- Args:
- func: The function to be run in a transaction. The transaction will be
- retried if `func` raises an `OperationalError`.
- Returns:
- Two mocks, which were registered as an `after_callback` and an
- `exception_callback` respectively, on every transaction attempt.
- """
- after_callback = Mock()
- exception_callback = Mock()
- def _test_txn(txn: LoggingTransaction) -> None:
- txn.call_after(after_callback, 123, 456, extra=789)
- txn.call_on_exception(exception_callback, 987, 654, extra=321)
- func(txn)
- try:
- self.get_success_or_raise(
- self.db_pool.runInteraction("test_transaction", _test_txn)
- )
- except Exception:
- pass
- return after_callback, exception_callback
- def test_after_callback(self) -> None:
- """Test that the after callback is called when a transaction succeeds."""
- after_callback, exception_callback = self._run_interaction(lambda txn: None)
- after_callback.assert_called_once_with(123, 456, extra=789)
- exception_callback.assert_not_called()
- def test_exception_callback(self) -> None:
- """Test that the exception callback is called when a transaction fails."""
- _test_txn = Mock(side_effect=ZeroDivisionError)
- after_callback, exception_callback = self._run_interaction(_test_txn)
- after_callback.assert_not_called()
- exception_callback.assert_called_once_with(987, 654, extra=321)
- def test_failed_retry(self) -> None:
- """Test that the exception callback is called for every failed attempt."""
- # Always raise an `OperationalError`.
- _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
- after_callback, exception_callback = self._run_interaction(_test_txn)
- after_callback.assert_not_called()
- exception_callback.assert_has_calls(
- [
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- ]
- )
- self.assertEqual(exception_callback.call_count, 6) # no additional calls
- def test_successful_retry(self) -> None:
- """Test callbacks for a failed transaction followed by a successful attempt."""
- # Raise an `OperationalError` on the first attempt only.
- _test_txn = Mock(
- side_effect=[self.db_pool.engine.module.OperationalError, None]
- )
- after_callback, exception_callback = self._run_interaction(_test_txn)
- # Calling both `after_callback`s when the first attempt failed is rather
- # surprising (#12184). Let's document the behaviour in a test.
- after_callback.assert_has_calls(
- [
- call(123, 456, extra=789),
- call(123, 456, extra=789),
- ]
- )
- self.assertEqual(after_callback.call_count, 2) # no additional calls
- exception_callback.assert_not_called()
- class CancellationTestCase(unittest.HomeserverTestCase):
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.db_pool: DatabasePool = self.store.db_pool
- def test_after_callback(self) -> None:
- """Test that the after callback is called when a transaction succeeds."""
- d: "Deferred[None]"
- after_callback = Mock()
- exception_callback = Mock()
- def _test_txn(txn: LoggingTransaction) -> None:
- txn.call_after(after_callback, 123, 456, extra=789)
- txn.call_on_exception(exception_callback, 987, 654, extra=321)
- d.cancel()
- d = defer.ensureDeferred(
- self.db_pool.runInteraction("test_transaction", _test_txn)
- )
- self.get_failure(d, CancelledError)
- after_callback.assert_called_once_with(123, 456, extra=789)
- exception_callback.assert_not_called()
- def test_exception_callback(self) -> None:
- """Test that the exception callback is called when a transaction fails."""
- d: "Deferred[None]"
- after_callback = Mock()
- exception_callback = Mock()
- def _test_txn(txn: LoggingTransaction) -> None:
- txn.call_after(after_callback, 123, 456, extra=789)
- txn.call_on_exception(exception_callback, 987, 654, extra=321)
- d.cancel()
- # Simulate a retryable failure on every attempt.
- raise self.db_pool.engine.module.OperationalError()
- d = defer.ensureDeferred(
- self.db_pool.runInteraction("test_transaction", _test_txn)
- )
- self.get_failure(d, CancelledError)
- after_callback.assert_not_called()
- exception_callback.assert_has_calls(
- [
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- call(987, 654, extra=321),
- ]
- )
- self.assertEqual(exception_callback.call_count, 6) # no additional calls
|