Browse Source

Add missing types to opentracing. (#13345)

After this change `synapse.logging` is fully typed.
Patrick Cloke 1 year ago
parent
commit
50122754c8

+ 1 - 1
changelog.d/13328.misc

@@ -1 +1 @@
-Add type hints to `trace` decorator.
+Add missing type hints to open tracing module.

+ 1 - 0
changelog.d/13345.misc

@@ -0,0 +1 @@
+Add missing type hints to open tracing module.

+ 0 - 3
mypy.ini

@@ -84,9 +84,6 @@ disallow_untyped_defs = False
 [mypy-synapse.http.matrixfederationclient]
 disallow_untyped_defs = False
 
-[mypy-synapse.logging.opentracing]
-disallow_untyped_defs = False
-
 [mypy-synapse.metrics._reactor_metrics]
 disallow_untyped_defs = False
 # This module imports select.epoll. That exists on Linux, but doesn't on macOS.

+ 1 - 1
synapse/federation/transport/server/_base.py

@@ -309,7 +309,7 @@ class BaseFederationServlet:
                 raise
 
             # update the active opentracing span with the authenticated entity
-            set_tag("authenticated_entity", origin)
+            set_tag("authenticated_entity", str(origin))
 
             # if the origin is authenticated and whitelisted, use its span context
             # as the parent.

+ 4 - 4
synapse/handlers/device.py

@@ -118,8 +118,8 @@ class DeviceWorkerHandler:
         ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
         _update_device_from_client_ips(device, ips)
 
-        set_tag("device", device)
-        set_tag("ips", ips)
+        set_tag("device", str(device))
+        set_tag("ips", str(ips))
 
         return device
 
@@ -170,7 +170,7 @@ class DeviceWorkerHandler:
         """
 
         set_tag("user_id", user_id)
-        set_tag("from_token", from_token)
+        set_tag("from_token", str(from_token))
         now_room_key = self.store.get_room_max_token()
 
         room_ids = await self.store.get_rooms_for_user(user_id)
@@ -795,7 +795,7 @@ class DeviceListUpdater:
         """
 
         set_tag("origin", origin)
-        set_tag("edu_content", edu_content)
+        set_tag("edu_content", str(edu_content))
         user_id = edu_content.pop("user_id")
         device_id = edu_content.pop("device_id")
         stream_id = str(edu_content.pop("stream_id"))  # They may come as ints

+ 8 - 8
synapse/handlers/e2e_keys.py

@@ -138,8 +138,8 @@ class E2eKeysHandler:
                 else:
                     remote_queries[user_id] = device_ids
 
-            set_tag("local_key_query", local_query)
-            set_tag("remote_key_query", remote_queries)
+            set_tag("local_key_query", str(local_query))
+            set_tag("remote_key_query", str(remote_queries))
 
             # First get local devices.
             # A map of destination -> failure response.
@@ -343,7 +343,7 @@ class E2eKeysHandler:
             failure = _exception_to_failure(e)
             failures[destination] = failure
             set_tag("error", True)
-            set_tag("reason", failure)
+            set_tag("reason", str(failure))
 
         return
 
@@ -405,7 +405,7 @@ class E2eKeysHandler:
         Returns:
             A map from user_id -> device_id -> device details
         """
-        set_tag("local_query", query)
+        set_tag("local_query", str(query))
         local_query: List[Tuple[str, Optional[str]]] = []
 
         result_dict: Dict[str, Dict[str, dict]] = {}
@@ -477,8 +477,8 @@ class E2eKeysHandler:
                 domain = get_domain_from_id(user_id)
                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
-        set_tag("local_key_query", local_query)
-        set_tag("remote_key_query", remote_queries)
+        set_tag("local_key_query", str(local_query))
+        set_tag("remote_key_query", str(remote_queries))
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
@@ -508,7 +508,7 @@ class E2eKeysHandler:
                 failure = _exception_to_failure(e)
                 failures[destination] = failure
                 set_tag("error", True)
-                set_tag("reason", failure)
+                set_tag("reason", str(failure))
 
         await make_deferred_yieldable(
             defer.gatherResults(
@@ -611,7 +611,7 @@ class E2eKeysHandler:
 
         result = await self.store.count_e2e_one_time_keys(user_id, device_id)
 
-        set_tag("one_time_key_counts", result)
+        set_tag("one_time_key_counts", str(result))
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(

+ 2 - 2
synapse/handlers/e2e_room_keys.py

@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, Optional, cast
 
 from typing_extensions import Literal
 
@@ -97,7 +97,7 @@ class E2eRoomKeysHandler:
                 user_id, version, room_id, session_id
             )
 
-            log_kv(results)
+            log_kv(cast(JsonDict, results))
             return results
 
     @trace

+ 35 - 9
synapse/logging/opentracing.py

@@ -182,6 +182,8 @@ from typing import (
     Type,
     TypeVar,
     Union,
+    cast,
+    overload,
 )
 
 import attr
@@ -328,6 +330,7 @@ class _Sentinel(enum.Enum):
 
 P = ParamSpec("P")
 R = TypeVar("R")
+T = TypeVar("T")
 
 
 def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@@ -343,22 +346,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
     return _only_if_tracing_inner
 
 
-def ensure_active_span(message: str, ret=None):
+@overload
+def ensure_active_span(
+    message: str,
+) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
+    ...
+
+
+@overload
+def ensure_active_span(
+    message: str, ret: T
+) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
+    ...
+
+
+def ensure_active_span(
+    message: str, ret: Optional[T] = None
+) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
     """Executes the operation only if opentracing is enabled and there is an active span.
     If there is no active span it logs message at the error level.
 
     Args:
         message: Message which fills in "There was no active span when trying to %s"
             in the error log if there is no active span and opentracing is enabled.
-        ret (object): return value if opentracing is None or there is no active span.
+        ret: return value if opentracing is None or there is no active span.
 
-    Returns (object): The result of the func or ret if opentracing is disabled or there
+    Returns:
+        The result of the func, falling back to ret if opentracing is disabled or there
         was no active span.
     """
 
-    def ensure_active_span_inner_1(func):
+    def ensure_active_span_inner_1(
+        func: Callable[P, R]
+    ) -> Callable[P, Union[Optional[T], R]]:
         @wraps(func)
-        def ensure_active_span_inner_2(*args, **kwargs):
+        def ensure_active_span_inner_2(
+            *args: P.args, **kwargs: P.kwargs
+        ) -> Union[Optional[T], R]:
             if not opentracing:
                 return ret
 
@@ -464,7 +488,7 @@ def start_active_span(
     finish_on_close: bool = True,
     *,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span.
 
     Records the start time for the span, and sets it as the "active span" in the
@@ -502,7 +526,7 @@ def start_active_span_follows_from(
     *,
     inherit_force_tracing: bool = False,
     tracer: Optional["opentracing.Tracer"] = None,
-):
+) -> "opentracing.Scope":
     """Starts an active opentracing span, with additional references to previous spans
 
     Args:
@@ -717,7 +741,9 @@ def inject_response_headers(response_headers: Headers) -> None:
         response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
 
 
-@ensure_active_span("get the active span context as a dict", ret={})
+@ensure_active_span(
+    "get the active span context as a dict", ret=cast(Dict[str, str], {})
+)
 def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
     """
     Gets a span context as a dict. This can be used instead of manually
@@ -886,7 +912,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
         for i, arg in enumerate(argspec.args[1:]):
             set_tag("ARG_" + arg, args[i])  # type: ignore[index]
         set_tag("args", args[len(argspec.args) :])  # type: ignore[index]
-        set_tag("kwargs", kwargs)
+        set_tag("kwargs", str(kwargs))
         return func(*args, **kwargs)
 
     return _tag_args_inner

+ 1 - 1
synapse/metrics/background_process_metrics.py

@@ -235,7 +235,7 @@ def run_as_background_process(
                         f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
                     )
                 else:
-                    ctx = nullcontext()
+                    ctx = nullcontext()  # type: ignore[assignment]
                 with ctx:
                     return await func(*args, **kwargs)
             except Exception:

+ 3 - 1
synapse/rest/client/keys.py

@@ -208,7 +208,9 @@ class KeyChangesServlet(RestServlet):
 
         # We want to enforce they do pass us one, but we ignore it and return
         # changes after the "to" as well as before.
-        set_tag("to", parse_string(request, "to"))
+        #
+        # XXX This does not enforce that "to" is passed.
+        set_tag("to", str(parse_string(request, "to")))
 
         from_token = await StreamToken.from_string(self.store, from_token_string)
 

+ 1 - 1
synapse/storage/databases/main/deviceinbox.py

@@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             (user_id, device_id), None
         )
 
-        set_tag("last_deleted_stream_id", last_deleted_stream_id)
+        set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
 
         if last_deleted_stream_id:
             has_changed = self._device_inbox_stream_cache.has_entity_changed(

+ 2 - 2
synapse/storage/databases/main/devices.py

@@ -706,8 +706,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             else:
                 results[user_id] = await self.get_cached_devices_for_user(user_id)
 
-        set_tag("in_cache", results)
-        set_tag("not_in_cache", user_ids_not_in_cache)
+        set_tag("in_cache", str(results))
+        set_tag("not_in_cache", str(user_ids_not_in_cache))
 
         return user_ids_not_in_cache, results
 

+ 3 - 3
synapse/storage/databases/main/end_to_end_keys.py

@@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             key data.  The key data will be a dict in the same format as the
             DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
         """
-        set_tag("query_list", query_list)
+        set_tag("query_list", str(query_list))
         if not query_list:
             return {}
 
@@ -418,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
-            set_tag("new_keys", new_keys)
+            set_tag("new_keys", str(new_keys))
             # We are protected from race between lookup and insertion due to
             # a unique constraint. If there is a race of two calls to
             # `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -1161,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             set_tag("user_id", user_id)
             set_tag("device_id", device_id)
             set_tag("time_now", time_now)
-            set_tag("device_keys", device_keys)
+            set_tag("device_keys", str(device_keys))
 
             old_key_json = self.db_pool.simple_select_one_onecol_txn(
                 txn,

+ 21 - 9
tests/logging/test_opentracing.py

@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import cast
+
 from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactorClock
 
@@ -40,6 +42,15 @@ from tests.unittest import TestCase
 
 
 class LogContextScopeManagerTestCase(TestCase):
+    """
+    Test logging contexts and active opentracing spans.
+
+    There's casts throughout this from generic opentracing objects (e.g.
+    opentracing.Span) to the ones specific to Jaeger since they have additional
+    properties that these tests depend on. This is safe since the only supported
+    opentracing backend is Jaeger.
+    """
+
     if LogContextScopeManager is None:
         skip = "Requires opentracing"  # type: ignore[unreachable]
     if jaeger_client is None:
@@ -69,7 +80,7 @@ class LogContextScopeManagerTestCase(TestCase):
 
             # start_active_span should start and activate a span.
             scope = start_active_span("span", tracer=self._tracer)
-            span = scope.span
+            span = cast(jaeger_client.Span, scope.span)
             self.assertEqual(self._tracer.active_span, span)
             self.assertIsNotNone(span.start_time)
 
@@ -91,6 +102,7 @@ class LogContextScopeManagerTestCase(TestCase):
         with LoggingContext("root context"):
             with start_active_span("root span", tracer=self._tracer) as root_scope:
                 self.assertEqual(self._tracer.active_span, root_scope.span)
+                root_context = cast(jaeger_client.SpanContext, root_scope.span.context)
 
                 scope1 = start_active_span(
                     "child1",
@@ -99,9 +111,8 @@ class LogContextScopeManagerTestCase(TestCase):
                 self.assertEqual(
                     self._tracer.active_span, scope1.span, "child1 was not activated"
                 )
-                self.assertEqual(
-                    scope1.span.context.parent_id, root_scope.span.context.span_id
-                )
+                context1 = cast(jaeger_client.SpanContext, scope1.span.context)
+                self.assertEqual(context1.parent_id, root_context.span_id)
 
                 scope2 = start_active_span_follows_from(
                     "child2",
@@ -109,17 +120,18 @@ class LogContextScopeManagerTestCase(TestCase):
                     tracer=self._tracer,
                 )
                 self.assertEqual(self._tracer.active_span, scope2.span)
-                self.assertEqual(
-                    scope2.span.context.parent_id, scope1.span.context.span_id
-                )
+                context2 = cast(jaeger_client.SpanContext, scope2.span.context)
+                self.assertEqual(context2.parent_id, context1.span_id)
 
                 with scope1, scope2:
                     pass
 
                 # the root scope should be restored
                 self.assertEqual(self._tracer.active_span, root_scope.span)
-                self.assertIsNotNone(scope2.span.end_time)
-                self.assertIsNotNone(scope1.span.end_time)
+                span2 = cast(jaeger_client.Span, scope2.span)
+                span1 = cast(jaeger_client.Span, scope1.span)
+                self.assertIsNotNone(span2.end_time)
+                self.assertIsNotNone(span1.end_time)
 
             self.assertIsNone(self._tracer.active_span)