Browse Source

Add types to synapse.util. (#10601)

reivilibre 2 years ago
parent
commit
524b8ead77
41 changed files with 401 additions and 254 deletions
  1. 1 0
      changelog.d/10601.misc
  2. 64 11
      mypy.ini
  3. 1 1
      stubs/txredisapi.pyi
  4. 4 4
      synapse/api/ratelimiting.py
  5. 17 16
      synapse/config/ratelimiting.py
  6. 6 2
      synapse/federation/sender/__init__.py
  7. 1 0
      synapse/handlers/account_validity.py
  8. 3 0
      synapse/handlers/appservice.py
  9. 3 2
      synapse/handlers/presence.py
  10. 1 1
      synapse/handlers/typing.py
  11. 7 4
      synapse/rest/client/register.py
  12. 1 1
      synapse/rest/synapse/client/new_user_consent.py
  13. 1 1
      synapse/rest/synapse/client/pick_username.py
  14. 1 0
      synapse/storage/databases/main/registration.py
  15. 7 1
      synapse/types.py
  16. 24 16
      synapse/util/__init__.py
  17. 10 6
      synapse/util/async_helpers.py
  18. 1 1
      synapse/util/batching_queue.py
  19. 7 7
      synapse/util/caches/__init__.py
  20. 7 7
      synapse/util/caches/deferred_cache.py
  21. 14 10
      synapse/util/caches/dictionary_cache.py
  22. 3 2
      synapse/util/caches/lrucache.py
  23. 1 1
      synapse/util/caches/stream_change_cache.py
  24. 8 8
      synapse/util/caches/treecache.py
  25. 1 1
      synapse/util/daemonize.py
  26. 12 11
      synapse/util/distributor.py
  27. 29 19
      synapse/util/file_consumer.py
  28. 3 2
      synapse/util/frozenutils.py
  29. 14 13
      synapse/util/httpresourcetree.py
  30. 4 4
      synapse/util/linked_list.py
  31. 1 1
      synapse/util/macaroons.py
  32. 34 18
      synapse/util/manhole.py
  33. 2 2
      synapse/util/patch_inline_callbacks.py
  34. 31 26
      synapse/util/ratelimitutils.py
  35. 42 27
      synapse/util/retryutils.py
  36. 1 1
      synapse/util/rlimit.py
  37. 4 4
      synapse/util/templates.py
  38. 8 4
      synapse/util/threepids.py
  39. 1 1
      synapse/util/versionstring.py
  40. 19 16
      synapse/util/wheel_timer.py
  41. 2 2
      tests/unittest.py

+ 1 - 0
changelog.d/10601.misc

@@ -0,0 +1 @@
+Add type annotations to the synapse.util package.

+ 64 - 11
mypy.ini

@@ -74,17 +74,7 @@ files =
   synapse/storage/util,
   synapse/streams,
   synapse/types.py,
-  synapse/util/async_helpers.py,
-  synapse/util/caches,
-  synapse/util/daemonize.py,
-  synapse/util/hash.py,
-  synapse/util/iterutils.py,
-  synapse/util/linked_list.py,
-  synapse/util/metrics.py,
-  synapse/util/macaroons.py,
-  synapse/util/module_loader.py,
-  synapse/util/msisdn.py,
-  synapse/util/stringutils.py,
+  synapse/util,
   synapse/visibility.py,
   tests/replication,
   tests/test_event_auth.py,
@@ -102,6 +92,69 @@ files =
 [mypy-synapse.rest.client.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.util.batching_queue]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.caches.dictionary_cache]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.file_consumer]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.frozenutils]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.hash]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.httpresourcetree]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.iterutils]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.linked_list]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.logcontext]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.logformatter]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.macaroons]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.manhole]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.module_loader]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.msisdn]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.ratelimitutils]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.retryutils]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.rlimit]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.stringutils]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.templates]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.threepids]
+disallow_untyped_defs = True
+
+[mypy-synapse.util.wheel_timer]
+disallow_untyped_defs = True
+
 [mypy-pymacaroons.*]
 ignore_missing_imports = True
 

+ 1 - 1
stubs/txredisapi.pyi

@@ -73,4 +73,4 @@ class RedisFactory(protocol.ReconnectingClientFactory):
     def buildProtocol(self, addr) -> RedisProtocol: ...
 
 class SubscriberFactory(RedisFactory):
-    def __init__(self): ...
+    def __init__(self) -> None: ...

+ 4 - 4
synapse/api/ratelimiting.py

@@ -46,7 +46,7 @@ class Ratelimiter:
         #   * How many times an action has occurred since a point in time
         #   * The point in time
         #   * The rate_hz of this particular entry. This can vary per request
-        self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict()
+        self.actions: OrderedDict[Hashable, Tuple[float, float, float]] = OrderedDict()
 
     async def can_do_action(
         self,
@@ -56,7 +56,7 @@ class Ratelimiter:
         burst_count: Optional[int] = None,
         update: bool = True,
         n_actions: int = 1,
-        _time_now_s: Optional[int] = None,
+        _time_now_s: Optional[float] = None,
     ) -> Tuple[bool, float]:
         """Can the entity (e.g. user or IP address) perform the action?
 
@@ -160,7 +160,7 @@ class Ratelimiter:
 
         return allowed, time_allowed
 
-    def _prune_message_counts(self, time_now_s: int):
+    def _prune_message_counts(self, time_now_s: float):
         """Remove message count entries that have not exceeded their defined
         rate_hz limit
 
@@ -188,7 +188,7 @@ class Ratelimiter:
         burst_count: Optional[int] = None,
         update: bool = True,
         n_actions: int = 1,
-        _time_now_s: Optional[int] = None,
+        _time_now_s: Optional[float] = None,
     ):
         """Checks if an action can be performed. If not, raises a LimitExceededError
 

+ 17 - 16
synapse/config/ratelimiting.py

@@ -14,6 +14,8 @@
 
 from typing import Dict, Optional
 
+import attr
+
 from ._base import Config
 
 
@@ -29,18 +31,13 @@ class RateLimitConfig:
         self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
 
 
+@attr.s(auto_attribs=True)
 class FederationRateLimitConfig:
-    _items_and_default = {
-        "window_size": 1000,
-        "sleep_limit": 10,
-        "sleep_delay": 500,
-        "reject_limit": 50,
-        "concurrent": 3,
-    }
-
-    def __init__(self, **kwargs):
-        for i in self._items_and_default.keys():
-            setattr(self, i, kwargs.get(i) or self._items_and_default[i])
+    window_size: int = 1000
+    sleep_limit: int = 10
+    sleep_delay: int = 500
+    reject_limit: int = 50
+    concurrent: int = 3
 
 
 class RatelimitConfig(Config):
@@ -69,11 +66,15 @@ class RatelimitConfig(Config):
         else:
             self.rc_federation = FederationRateLimitConfig(
                 **{
-                    "window_size": config.get("federation_rc_window_size"),
-                    "sleep_limit": config.get("federation_rc_sleep_limit"),
-                    "sleep_delay": config.get("federation_rc_sleep_delay"),
-                    "reject_limit": config.get("federation_rc_reject_limit"),
-                    "concurrent": config.get("federation_rc_concurrent"),
+                    k: v
+                    for k, v in {
+                        "window_size": config.get("federation_rc_window_size"),
+                        "sleep_limit": config.get("federation_rc_sleep_limit"),
+                        "sleep_delay": config.get("federation_rc_sleep_delay"),
+                        "reject_limit": config.get("federation_rc_reject_limit"),
+                        "concurrent": config.get("federation_rc_concurrent"),
+                    }.items()
+                    if v is not None
                 }
             )
 

+ 6 - 2
synapse/federation/sender/__init__.py

@@ -22,6 +22,7 @@ from prometheus_client import Counter
 from typing_extensions import Literal
 
 from twisted.internet import defer
+from twisted.internet.interfaces import IDelayedCall
 
 import synapse.metrics
 from synapse.api.presence import UserPresenceState
@@ -284,7 +285,9 @@ class FederationSender(AbstractFederationSender):
         )
 
         # wake up destinations that have outstanding PDUs to be caught up
-        self._catchup_after_startup_timer = self.clock.call_later(
+        self._catchup_after_startup_timer: Optional[
+            IDelayedCall
+        ] = self.clock.call_later(
             CATCH_UP_STARTUP_DELAY_SEC,
             run_as_background_process,
             "wake_destinations_needing_catchup",
@@ -406,7 +409,7 @@ class FederationSender(AbstractFederationSender):
 
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(event.event_id)
-
+                        assert ts is not None
                         synapse.metrics.event_processing_lag_by_event.labels(
                             "federation_sender"
                         ).observe((now - ts) / 1000)
@@ -435,6 +438,7 @@ class FederationSender(AbstractFederationSender):
                 if events:
                     now = self.clock.time_msec()
                     ts = await self.store.get_received_ts(events[-1].event_id)
+                    assert ts is not None
 
                     synapse.metrics.event_processing_lag.labels(
                         "federation_sender"

+ 1 - 0
synapse/handlers/account_validity.py

@@ -398,6 +398,7 @@ class AccountValidityHandler:
         """
         now = self.clock.time_msec()
         if expiration_ts is None:
+            assert self._account_validity_period is not None
             expiration_ts = now + self._account_validity_period
 
         await self.store.set_account_validity_for_user(

+ 3 - 0
synapse/handlers/appservice.py

@@ -131,6 +131,8 @@ class ApplicationServicesHandler:
 
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(event.event_id)
+                        assert ts is not None
+
                         synapse.metrics.event_processing_lag_by_event.labels(
                             "appservice_sender"
                         ).observe((now - ts) / 1000)
@@ -166,6 +168,7 @@ class ApplicationServicesHandler:
                     if events:
                         now = self.clock.time_msec()
                         ts = await self.store.get_received_ts(events[-1].event_id)
+                        assert ts is not None
 
                         synapse.metrics.event_processing_lag.labels(
                             "appservice_sender"

+ 3 - 2
synapse/handlers/presence.py

@@ -28,6 +28,7 @@ from bisect import bisect
 from contextlib import contextmanager
 from typing import (
     TYPE_CHECKING,
+    Any,
     Callable,
     Collection,
     Dict,
@@ -615,7 +616,7 @@ class PresenceHandler(BasePresenceHandler):
         super().__init__(hs)
         self.hs = hs
         self.server_name = hs.hostname
-        self.wheel_timer = WheelTimer()
+        self.wheel_timer: WheelTimer[str] = WheelTimer()
         self.notifier = hs.get_notifier()
         self._presence_enabled = hs.config.use_presence
 
@@ -924,7 +925,7 @@ class PresenceHandler(BasePresenceHandler):
 
         prev_state = await self.current_state_for_user(user_id)
 
-        new_fields = {"last_active_ts": self.clock.time_msec()}
+        new_fields: Dict[str, Any] = {"last_active_ts": self.clock.time_msec()}
         if prev_state.state == PresenceState.UNAVAILABLE:
             new_fields["state"] = PresenceState.ONLINE
 

+ 1 - 1
synapse/handlers/typing.py

@@ -73,7 +73,7 @@ class FollowerTypingHandler:
         self._room_typing: Dict[str, Set[str]] = {}
 
         self._member_last_federation_poke: Dict[RoomMember, int] = {}
-        self.wheel_timer = WheelTimer(bucket_size=5000)
+        self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000)
         self._latest_room_serial = 0
 
         self.clock.looping_call(self._handle_timeouts, 5000)

+ 7 - 4
synapse/rest/client/register.py

@@ -330,11 +330,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
                 # Artificially delay requests if rate > sleep_limit/window_size
                 sleep_limit=1,
                 # Amount of artificial delay to apply
-                sleep_msec=1000,
+                sleep_delay=1000,
                 # Error with 429 if more than reject_limit requests are queued
                 reject_limit=1,
                 # Allow 1 request at a time
-                concurrent_requests=1,
+                concurrent=1,
             ),
         )
 
@@ -763,7 +763,10 @@ class RegisterRestServlet(RestServlet):
         Returns:
              dictionary for response from /register
         """
-        result = {"user_id": user_id, "home_server": self.hs.hostname}
+        result: JsonDict = {
+            "user_id": user_id,
+            "home_server": self.hs.hostname,
+        }
         if not params.get("inhibit_login", False):
             device_id = params.get("device_id")
             initial_display_name = params.get("initial_device_display_name")
@@ -814,7 +817,7 @@ class RegisterRestServlet(RestServlet):
             user_id, device_id, initial_display_name, is_guest=True
         )
 
-        result = {
+        result: JsonDict = {
             "user_id": user_id,
             "device_id": device_id,
             "access_token": access_token,

+ 1 - 1
synapse/rest/synapse/client/new_user_consent.py

@@ -52,7 +52,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
                 yield hs.config.sso.sso_template_dir
             yield hs.config.sso.default_template_dir
 
-        self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+        self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
 
     async def _async_render_GET(self, request: Request) -> None:
         try:

+ 1 - 1
synapse/rest/synapse/client/pick_username.py

@@ -80,7 +80,7 @@ class AccountDetailsResource(DirectServeHtmlResource):
                 yield hs.config.sso.sso_template_dir
             yield hs.config.sso.default_template_dir
 
-        self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+        self._jinja_env = build_jinja_env(list(template_search_dirs()), hs.config)
 
     async def _async_render_GET(self, request: Request) -> None:
         try:

+ 1 - 0
synapse/storage/databases/main/registration.py

@@ -1091,6 +1091,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 delta equal to 10% of the validity period.
         """
         now_ms = self._clock.time_msec()
+        assert self._account_validity_period is not None
         expiration_ts = now_ms + self._account_validity_period
 
         if use_delta:

+ 7 - 1
synapse/types.py

@@ -38,6 +38,7 @@ from twisted.internet.interfaces import (
     IReactorCore,
     IReactorPluggableNameResolver,
     IReactorTCP,
+    IReactorThreads,
     IReactorTime,
 )
 
@@ -63,7 +64,12 @@ JsonDict = Dict[str, Any]
 # Note that this seems to require inheriting *directly* from Interface in order
 # for mypy-zope to realize it is an interface.
 class ISynapseReactor(
-    IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
+    IReactorTCP,
+    IReactorPluggableNameResolver,
+    IReactorTime,
+    IReactorCore,
+    IReactorThreads,
+    Interface,
 ):
     """The interfaces necessary for Synapse to function."""
 

+ 24 - 16
synapse/util/__init__.py

@@ -15,27 +15,35 @@
 import json
 import logging
 import re
-from typing import Pattern
+import typing
+from typing import Any, Callable, Dict, Generator, Pattern
 
 import attr
 from frozendict import frozendict
 
 from twisted.internet import defer, task
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IDelayedCall, IReactorTime
+from twisted.internet.task import LoopingCall
+from twisted.python.failure import Failure
 
 from synapse.logging import context
 
+if typing.TYPE_CHECKING:
+    pass
+
 logger = logging.getLogger(__name__)
 
 
 _WILDCARD_RUN = re.compile(r"([\?\*]+)")
 
 
-def _reject_invalid_json(val):
+def _reject_invalid_json(val: Any) -> None:
     """Do not allow Infinity, -Infinity, or NaN values in JSON."""
     raise ValueError("Invalid JSON value: '%s'" % val)
 
 
-def _handle_frozendict(obj):
+def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
     """Helper for json_encoder. Makes frozendicts serializable by returning
     the underlying dict
     """
@@ -60,10 +68,10 @@ json_encoder = json.JSONEncoder(
 json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
 
 
-def unwrapFirstError(failure):
+def unwrapFirstError(failure: Failure) -> Failure:
     # defer.gatherResults and DeferredLists wrap failures.
     failure.trap(defer.FirstError)
-    return failure.value.subFailure
+    return failure.value.subFailure  # type: ignore[union-attr]  # Issue in Twisted's annotations
 
 
 @attr.s(slots=True)
@@ -75,25 +83,25 @@ class Clock:
         reactor: The Twisted reactor to use.
     """
 
-    _reactor = attr.ib()
+    _reactor: IReactorTime = attr.ib()
 
-    @defer.inlineCallbacks
-    def sleep(self, seconds):
-        d = defer.Deferred()
+    @defer.inlineCallbacks  # type: ignore[arg-type]  # Issue in Twisted's type annotations
+    def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]":
+        d: defer.Deferred[float] = defer.Deferred()
         with context.PreserveLoggingContext():
             self._reactor.callLater(seconds, d.callback, seconds)
             res = yield d
         return res
 
-    def time(self):
+    def time(self) -> float:
         """Returns the current system time in seconds since epoch."""
         return self._reactor.seconds()
 
-    def time_msec(self):
+    def time_msec(self) -> int:
         """Returns the current system time in milliseconds since epoch."""
         return int(self.time() * 1000)
 
-    def looping_call(self, f, msec, *args, **kwargs):
+    def looping_call(self, f: Callable, msec: float, *args, **kwargs) -> LoopingCall:
         """Call a function repeatedly.
 
         Waits `msec` initially before calling `f` for the first time.
@@ -102,8 +110,8 @@ class Clock:
         other than trivial, you probably want to wrap it in run_as_background_process.
 
         Args:
-            f(function): The function to call repeatedly.
-            msec(float): How long to wait between calls in milliseconds.
+            f: The function to call repeatedly.
+            msec: How long to wait between calls in milliseconds.
             *args: Postional arguments to pass to function.
             **kwargs: Key arguments to pass to function.
         """
@@ -113,7 +121,7 @@ class Clock:
         d.addErrback(log_failure, "Looping call died", consumeErrors=False)
         return call
 
-    def call_later(self, delay, callback, *args, **kwargs):
+    def call_later(self, delay, callback, *args, **kwargs) -> IDelayedCall:
         """Call something later
 
         Note that the function will be called with no logcontext, so if it is anything
@@ -133,7 +141,7 @@ class Clock:
         with context.PreserveLoggingContext():
             return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
 
-    def cancel_call_later(self, timer, ignore_errs=False):
+    def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None:
         try:
             timer.cancel()
         except Exception:

+ 10 - 6
synapse/util/async_helpers.py

@@ -37,6 +37,7 @@ import attr
 from typing_extensions import ContextManager
 
 from twisted.internet import defer
+from twisted.internet.base import ReactorBase
 from twisted.internet.defer import CancelledError
 from twisted.internet.interfaces import IReactorTime
 from twisted.python import failure
@@ -268,6 +269,7 @@ class Linearizer:
         if not clock:
             from twisted.internet import reactor
 
+            assert isinstance(reactor, ReactorBase)
             clock = Clock(reactor)
         self._clock = clock
         self.max_count = max_count
@@ -411,7 +413,7 @@ class ReadWriteLock:
     # writers and readers have been resolved. The new writer replaces the latest
     # writer.
 
-    def __init__(self):
+    def __init__(self) -> None:
         # Latest readers queued
         self.key_to_current_readers: Dict[str, Set[defer.Deferred]] = {}
 
@@ -503,7 +505,7 @@ def timeout_deferred(
 
     timed_out = [False]
 
-    def time_it_out():
+    def time_it_out() -> None:
         timed_out[0] = True
 
         try:
@@ -550,19 +552,21 @@ def timeout_deferred(
     return new_d
 
 
+# This class can't be generic because it uses slots with attrs.
+# See: https://github.com/python-attrs/attrs/issues/313
 @attr.s(slots=True, frozen=True)
-class DoneAwaitable:
+class DoneAwaitable:  # should be: Generic[R]
     """Simple awaitable that returns the provided value."""
 
-    value = attr.ib()
+    value = attr.ib(type=Any)  # should be: R
 
     def __await__(self):
         return self
 
-    def __iter__(self):
+    def __iter__(self) -> "DoneAwaitable":
         return self
 
-    def __next__(self):
+    def __next__(self) -> None:
         raise StopIteration(self.value)
 
 

+ 1 - 1
synapse/util/batching_queue.py

@@ -122,7 +122,7 @@ class BatchingQueue(Generic[V, R]):
 
         # First we create a defer and add it and the value to the list of
         # pending items.
-        d = defer.Deferred()
+        d: defer.Deferred[R] = defer.Deferred()
         self._next_values.setdefault(key, []).append((value, d))
 
         # If we're not currently processing the key fire off a background

+ 7 - 7
synapse/util/caches/__init__.py

@@ -64,32 +64,32 @@ class CacheMetric:
     evicted_size = attr.ib(default=0)
     memory_usage = attr.ib(default=None)
 
-    def inc_hits(self):
+    def inc_hits(self) -> None:
         self.hits += 1
 
-    def inc_misses(self):
+    def inc_misses(self) -> None:
         self.misses += 1
 
-    def inc_evictions(self, size=1):
+    def inc_evictions(self, size: int = 1) -> None:
         self.evicted_size += size
 
-    def inc_memory_usage(self, memory: int):
+    def inc_memory_usage(self, memory: int) -> None:
         if self.memory_usage is None:
             self.memory_usage = 0
 
         self.memory_usage += memory
 
-    def dec_memory_usage(self, memory: int):
+    def dec_memory_usage(self, memory: int) -> None:
         self.memory_usage -= memory
 
-    def clear_memory_usage(self):
+    def clear_memory_usage(self) -> None:
         if self.memory_usage is not None:
             self.memory_usage = 0
 
     def describe(self):
         return []
 
-    def collect(self):
+    def collect(self) -> None:
         try:
             if self._cache_type == "response_cache":
                 response_cache_size.labels(self._cache_name).set(len(self._cache))

+ 7 - 7
synapse/util/caches/deferred_cache.py

@@ -93,7 +93,7 @@ class DeferredCache(Generic[KT, VT]):
             TreeCache, "MutableMapping[KT, CacheEntry]"
         ] = cache_type()
 
-        def metrics_cb():
+        def metrics_cb() -> None:
             cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
 
         # cache is used for completed results and maps to the result itself, rather than
@@ -113,7 +113,7 @@ class DeferredCache(Generic[KT, VT]):
     def max_entries(self):
         return self.cache.max_size
 
-    def check_thread(self):
+    def check_thread(self) -> None:
         expected_thread = self.thread
         if expected_thread is None:
             self.thread = threading.current_thread()
@@ -235,7 +235,7 @@ class DeferredCache(Generic[KT, VT]):
 
         self._pending_deferred_cache[key] = entry
 
-        def compare_and_pop():
+        def compare_and_pop() -> bool:
             """Check if our entry is still the one in _pending_deferred_cache, and
             if so, pop it.
 
@@ -256,7 +256,7 @@ class DeferredCache(Generic[KT, VT]):
 
             return False
 
-        def cb(result):
+        def cb(result) -> None:
             if compare_and_pop():
                 self.cache.set(key, result, entry.callbacks)
             else:
@@ -268,7 +268,7 @@ class DeferredCache(Generic[KT, VT]):
                 # not have been. Either way, let's double-check now.
                 entry.invalidate()
 
-        def eb(_fail):
+        def eb(_fail) -> None:
             compare_and_pop()
             entry.invalidate()
 
@@ -314,7 +314,7 @@ class DeferredCache(Generic[KT, VT]):
             for entry in iterate_tree_cache_entry(entry):
                 entry.invalidate()
 
-    def invalidate_all(self):
+    def invalidate_all(self) -> None:
         self.check_thread()
         self.cache.clear()
         for entry in self._pending_deferred_cache.values():
@@ -332,7 +332,7 @@ class CacheEntry:
         self.callbacks = set(callbacks)
         self.invalidated = False
 
-    def invalidate(self):
+    def invalidate(self) -> None:
         if not self.invalidated:
             self.invalidated = True
             for callback in self.callbacks:

+ 14 - 10
synapse/util/caches/dictionary_cache.py

@@ -27,10 +27,14 @@ logger = logging.getLogger(__name__)
 KT = TypeVar("KT")
 # The type of the dictionary keys.
 DKT = TypeVar("DKT")
+# The type of the dictionary values.
+DV = TypeVar("DV")
 
 
+# This class can't be generic because it uses slots with attrs.
+# See: https://github.com/python-attrs/attrs/issues/313
 @attr.s(slots=True)
-class DictionaryEntry:
+class DictionaryEntry:  # should be: Generic[DKT, DV].
     """Returned when getting an entry from the cache
 
     Attributes:
@@ -43,10 +47,10 @@ class DictionaryEntry:
     """
 
     full = attr.ib(type=bool)
-    known_absent = attr.ib()
-    value = attr.ib()
+    known_absent = attr.ib(type=Set[Any])  # should be: Set[DKT]
+    value = attr.ib(type=Dict[Any, Any])  # should be: Dict[DKT, DV]
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.value)
 
 
@@ -56,7 +60,7 @@ class _Sentinel(enum.Enum):
     sentinel = object()
 
 
-class DictionaryCache(Generic[KT, DKT]):
+class DictionaryCache(Generic[KT, DKT, DV]):
     """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
     fetching a subset of dictionary keys for a particular key.
     """
@@ -87,7 +91,7 @@ class DictionaryCache(Generic[KT, DKT]):
 
         Args:
             key
-            dict_key: If given a set of keys then return only those keys
+            dict_keys: If given a set of keys then return only those keys
                 that exist in the cache.
 
         Returns:
@@ -125,7 +129,7 @@ class DictionaryCache(Generic[KT, DKT]):
         self,
         sequence: int,
         key: KT,
-        value: Dict[DKT, Any],
+        value: Dict[DKT, DV],
         fetched_keys: Optional[Set[DKT]] = None,
     ) -> None:
         """Updates the entry in the cache
@@ -151,15 +155,15 @@ class DictionaryCache(Generic[KT, DKT]):
                 self._update_or_insert(key, value, fetched_keys)
 
     def _update_or_insert(
-        self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
+        self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]
     ) -> None:
         # We pop and reinsert as we need to tell the cache the size may have
         # changed
 
-        entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
+        entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
         entry.value.update(value)
         entry.known_absent.update(known_absent)
         self.cache[key] = entry
 
-    def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
+    def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
         self.cache[key] = DictionaryEntry(True, known_absent, value)

+ 3 - 2
synapse/util/caches/lrucache.py

@@ -35,6 +35,7 @@ from typing import (
 from typing_extensions import Literal
 
 from twisted.internet import reactor
+from twisted.internet.interfaces import IReactorTime
 
 from synapse.config import cache as cache_config
 from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -341,7 +342,7 @@ class LruCache(Generic[KT, VT]):
         # Default `clock` to something sensible. Note that we rename it to
         # `real_clock` so that mypy doesn't think its still `Optional`.
         if clock is None:
-            real_clock = Clock(reactor)
+            real_clock = Clock(cast(IReactorTime, reactor))
         else:
             real_clock = clock
 
@@ -384,7 +385,7 @@ class LruCache(Generic[KT, VT]):
 
         lock = threading.Lock()
 
-        def evict():
+        def evict() -> None:
             while cache_len() > self.max_size:
                 # Get the last node in the list (i.e. the oldest node).
                 todelete = list_root.prev_node

+ 1 - 1
synapse/util/caches/stream_change_cache.py

@@ -195,7 +195,7 @@ class StreamChangeCache:
             for entity in r:
                 del self._entity_to_key[entity]
 
-    def _evict(self):
+    def _evict(self) -> None:
         while len(self._cache) > self._max_size:
             k, r = self._cache.popitem(0)
             self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)

+ 8 - 8
synapse/util/caches/treecache.py

@@ -35,17 +35,17 @@ class TreeCache:
         root = {key_1: {key_2: _value}}
     """
 
-    def __init__(self):
-        self.size = 0
+    def __init__(self) -> None:
+        self.size: int = 0
         self.root = TreeCacheNode()
 
-    def __setitem__(self, key, value):
-        return self.set(key, value)
+    def __setitem__(self, key, value) -> None:
+        self.set(key, value)
 
-    def __contains__(self, key):
+    def __contains__(self, key) -> bool:
         return self.get(key, SENTINEL) is not SENTINEL
 
-    def set(self, key, value):
+    def set(self, key, value) -> None:
         if isinstance(value, TreeCacheNode):
             # this would mean we couldn't tell where our tree ended and the value
             # started.
@@ -73,7 +73,7 @@ class TreeCache:
                 return default
         return node.get(key[-1], default)
 
-    def clear(self):
+    def clear(self) -> None:
         self.size = 0
         self.root = TreeCacheNode()
 
@@ -128,7 +128,7 @@ class TreeCache:
     def values(self):
         return iterate_tree_cache_entry(self.root)
 
-    def __len__(self):
+    def __len__(self) -> int:
         return self.size
 
 

+ 1 - 1
synapse/util/daemonize.py

@@ -126,7 +126,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
     signal.signal(signal.SIGTERM, sigterm)
 
     # Cleanup pid file at exit.
-    def exit():
+    def exit() -> None:
         logger.warning("Stopping daemon.")
         os.remove(pid_file)
         sys.exit(0)

+ 12 - 11
synapse/util/distributor.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import Any, Callable, Dict, List
 
 from twisted.internet import defer
 
@@ -37,11 +38,11 @@ class Distributor:
       model will do for today.
     """
 
-    def __init__(self):
-        self.signals = {}
-        self.pre_registration = {}
+    def __init__(self) -> None:
+        self.signals: Dict[str, Signal] = {}
+        self.pre_registration: Dict[str, List[Callable]] = {}
 
-    def declare(self, name):
+    def declare(self, name: str) -> None:
         if name in self.signals:
             raise KeyError("%r already has a signal named %s" % (self, name))
 
@@ -52,7 +53,7 @@ class Distributor:
             for observer in self.pre_registration[name]:
                 signal.observe(observer)
 
-    def observe(self, name, observer):
+    def observe(self, name: str, observer: Callable) -> None:
         if name in self.signals:
             self.signals[name].observe(observer)
         else:
@@ -62,7 +63,7 @@ class Distributor:
                 self.pre_registration[name] = []
             self.pre_registration[name].append(observer)
 
-    def fire(self, name, *args, **kwargs):
+    def fire(self, name: str, *args, **kwargs) -> None:
         """Dispatches the given signal to the registered observers.
 
         Runs the observers as a background process. Does not return a deferred.
@@ -83,18 +84,18 @@ class Signal:
     method into all of the observers.
     """
 
-    def __init__(self, name):
-        self.name = name
-        self.observers = []
+    def __init__(self, name: str):
+        self.name: str = name
+        self.observers: List[Callable] = []
 
-    def observe(self, observer):
+    def observe(self, observer: Callable) -> None:
         """Adds a new callable to the observer list which will be invoked by
         the 'fire' method.
 
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
-    def fire(self, *args, **kwargs):
+    def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
         not an error to fire a signal with no observers.

+ 29 - 19
synapse/util/file_consumer.py

@@ -13,10 +13,14 @@
 # limitations under the License.
 
 import queue
+from typing import BinaryIO, Optional, Union, cast
 
 from twisted.internet import threads
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IPullProducer, IPushProducer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import ISynapseReactor
 
 
 class BackgroundFileConsumer:
@@ -24,9 +28,9 @@ class BackgroundFileConsumer:
     and pull producers
 
     Args:
-        file_obj (file): The file like object to write to. Closed when
+        file_obj: The file like object to write to. Closed when
             finished.
-        reactor (twisted.internet.reactor): the Twisted reactor to use
+        reactor: the Twisted reactor to use
     """
 
     # For PushProducers pause if we have this many unwritten slices
@@ -34,13 +38,13 @@ class BackgroundFileConsumer:
     # And resume once the size of the queue is less than this
     _RESUME_ON_QUEUE_SIZE = 2
 
-    def __init__(self, file_obj, reactor):
-        self._file_obj = file_obj
+    def __init__(self, file_obj: BinaryIO, reactor: ISynapseReactor) -> None:
+        self._file_obj: BinaryIO = file_obj
 
-        self._reactor = reactor
+        self._reactor: ISynapseReactor = reactor
 
         # Producer we're registered with
-        self._producer = None
+        self._producer: Optional[Union[IPushProducer, IPullProducer]] = None
 
         # True if PushProducer, false if PullProducer
         self.streaming = False
@@ -51,20 +55,22 @@ class BackgroundFileConsumer:
 
         # Queue of slices of bytes to be written. When producer calls
         # unregister a final None is sent.
-        self._bytes_queue = queue.Queue()
+        self._bytes_queue: queue.Queue[Optional[bytes]] = queue.Queue()
 
         # Deferred that is resolved when finished writing
-        self._finished_deferred = None
+        self._finished_deferred: Optional[Deferred[None]] = None
 
         # If the _writer thread throws an exception it gets stored here.
-        self._write_exception = None
+        self._write_exception: Optional[Exception] = None
 
-    def registerProducer(self, producer, streaming):
+    def registerProducer(
+        self, producer: Union[IPushProducer, IPullProducer], streaming: bool
+    ) -> None:
         """Part of IConsumer interface
 
         Args:
-            producer (IProducer)
-            streaming (bool): True if push based producer, False if pull
+            producer
+            streaming: True if push based producer, False if pull
                 based.
         """
         if self._producer:
@@ -81,29 +87,33 @@ class BackgroundFileConsumer:
         if not streaming:
             self._producer.resumeProducing()
 
-    def unregisterProducer(self):
+    def unregisterProducer(self) -> None:
         """Part of IProducer interface"""
         self._producer = None
+        assert self._finished_deferred is not None
         if not self._finished_deferred.called:
             self._bytes_queue.put_nowait(None)
 
-    def write(self, bytes):
+    def write(self, write_bytes: bytes) -> None:
         """Part of IProducer interface"""
         if self._write_exception:
             raise self._write_exception
 
+        assert self._finished_deferred is not None
         if self._finished_deferred.called:
             raise Exception("consumer has closed")
 
-        self._bytes_queue.put_nowait(bytes)
+        self._bytes_queue.put_nowait(write_bytes)
 
         # If this is a PushProducer and the queue is getting behind
         # then we pause the producer.
         if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
             self._paused_producer = True
-            self._producer.pauseProducing()
+            assert self._producer is not None
+            # cast safe because `streaming` means this is an IPushProducer
+            cast(IPushProducer, self._producer).pauseProducing()
 
-    def _writer(self):
+    def _writer(self) -> None:
         """This is run in a background thread to write to the file."""
         try:
             while self._producer or not self._bytes_queue.empty():
@@ -130,11 +140,11 @@ class BackgroundFileConsumer:
         finally:
             self._file_obj.close()
 
-    def wait(self):
+    def wait(self) -> "Deferred[None]":
         """Returns a deferred that resolves when finished writing to file"""
         return make_deferred_yieldable(self._finished_deferred)
 
-    def _resume_paused_producer(self):
+    def _resume_paused_producer(self) -> None:
         """Gets called if we should resume producing after being paused"""
         if self._paused_producer and self._producer:
             self._paused_producer = False

+ 3 - 2
synapse/util/frozenutils.py

@@ -11,11 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Any
 
 from frozendict import frozendict
 
 
-def freeze(o):
+def freeze(o: Any) -> Any:
     if isinstance(o, dict):
         return frozendict({k: freeze(v) for k, v in o.items()})
 
@@ -33,7 +34,7 @@ def freeze(o):
     return o
 
 
-def unfreeze(o):
+def unfreeze(o: Any) -> Any:
     if isinstance(o, (dict, frozendict)):
         return {k: unfreeze(v) for k, v in o.items()}
 

+ 14 - 13
synapse/util/httpresourcetree.py

@@ -13,42 +13,43 @@
 # limitations under the License.
 
 import logging
+from typing import Dict
 
-from twisted.web.resource import NoResource
+from twisted.web.resource import NoResource, Resource
 
 logger = logging.getLogger(__name__)
 
 
-def create_resource_tree(desired_tree, root_resource):
+def create_resource_tree(
+    desired_tree: Dict[str, Resource], root_resource: Resource
+) -> Resource:
     """Create the resource tree for this homeserver.
 
     This in unduly complicated because Twisted does not support putting
     child resources more than 1 level deep at a time.
 
     Args:
-        web_client (bool): True to enable the web client.
-        root_resource (twisted.web.resource.Resource): The root
-            resource to add the tree to.
+        desired_tree: Dict from desired paths to desired resources.
+        root_resource: The root resource to add the tree to.
     Returns:
-        twisted.web.resource.Resource: the ``root_resource`` with a tree of
-        child resources added to it.
+        The ``root_resource`` with a tree of child resources added to it.
     """
 
     # ideally we'd just use getChild and putChild but getChild doesn't work
     # unless you give it a Request object IN ADDITION to the name :/ So
     # instead, we'll store a copy of this mapping so we can actually add
     # extra resources to existing nodes. See self._resource_id for the key.
-    resource_mappings = {}
-    for full_path, res in desired_tree.items():
+    resource_mappings: Dict[str, Resource] = {}
+    for full_path_str, res in desired_tree.items():
         # twisted requires all resources to be bytes
-        full_path = full_path.encode("utf-8")
+        full_path = full_path_str.encode("utf-8")
 
         logger.info("Attaching %s to path %s", res, full_path)
         last_resource = root_resource
         for path_seg in full_path.split(b"/")[1:-1]:
             if path_seg not in last_resource.listNames():
                 # resource doesn't exist, so make a "dummy resource"
-                child_resource = NoResource()
+                child_resource: Resource = NoResource()
                 last_resource.putChild(path_seg, child_resource)
                 res_id = _resource_id(last_resource, path_seg)
                 resource_mappings[res_id] = child_resource
@@ -83,7 +84,7 @@ def create_resource_tree(desired_tree, root_resource):
     return root_resource
 
 
-def _resource_id(resource, path_seg):
+def _resource_id(resource: Resource, path_seg: bytes) -> str:
     """Construct an arbitrary resource ID so you can retrieve the mapping
     later.
 
@@ -96,4 +97,4 @@ def _resource_id(resource, path_seg):
     Returns:
         str: A unique string which can be a key to the child Resource.
     """
-    return "%s-%s" % (resource, path_seg)
+    return "%s-%r" % (resource, path_seg)

+ 4 - 4
synapse/util/linked_list.py

@@ -74,7 +74,7 @@ class ListNode(Generic[P]):
             new_node._refs_insert_after(node)
         return new_node
 
-    def remove_from_list(self):
+    def remove_from_list(self) -> None:
         """Remove this node from the list."""
         with self._LOCK:
             self._refs_remove_node_from_list()
@@ -84,7 +84,7 @@ class ListNode(Generic[P]):
         # immediately rather than at the next GC.
         self.cache_entry = None
 
-    def move_after(self, node: "ListNode"):
+    def move_after(self, node: "ListNode") -> None:
         """Move this node from its current location in the list to after the
         given node.
         """
@@ -103,7 +103,7 @@ class ListNode(Generic[P]):
             # Insert self back into the list, after target node
             self._refs_insert_after(node)
 
-    def _refs_remove_node_from_list(self):
+    def _refs_remove_node_from_list(self) -> None:
         """Internal method to *just* remove the node from the list, without
         e.g. clearing out the cache entry.
         """
@@ -122,7 +122,7 @@ class ListNode(Generic[P]):
         self.prev_node = None
         self.next_node = None
 
-    def _refs_insert_after(self, node: "ListNode"):
+    def _refs_insert_after(self, node: "ListNode") -> None:
         """Internal method to insert the node after the given node."""
 
         # This method should only be called when we're not already in the list.

+ 1 - 1
synapse/util/macaroons.py

@@ -77,7 +77,7 @@ def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> N
             should be considered expired. Normally the current time.
     """
 
-    def verify_expiry_caveat(caveat: str):
+    def verify_expiry_caveat(caveat: str) -> bool:
         time_msec = get_time_ms()
         prefix = "time < "
         if not caveat.startswith(prefix):

+ 34 - 18
synapse/util/manhole.py

@@ -15,6 +15,7 @@
 import inspect
 import sys
 import traceback
+from typing import Any, Dict, Optional
 
 from twisted.conch import manhole_ssh
 from twisted.conch.insults import insults
@@ -22,6 +23,9 @@ from twisted.conch.manhole import ColoredManhole, ManholeInterpreter
 from twisted.conch.ssh.keys import Key
 from twisted.cred import checkers, portal
 from twisted.internet import defer
+from twisted.internet.protocol import Factory
+
+from synapse.config.server import ManholeConfig
 
 PUBLIC_KEY = (
     "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDHhGATaW4KhE23+7nrH4jFx3yLq9OjaEs5"
@@ -61,22 +65,22 @@ EddTrx3TNpr1D5m/f+6mnXWrc8u9y1+GNx9yz889xMjIBTBI9KqaaOs=
 -----END RSA PRIVATE KEY-----"""
 
 
-def manhole(settings, globals):
+def manhole(settings: ManholeConfig, globals: Dict[str, Any]) -> Factory:
     """Starts a ssh listener with password authentication using
     the given username and password. Clients connecting to the ssh
     listener will find themselves in a colored python shell with
     the supplied globals.
 
     Args:
-        username(str): The username ssh clients should auth with.
-        password(str): The password ssh clients should auth with.
-        globals(dict): The variables to expose in the shell.
+        username: The username ssh clients should auth with.
+        password: The password ssh clients should auth with.
+        globals: The variables to expose in the shell.
 
     Returns:
-        twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
+        A factory to pass to ``listenTCP``
     """
     username = settings.username
-    password = settings.password
+    password = settings.password.encode("ascii")
     priv_key = settings.priv_key
     if priv_key is None:
         priv_key = Key.fromString(PRIVATE_KEY)
@@ -84,19 +88,22 @@ def manhole(settings, globals):
     if pub_key is None:
         pub_key = Key.fromString(PUBLIC_KEY)
 
-    if not isinstance(password, bytes):
-        password = password.encode("ascii")
-
     checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
 
     rlm = manhole_ssh.TerminalRealm()
-    rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
+    # mypy ignored here because:
+    # - can't deduce types of lambdas
+    # - variable is Type[ServerProtocol], expr is Callable[[], ServerProtocol]
+    rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(  # type: ignore[misc,assignment]
         SynapseManhole, dict(globals, __name__="__console__")
     )
 
     factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
-    factory.privateKeys[b"ssh-rsa"] = priv_key
-    factory.publicKeys[b"ssh-rsa"] = pub_key
+
+    # conch has the wrong type on these dicts (says bytes to bytes,
+    # should be bytes to Keys judging by how it's used).
+    factory.privateKeys[b"ssh-rsa"] = priv_key  # type: ignore[assignment]
+    factory.publicKeys[b"ssh-rsa"] = pub_key  # type: ignore[assignment]
 
     return factory
 
@@ -104,7 +111,7 @@ def manhole(settings, globals):
 class SynapseManhole(ColoredManhole):
     """Overrides connectionMade to create our own ManholeInterpreter"""
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         super().connectionMade()
 
         # replace the manhole interpreter with our own impl
@@ -114,13 +121,14 @@ class SynapseManhole(ColoredManhole):
 
 
 class SynapseManholeInterpreter(ManholeInterpreter):
-    def showsyntaxerror(self, filename=None):
+    def showsyntaxerror(self, filename: Optional[str] = None) -> None:
         """Display the syntax error that just occurred.
 
         Overrides the base implementation, ignoring sys.excepthook. We always want
         any syntax errors to be sent to the terminal, rather than sentry.
         """
         type, value, tb = sys.exc_info()
+        assert value is not None
         sys.last_type = type
         sys.last_value = value
         sys.last_traceback = tb
@@ -138,7 +146,7 @@ class SynapseManholeInterpreter(ManholeInterpreter):
         lines = traceback.format_exception_only(type, value)
         self.write("".join(lines))
 
-    def showtraceback(self):
+    def showtraceback(self) -> None:
         """Display the exception that just occurred.
 
         Overrides the base implementation, ignoring sys.excepthook. We always want
@@ -146,14 +154,22 @@ class SynapseManholeInterpreter(ManholeInterpreter):
         """
         sys.last_type, sys.last_value, last_tb = ei = sys.exc_info()
         sys.last_traceback = last_tb
+        assert last_tb is not None
+
         try:
             # We remove the first stack item because it is our own code.
             lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
             self.write("".join(lines))
         finally:
-            last_tb = ei = None
-
-    def displayhook(self, obj):
+            # On the line below, last_tb and ei appear to be dead.
+            # It's unclear whether there is a reason behind this line.
+            # It conceivably could be because an exception raised in this block
+            # will keep the local frame (containing these local variables) around.
+            # This was adapted taken from CPython's Lib/code.py; see here:
+            # https://github.com/python/cpython/blob/4dc4300c686f543d504ab6fa9fe600eaf11bb695/Lib/code.py#L131-L150
+            last_tb = ei = None  # type: ignore
+
+    def displayhook(self, obj: Any) -> None:
         """
         We override the displayhook so that we automatically convert coroutines
         into Deferreds. (Our superclass' displayhook will take care of the rest,

+ 2 - 2
synapse/util/patch_inline_callbacks.py

@@ -24,7 +24,7 @@ from twisted.python.failure import Failure
 _already_patched = False
 
 
-def do_patch():
+def do_patch() -> None:
     """
     Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
     """
@@ -107,7 +107,7 @@ def do_patch():
     _already_patched = True
 
 
-def _check_yield_points(f: Callable, changes: List[str]):
+def _check_yield_points(f: Callable, changes: List[str]) -> Callable:
     """Wraps a generator that is about to be passed to defer.inlineCallbacks
     checking that after every yield the log contexts are correct.
 

+ 31 - 26
synapse/util/ratelimitutils.py

@@ -15,33 +15,36 @@
 import collections
 import contextlib
 import logging
+import typing
+from typing import Any, DefaultDict, Iterator, List, Set
 
 from twisted.internet import defer
 
 from synapse.api.errors import LimitExceededError
+from synapse.config.ratelimiting import FederationRateLimitConfig
 from synapse.logging.context import (
     PreserveLoggingContext,
     make_deferred_yieldable,
     run_in_background,
 )
+from synapse.util import Clock
+
+if typing.TYPE_CHECKING:
+    from contextlib import _GeneratorContextManager
 
 logger = logging.getLogger(__name__)
 
 
 class FederationRateLimiter:
-    def __init__(self, clock, config):
-        """
-        Args:
-            clock (Clock)
-            config (FederationRateLimitConfig)
-        """
-
-        def new_limiter():
+    def __init__(self, clock: Clock, config: FederationRateLimitConfig):
+        def new_limiter() -> "_PerHostRatelimiter":
             return _PerHostRatelimiter(clock=clock, config=config)
 
-        self.ratelimiters = collections.defaultdict(new_limiter)
+        self.ratelimiters: DefaultDict[
+            str, "_PerHostRatelimiter"
+        ] = collections.defaultdict(new_limiter)
 
-    def ratelimit(self, host):
+    def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
         """Used to ratelimit an incoming request from a given host
 
         Example usage:
@@ -60,11 +63,11 @@ class FederationRateLimiter:
 
 
 class _PerHostRatelimiter:
-    def __init__(self, clock, config):
+    def __init__(self, clock: Clock, config: FederationRateLimitConfig):
         """
         Args:
-            clock (Clock)
-            config (FederationRateLimitConfig)
+            clock
+            config
         """
         self.clock = clock
 
@@ -75,21 +78,23 @@ class _PerHostRatelimiter:
         self.concurrent_requests = config.concurrent
 
         # request_id objects for requests which have been slept
-        self.sleeping_requests = set()
+        self.sleeping_requests: Set[object] = set()
 
         # map from request_id object to Deferred for requests which are ready
         # for processing but have been queued
-        self.ready_request_queue = collections.OrderedDict()
+        self.ready_request_queue: collections.OrderedDict[
+            object, defer.Deferred[None]
+        ] = collections.OrderedDict()
 
         # request id objects for requests which are in progress
-        self.current_processing = set()
+        self.current_processing: Set[object] = set()
 
         # times at which we have recently (within the last window_size ms)
         # received requests.
-        self.request_times = []
+        self.request_times: List[int] = []
 
     @contextlib.contextmanager
-    def ratelimit(self):
+    def ratelimit(self) -> "Iterator[defer.Deferred[None]]":
         # `contextlib.contextmanager` takes a generator and turns it into a
         # context manager. The generator should only yield once with a value
         # to be returned by manager.
@@ -102,7 +107,7 @@ class _PerHostRatelimiter:
         finally:
             self._on_exit(request_id)
 
-    def _on_enter(self, request_id):
+    def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
         time_now = self.clock.time_msec()
 
         # remove any entries from request_times which aren't within the window
@@ -120,9 +125,9 @@ class _PerHostRatelimiter:
 
         self.request_times.append(time_now)
 
-        def queue_request():
+        def queue_request() -> "defer.Deferred[None]":
             if len(self.current_processing) >= self.concurrent_requests:
-                queue_defer = defer.Deferred()
+                queue_defer: defer.Deferred[None] = defer.Deferred()
                 self.ready_request_queue[request_id] = queue_defer
                 logger.info(
                     "Ratelimiter: queueing request (queue now %i items)",
@@ -145,7 +150,7 @@ class _PerHostRatelimiter:
 
             self.sleeping_requests.add(request_id)
 
-            def on_wait_finished(_):
+            def on_wait_finished(_: Any) -> "defer.Deferred[None]":
                 logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
                 self.sleeping_requests.discard(request_id)
                 queue_defer = queue_request()
@@ -155,19 +160,19 @@ class _PerHostRatelimiter:
         else:
             ret_defer = queue_request()
 
-        def on_start(r):
+        def on_start(r: object) -> object:
             logger.debug("Ratelimit [%s]: Processing req", id(request_id))
             self.current_processing.add(request_id)
             return r
 
-        def on_err(r):
+        def on_err(r: object) -> object:
             # XXX: why is this necessary? this is called before we start
             # processing the request so why would the request be in
             # current_processing?
             self.current_processing.discard(request_id)
             return r
 
-        def on_both(r):
+        def on_both(r: object) -> object:
             # Ensure that we've properly cleaned up.
             self.sleeping_requests.discard(request_id)
             self.ready_request_queue.pop(request_id, None)
@@ -177,7 +182,7 @@ class _PerHostRatelimiter:
         ret_defer.addBoth(on_both)
         return make_deferred_yieldable(ret_defer)
 
-    def _on_exit(self, request_id):
+    def _on_exit(self, request_id: object) -> None:
         logger.debug("Ratelimit [%s]: Processed req", id(request_id))
         self.current_processing.discard(request_id)
         try:

+ 42 - 27
synapse/util/retryutils.py

@@ -13,9 +13,13 @@
 # limitations under the License.
 import logging
 import random
+from types import TracebackType
+from typing import Any, Optional, Type
 
 import synapse.logging.context
 from synapse.api.errors import CodeMessageException
+from synapse.storage import DataStore
+from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
@@ -30,17 +34,17 @@ MAX_RETRY_INTERVAL = 2 ** 62
 
 
 class NotRetryingDestination(Exception):
-    def __init__(self, retry_last_ts, retry_interval, destination):
+    def __init__(self, retry_last_ts: int, retry_interval: int, destination: str):
         """Raised by the limiter (and federation client) to indicate that we are
         are deliberately not attempting to contact a given server.
 
         Args:
-            retry_last_ts (int): the unix ts in milliseconds of our last attempt
+            retry_last_ts: the unix ts in milliseconds of our last attempt
                 to contact the server.  0 indicates that the last attempt was
                 successful or that we've never actually attempted to connect.
-            retry_interval (int): the time in milliseconds to wait until the next
+            retry_interval: the time in milliseconds to wait until the next
                 attempt.
-            destination (str): the domain in question
+            destination: the domain in question
         """
 
         msg = "Not retrying server %s." % (destination,)
@@ -51,7 +55,13 @@ class NotRetryingDestination(Exception):
         self.destination = destination
 
 
-async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
+async def get_retry_limiter(
+    destination: str,
+    clock: Clock,
+    store: DataStore,
+    ignore_backoff: bool = False,
+    **kwargs: Any,
+) -> "RetryDestinationLimiter":
     """For a given destination check if we have previously failed to
     send a request there and are waiting before retrying the destination.
     If we are not ready to retry the destination, this will raise a
@@ -60,10 +70,10 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
     CodeMessageException with code < 500)
 
     Args:
-        destination (str): name of homeserver
-        clock (synapse.util.clock): timing source
-        store (synapse.storage.transactions.TransactionStore): datastore
-        ignore_backoff (bool): true to ignore the historical backoff data and
+        destination: name of homeserver
+        clock: timing source
+        store: datastore
+        ignore_backoff: true to ignore the historical backoff data and
             try the request anyway. We will still reset the retry_interval on success.
 
     Example usage:
@@ -114,13 +124,13 @@ async def get_retry_limiter(destination, clock, store, ignore_backoff=False, **k
 class RetryDestinationLimiter:
     def __init__(
         self,
-        destination,
-        clock,
-        store,
-        failure_ts,
-        retry_interval,
-        backoff_on_404=False,
-        backoff_on_failure=True,
+        destination: str,
+        clock: Clock,
+        store: DataStore,
+        failure_ts: Optional[int],
+        retry_interval: int,
+        backoff_on_404: bool = False,
+        backoff_on_failure: bool = True,
     ):
         """Marks the destination as "down" if an exception is thrown in the
         context, except for CodeMessageException with code < 500.
@@ -128,17 +138,17 @@ class RetryDestinationLimiter:
         If no exception is raised, marks the destination as "up".
 
         Args:
-            destination (str)
-            clock (Clock)
-            store (DataStore)
-            failure_ts (int|None): when this destination started failing (in ms since
+            destination
+            clock
+            store
+            failure_ts: when this destination started failing (in ms since
                 the epoch), or zero if the last request was successful
-            retry_interval (int): The next retry interval taken from the
+            retry_interval: The next retry interval taken from the
                 database in milliseconds, or zero if the last request was
                 successful.
-            backoff_on_404 (bool): Back off if we get a 404
+            backoff_on_404: Back off if we get a 404
 
-            backoff_on_failure (bool): set to False if we should not increase the
+            backoff_on_failure: set to False if we should not increase the
                 retry interval on a failure.
         """
         self.clock = clock
@@ -150,10 +160,15 @@ class RetryDestinationLimiter:
         self.backoff_on_404 = backoff_on_404
         self.backoff_on_failure = backoff_on_failure
 
-    def __enter__(self):
+    def __enter__(self) -> None:
         pass
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(
+        self,
+        exc_type: Optional[Type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         valid_err_code = False
         if exc_type is None:
             valid_err_code = True
@@ -161,7 +176,7 @@ class RetryDestinationLimiter:
             # avoid treating exceptions which don't derive from Exception as
             # failures; this is mostly so as not to catch defer._DefGen.
             valid_err_code = True
-        elif issubclass(exc_type, CodeMessageException):
+        elif isinstance(exc_val, CodeMessageException):
             # Some error codes are perfectly fine for some APIs, whereas other
             # APIs may expect to never received e.g. a 404. It's important to
             # handle 404 as some remote servers will return a 404 when the HS
@@ -216,7 +231,7 @@ class RetryDestinationLimiter:
             if self.failure_ts is None:
                 self.failure_ts = retry_last_ts
 
-        async def store_retry_timings():
+        async def store_retry_timings() -> None:
             try:
                 await self.store.set_destination_retry_timings(
                     self.destination,

+ 1 - 1
synapse/util/rlimit.py

@@ -18,7 +18,7 @@ import resource
 logger = logging.getLogger("synapse.app.homeserver")
 
 
-def change_resource_limit(soft_file_no):
+def change_resource_limit(soft_file_no: int) -> None:
     try:
         soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
 

+ 4 - 4
synapse/util/templates.py

@@ -16,7 +16,7 @@
 
 import time
 import urllib.parse
-from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
+from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union
 
 import jinja2
 
@@ -25,9 +25,9 @@ if TYPE_CHECKING:
 
 
 def build_jinja_env(
-    template_search_directories: Iterable[str],
+    template_search_directories: Sequence[str],
     config: "HomeServerConfig",
-    autoescape: Union[bool, Callable[[str], bool], None] = None,
+    autoescape: Union[bool, Callable[[Optional[str]], bool], None] = None,
 ) -> jinja2.Environment:
     """Set up a Jinja2 environment to load templates from the given search path
 
@@ -110,5 +110,5 @@ def _create_mxc_to_http_filter(
     return mxc_to_http_filter
 
 
-def _format_ts_filter(value: int, format: str):
+def _format_ts_filter(value: int, format: str) -> str:
     return time.strftime(format, time.localtime(value / 1000))

+ 8 - 4
synapse/util/threepids.py

@@ -14,6 +14,10 @@
 
 import logging
 import re
+import typing
+
+if typing.TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -28,13 +32,13 @@ logger = logging.getLogger(__name__)
 MAX_EMAIL_ADDRESS_LENGTH = 500
 
 
-def check_3pid_allowed(hs, medium, address):
+def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool:
     """Checks whether a given format of 3PID is allowed to be used on this HS
 
     Args:
-        hs (synapse.server.HomeServer): server
-        medium (str): 3pid medium - e.g. email, msisdn
-        address (str): address within that medium (e.g. "wotan@matrix.org")
+        hs: server
+        medium: 3pid medium - e.g. email, msisdn
+        address: address within that medium (e.g. "wotan@matrix.org")
             msisdns need to first have been canonicalised
     Returns:
         bool: whether the 3PID medium/address is allowed to be added to this HS

+ 1 - 1
synapse/util/versionstring.py

@@ -19,7 +19,7 @@ import subprocess
 logger = logging.getLogger(__name__)
 
 
-def get_version_string(module):
+def get_version_string(module) -> str:
     """Given a module calculate a git-aware version string for it.
 
     If called on a module not in a git checkout will return `__verison__`.

+ 19 - 16
synapse/util/wheel_timer.py

@@ -11,38 +11,41 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Generic, List, TypeVar
 
+T = TypeVar("T")
 
-class _Entry:
+
+class _Entry(Generic[T]):
     __slots__ = ["end_key", "queue"]
 
-    def __init__(self, end_key):
-        self.end_key = end_key
-        self.queue = []
+    def __init__(self, end_key: int) -> None:
+        self.end_key: int = end_key
+        self.queue: List[T] = []
 
 
-class WheelTimer:
+class WheelTimer(Generic[T]):
     """Stores arbitrary objects that will be returned after their timers have
     expired.
     """
 
-    def __init__(self, bucket_size=5000):
+    def __init__(self, bucket_size: int = 5000) -> None:
         """
         Args:
-            bucket_size (int): Size of buckets in ms. Corresponds roughly to the
+            bucket_size: Size of buckets in ms. Corresponds roughly to the
                 accuracy of the timer.
         """
-        self.bucket_size = bucket_size
-        self.entries = []
-        self.current_tick = 0
+        self.bucket_size: int = bucket_size
+        self.entries: List[_Entry[T]] = []
+        self.current_tick: int = 0
 
-    def insert(self, now, obj, then):
+    def insert(self, now: int, obj: T, then: int) -> None:
         """Inserts object into timer.
 
         Args:
-            now (int): Current time in msec
-            obj (object): Object to be inserted
-            then (int): When to return the object strictly after.
+            now: Current time in msec
+            obj: Object to be inserted
+            then: When to return the object strictly after.
         """
         then_key = int(then / self.bucket_size) + 1
 
@@ -70,7 +73,7 @@ class WheelTimer:
 
         self.entries[-1].queue.append(obj)
 
-    def fetch(self, now):
+    def fetch(self, now: int) -> List[T]:
         """Fetch any objects that have timed out
 
         Args:
@@ -87,5 +90,5 @@ class WheelTimer:
 
         return ret
 
-    def __len__(self):
+    def __len__(self) -> int:
         return sum(len(entry.queue) for entry in self.entries)

+ 2 - 2
tests/unittest.py

@@ -734,9 +734,9 @@ class TestTransportLayerServer(JsonResource):
             FederationRateLimitConfig(
                 window_size=1,
                 sleep_limit=1,
-                sleep_msec=1,
+                sleep_delay=1,
                 reject_limit=1000,
-                concurrent_requests=1000,
+                concurrent=1000,
             ),
         )