Browse Source

Add missing type hints to `synapse.logging.context` (#11556)

Sean Quah 2 years ago
parent
commit
0147b3de20

+ 1 - 0
changelog.d/11556.misc

@@ -0,0 +1 @@
+Add missing type hints to `synapse.logging.context`.

+ 3 - 0
mypy.ini

@@ -167,6 +167,9 @@ disallow_untyped_defs = True
 [mypy-synapse.http.server]
 disallow_untyped_defs = True
 
+[mypy-synapse.logging.context]
+disallow_untyped_defs = True
+
 [mypy-synapse.metrics.*]
 disallow_untyped_defs = True
 

+ 5 - 4
stubs/txredisapi.pyi

@@ -17,11 +17,12 @@
 from typing import Any, List, Optional, Type, Union
 
 from twisted.internet import protocol
+from twisted.internet.defer import Deferred
 
 class RedisProtocol(protocol.Protocol):
     def publish(self, channel: str, message: bytes): ...
-    async def ping(self) -> None: ...
-    async def set(
+    def ping(self) -> "Deferred[None]": ...
+    def set(
         self,
         key: str,
         value: Any,
@@ -29,8 +30,8 @@ class RedisProtocol(protocol.Protocol):
         pexpire: Optional[int] = None,
         only_if_not_exists: bool = False,
         only_if_exists: bool = False,
-    ) -> None: ...
-    async def get(self, key: str) -> Any: ...
+    ) -> "Deferred[None]": ...
+    def get(self, key: str) -> "Deferred[Any]": ...
 
 class SubscriberProtocol(RedisProtocol):
     def __init__(self, *args, **kwargs): ...

+ 4 - 5
synapse/federation/federation_server.py

@@ -30,7 +30,6 @@ from typing import (
 
 from prometheus_client import Counter, Gauge, Histogram
 
-from twisted.internet import defer
 from twisted.internet.abstract import isIPAddress
 from twisted.python import failure
 
@@ -67,7 +66,7 @@ from synapse.replication.http.federation import (
 from synapse.storage.databases.main.lock import Lock
 from synapse.types import JsonDict, get_domain_from_id
 from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer, concurrently_execute
+from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import parse_server_name
 
@@ -360,13 +359,13 @@ class FederationServer(FederationBase):
         # want to block things like to device messages from reaching clients
         # behind the potentially expensive handling of PDUs.
         pdu_results, _ = await make_deferred_yieldable(
-            defer.gatherResults(
-                [
+            gather_results(
+                (
                     run_in_background(
                         self._handle_pdus_in_txn, origin, transaction, request_time
                     ),
                     run_in_background(self._handle_edus_in_txn, origin, transaction),
-                ],
+                ),
                 consumeErrors=True,
             ).addErrback(unwrapFirstError)
         )

+ 11 - 8
synapse/handlers/federation.py

@@ -360,31 +360,34 @@ class FederationHandler:
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
         resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
-        states = await make_deferred_yieldable(
+        states_list = await make_deferred_yieldable(
             defer.gatherResults(
                 [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
             )
         )
 
-        # dict[str, dict[tuple, str]], a map from event_id to state map of
-        # event_ids.
-        states = dict(zip(event_ids, [s.state for s in states]))
+        # A map from event_id to state map of event_ids.
+        state_ids: Dict[str, StateMap[str]] = dict(
+            zip(event_ids, [s.state for s in states_list])
+        )
 
         state_map = await self.store.get_events(
-            [e_id for ids in states.values() for e_id in ids.values()],
+            [e_id for ids in state_ids.values() for e_id in ids.values()],
             get_prev_content=False,
         )
-        states = {
+
+        # A map from event_id to state map of events.
+        state_events: Dict[str, StateMap[EventBase]] = {
             key: {
                 k: state_map[e_id]
                 for k, e_id in state_dict.items()
                 if e_id in state_map
             }
-            for key, state_dict in states.items()
+            for key, state_dict in state_ids.items()
         }
 
         for e_id in event_ids:
-            likely_extremeties_domains = get_domains_from_state(states[e_id])
+            likely_extremeties_domains = get_domains_from_state(state_events[e_id])
 
             success = await try_backfill(
                 [

+ 19 - 14
synapse/handlers/initial_sync.py

@@ -13,21 +13,27 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
-
-from twisted.internet import defer
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
 from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import SynapseError
+from synapse.events import EventBase
 from synapse.events.validator import EventValidator
 from synapse.handlers.presence import format_user_presence_state
 from synapse.handlers.receipts import ReceiptEventSource
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage.roommember import RoomsForUser
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+    JsonDict,
+    Requester,
+    RoomStreamToken,
+    StateMap,
+    StreamToken,
+    UserID,
+)
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import concurrently_execute
+from synapse.util.async_helpers import concurrently_execute, gather_results
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.visibility import filter_events_for_client
 
@@ -190,14 +196,13 @@ class InitialSyncHandler:
                     )
                     deferred_room_state = run_in_background(
                         self.state_store.get_state_for_events, [event.event_id]
-                    )
-                    deferred_room_state.addCallback(
-                        lambda states: states[event.event_id]
+                    ).addCallback(
+                        lambda states: cast(StateMap[EventBase], states[event.event_id])
                     )
 
                 (messages, token), current_state = await make_deferred_yieldable(
-                    defer.gatherResults(
-                        [
+                    gather_results(
+                        (
                             run_in_background(
                                 self.store.get_recent_events_for_room,
                                 event.room_id,
@@ -205,7 +210,7 @@ class InitialSyncHandler:
                                 end_token=room_end_token,
                             ),
                             deferred_room_state,
-                        ]
+                        )
                     )
                 ).addErrback(unwrapFirstError)
 
@@ -454,8 +459,8 @@ class InitialSyncHandler:
             return receipts
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
-            defer.gatherResults(
-                [
+            gather_results(
+                (
                     run_in_background(get_presence),
                     run_in_background(get_receipts),
                     run_in_background(
@@ -464,7 +469,7 @@ class InitialSyncHandler:
                         limit=limit,
                         end_token=now_token.room_key,
                     ),
-                ],
+                ),
                 consumeErrors=True,
             ).addErrback(unwrapFirstError)
         )

+ 6 - 7
synapse/handlers/message.py

@@ -21,7 +21,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
-from twisted.internet import defer
 from twisted.internet.interfaces import IDelayedCall
 
 from synapse import event_auth
@@ -57,7 +56,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
 from synapse.util import json_decoder, json_encoder, log_failure
-from synapse.util.async_helpers import Linearizer, unwrapFirstError
+from synapse.util.async_helpers import Linearizer, gather_results, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_events_for_client
@@ -1168,9 +1167,9 @@ class EventCreationHandler:
 
         # We now persist the event (and update the cache in parallel, since we
         # don't want to block on it).
-        result = await make_deferred_yieldable(
-            defer.gatherResults(
-                [
+        result, _ = await make_deferred_yieldable(
+            gather_results(
+                (
                     run_in_background(
                         self._persist_event,
                         requester=requester,
@@ -1182,12 +1181,12 @@ class EventCreationHandler:
                     run_in_background(
                         self.cache_joined_hosts_for_event, event, context
                     ).addErrback(log_failure, "cache_joined_hosts_for_event failed"),
-                ],
+                ),
                 consumeErrors=True,
             )
         ).addErrback(unwrapFirstError)
 
-        return result[0]
+        return result
 
     async def _persist_event(
         self,

+ 5 - 2
synapse/http/federation/matrix_federation_agent.py

@@ -25,6 +25,7 @@ from zope.interface import implementer
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet.interfaces import (
+    IProtocol,
     IProtocolFactory,
     IReactorCore,
     IStreamClientEndpoint,
@@ -309,12 +310,14 @@ class MatrixHostnameEndpoint:
 
         self._srv_resolver = srv_resolver
 
-    def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
+    def connect(
+        self, protocol_factory: IProtocolFactory
+    ) -> "defer.Deferred[IProtocol]":
         """Implements IStreamClientEndpoint interface"""
 
         return run_in_background(self._do_connect, protocol_factory)
 
-    async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
+    async def _do_connect(self, protocol_factory: IProtocolFactory) -> IProtocol:
         first_exception = None
 
         server_list = await self._resolve_server()

+ 103 - 46
synapse/logging/context.py

@@ -22,20 +22,33 @@ them.
 
 See doc/log_contexts.rst for details on how this works.
 """
-import inspect
 import logging
 import threading
 import typing
 import warnings
-from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+from types import TracebackType
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Callable,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    overload,
+)
 
 import attr
 from typing_extensions import Literal
 
 from twisted.internet import defer, threads
+from twisted.python.threadpool import ThreadPool
 
 if TYPE_CHECKING:
     from synapse.logging.scopecontextmanager import _LogContextScope
+    from synapse.types import ISynapseReactor
 
 logger = logging.getLogger(__name__)
 
@@ -66,7 +79,7 @@ except Exception:
 
 
 # a hook which can be set during testing to assert that we aren't abusing logcontexts.
-def logcontext_error(msg: str):
+def logcontext_error(msg: str) -> None:
     logger.warning(msg)
 
 
@@ -223,22 +236,19 @@ class _Sentinel:
     def __str__(self) -> str:
         return "sentinel"
 
-    def copy_to(self, record):
-        pass
-
-    def start(self, rusage: "Optional[resource.struct_rusage]"):
+    def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
         pass
 
-    def stop(self, rusage: "Optional[resource.struct_rusage]"):
+    def stop(self, rusage: "Optional[resource.struct_rusage]") -> None:
         pass
 
-    def add_database_transaction(self, duration_sec):
+    def add_database_transaction(self, duration_sec: float) -> None:
         pass
 
-    def add_database_scheduled(self, sched_sec):
+    def add_database_scheduled(self, sched_sec: float) -> None:
         pass
 
-    def record_event_fetch(self, event_count):
+    def record_event_fetch(self, event_count: int) -> None:
         pass
 
     def __bool__(self) -> Literal[False]:
@@ -379,7 +389,12 @@ class LoggingContext:
             )
         return self
 
-    def __exit__(self, type, value, traceback) -> None:
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         """Restore the logging context in thread local storage to the state it
         was before this context was entered.
         Returns:
@@ -399,17 +414,6 @@ class LoggingContext:
         # recorded against the correct metrics.
         self.finished = True
 
-    def copy_to(self, record) -> None:
-        """Copy logging fields from this context to a log record or
-        another LoggingContext
-        """
-
-        # we track the current request
-        record.request = self.request
-
-        # we also track the current scope:
-        record.scope = self.scope
-
     def start(self, rusage: "Optional[resource.struct_rusage]") -> None:
         """
         Record that this logcontext is currently running.
@@ -626,7 +630,12 @@ class PreserveLoggingContext:
     def __enter__(self) -> None:
         self._old_context = set_current_context(self._new_context)
 
-    def __exit__(self, type, value, traceback) -> None:
+    def __exit__(
+        self,
+        type: Optional[Type[BaseException]],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         context = set_current_context(self._old_context)
 
         if context != self._new_context:
@@ -711,16 +720,61 @@ def nested_logging_context(suffix: str) -> LoggingContext:
     )
 
 
-def preserve_fn(f):
+R = TypeVar("R")
+
+
+@overload
+def preserve_fn(  # type: ignore[misc]
+    f: Callable[..., Awaitable[R]],
+) -> Callable[..., "defer.Deferred[R]"]:
+    # The `type: ignore[misc]` above suppresses
+    # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
+    ...
+
+
+@overload
+def preserve_fn(f: Callable[..., R]) -> Callable[..., "defer.Deferred[R]"]:
+    ...
+
+
+def preserve_fn(
+    f: Union[
+        Callable[..., R],
+        Callable[..., Awaitable[R]],
+    ]
+) -> Callable[..., "defer.Deferred[R]"]:
     """Function decorator which wraps the function with run_in_background"""
 
-    def g(*args, **kwargs):
+    def g(*args: Any, **kwargs: Any) -> "defer.Deferred[R]":
         return run_in_background(f, *args, **kwargs)
 
     return g
 
 
-def run_in_background(f, *args, **kwargs) -> defer.Deferred:
+@overload
+def run_in_background(  # type: ignore[misc]
+    f: Callable[..., Awaitable[R]], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
+    # The `type: ignore[misc]` above suppresses
+    # "Overloaded function signatures 1 and 2 overlap with incompatible return types"
+    ...
+
+
+@overload
+def run_in_background(
+    f: Callable[..., R], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
+    ...
+
+
+def run_in_background(
+    f: Union[
+        Callable[..., R],
+        Callable[..., Awaitable[R]],
+    ],
+    *args: Any,
+    **kwargs: Any,
+) -> "defer.Deferred[R]":
     """Calls a function, ensuring that the current context is restored after
     return from the function, and that the sentinel context is set once the
     deferred returned by the function completes.
@@ -751,6 +805,10 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
     # At this point we should have a Deferred, if not then f was a synchronous
     # function, wrap it in a Deferred for consistency.
     if not isinstance(res, defer.Deferred):
+        # `res` is not a `Deferred` and not a `Coroutine`.
+        # There are no other types of `Awaitable`s we expect to encounter in Synapse.
+        assert not isinstance(res, Awaitable)
+
         return defer.succeed(res)
 
     if res.called and not res.paused:
@@ -778,13 +836,14 @@ def run_in_background(f, *args, **kwargs) -> defer.Deferred:
     return res
 
 
-def make_deferred_yieldable(deferred):
-    """Given a deferred (or coroutine), make it follow the Synapse logcontext
-    rules:
+T = TypeVar("T")
+
 
-    If the deferred has completed (or is not actually a Deferred), essentially
-    does nothing (just returns another completed deferred with the
-    result/failure).
+def make_deferred_yieldable(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
+    """Given a deferred, make it follow the Synapse logcontext rules:
+
+    If the deferred has completed, essentially does nothing (just returns another
+    completed deferred with the result/failure).
 
     If the deferred has not yet completed, resets the logcontext before
     returning a deferred. Then, when the deferred completes, restores the
@@ -792,16 +851,6 @@ def make_deferred_yieldable(deferred):
 
     (This is more-or-less the opposite operation to run_in_background.)
     """
-    if inspect.isawaitable(deferred):
-        # If we're given a coroutine we convert it to a deferred so that we
-        # run it and find out if it immediately finishes, it it does then we
-        # don't need to fiddle with log contexts at all and can return
-        # immediately.
-        deferred = defer.ensureDeferred(deferred)
-
-    if not isinstance(deferred, defer.Deferred):
-        return deferred
-
     if deferred.called and not deferred.paused:
         # it looks like this deferred is ready to run any callbacks we give it
         # immediately. We may as well optimise out the logcontext faffery.
@@ -823,7 +872,9 @@ def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
     return result
 
 
-def defer_to_thread(reactor, f, *args, **kwargs):
+def defer_to_thread(
+    reactor: "ISynapseReactor", f: Callable[..., R], *args: Any, **kwargs: Any
+) -> "defer.Deferred[R]":
     """
     Calls the function `f` using a thread from the reactor's default threadpool and
     returns the result as a Deferred.
@@ -855,7 +906,13 @@ def defer_to_thread(reactor, f, *args, **kwargs):
     return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
 
 
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
+def defer_to_threadpool(
+    reactor: "ISynapseReactor",
+    threadpool: ThreadPool,
+    f: Callable[..., R],
+    *args: Any,
+    **kwargs: Any,
+) -> "defer.Deferred[R]":
     """
     A wrapper for twisted.internet.threads.deferToThreadpool, which handles
     logcontexts correctly.
@@ -897,7 +954,7 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         assert isinstance(curr_context, LoggingContext)
         parent_context = curr_context
 
-    def g():
+    def g() -> R:
         with LoggingContext(str(curr_context), parent_context=parent_context):
             return f(*args, **kwargs)
 

+ 56 - 1
synapse/util/async_helpers.py

@@ -30,9 +30,11 @@ from typing import (
     Iterator,
     Optional,
     Set,
+    Tuple,
     TypeVar,
     Union,
     cast,
+    overload,
 )
 
 import attr
@@ -234,6 +236,59 @@ def yieldable_gather_results(
     ).addErrback(unwrapFirstError)
 
 
+T1 = TypeVar("T1")
+T2 = TypeVar("T2")
+T3 = TypeVar("T3")
+
+
+@overload
+def gather_results(
+    deferredList: Tuple[()], consumeErrors: bool = ...
+) -> "defer.Deferred[Tuple[()]]":
+    ...
+
+
+@overload
+def gather_results(
+    deferredList: Tuple["defer.Deferred[T1]"],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1]]":
+    ...
+
+
+@overload
+def gather_results(
+    deferredList: Tuple["defer.Deferred[T1]", "defer.Deferred[T2]"],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2]]":
+    ...
+
+
+@overload
+def gather_results(
+    deferredList: Tuple[
+        "defer.Deferred[T1]", "defer.Deferred[T2]", "defer.Deferred[T3]"
+    ],
+    consumeErrors: bool = ...,
+) -> "defer.Deferred[Tuple[T1, T2, T3]]":
+    ...
+
+
+def gather_results(  # type: ignore[misc]
+    deferredList: Tuple["defer.Deferred[T1]", ...],
+    consumeErrors: bool = False,
+) -> "defer.Deferred[Tuple[T1, ...]]":
+    """Combines a tuple of `Deferred`s into a single `Deferred`.
+
+    Wraps `defer.gatherResults` to provide type annotations that support heterogenous
+    lists of `Deferred`s.
+    """
+    # The `type: ignore[misc]` above suppresses
+    # "Overloaded function implementation cannot produce return type of signature 1/2/3"
+    deferred = defer.gatherResults(deferredList, consumeErrors=consumeErrors)
+    return deferred.addCallback(tuple)
+
+
 @attr.s(slots=True)
 class _LinearizerEntry:
     # The number of things executing.
@@ -352,7 +407,7 @@ class Linearizer:
 
         logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
 
-        new_defer = make_deferred_yieldable(defer.Deferred())
+        new_defer: "defer.Deferred[None]" = make_deferred_yieldable(defer.Deferred())
         entry.deferreds[new_defer] = 1
 
         def cb(_r: None) -> "defer.Deferred[None]":

+ 1 - 0
synapse/util/caches/cached_call.py

@@ -76,6 +76,7 @@ class CachedCall(Generic[TV]):
 
         # Fire off the callable now if this is our first time
         if not self._deferred:
+            assert self._callable is not None
             self._deferred = run_in_background(self._callable)
 
             # we will never need the callable again, so make sure it can be GCed

+ 1 - 0
synapse/util/file_consumer.py

@@ -142,6 +142,7 @@ class BackgroundFileConsumer:
 
     def wait(self) -> "Deferred[None]":
         """Returns a deferred that resolves when finished writing to file"""
+        assert self._finished_deferred is not None
         return make_deferred_yieldable(self._finished_deferred)
 
     def _resume_paused_producer(self) -> None:

+ 0 - 35
tests/util/test_logcontext.py

@@ -152,46 +152,11 @@ class LoggingContextTestCase(unittest.TestCase):
             # now it should be restored
             self._check_test_key("one")
 
-    @defer.inlineCallbacks
-    def test_make_deferred_yieldable_on_non_deferred(self):
-        """Check that make_deferred_yieldable does the right thing when its
-        argument isn't actually a deferred"""
-
-        with LoggingContext("one"):
-            d1 = make_deferred_yieldable("bum")
-            self._check_test_key("one")
-
-            r = yield d1
-            self.assertEqual(r, "bum")
-            self._check_test_key("one")
-
     def test_nested_logging_context(self):
         with LoggingContext("foo"):
             nested_context = nested_logging_context(suffix="bar")
             self.assertEqual(nested_context.name, "foo-bar")
 
-    @defer.inlineCallbacks
-    def test_make_deferred_yieldable_with_await(self):
-        # an async function which returns an incomplete coroutine, but doesn't
-        # follow the synapse rules.
-
-        async def blocking_function():
-            d = defer.Deferred()
-            reactor.callLater(0, d.callback, None)
-            await d
-
-        sentinel_context = current_context()
-
-        with LoggingContext("one"):
-            d1 = make_deferred_yieldable(blocking_function())
-            # make sure that the context was reset by make_deferred_yieldable
-            self.assertIs(current_context(), sentinel_context)
-
-            yield d1
-
-            # now it should be restored
-            self._check_test_key("one")
-
 
 # a function which returns a deferred which has been "called", but
 # which had a function which returned another incomplete deferred on