batching_queue.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import (
  16. Awaitable,
  17. Callable,
  18. Dict,
  19. Generic,
  20. Hashable,
  21. List,
  22. Set,
  23. Tuple,
  24. TypeVar,
  25. )
  26. from prometheus_client import Gauge
  27. from twisted.internet import defer
  28. from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
  29. from synapse.metrics.background_process_metrics import run_as_background_process
  30. from synapse.util import Clock
  31. logger = logging.getLogger(__name__)
  32. V = TypeVar("V")
  33. R = TypeVar("R")
  34. number_queued = Gauge(
  35. "synapse_util_batching_queue_number_queued",
  36. "The number of items waiting in the queue across all keys",
  37. labelnames=("name",),
  38. )
  39. number_in_flight = Gauge(
  40. "synapse_util_batching_queue_number_pending",
  41. "The number of items across all keys either being processed or waiting in a queue",
  42. labelnames=("name",),
  43. )
  44. number_of_keys = Gauge(
  45. "synapse_util_batching_queue_number_of_keys",
  46. "The number of distinct keys that have items queued",
  47. labelnames=("name",),
  48. )
  49. class BatchingQueue(Generic[V, R]):
  50. """A queue that batches up work, calling the provided processing function
  51. with all pending work (for a given key).
  52. The provided processing function will only be called once at a time for each
  53. key. It will be called the next reactor tick after `add_to_queue` has been
  54. called, and will keep being called until the queue has been drained (for the
  55. given key).
  56. If the processing function raises an exception then the exception is proxied
  57. through to the callers waiting on that batch of work.
  58. Note that the return value of `add_to_queue` will be the return value of the
  59. processing function that processed the given item. This means that the
  60. returned value will likely include data for other items that were in the
  61. batch.
  62. Args:
  63. name: A name for the queue, used for logging contexts and metrics.
  64. This must be unique, otherwise the metrics will be wrong.
  65. clock: The clock to use to schedule work.
  66. process_batch_callback: The callback to to be run to process a batch of
  67. work.
  68. """
  69. def __init__(
  70. self,
  71. name: str,
  72. clock: Clock,
  73. process_batch_callback: Callable[[List[V]], Awaitable[R]],
  74. ):
  75. self._name = name
  76. self._clock = clock
  77. # The set of keys currently being processed.
  78. self._processing_keys: Set[Hashable] = set()
  79. # The currently pending batch of values by key, with a Deferred to call
  80. # with the result of the corresponding `_process_batch_callback` call.
  81. self._next_values: Dict[Hashable, List[Tuple[V, defer.Deferred]]] = {}
  82. # The function to call with batches of values.
  83. self._process_batch_callback = process_batch_callback
  84. number_queued.labels(self._name).set_function(
  85. lambda: sum(len(q) for q in self._next_values.values())
  86. )
  87. number_of_keys.labels(self._name).set_function(lambda: len(self._next_values))
  88. self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name)
  89. async def add_to_queue(self, value: V, key: Hashable = ()) -> R:
  90. """Adds the value to the queue with the given key, returning the result
  91. of the processing function for the batch that included the given value.
  92. The optional `key` argument allows sharding the queue by some key. The
  93. queues will then be processed in parallel, i.e. the process batch
  94. function will be called in parallel with batched values from a single
  95. key.
  96. """
  97. # First we create a defer and add it and the value to the list of
  98. # pending items.
  99. d: defer.Deferred[R] = defer.Deferred()
  100. self._next_values.setdefault(key, []).append((value, d))
  101. # If we're not currently processing the key fire off a background
  102. # process to start processing.
  103. if key not in self._processing_keys:
  104. run_as_background_process(self._name, self._process_queue, key)
  105. with self._number_in_flight_metric.track_inprogress():
  106. return await make_deferred_yieldable(d)
  107. async def _process_queue(self, key: Hashable) -> None:
  108. """A background task to repeatedly pull things off the queue for the
  109. given key and call the `self._process_batch_callback` with the values.
  110. """
  111. if key in self._processing_keys:
  112. return
  113. try:
  114. self._processing_keys.add(key)
  115. while True:
  116. # We purposefully wait a reactor tick to allow us to batch
  117. # together requests that we're about to receive. A common
  118. # pattern is to call `add_to_queue` multiple times at once, and
  119. # deferring to the next reactor tick allows us to batch all of
  120. # those up.
  121. await self._clock.sleep(0)
  122. next_values = self._next_values.pop(key, [])
  123. if not next_values:
  124. # We've exhausted the queue.
  125. break
  126. try:
  127. values = [value for value, _ in next_values]
  128. results = await self._process_batch_callback(values)
  129. with PreserveLoggingContext():
  130. for _, deferred in next_values:
  131. deferred.callback(results)
  132. except Exception as e:
  133. with PreserveLoggingContext():
  134. for _, deferred in next_values:
  135. if deferred.called:
  136. continue
  137. deferred.errback(e)
  138. finally:
  139. self._processing_keys.discard(key)