|
@@ -22,20 +22,33 @@ them.
|
|
|
|
|
|
See doc/log_contexts.rst for details on how this works.
|
|
|
"""
|
|
|
-import inspect
|
|
|
import logging
|
|
|
import threading
|
|
|
import typing
|
|
|
import warnings
|
|
|
-from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
|
|
|
+from types import TracebackType
|
|
|
+from typing import (
|
|
|
+ TYPE_CHECKING,
|
|
|
+ Any,
|
|
|
+ Awaitable,
|
|
|
+ Callable,
|
|
|
+ Optional,
|
|
|
+ Tuple,
|
|
|
+ Type,
|
|
|
+ TypeVar,
|
|
|
+ Union,
|
|
|
+ overload,
|
|
|
+)
|
|
|
|
|
|
import attr
|
|
|
from typing_extensions import Literal
|
|
|
|
|
|
from twisted.internet import defer, threads
|
|
|
+from twisted.python.threadpool import ThreadPool
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from synapse.logging.scopecontextmanager import _LogContextScope
|
|
|
+ from synapse.types import ISynapseReactor
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -66,7 +79,7 @@ except Exception:
|
|
|
|
|
|
|
|
|
# a hook which can be set during testing to assert that we aren't abusing logcontexts.
|
|
|
-def logcontext_error(msg: str):
|
|
|
+def logcontext_error(msg: str) -> None:
|
|
|
logger.warning(msg)
|
|
|
|
|
|
|
|
@@ -223,22 +236,19 @@ class _Sentinel:
|
|
|
def __str__(self) -> str:
|
|
|
return "sentinel"
|
|
|
|
|
|
- def copy_to(self, record):
|
|
|
- pass
|
|
|
-
|
|
|
- def start(self, rusage: "Optional[resource.struct_rusage]"):
|
|
|
+ def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
|
|
|
pass
|
|
|
|
|
|
- def stop(self, rusage: "Optional[resource.struct_rusage]"):
|
|
|
+ def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
|
|
|
pass
|
|
|
|
|
|
- def add_database_transaction(self, duration_sec):
|
|
|
+ def add_database_transaction(self, duration_sec: float) -> None:
|
|
|
pass
|
|
|
|
|
|
- def add_database_scheduled(self, sched_sec):
|
|
|
+ def add_database_scheduled(self, sched_sec: float) -> None:
|
|
|
pass
|
|
|
|
|
|
- def record_event_fetch(self, event_count):
|
|
|
+ def record_event_fetch(self, event_count: int) -> None:
|
|
|
pass
|
|
|
|
|
|
def __bool__(self) -> Literal[False]:
|
|
@@ -379,7 +389,12 @@ class LoggingContext:
|
|
|
)
|
|
|
return self
|
|
|
|
|
|
- def __exit__(self, type, value, traceback) -> None:
|
|
|
+ def __exit__(
|
|
|
+ self,
|
|
|
+ type: Optional[Type[BaseException]],
|
|
|
+ value: Optional[BaseException],
|
|
|
+ traceback: Optional[TracebackType],
|
|
|
+ ) -> None:
|
|
|
"""Restore the logging context in thread local storage to the state it
|
|
|
was before this context was entered.
|
|
|
Returns:
|
|
@@ -399,17 +414,6 @@ class LoggingContext:
|
|
|
# recorded against the correct metrics.
|
|
|
self.finished = True
|
|
|
|
|
|
- def copy_to(self, record) -> None:
|
|
|
- """Copy logging fields from this context to a log record or
|
|
|
- another LoggingContext
|
|
|
- """
|
|
|
-
|
|
|
- # we track the current request
|
|
|
- record.request = self.request
|
|
|
-
|
|
|
- # we also track the current scope:
|
|
|
- record.scope = self.scope
|
|
|
-
|
|
|
def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
|
|
|
"""
|
|
|
Record that this logcontext is currently running.
|
|
@@ -626,7 +630,12 @@ class PreserveLoggingContext:
|
|
|
def __enter__(self) -> None:
|
|
|
self._old_context = set_current_context(self._new_context)
|
|
|
|
|
|
- def __exit__(self, type, value, traceback) -> None:
|
|
|
+ def __exit__(
|
|
|
+ self,
|
|
|
+ type: Optional[Type[BaseException]],
|
|
|
+ value: Optional[BaseException],
|
|
|
+ traceback: Optional[TracebackType],
|
|
|
+ ) -> None:
|
|
|
context = set_current_context(self._old_context)
|
|
|
|
|
|
if context != self._new_context:
|
|
@@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
|
|
|
)
|
|
|
|
|
|
|
|
|
-def preserve_fn(f):
|
|
|
+R = TypeVar("R")
|
|
|
+
|
|
|
+
|
|
|
+@overload
|
|
|
+def preserve_fn( # type: ignore[misc]
|
|
|
+ f: Callable[..., Awaitable[R]],
|
|
|
+) -> Callable[..., "defer.Deferred[R]"]:
|
|
|
+ # The `type: ignore[misc]` above suppresses
|
|
|
+ # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
|
|
|
+ ...
|
|
|
+
|
|
|
+
|
|
|
+@overload
|
|
|
+def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
|
|
|
+ ...
|
|
|
+
|
|
|
+
|
|
|
+def preserve_fn(
|
|
|
+ f: Union[
|
|
|
+ Callable[..., R],
|
|
|
+ Callable[..., Awaitable[R]],
|
|
|
+ ]
|
|
|
+) -> Callable[..., "defer.Deferred[R]"]:
|
|
|
"""Function decorator which wraps the function with run_in_background"""
|
|
|
|
|
|
- def g(*args, **kwargs):
|
|
|
+ def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
|
|
|
return run_in_background(f, *args, **kwargs)
|
|
|
|
|
|
return g
|
|
|
|
|
|
|
|
|
-def run_in_background(f, *args, **kwargs) -> defer.Deferred:
|
|
|
+@overload
|
|
|
+def run_in_background( # type: ignore[misc]
|
|
|
+ f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
|
|
|
+) -> "defer.Deferred[R]":
|
|
|
+ # The `type: ignore[misc]` above suppresses
|
|
|
+ # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
|
|
|
+ ...
|
|
|
+
|
|
|
+
|
|
|
+@overload
|
|
|
+def run_in_background(
|
|
|
+ f: Callable[..., R], *args: Any, **kwargs: Any
|
|
|
+) -> "defer.Deferred[R]":
|
|
|
+ ...
|
|
|
+
|
|
|
+
|
|
|
+def run_in_background(
|
|
|
+ f: Union[
|
|
|
+ Callable[..., R],
|
|
|
+ Callable[..., Awaitable[R]],
|
|
|
+ ],
|
|
|
+ *args: Any,
|
|
|
+ **kwargs: Any,
|
|
|
+) -> "defer.Deferred[R]":
|
|
|
"""Calls a function, ensuring that the current context is restored after
|
|
|
return from the function, and that the sentinel context is set once the
|
|
|
deferred returned by the function completes.
|
|
@@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
|
|
|
# At this point we should have a Deferred, if not then f was a synchronous
|
|
|
# function, wrap it in a Deferred for consistency.
|
|
|
if not isinstance(res, defer.Deferred):
|
|
|
+ # `res` is not a `Deferred` and not a `Coroutine`.
|
|
|
+ # There are no other types of `Awaitable`s we expect to encounter in Synapse.
|
|
|
+ assert not isinstance(res, Awaitable)
|
|
|
+
|
|
|
return defer.succeed(res)
|
|
|
|
|
|
if res.called and not res.paused:
|
|
@@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
|
|
|
return res
|
|
|
|
|
|
|
|
|
-def make_deferred_yieldable(deferred):
|
|
|
- """Given a deferred (or coroutine), make it follow the Synapse logcontext
|
|
|
- rules:
|
|
|
+T = TypeVar("T")
|
|
|
+
|
|
|
|
|
|
- If the deferred has completed (or is not actually a Deferred), essentially
|
|
|
- does nothing (just returns another completed deferred with the
|
|
|
- result/failure).
|
|
|
+def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
|
|
+ """Given a deferred, make it follow the Synapse logcontext rules:
|
|
|
+
|
|
|
+ If the deferred has completed, essentially does nothing (just returns another
|
|
|
+ completed deferred with the result/failure).
|
|
|
|
|
|
If the deferred has not yet completed, resets the logcontext before
|
|
|
returning a deferred. Then, when the deferred completes, restores the
|
|
@@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred):
|
|
|
|
|
|
(This is more-or-less the opposite operation to run_in_background.)
|
|
|
"""
|
|
|
- if inspect.isawaitable(deferred):
|
|
|
- # If we're given a coroutine we convert it to a deferred so that we
|
|
|
- # run it and find out if it immediately finishes, it it does then we
|
|
|
- # don't need to fiddle with log contexts at all and can return
|
|
|
- # immediately.
|
|
|
- deferred = defer.ensureDeferred(deferred)
|
|
|
-
|
|
|
- if not isinstance(deferred, defer.Deferred):
|
|
|
- return deferred
|
|
|
-
|
|
|
if deferred.called and not deferred.paused:
|
|
|
# it looks like this deferred is ready to run any callbacks we give it
|
|
|
# immediately. We may as well optimise out the logcontext faffery.
|
|
@@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
|
|
|
return result
|
|
|
|
|
|
|
|
|
-def defer_to_thread(reactor, f, *args, **kwargs):
|
|
|
+def defer_to_thread(
|
|
|
+ reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
|
|
|
+) -> "defer.Deferred[R]":
|
|
|
"""
|
|
|
Calls the function `f` using a thread from the reactor's default threadpool and
|
|
|
returns the result as a Deferred.
|
|
@@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
|
|
|
return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
|
|
|
|
|
|
|
|
|
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
|
|
|
+def defer_to_threadpool(
|
|
|
+ reactor: "ISynapseReactor",
|
|
|
+ threadpool: ThreadPool,
|
|
|
+ f: Callable[..., R],
|
|
|
+ *args: Any,
|
|
|
+ **kwargs: Any,
|
|
|
+) -> "defer.Deferred[R]":
|
|
|
"""
|
|
|
A wrapper for twisted.internet.threads.deferToThreadpool, which handles
|
|
|
logcontexts correctly.
|
|
@@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
|
|
|
assert isinstance(curr_context, LoggingContext)
|
|
|
parent_context = curr_context
|
|
|
|
|
|
- def g():
|
|
|
+ def g() -> R:
|
|
|
with LoggingContext(str(curr_context), parent_context=parent_context):
|
|
|
return f(*args, **kwargs)
|
|
|
|