Browse Source

Convert some util functions to async (#8035)

Patrick Cloke 3 years ago
parent
commit
fe6cfc80ec
4 changed files with 39 additions and 61 deletions
  1. 1 0
      changelog.d/8035.misc
  2. 21 18
      synapse/util/metrics.py
  3. 6 10
      synapse/util/retryutils.py
  4. 11 33
      tests/util/test_retryutils.py

+ 1 - 0
changelog.d/8035.misc

@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.

+ 21 - 18
synapse/util/metrics.py

@@ -13,14 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import inspect
 import logging
 from functools import wraps
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 from synapse.logging.context import LoggingContext, current_context
 from synapse.metrics import InFlightGauge
 
@@ -62,25 +59,31 @@ in_flight = InFlightGauge(
 
 
 def measure_func(name=None):
-    def wrapper(func):
-        block_name = func.__name__ if name is None else name
+    """
+    Used to decorate an async function with a `Measure` context manager.
+
+    Usage:
 
-        if inspect.iscoroutinefunction(func):
+    @measure_func()
+    async def foo(...):
+        ...
 
-            @wraps(func)
-            async def measured_func(self, *args, **kwargs):
-                with Measure(self.clock, block_name):
-                    r = await func(self, *args, **kwargs)
-                return r
+    Which is analogous to:
 
-        else:
+    async def foo(...):
+        with Measure(...):
+            ...
+
+    """
+
+    def wrapper(func):
+        block_name = func.__name__ if name is None else name
 
-            @wraps(func)
-            @defer.inlineCallbacks
-            def measured_func(self, *args, **kwargs):
-                with Measure(self.clock, block_name):
-                    r = yield func(self, *args, **kwargs)
-                return r
+        @wraps(func)
+        async def measured_func(self, *args, **kwargs):
+            with Measure(self.clock, block_name):
+                r = await func(self, *args, **kwargs)
+            return r
 
         return measured_func
 

+ 6 - 10
synapse/util/retryutils.py

@@ -15,8 +15,6 @@
 import logging
 import random
 
-from twisted.internet import defer
-
 import synapse.logging.context
 from synapse.api.errors import CodeMessageException
 
@@ -54,8 +52,7 @@ class NotRetryingDestination(Exception):
         self.destination = destination
 
 
-@defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
+async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
     """For a given destination check if we have previously failed to
     send a request there and are waiting before retrying the destination.
     If we are not ready to retry the destination, this will raise a
@@ -73,9 +70,9 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
     Example usage:
 
         try:
-            limiter = yield get_retry_limiter(destination, clock, store)
+            limiter = await get_retry_limiter(destination, clock, store)
             with limiter:
-                response = yield do_request()
+                response = await do_request()
         except NotRetryingDestination:
             # We aren't ready to retry that destination.
             raise
@@ -83,7 +80,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
     failure_ts = None
     retry_last_ts, retry_interval = (0, 0)
 
-    retry_timings = yield store.get_destination_retry_timings(destination)
+    retry_timings = await store.get_destination_retry_timings(destination)
 
     if retry_timings:
         failure_ts = retry_timings["failure_ts"]
@@ -222,10 +219,9 @@ class RetryDestinationLimiter(object):
             if self.failure_ts is None:
                 self.failure_ts = retry_last_ts
 
-        @defer.inlineCallbacks
-        def store_retry_timings():
+        async def store_retry_timings():
             try:
-                yield self.store.set_destination_retry_timings(
+                await self.store.set_destination_retry_timings(
                     self.destination,
                     self.failure_ts,
                     retry_last_ts,

+ 11 - 33
tests/util/test_retryutils.py

@@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
     def test_new_destination(self):
         """A happy-path case with a new destination and a successful operation"""
         store = self.hs.get_datastore()
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        limiter = self.successResultOf(d)
+        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         # advance the clock a bit before making the request
         self.pump(1)
@@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
         with limiter:
             pass
 
-        d = store.get_destination_retry_timings("test_dest")
-        self.pump()
-        new_timings = self.successResultOf(d)
+        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertIsNone(new_timings)
 
     def test_limiter(self):
         """General test case which walks through the process of a failing request"""
         store = self.hs.get_datastore()
 
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        limiter = self.successResultOf(d)
+        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         self.pump(1)
         try:
@@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
         except AssertionError:
             pass
 
-        # wait for the update to land
-        self.pump()
-
-        d = store.get_destination_retry_timings("test_dest")
-        self.pump()
-        new_timings = self.successResultOf(d)
+        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertEqual(new_timings["failure_ts"], failure_ts)
         self.assertEqual(new_timings["retry_last_ts"], failure_ts)
         self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
 
         # now if we try again we should get a failure
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        self.failureResultOf(d, NotRetryingDestination)
+        self.get_failure(
+            get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+        )
 
         #
         # advance the clock and try again
         #
 
         self.pump(MIN_RETRY_INTERVAL)
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        limiter = self.successResultOf(d)
+        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         self.pump(1)
         try:
@@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
         except AssertionError:
             pass
 
-        # wait for the update to land
-        self.pump()
-
-        d = store.get_destination_retry_timings("test_dest")
-        self.pump()
-        new_timings = self.successResultOf(d)
+        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertEqual(new_timings["failure_ts"], failure_ts)
         self.assertEqual(new_timings["retry_last_ts"], retry_ts)
         self.assertGreaterEqual(
@@ -110,9 +92,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
         # one more go, with success
         #
         self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
-        d = get_retry_limiter("test_dest", self.clock, store)
-        self.pump()
-        limiter = self.successResultOf(d)
+        limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
 
         self.pump(1)
         with limiter:
@@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase):
         # wait for the update to land
         self.pump()
 
-        d = store.get_destination_retry_timings("test_dest")
-        self.pump()
-        new_timings = self.successResultOf(d)
+        new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
         self.assertIsNone(new_timings)