test_database.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Callable, Tuple
  15. from unittest.mock import Mock, call
  16. from twisted.internet import defer
  17. from twisted.internet.defer import CancelledError, Deferred
  18. from twisted.test.proto_helpers import MemoryReactor
  19. from synapse.server import HomeServer
  20. from synapse.storage.database import (
  21. DatabasePool,
  22. LoggingTransaction,
  23. make_tuple_comparison_clause,
  24. )
  25. from synapse.util import Clock
  26. from tests import unittest
  27. class TupleComparisonClauseTestCase(unittest.TestCase):
  28. def test_native_tuple_comparison(self) -> None:
  29. clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
  30. self.assertEqual(clause, "(a,b) > (?,?)")
  31. self.assertEqual(args, [1, 2])
  32. class CallbacksTestCase(unittest.HomeserverTestCase):
  33. """Tests for transaction callbacks."""
  34. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  35. self.store = hs.get_datastores().main
  36. self.db_pool: DatabasePool = self.store.db_pool
  37. def _run_interaction(
  38. self, func: Callable[[LoggingTransaction], object]
  39. ) -> Tuple[Mock, Mock]:
  40. """Run the given function in a database transaction, with callbacks registered.
  41. Args:
  42. func: The function to be run in a transaction. The transaction will be
  43. retried if `func` raises an `OperationalError`.
  44. Returns:
  45. Two mocks, which were registered as an `after_callback` and an
  46. `exception_callback` respectively, on every transaction attempt.
  47. """
  48. after_callback = Mock()
  49. exception_callback = Mock()
  50. def _test_txn(txn: LoggingTransaction) -> None:
  51. txn.call_after(after_callback, 123, 456, extra=789)
  52. txn.call_on_exception(exception_callback, 987, 654, extra=321)
  53. func(txn)
  54. try:
  55. self.get_success_or_raise(
  56. self.db_pool.runInteraction("test_transaction", _test_txn)
  57. )
  58. except Exception:
  59. pass
  60. return after_callback, exception_callback
  61. def test_after_callback(self) -> None:
  62. """Test that the after callback is called when a transaction succeeds."""
  63. after_callback, exception_callback = self._run_interaction(lambda txn: None)
  64. after_callback.assert_called_once_with(123, 456, extra=789)
  65. exception_callback.assert_not_called()
  66. def test_exception_callback(self) -> None:
  67. """Test that the exception callback is called when a transaction fails."""
  68. _test_txn = Mock(side_effect=ZeroDivisionError)
  69. after_callback, exception_callback = self._run_interaction(_test_txn)
  70. after_callback.assert_not_called()
  71. exception_callback.assert_called_once_with(987, 654, extra=321)
  72. def test_failed_retry(self) -> None:
  73. """Test that the exception callback is called for every failed attempt."""
  74. # Always raise an `OperationalError`.
  75. _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
  76. after_callback, exception_callback = self._run_interaction(_test_txn)
  77. after_callback.assert_not_called()
  78. exception_callback.assert_has_calls(
  79. [
  80. call(987, 654, extra=321),
  81. call(987, 654, extra=321),
  82. call(987, 654, extra=321),
  83. call(987, 654, extra=321),
  84. call(987, 654, extra=321),
  85. call(987, 654, extra=321),
  86. ]
  87. )
  88. self.assertEqual(exception_callback.call_count, 6) # no additional calls
  89. def test_successful_retry(self) -> None:
  90. """Test callbacks for a failed transaction followed by a successful attempt."""
  91. # Raise an `OperationalError` on the first attempt only.
  92. _test_txn = Mock(
  93. side_effect=[self.db_pool.engine.module.OperationalError, None]
  94. )
  95. after_callback, exception_callback = self._run_interaction(_test_txn)
  96. # Calling both `after_callback`s when the first attempt failed is rather
  97. # surprising (#12184). Let's document the behaviour in a test.
  98. after_callback.assert_has_calls(
  99. [
  100. call(123, 456, extra=789),
  101. call(123, 456, extra=789),
  102. ]
  103. )
  104. self.assertEqual(after_callback.call_count, 2) # no additional calls
  105. exception_callback.assert_not_called()
  106. class CancellationTestCase(unittest.HomeserverTestCase):
  107. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  108. self.store = hs.get_datastores().main
  109. self.db_pool: DatabasePool = self.store.db_pool
  110. def test_after_callback(self) -> None:
  111. """Test that the after callback is called when a transaction succeeds."""
  112. d: "Deferred[None]"
  113. after_callback = Mock()
  114. exception_callback = Mock()
  115. def _test_txn(txn: LoggingTransaction) -> None:
  116. txn.call_after(after_callback, 123, 456, extra=789)
  117. txn.call_on_exception(exception_callback, 987, 654, extra=321)
  118. d.cancel()
  119. d = defer.ensureDeferred(
  120. self.db_pool.runInteraction("test_transaction", _test_txn)
  121. )
  122. self.get_failure(d, CancelledError)
  123. after_callback.assert_called_once_with(123, 456, extra=789)
  124. exception_callback.assert_not_called()
  125. def test_exception_callback(self) -> None:
  126. """Test that the exception callback is called when a transaction fails."""
  127. d: "Deferred[None]"
  128. after_callback = Mock()
  129. exception_callback = Mock()
  130. def _test_txn(txn: LoggingTransaction) -> None:
  131. txn.call_after(after_callback, 123, 456, extra=789)
  132. txn.call_on_exception(exception_callback, 987, 654, extra=321)
  133. d.cancel()
  134. # Simulate a retryable failure on every attempt.
  135. raise self.db_pool.engine.module.OperationalError()
  136. d = defer.ensureDeferred(
  137. self.db_pool.runInteraction("test_transaction", _test_txn)
  138. )
  139. self.get_failure(d, CancelledError)
  140. after_callback.assert_not_called()
  141. exception_callback.assert_has_calls(
  142. [
  143. call(987, 654, extra=321),
  144. call(987, 654, extra=321),
  145. call(987, 654, extra=321),
  146. call(987, 654, extra=321),
  147. call(987, 654, extra=321),
  148. call(987, 654, extra=321),
  149. ]
  150. )
  151. self.assertEqual(exception_callback.call_count, 6) # no additional calls