瀏覽代碼

Add most missing type hints to synapse.util (#11328)

Patrick Cloke 2 年之前
父節點
當前提交
7468723697

+ 1 - 0
changelog.d/11328.misc

@@ -0,0 +1 @@
+Add type hints to `synapse.util`.

+ 3 - 84
mypy.ini

@@ -196,92 +196,11 @@ disallow_untyped_defs = True
 [mypy-synapse.streams.*]
 [mypy-synapse.streams.*]
 disallow_untyped_defs = True
 disallow_untyped_defs = True
 
 
-[mypy-synapse.util.batching_queue]
+[mypy-synapse.util.*]
 disallow_untyped_defs = True
 disallow_untyped_defs = True
 
 
-[mypy-synapse.util.caches.cached_call]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.caches.dictionary_cache]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.caches.lrucache]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.caches.response_cache]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.caches.stream_change_cache]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.caches.ttl_cache]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.daemonize]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.file_consumer]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.frozenutils]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.hash]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.httpresourcetree]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.iterutils]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.linked_list]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.logcontext]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.logformatter]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.macaroons]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.manhole]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.module_loader]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.msisdn]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.patch_inline_callbacks]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.ratelimitutils]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.retryutils]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.rlimit]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.stringutils]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.templates]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.threepids]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.wheel_timer]
-disallow_untyped_defs = True
-
-[mypy-synapse.util.versionstring]
-disallow_untyped_defs = True
+[mypy-synapse.util.caches.treecache]
+disallow_untyped_defs = False
 
 
 [mypy-tests.handlers.test_user_directory]
 [mypy-tests.handlers.test_user_directory]
 disallow_untyped_defs = True
 disallow_untyped_defs = True

+ 16 - 16
synapse/util/async_helpers.py

@@ -27,6 +27,7 @@ from typing import (
     Generic,
     Generic,
     Hashable,
     Hashable,
     Iterable,
     Iterable,
+    Iterator,
     Optional,
     Optional,
     Set,
     Set,
     TypeVar,
     TypeVar,
@@ -40,7 +41,6 @@ from typing_extensions import ContextManager
 from twisted.internet import defer
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.internet.defer import CancelledError
 from twisted.internet.interfaces import IReactorTime
 from twisted.internet.interfaces import IReactorTime
-from twisted.python import failure
 from twisted.python.failure import Failure
 from twisted.python.failure import Failure
 
 
 from synapse.logging.context import (
 from synapse.logging.context import (
@@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_result", None)
         object.__setattr__(self, "_observers", [])
         object.__setattr__(self, "_observers", [])
 
 
-        def callback(r):
+        def callback(r: _T) -> _T:
             object.__setattr__(self, "_result", (True, r))
             object.__setattr__(self, "_result", (True, r))
 
 
             # once we have set _result, no more entries will be added to _observers,
             # once we have set _result, no more entries will be added to _observers,
@@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
                     )
                     )
             return r
             return r
 
 
-        def errback(f):
+        def errback(f: Failure) -> Optional[Failure]:
             object.__setattr__(self, "_result", (False, f))
             object.__setattr__(self, "_result", (False, f))
 
 
             # once we have set _result, no more entries will be added to _observers,
             # once we have set _result, no more entries will be added to _observers,
@@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
             for observer in observers:
             for observer in observers:
                 # This is a little bit of magic to correctly propagate stack
                 # This is a little bit of magic to correctly propagate stack
                 # traces when we `await` on one of the observer deferreds.
                 # traces when we `await` on one of the observer deferreds.
-                f.value.__failure__ = f
+                f.value.__failure__ = f  # type: ignore[union-attr]
                 try:
                 try:
                     observer.errback(f)
                     observer.errback(f)
                 except Exception as e:
                 except Exception as e:
@@ -314,7 +314,7 @@ class Linearizer:
         # will release the lock.
         # will release the lock.
 
 
         @contextmanager
         @contextmanager
-        def _ctx_manager(_):
+        def _ctx_manager(_: None) -> Iterator[None]:
             try:
             try:
                 yield
                 yield
             finally:
             finally:
@@ -355,7 +355,7 @@ class Linearizer:
         new_defer = make_deferred_yieldable(defer.Deferred())
         new_defer = make_deferred_yieldable(defer.Deferred())
         entry.deferreds[new_defer] = 1
         entry.deferreds[new_defer] = 1
 
 
-        def cb(_r):
+        def cb(_r: None) -> "defer.Deferred[None]":
             logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
             logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
             entry.count += 1
             entry.count += 1
 
 
@@ -371,7 +371,7 @@ class Linearizer:
             # code must be synchronous, so this is the only sensible place.)
             # code must be synchronous, so this is the only sensible place.)
             return self._clock.sleep(0)
             return self._clock.sleep(0)
 
 
-        def eb(e):
+        def eb(e: Failure) -> Failure:
             logger.info("defer %r got err %r", new_defer, e)
             logger.info("defer %r got err %r", new_defer, e)
             if isinstance(e, CancelledError):
             if isinstance(e, CancelledError):
                 logger.debug(
                 logger.debug(
@@ -435,7 +435,7 @@ class ReadWriteLock:
             await make_deferred_yieldable(curr_writer)
             await make_deferred_yieldable(curr_writer)
 
 
         @contextmanager
         @contextmanager
-        def _ctx_manager():
+        def _ctx_manager() -> Iterator[None]:
             try:
             try:
                 yield
                 yield
             finally:
             finally:
@@ -464,7 +464,7 @@ class ReadWriteLock:
         await make_deferred_yieldable(defer.gatherResults(to_wait_on))
         await make_deferred_yieldable(defer.gatherResults(to_wait_on))
 
 
         @contextmanager
         @contextmanager
-        def _ctx_manager():
+        def _ctx_manager() -> Iterator[None]:
             try:
             try:
                 yield
                 yield
             finally:
             finally:
@@ -524,7 +524,7 @@ def timeout_deferred(
 
 
     delayed_call = reactor.callLater(timeout, time_it_out)
     delayed_call = reactor.callLater(timeout, time_it_out)
 
 
-    def convert_cancelled(value: failure.Failure):
+    def convert_cancelled(value: Failure) -> Failure:
         # if the original deferred was cancelled, and our timeout has fired, then
         # if the original deferred was cancelled, and our timeout has fired, then
         # the reason it was cancelled was due to our timeout. Turn the CancelledError
         # the reason it was cancelled was due to our timeout. Turn the CancelledError
         # into a TimeoutError.
         # into a TimeoutError.
@@ -534,7 +534,7 @@ def timeout_deferred(
 
 
     deferred.addErrback(convert_cancelled)
     deferred.addErrback(convert_cancelled)
 
 
-    def cancel_timeout(result):
+    def cancel_timeout(result: _T) -> _T:
         # stop the pending call to cancel the deferred if it's been fired
         # stop the pending call to cancel the deferred if it's been fired
         if delayed_call.active():
         if delayed_call.active():
             delayed_call.cancel()
             delayed_call.cancel()
@@ -542,11 +542,11 @@ def timeout_deferred(
 
 
     deferred.addBoth(cancel_timeout)
     deferred.addBoth(cancel_timeout)
 
 
-    def success_cb(val):
+    def success_cb(val: _T) -> None:
         if not new_d.called:
         if not new_d.called:
             new_d.callback(val)
             new_d.callback(val)
 
 
-    def failure_cb(val):
+    def failure_cb(val: Failure) -> None:
         if not new_d.called:
         if not new_d.called:
             new_d.errback(val)
             new_d.errback(val)
 
 
@@ -557,13 +557,13 @@ def timeout_deferred(
 
 
 # This class can't be generic because it uses slots with attrs.
 # This class can't be generic because it uses slots with attrs.
 # See: https://github.com/python-attrs/attrs/issues/313
 # See: https://github.com/python-attrs/attrs/issues/313
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class DoneAwaitable:  # should be: Generic[R]
 class DoneAwaitable:  # should be: Generic[R]
     """Simple awaitable that returns the provided value."""
     """Simple awaitable that returns the provided value."""
 
 
-    value = attr.ib(type=Any)  # should be: R
+    value: Any  # should be: R
 
 
-    def __await__(self):
+    def __await__(self) -> Any:
         return self
         return self
 
 
     def __iter__(self) -> "DoneAwaitable":
     def __iter__(self) -> "DoneAwaitable":

+ 17 - 15
synapse/util/caches/__init__.py

@@ -17,7 +17,7 @@ import logging
 import typing
 import typing
 from enum import Enum, auto
 from enum import Enum, auto
 from sys import intern
 from sys import intern
-from typing import Callable, Dict, Optional, Sized
+from typing import Any, Callable, Dict, List, Optional, Sized
 
 
 import attr
 import attr
 from prometheus_client.core import Gauge
 from prometheus_client.core import Gauge
@@ -58,20 +58,20 @@ class EvictionReason(Enum):
     time = auto()
     time = auto()
 
 
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class CacheMetric:
 class CacheMetric:
 
 
-    _cache = attr.ib()
-    _cache_type = attr.ib(type=str)
-    _cache_name = attr.ib(type=str)
-    _collect_callback = attr.ib(type=Optional[Callable])
+    _cache: Sized
+    _cache_type: str
+    _cache_name: str
+    _collect_callback: Optional[Callable]
 
 
-    hits = attr.ib(default=0)
-    misses = attr.ib(default=0)
+    hits: int = 0
+    misses: int = 0
     eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
     eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
         factory=collections.Counter
         factory=collections.Counter
     )
     )
-    memory_usage = attr.ib(default=None)
+    memory_usage: Optional[int] = None
 
 
     def inc_hits(self) -> None:
     def inc_hits(self) -> None:
         self.hits += 1
         self.hits += 1
@@ -89,13 +89,14 @@ class CacheMetric:
         self.memory_usage += memory
         self.memory_usage += memory
 
 
     def dec_memory_usage(self, memory: int) -> None:
     def dec_memory_usage(self, memory: int) -> None:
+        assert self.memory_usage is not None
         self.memory_usage -= memory
         self.memory_usage -= memory
 
 
     def clear_memory_usage(self) -> None:
     def clear_memory_usage(self) -> None:
         if self.memory_usage is not None:
         if self.memory_usage is not None:
             self.memory_usage = 0
             self.memory_usage = 0
 
 
-    def describe(self):
+    def describe(self) -> List[str]:
         return []
         return []
 
 
     def collect(self) -> None:
     def collect(self) -> None:
@@ -118,8 +119,9 @@ class CacheMetric:
                         self.eviction_size_by_reason[reason]
                         self.eviction_size_by_reason[reason]
                     )
                     )
                 cache_total.labels(self._cache_name).set(self.hits + self.misses)
                 cache_total.labels(self._cache_name).set(self.hits + self.misses)
-                if getattr(self._cache, "max_size", None):
-                    cache_max_size.labels(self._cache_name).set(self._cache.max_size)
+                max_size = getattr(self._cache, "max_size", None)
+                if max_size:
+                    cache_max_size.labels(self._cache_name).set(max_size)
 
 
                 if TRACK_MEMORY_USAGE:
                 if TRACK_MEMORY_USAGE:
                     # self.memory_usage can be None if nothing has been inserted
                     # self.memory_usage can be None if nothing has been inserted
@@ -193,7 +195,7 @@ KNOWN_KEYS = {
 }
 }
 
 
 
 
-def intern_string(string):
+def intern_string(string: Optional[str]) -> Optional[str]:
     """Takes a (potentially) unicode string and interns it if it's ascii"""
     """Takes a (potentially) unicode string and interns it if it's ascii"""
     if string is None:
     if string is None:
         return None
         return None
@@ -204,7 +206,7 @@ def intern_string(string):
         return string
         return string
 
 
 
 
-def intern_dict(dictionary):
+def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
     """Takes a dictionary and interns well known keys and their values"""
     """Takes a dictionary and interns well known keys and their values"""
     return {
     return {
         KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
         KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
@@ -212,7 +214,7 @@ def intern_dict(dictionary):
     }
     }
 
 
 
 
-def _intern_known_values(key, value):
+def _intern_known_values(key: str, value: Any) -> Any:
     intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
     intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
 
 
     if key in intern_keys:
     if key in intern_keys:

+ 1 - 1
synapse/util/caches/deferred_cache.py

@@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]):
         callbacks = [callback] if callback else []
         callbacks = [callback] if callback else []
         self.cache.set(key, value, callbacks=callbacks)
         self.cache.set(key, value, callbacks=callbacks)
 
 
-    def invalidate(self, key) -> None:
+    def invalidate(self, key: KT) -> None:
         """Delete a key, or tree of entries
         """Delete a key, or tree of entries
 
 
         If the cache is backed by a regular dict, then "key" must be of
         If the cache is backed by a regular dict, then "key" must be of

+ 42 - 25
synapse/util/caches/descriptors.py

@@ -19,12 +19,15 @@ import logging
 from typing import (
 from typing import (
     Any,
     Any,
     Callable,
     Callable,
+    Dict,
     Generic,
     Generic,
+    Hashable,
     Iterable,
     Iterable,
     Mapping,
     Mapping,
     Optional,
     Optional,
     Sequence,
     Sequence,
     Tuple,
     Tuple,
+    Type,
     TypeVar,
     TypeVar,
     Union,
     Union,
     cast,
     cast,
@@ -32,6 +35,7 @@ from typing import (
 from weakref import WeakValueDictionary
 from weakref import WeakValueDictionary
 
 
 from twisted.internet import defer
 from twisted.internet import defer
+from twisted.python.failure import Failure
 
 
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.util import unwrapFirstError
 from synapse.util import unwrapFirstError
@@ -60,7 +64,12 @@ class _CachedFunction(Generic[F]):
 
 
 
 
 class _CacheDescriptorBase:
 class _CacheDescriptorBase:
-    def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
+    def __init__(
+        self,
+        orig: Callable[..., Any],
+        num_args: Optional[int],
+        cache_context: bool = False,
+    ):
         self.orig = orig
         self.orig = orig
 
 
         arg_spec = inspect.getfullargspec(orig)
         arg_spec = inspect.getfullargspec(orig)
@@ -172,14 +181,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        orig,
+        orig: Callable[..., Any],
         max_entries: int = 1000,
         max_entries: int = 1000,
         cache_context: bool = False,
         cache_context: bool = False,
     ):
     ):
         super().__init__(orig, num_args=None, cache_context=cache_context)
         super().__init__(orig, num_args=None, cache_context=cache_context)
         self.max_entries = max_entries
         self.max_entries = max_entries
 
 
-    def __get__(self, obj, owner):
+    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: LruCache[CacheKey, Any] = LruCache(
         cache: LruCache[CacheKey, Any] = LruCache(
             cache_name=self.orig.__name__,
             cache_name=self.orig.__name__,
             max_size=self.max_entries,
             max_size=self.max_entries,
@@ -189,7 +198,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
         sentinel = LruCacheDescriptor._Sentinel.sentinel
         sentinel = LruCacheDescriptor._Sentinel.sentinel
 
 
         @functools.wraps(self.orig)
         @functools.wraps(self.orig)
-        def _wrapped(*args, **kwargs):
+        def _wrapped(*args: Any, **kwargs: Any) -> Any:
             invalidate_callback = kwargs.pop("on_invalidate", None)
             invalidate_callback = kwargs.pop("on_invalidate", None)
             callbacks = (invalidate_callback,) if invalidate_callback else ()
             callbacks = (invalidate_callback,) if invalidate_callback else ()
 
 
@@ -245,19 +254,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
             return r1 + r2
             return r1 + r2
 
 
     Args:
     Args:
-        num_args (int): number of positional arguments (excluding ``self`` and
+        num_args: number of positional arguments (excluding ``self`` and
             ``cache_context``) to use as cache keys. Defaults to all named
             ``cache_context``) to use as cache keys. Defaults to all named
             args of the function.
             args of the function.
     """
     """
 
 
     def __init__(
     def __init__(
         self,
         self,
-        orig,
-        max_entries=1000,
-        num_args=None,
-        tree=False,
-        cache_context=False,
-        iterable=False,
+        orig: Callable[..., Any],
+        max_entries: int = 1000,
+        num_args: Optional[int] = None,
+        tree: bool = False,
+        cache_context: bool = False,
+        iterable: bool = False,
         prune_unread_entries: bool = True,
         prune_unread_entries: bool = True,
     ):
     ):
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
         super().__init__(orig, num_args=num_args, cache_context=cache_context)
@@ -272,7 +281,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         self.iterable = iterable
         self.iterable = iterable
         self.prune_unread_entries = prune_unread_entries
         self.prune_unread_entries = prune_unread_entries
 
 
-    def __get__(self, obj, owner):
+    def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
             name=self.orig.__name__,
             name=self.orig.__name__,
             max_entries=self.max_entries,
             max_entries=self.max_entries,
@@ -284,7 +293,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         get_cache_key = self.cache_key_builder
         get_cache_key = self.cache_key_builder
 
 
         @functools.wraps(self.orig)
         @functools.wraps(self.orig)
-        def _wrapped(*args, **kwargs):
+        def _wrapped(*args: Any, **kwargs: Any) -> Any:
             # If we're passed a cache_context then we'll want to call its invalidate()
             # If we're passed a cache_context then we'll want to call its invalidate()
             # whenever we are invalidated
             # whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -335,13 +344,19 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
     of results.
     of results.
     """
     """
 
 
-    def __init__(self, orig, cached_method_name, list_name, num_args=None):
+    def __init__(
+        self,
+        orig: Callable[..., Any],
+        cached_method_name: str,
+        list_name: str,
+        num_args: Optional[int] = None,
+    ):
         """
         """
         Args:
         Args:
-            orig (function)
-            cached_method_name (str): The name of the cached method.
-            list_name (str): Name of the argument which is the bulk lookup list
-            num_args (int): number of positional arguments (excluding ``self``,
+            orig
+            cached_method_name: The name of the cached method.
+            list_name: Name of the argument which is the bulk lookup list
+            num_args: number of positional arguments (excluding ``self``,
                 but including list_name) to use as cache keys. Defaults to all
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
                 named args of the function.
         """
         """
@@ -360,13 +375,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 % (self.list_name, cached_method_name)
                 % (self.list_name, cached_method_name)
             )
             )
 
 
-    def __get__(self, obj, objtype=None):
+    def __get__(
+        self, obj: Optional[Any], objtype: Optional[Type] = None
+    ) -> Callable[..., Any]:
         cached_method = getattr(obj, self.cached_method_name)
         cached_method = getattr(obj, self.cached_method_name)
         cache: DeferredCache[CacheKey, Any] = cached_method.cache
         cache: DeferredCache[CacheKey, Any] = cached_method.cache
         num_args = cached_method.num_args
         num_args = cached_method.num_args
 
 
         @functools.wraps(self.orig)
         @functools.wraps(self.orig)
-        def wrapped(*args, **kwargs):
+        def wrapped(*args: Any, **kwargs: Any) -> Any:
             # If we're passed a cache_context then we'll want to call its
             # If we're passed a cache_context then we'll want to call its
             # invalidate() whenever we are invalidated
             # invalidate() whenever we are invalidated
             invalidate_callback = kwargs.pop("on_invalidate", None)
             invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -377,7 +394,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
 
 
             results = {}
             results = {}
 
 
-            def update_results_dict(res, arg):
+            def update_results_dict(res: Any, arg: Hashable) -> None:
                 results[arg] = res
                 results[arg] = res
 
 
             # list of deferreds to wait for
             # list of deferreds to wait for
@@ -389,13 +406,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             # otherwise a tuple is used.
             # otherwise a tuple is used.
             if num_args == 1:
             if num_args == 1:
 
 
-                def arg_to_cache_key(arg):
+                def arg_to_cache_key(arg: Hashable) -> Hashable:
                     return arg
                     return arg
 
 
             else:
             else:
                 keylist = list(keyargs)
                 keylist = list(keyargs)
 
 
-                def arg_to_cache_key(arg):
+                def arg_to_cache_key(arg: Hashable) -> Hashable:
                     keylist[self.list_pos] = arg
                     keylist[self.list_pos] = arg
                     return tuple(keylist)
                     return tuple(keylist)
 
 
@@ -421,7 +438,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                     key = arg_to_cache_key(arg)
                     key = arg_to_cache_key(arg)
                     cache.set(key, deferred, callback=invalidate_callback)
                     cache.set(key, deferred, callback=invalidate_callback)
 
 
-                def complete_all(res):
+                def complete_all(res: Dict[Hashable, Any]) -> None:
                     # the wrapped function has completed. It returns a
                     # the wrapped function has completed. It returns a
                     # a dict. We can now resolve the observable deferreds in
                     # a dict. We can now resolve the observable deferreds in
                     # the cache and update our own result map.
                     # the cache and update our own result map.
@@ -430,7 +447,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                         deferreds_map[e].callback(val)
                         deferreds_map[e].callback(val)
                         results[e] = val
                         results[e] = val
 
 
-                def errback(f):
+                def errback(f: Failure) -> Failure:
                     # the wrapped function has failed. Invalidate any cache
                     # the wrapped function has failed. Invalidate any cache
                     # entries we're supposed to be populating, and fail
                     # entries we're supposed to be populating, and fail
                     # their deferreds.
                     # their deferreds.

+ 6 - 4
synapse/util/caches/expiringcache.py

@@ -19,6 +19,8 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
 import attr
 import attr
 from typing_extensions import Literal
 from typing_extensions import Literal
 
 
+from twisted.internet import defer
+
 from synapse.config import cache as cache_config
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.util import Clock
 from synapse.util import Clock
@@ -81,7 +83,7 @@ class ExpiringCache(Generic[KT, VT]):
             # Don't bother starting the loop if things never expire
             # Don't bother starting the loop if things never expire
             return
             return
 
 
-        def f():
+        def f() -> "defer.Deferred[None]":
             return run_as_background_process(
             return run_as_background_process(
                 "prune_cache_%s" % self._cache_name, self._prune_cache
                 "prune_cache_%s" % self._cache_name, self._prune_cache
             )
             )
@@ -210,7 +212,7 @@ class ExpiringCache(Generic[KT, VT]):
         return False
         return False
 
 
 
 
-@attr.s(slots=True)
+@attr.s(slots=True, auto_attribs=True)
 class _CacheEntry:
 class _CacheEntry:
-    time = attr.ib(type=int)
-    value = attr.ib()
+    time: int
+    value: Any

+ 6 - 5
synapse/util/distributor.py

@@ -18,12 +18,13 @@ from twisted.internet import defer
 
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
 from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.async_helpers import maybe_awaitable
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
-def user_left_room(distributor, user, room_id):
+def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None:
     distributor.fire("user_left_room", user=user, room_id=room_id)
     distributor.fire("user_left_room", user=user, room_id=room_id)
 
 
 
 
@@ -63,7 +64,7 @@ class Distributor:
                 self.pre_registration[name] = []
                 self.pre_registration[name] = []
             self.pre_registration[name].append(observer)
             self.pre_registration[name].append(observer)
 
 
-    def fire(self, name: str, *args, **kwargs) -> None:
+    def fire(self, name: str, *args: Any, **kwargs: Any) -> None:
         """Dispatches the given signal to the registered observers.
         """Dispatches the given signal to the registered observers.
 
 
         Runs the observers as a background process. Does not return a deferred.
         Runs the observers as a background process. Does not return a deferred.
@@ -95,7 +96,7 @@ class Signal:
         Each observer callable may return a Deferred."""
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
         self.observers.append(observer)
 
 
-    def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
+    def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]":
         """Invokes every callable in the observer list, passing in the args and
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
         kwargs. Exceptions thrown by observers are logged but ignored. It is
         not an error to fire a signal with no observers.
         not an error to fire a signal with no observers.
@@ -103,7 +104,7 @@ class Signal:
         Returns a Deferred that will complete when all the observers have
         Returns a Deferred that will complete when all the observers have
         completed."""
         completed."""
 
 
-        async def do(observer):
+        async def do(observer: Callable[..., Any]) -> Any:
             try:
             try:
                 return await maybe_awaitable(observer(*args, **kwargs))
                 return await maybe_awaitable(observer(*args, **kwargs))
             except Exception as e:
             except Exception as e:
@@ -120,5 +121,5 @@ class Signal:
             defer.gatherResults(deferreds, consumeErrors=True)
             defer.gatherResults(deferreds, consumeErrors=True)
         )
         )
 
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<Signal name=%r>" % (self.name,)
         return "<Signal name=%r>" % (self.name,)

+ 61 - 14
synapse/util/gai_resolver.py

@@ -3,23 +3,52 @@
 # We copy it here as we need to instantiate `GAIResolver` manually, but it is a
 # We copy it here as we need to instantiate `GAIResolver` manually, but it is a
 # private class.
 # private class.
 
 
-
 from socket import (
 from socket import (
     AF_INET,
     AF_INET,
     AF_INET6,
     AF_INET6,
     AF_UNSPEC,
     AF_UNSPEC,
     SOCK_DGRAM,
     SOCK_DGRAM,
     SOCK_STREAM,
     SOCK_STREAM,
+    AddressFamily,
+    SocketKind,
     gaierror,
     gaierror,
     getaddrinfo,
     getaddrinfo,
 )
 )
+from typing import (
+    TYPE_CHECKING,
+    Callable,
+    List,
+    NoReturn,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    Union,
+)
 
 
 from zope.interface import implementer
 from zope.interface import implementer
 
 
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.address import IPv4Address, IPv6Address
-from twisted.internet.interfaces import IHostnameResolver, IHostResolution
+from twisted.internet.interfaces import (
+    IAddress,
+    IHostnameResolver,
+    IHostResolution,
+    IReactorThreads,
+    IResolutionReceiver,
+)
 from twisted.internet.threads import deferToThreadPool
 from twisted.internet.threads import deferToThreadPool
 
 
+if TYPE_CHECKING:
+    # The types below are copied from
+    # https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
+    # so that the type hints can match the interfaces.
+    from twisted.python.runtime import platform
+
+    if platform.supportsThreads():
+        from twisted.python.threadpool import ThreadPool
+    else:
+        ThreadPool = object  # type: ignore[misc, assignment]
+
 
 
 @implementer(IHostResolution)
 @implementer(IHostResolution)
 class HostResolution:
 class HostResolution:
@@ -27,13 +56,13 @@ class HostResolution:
     The in-progress resolution of a given hostname.
     The in-progress resolution of a given hostname.
     """
     """
 
 
-    def __init__(self, name):
+    def __init__(self, name: str):
         """
         """
         Create a L{HostResolution} with the given name.
         Create a L{HostResolution} with the given name.
         """
         """
         self.name = name
         self.name = name
 
 
-    def cancel(self):
+    def cancel(self) -> NoReturn:
         # IHostResolution.cancel
         # IHostResolution.cancel
         raise NotImplementedError()
         raise NotImplementedError()
 
 
@@ -62,6 +91,17 @@ _socktypeToType = {
 }
 }
 
 
 
 
+_GETADDRINFO_RESULT = List[
+    Tuple[
+        AddressFamily,
+        SocketKind,
+        int,
+        str,
+        Union[Tuple[str, int], Tuple[str, int, int, int]],
+    ]
+]
+
+
 @implementer(IHostnameResolver)
 @implementer(IHostnameResolver)
 class GAIResolver:
 class GAIResolver:
     """
     """
@@ -69,7 +109,12 @@ class GAIResolver:
     L{getaddrinfo} in a thread.
     L{getaddrinfo} in a thread.
     """
     """
 
 
-    def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
+    def __init__(
+        self,
+        reactor: IReactorThreads,
+        getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
+        getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
+    ):
         """
         """
         Create a L{GAIResolver}.
         Create a L{GAIResolver}.
         @param reactor: the reactor to schedule result-delivery on
         @param reactor: the reactor to schedule result-delivery on
@@ -89,14 +134,16 @@ class GAIResolver:
         )
         )
         self._getaddrinfo = getaddrinfo
         self._getaddrinfo = getaddrinfo
 
 
-    def resolveHostName(
+    # The types on IHostnameResolver is incorrect in Twisted, see
+    # https://twistedmatrix.com/trac/ticket/10276
+    def resolveHostName(  # type: ignore[override]
         self,
         self,
-        resolutionReceiver,
-        hostName,
-        portNumber=0,
-        addressTypes=None,
-        transportSemantics="TCP",
-    ):
+        resolutionReceiver: IResolutionReceiver,
+        hostName: str,
+        portNumber: int = 0,
+        addressTypes: Optional[Sequence[Type[IAddress]]] = None,
+        transportSemantics: str = "TCP",
+    ) -> IHostResolution:
         """
         """
         See L{IHostnameResolver.resolveHostName}
         See L{IHostnameResolver.resolveHostName}
         @param resolutionReceiver: see interface
         @param resolutionReceiver: see interface
@@ -112,7 +159,7 @@ class GAIResolver:
         ]
         ]
         socketType = _transportToSocket[transportSemantics]
         socketType = _transportToSocket[transportSemantics]
 
 
-        def get():
+        def get() -> _GETADDRINFO_RESULT:
             try:
             try:
                 return self._getaddrinfo(
                 return self._getaddrinfo(
                     hostName, portNumber, addressFamily, socketType
                     hostName, portNumber, addressFamily, socketType
@@ -125,7 +172,7 @@ class GAIResolver:
         resolutionReceiver.resolutionBegan(resolution)
         resolutionReceiver.resolutionBegan(resolution)
 
 
         @d.addCallback
         @d.addCallback
-        def deliverResults(result):
+        def deliverResults(result: _GETADDRINFO_RESULT) -> None:
             for family, socktype, _proto, _cannoname, sockaddr in result:
             for family, socktype, _proto, _cannoname, sockaddr in result:
                 addrType = _afToType[family]
                 addrType = _afToType[family]
                 resolutionReceiver.addressResolved(
                 resolutionReceiver.addressResolved(

+ 8 - 1
synapse/util/metrics.py

@@ -64,6 +64,13 @@ in_flight = InFlightGauge(
     sub_metrics=["real_time_max", "real_time_sum"],
     sub_metrics=["real_time_max", "real_time_sum"],
 )
 )
 
 
+
+# This is dynamically created in InFlightGauge.__init__.
+class _InFlightMetric(Protocol):
+    real_time_max: float
+    real_time_sum: float
+
+
 T = TypeVar("T", bound=Callable[..., Any])
 T = TypeVar("T", bound=Callable[..., Any])
 
 
 
 
@@ -180,7 +187,7 @@ class Measure:
         """
         """
         return self._logging_context.get_resource_usage()
         return self._logging_context.get_resource_usage()
 
 
-    def _update_in_flight(self, metrics) -> None:
+    def _update_in_flight(self, metrics: _InFlightMetric) -> None:
         """Gets called when processing in flight metrics"""
         """Gets called when processing in flight metrics"""
         assert self.start is not None
         assert self.start is not None
         duration = self.clock.time() - self.start
         duration = self.clock.time() - self.start