Browse Source

Fail test cases if they fail to await all awaitables (#8690)

Patrick Cloke 3 years ago
parent
commit
fd7c743445
3 changed files with 39 additions and 2 deletions
  1. 1 0
      changelog.d/8690.misc
  2. 33 1
      tests/test_utils/__init__.py
  3. 5 1
      tests/unittest.py

+ 1 - 0
changelog.d/8690.misc

@@ -0,0 +1 @@
+Fail tests if they do not await coroutines.

+ 33 - 1
tests/test_utils/__init__.py

@@ -17,8 +17,10 @@
 """
 Utilities for running the unit tests
 """
+import sys
+import warnings
 from asyncio import Future
-from typing import Any, Awaitable, TypeVar
+from typing import Any, Awaitable, Callable, TypeVar
 
 TV = TypeVar("TV")
 
@@ -48,3 +50,33 @@ def make_awaitable(result: Any) -> Awaitable[Any]:
     future = Future()  # type: ignore
     future.set_result(result)
     return future
+
+
+def setup_awaitable_errors() -> Callable[[], None]:
+    """
+    Convert warnings from a non-awaited coroutines into errors.
+    """
+    warnings.simplefilter("error", RuntimeWarning)
+
+    # unraisablehook was added in Python 3.8.
+    if not hasattr(sys, "unraisablehook"):
+        return lambda: None
+
+    # State shared between unraisablehook and check_for_unraisable_exceptions.
+    unraisable_exceptions = []
+    orig_unraisablehook = sys.unraisablehook  # type: ignore
+
+    def unraisablehook(unraisable):
+        unraisable_exceptions.append(unraisable.exc_value)
+
+    def cleanup():
+        """
+        A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
+        """
+        sys.unraisablehook = orig_unraisablehook  # type: ignore
+        if unraisable_exceptions:
+            raise unraisable_exceptions.pop()
+
+    sys.unraisablehook = unraisablehook  # type: ignore
+
+    return cleanup

+ 5 - 1
tests/unittest.py

@@ -54,7 +54,7 @@ from tests.server import (
     render,
     setup_test_homeserver,
 )
-from tests.test_utils import event_injection
+from tests.test_utils import event_injection, setup_awaitable_errors
 from tests.test_utils.logging_setup import setup_logging
 from tests.utils import default_config, setupdb
 
@@ -119,6 +119,10 @@ class TestCase(unittest.TestCase):
 
                 logging.getLogger().setLevel(level)
 
+            # Trial messes with the warnings configuration, thus this has to be
+            # done in the context of an individual TestCase.
+            self.addCleanup(setup_awaitable_errors())
+
             return orig()
 
         @around(self)