test_database.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Callable, Tuple
  15. from unittest.mock import Mock, call
  16. from twisted.internet import defer
  17. from twisted.internet.defer import CancelledError, Deferred
  18. from twisted.test.proto_helpers import MemoryReactor
  19. from synapse.server import HomeServer
  20. from synapse.storage.database import (
  21. DatabasePool,
  22. LoggingDatabaseConnection,
  23. LoggingTransaction,
  24. make_tuple_comparison_clause,
  25. )
  26. from synapse.util import Clock
  27. from tests import unittest
  28. class TupleComparisonClauseTestCase(unittest.TestCase):
  29. def test_native_tuple_comparison(self) -> None:
  30. clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
  31. self.assertEqual(clause, "(a,b) > (?,?)")
  32. self.assertEqual(args, [1, 2])
  33. class ExecuteScriptTestCase(unittest.HomeserverTestCase):
  34. """Tests for `BaseDatabaseEngine.executescript` implementations."""
  35. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  36. self.store = hs.get_datastores().main
  37. self.db_pool: DatabasePool = self.store.db_pool
  38. self.get_success(
  39. self.db_pool.runInteraction(
  40. "create",
  41. lambda txn: txn.execute("CREATE TABLE foo (name TEXT PRIMARY KEY)"),
  42. )
  43. )
  44. def test_transaction(self) -> None:
  45. """Test that all statements are run in a single transaction."""
  46. def run(conn: LoggingDatabaseConnection) -> None:
  47. cur = conn.cursor(txn_name="test_transaction")
  48. self.db_pool.engine.executescript(
  49. cur,
  50. ";".join(
  51. [
  52. "INSERT INTO foo (name) VALUES ('transaction test')",
  53. # This next statement will fail. When `executescript` is not
  54. # transactional, the previous row will be observed later.
  55. "INSERT INTO foo (name) VALUES ('transaction test')",
  56. ]
  57. ),
  58. )
  59. self.get_failure(
  60. self.db_pool.runWithConnection(run),
  61. self.db_pool.engine.module.IntegrityError,
  62. )
  63. self.assertIsNone(
  64. self.get_success(
  65. self.db_pool.simple_select_one_onecol(
  66. "foo",
  67. keyvalues={"name": "transaction test"},
  68. retcol="name",
  69. allow_none=True,
  70. )
  71. ),
  72. "executescript is not running statements inside a transaction",
  73. )
  74. def test_commit(self) -> None:
  75. """Test that the script transaction remains open and can be committed."""
  76. def run(conn: LoggingDatabaseConnection) -> None:
  77. cur = conn.cursor(txn_name="test_commit")
  78. self.db_pool.engine.executescript(
  79. cur, "INSERT INTO foo (name) VALUES ('commit test')"
  80. )
  81. cur.execute("COMMIT")
  82. self.get_success(self.db_pool.runWithConnection(run))
  83. self.assertIsNotNone(
  84. self.get_success(
  85. self.db_pool.simple_select_one_onecol(
  86. "foo",
  87. keyvalues={"name": "commit test"},
  88. retcol="name",
  89. allow_none=True,
  90. )
  91. ),
  92. )
  93. def test_rollback(self) -> None:
  94. """Test that the script transaction remains open and can be rolled back."""
  95. def run(conn: LoggingDatabaseConnection) -> None:
  96. cur = conn.cursor(txn_name="test_rollback")
  97. self.db_pool.engine.executescript(
  98. cur, "INSERT INTO foo (name) VALUES ('rollback test')"
  99. )
  100. cur.execute("ROLLBACK")
  101. self.get_success(self.db_pool.runWithConnection(run))
  102. self.assertIsNone(
  103. self.get_success(
  104. self.db_pool.simple_select_one_onecol(
  105. "foo",
  106. keyvalues={"name": "rollback test"},
  107. retcol="name",
  108. allow_none=True,
  109. )
  110. ),
  111. "executescript is not leaving the script transaction open",
  112. )
  113. class CallbacksTestCase(unittest.HomeserverTestCase):
  114. """Tests for transaction callbacks."""
  115. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  116. self.store = hs.get_datastores().main
  117. self.db_pool: DatabasePool = self.store.db_pool
  118. def _run_interaction(
  119. self, func: Callable[[LoggingTransaction], object]
  120. ) -> Tuple[Mock, Mock]:
  121. """Run the given function in a database transaction, with callbacks registered.
  122. Args:
  123. func: The function to be run in a transaction. The transaction will be
  124. retried if `func` raises an `OperationalError`.
  125. Returns:
  126. Two mocks, which were registered as an `after_callback` and an
  127. `exception_callback` respectively, on every transaction attempt.
  128. """
  129. after_callback = Mock()
  130. exception_callback = Mock()
  131. def _test_txn(txn: LoggingTransaction) -> None:
  132. txn.call_after(after_callback, 123, 456, extra=789)
  133. txn.call_on_exception(exception_callback, 987, 654, extra=321)
  134. func(txn)
  135. try:
  136. self.get_success_or_raise(
  137. self.db_pool.runInteraction("test_transaction", _test_txn)
  138. )
  139. except Exception:
  140. pass
  141. return after_callback, exception_callback
  142. def test_after_callback(self) -> None:
  143. """Test that the after callback is called when a transaction succeeds."""
  144. after_callback, exception_callback = self._run_interaction(lambda txn: None)
  145. after_callback.assert_called_once_with(123, 456, extra=789)
  146. exception_callback.assert_not_called()
  147. def test_exception_callback(self) -> None:
  148. """Test that the exception callback is called when a transaction fails."""
  149. _test_txn = Mock(side_effect=ZeroDivisionError)
  150. after_callback, exception_callback = self._run_interaction(_test_txn)
  151. after_callback.assert_not_called()
  152. exception_callback.assert_called_once_with(987, 654, extra=321)
  153. def test_failed_retry(self) -> None:
  154. """Test that the exception callback is called for every failed attempt."""
  155. # Always raise an `OperationalError`.
  156. _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
  157. after_callback, exception_callback = self._run_interaction(_test_txn)
  158. after_callback.assert_not_called()
  159. exception_callback.assert_has_calls(
  160. [
  161. call(987, 654, extra=321),
  162. call(987, 654, extra=321),
  163. call(987, 654, extra=321),
  164. call(987, 654, extra=321),
  165. call(987, 654, extra=321),
  166. call(987, 654, extra=321),
  167. ]
  168. )
  169. self.assertEqual(exception_callback.call_count, 6) # no additional calls
  170. def test_successful_retry(self) -> None:
  171. """Test callbacks for a failed transaction followed by a successful attempt."""
  172. # Raise an `OperationalError` on the first attempt only.
  173. _test_txn = Mock(
  174. side_effect=[self.db_pool.engine.module.OperationalError, None]
  175. )
  176. after_callback, exception_callback = self._run_interaction(_test_txn)
  177. # Calling both `after_callback`s when the first attempt failed is rather
  178. # surprising (#12184). Let's document the behaviour in a test.
  179. after_callback.assert_has_calls(
  180. [
  181. call(123, 456, extra=789),
  182. call(123, 456, extra=789),
  183. ]
  184. )
  185. self.assertEqual(after_callback.call_count, 2) # no additional calls
  186. exception_callback.assert_not_called()
  187. class CancellationTestCase(unittest.HomeserverTestCase):
  188. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  189. self.store = hs.get_datastores().main
  190. self.db_pool: DatabasePool = self.store.db_pool
  191. def test_after_callback(self) -> None:
  192. """Test that the after callback is called when a transaction succeeds."""
  193. d: "Deferred[None]"
  194. after_callback = Mock()
  195. exception_callback = Mock()
  196. def _test_txn(txn: LoggingTransaction) -> None:
  197. txn.call_after(after_callback, 123, 456, extra=789)
  198. txn.call_on_exception(exception_callback, 987, 654, extra=321)
  199. d.cancel()
  200. d = defer.ensureDeferred(
  201. self.db_pool.runInteraction("test_transaction", _test_txn)
  202. )
  203. self.get_failure(d, CancelledError)
  204. after_callback.assert_called_once_with(123, 456, extra=789)
  205. exception_callback.assert_not_called()
  206. def test_exception_callback(self) -> None:
  207. """Test that the exception callback is called when a transaction fails."""
  208. d: "Deferred[None]"
  209. after_callback = Mock()
  210. exception_callback = Mock()
  211. def _test_txn(txn: LoggingTransaction) -> None:
  212. txn.call_after(after_callback, 123, 456, extra=789)
  213. txn.call_on_exception(exception_callback, 987, 654, extra=321)
  214. d.cancel()
  215. # Simulate a retryable failure on every attempt.
  216. raise self.db_pool.engine.module.OperationalError()
  217. d = defer.ensureDeferred(
  218. self.db_pool.runInteraction("test_transaction", _test_txn)
  219. )
  220. self.get_failure(d, CancelledError)
  221. after_callback.assert_not_called()
  222. exception_callback.assert_has_calls(
  223. [
  224. call(987, 654, extra=321),
  225. call(987, 654, extra=321),
  226. call(987, 654, extra=321),
  227. call(987, 654, extra=321),
  228. call(987, 654, extra=321),
  229. call(987, 654, extra=321),
  230. ]
  231. )
  232. self.assertEqual(exception_callback.call_count, 6) # no additional calls