test_batching_queue.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Tuple
  15. from prometheus_client import Gauge
  16. from twisted.internet import defer
  17. from synapse.logging.context import make_deferred_yieldable
  18. from synapse.util.batching_queue import (
  19. BatchingQueue,
  20. number_in_flight,
  21. number_of_keys,
  22. number_queued,
  23. )
  24. from tests.server import get_clock
  25. from tests.unittest import TestCase
  26. class BatchingQueueTestCase(TestCase):
  27. def setUp(self) -> None:
  28. self.clock, hs_clock = get_clock()
  29. # We ensure that we remove any existing metrics for "test_queue".
  30. try:
  31. number_queued.remove("test_queue")
  32. number_of_keys.remove("test_queue")
  33. number_in_flight.remove("test_queue")
  34. except KeyError:
  35. pass
  36. self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
  37. self.queue: BatchingQueue[str, str] = BatchingQueue(
  38. "test_queue", hs_clock, self._process_queue
  39. )
  40. async def _process_queue(self, values: List[str]) -> str:
  41. d: "defer.Deferred[str]" = defer.Deferred()
  42. self._pending_calls.append((values, d))
  43. return await make_deferred_yieldable(d)
  44. def _get_sample_with_name(self, metric: Gauge, name: str) -> float:
  45. """For a prometheus metric get the value of the sample that has a
  46. matching "name" label.
  47. """
  48. for sample in next(iter(metric.collect())).samples:
  49. if sample.labels.get("name") == name:
  50. return sample.value
  51. self.fail("Found no matching sample")
  52. def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None:
  53. """Assert that the metrics are correct"""
  54. sample = self._get_sample_with_name(number_queued, self.queue._name)
  55. self.assertEqual(
  56. sample,
  57. queued,
  58. "number_queued",
  59. )
  60. sample = self._get_sample_with_name(number_of_keys, self.queue._name)
  61. self.assertEqual(sample, keys, "number_of_keys")
  62. sample = self._get_sample_with_name(number_in_flight, self.queue._name)
  63. self.assertEqual(
  64. sample,
  65. in_flight,
  66. "number_in_flight",
  67. )
  68. def test_simple(self) -> None:
  69. """Tests the basic case of calling `add_to_queue` once and having
  70. `_process_queue` return.
  71. """
  72. self.assertFalse(self._pending_calls)
  73. queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo"))
  74. self._assert_metrics(queued=1, keys=1, in_flight=1)
  75. # The queue should wait a reactor tick before calling the processing
  76. # function.
  77. self.assertFalse(self._pending_calls)
  78. self.assertFalse(queue_d.called)
  79. # We should see a call to `_process_queue` after a reactor tick.
  80. self.clock.pump([0])
  81. self.assertEqual(len(self._pending_calls), 1)
  82. self.assertEqual(self._pending_calls[0][0], ["foo"])
  83. self.assertFalse(queue_d.called)
  84. self._assert_metrics(queued=0, keys=0, in_flight=1)
  85. # Return value of the `_process_queue` should be propagated back.
  86. self._pending_calls.pop()[1].callback("bar")
  87. self.assertEqual(self.successResultOf(queue_d), "bar")
  88. self._assert_metrics(queued=0, keys=0, in_flight=0)
  89. def test_batching(self) -> None:
  90. """Test that multiple calls at the same time get batched up into one
  91. call to `_process_queue`.
  92. """
  93. self.assertFalse(self._pending_calls)
  94. queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
  95. queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
  96. self._assert_metrics(queued=2, keys=1, in_flight=2)
  97. self.clock.pump([0])
  98. # We should see only *one* call to `_process_queue`
  99. self.assertEqual(len(self._pending_calls), 1)
  100. self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"])
  101. self.assertFalse(queue_d1.called)
  102. self.assertFalse(queue_d2.called)
  103. self._assert_metrics(queued=0, keys=0, in_flight=2)
  104. # Return value of the `_process_queue` should be propagated back to both.
  105. self._pending_calls.pop()[1].callback("bar")
  106. self.assertEqual(self.successResultOf(queue_d1), "bar")
  107. self.assertEqual(self.successResultOf(queue_d2), "bar")
  108. self._assert_metrics(queued=0, keys=0, in_flight=0)
  109. def test_queuing(self) -> None:
  110. """Test that we queue up requests while a `_process_queue` is being
  111. called.
  112. """
  113. self.assertFalse(self._pending_calls)
  114. queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
  115. self.clock.pump([0])
  116. self.assertEqual(len(self._pending_calls), 1)
  117. # We queue up work after the process function has been called, testing
  118. # that they get correctly queued up.
  119. queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2"))
  120. queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3"))
  121. # We should see only *one* call to `_process_queue`
  122. self.assertEqual(len(self._pending_calls), 1)
  123. self.assertEqual(self._pending_calls[0][0], ["foo1"])
  124. self.assertFalse(queue_d1.called)
  125. self.assertFalse(queue_d2.called)
  126. self.assertFalse(queue_d3.called)
  127. self._assert_metrics(queued=2, keys=1, in_flight=3)
  128. # Return value of the `_process_queue` should be propagated back to the
  129. # first.
  130. self._pending_calls.pop()[1].callback("bar1")
  131. self.assertEqual(self.successResultOf(queue_d1), "bar1")
  132. self.assertFalse(queue_d2.called)
  133. self.assertFalse(queue_d3.called)
  134. self._assert_metrics(queued=2, keys=1, in_flight=2)
  135. # We should now see a second call to `_process_queue`
  136. self.clock.pump([0])
  137. self.assertEqual(len(self._pending_calls), 1)
  138. self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"])
  139. self.assertFalse(queue_d2.called)
  140. self.assertFalse(queue_d3.called)
  141. self._assert_metrics(queued=0, keys=0, in_flight=2)
  142. # Return value of the `_process_queue` should be propagated back to the
  143. # second.
  144. self._pending_calls.pop()[1].callback("bar2")
  145. self.assertEqual(self.successResultOf(queue_d2), "bar2")
  146. self.assertEqual(self.successResultOf(queue_d3), "bar2")
  147. self._assert_metrics(queued=0, keys=0, in_flight=0)
  148. def test_different_keys(self) -> None:
  149. """Test that calls to different keys get processed in parallel."""
  150. self.assertFalse(self._pending_calls)
  151. queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1))
  152. self.clock.pump([0])
  153. queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2))
  154. self.clock.pump([0])
  155. # We queue up another item with key=2 to check that we will keep taking
  156. # things off the queue.
  157. queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3", key=2))
  158. # We should see two calls to `_process_queue`
  159. self.assertEqual(len(self._pending_calls), 2)
  160. self.assertEqual(self._pending_calls[0][0], ["foo1"])
  161. self.assertEqual(self._pending_calls[1][0], ["foo2"])
  162. self.assertFalse(queue_d1.called)
  163. self.assertFalse(queue_d2.called)
  164. self.assertFalse(queue_d3.called)
  165. self._assert_metrics(queued=1, keys=1, in_flight=3)
  166. # Return value of the `_process_queue` should be propagated back to the
  167. # first.
  168. self._pending_calls.pop(0)[1].callback("bar1")
  169. self.assertEqual(self.successResultOf(queue_d1), "bar1")
  170. self.assertFalse(queue_d2.called)
  171. self.assertFalse(queue_d3.called)
  172. self._assert_metrics(queued=1, keys=1, in_flight=2)
  173. # Return value of the `_process_queue` should be propagated back to the
  174. # second.
  175. self._pending_calls.pop()[1].callback("bar2")
  176. self.assertEqual(self.successResultOf(queue_d2), "bar2")
  177. self.assertFalse(queue_d3.called)
  178. # We should now see a call `_pending_calls` for `foo3`
  179. self.clock.pump([0])
  180. self.assertEqual(len(self._pending_calls), 1)
  181. self.assertEqual(self._pending_calls[0][0], ["foo3"])
  182. self.assertFalse(queue_d3.called)
  183. self._assert_metrics(queued=0, keys=0, in_flight=1)
  184. # Return value of the `_process_queue` should be propagated back to the
  185. # third deferred.
  186. self._pending_calls.pop()[1].callback("bar4")
  187. self.assertEqual(self.successResultOf(queue_d3), "bar4")
  188. self._assert_metrics(queued=0, keys=0, in_flight=0)