Browse Source

Make it possible to use dmypy (#9692)

Running `dmypy run` will do a `mypy` check while spinning up a daemon
that makes rerunning `dmypy run` a lot faster.

`dmypy` doesn't support `follow_imports = silent` and has
`local_partial_types` enabled, so this PR enables those options and
fixes the issues that were newly raised. Note that `local_partial_types`
will be enabled by default in upcoming mypy releases.
Erik Johnston 3 years ago
parent
commit
b5efcb577e

+ 1 - 0
changelog.d/9692.misc

@@ -0,0 +1 @@
+Make it possible to use `dmypy`.

+ 2 - 1
mypy.ini

@@ -1,12 +1,13 @@
 [mypy]
 namespace_packages = True
 plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
-follow_imports = silent
+follow_imports = normal
 check_untyped_defs = True
 show_error_codes = True
 show_traceback = True
 mypy_path = stubs
 warn_unreachable = True
+local_partial_types = True
 
 # To find all folders that pass mypy you run:
 #

+ 5 - 0
synapse/api/auth.py

@@ -558,6 +558,9 @@ class Auth:
         Returns:
             bool: False if no access_token was given, True otherwise.
         """
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
+
         query_params = request.args.get(b"access_token")
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
         return bool(query_params) or bool(auth_headers)
@@ -574,6 +577,8 @@ class Auth:
             MissingClientTokenError: If there isn't a single access_token in the
                 request
         """
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
 
         auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
         query_params = request.args.get(b"access_token")

+ 4 - 2
synapse/config/cache.py

@@ -24,7 +24,7 @@ from ._base import Config, ConfigError
 _CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR"
 
 # Map from canonicalised cache name to cache.
-_CACHES = {}
+_CACHES = {}  # type: Dict[str, Callable[[float], None]]
 
 # a lock on the contents of _CACHES
 _CACHES_LOCK = threading.Lock()
@@ -59,7 +59,9 @@ def _canonicalise_cache_name(cache_name: str) -> str:
     return cache_name.lower()
 
 
-def add_resizable_cache(cache_name: str, cache_resize_callback: Callable):
+def add_resizable_cache(
+    cache_name: str, cache_resize_callback: Callable[[float], None]
+):
     """Register a cache that's size can dynamically change
 
     Args:

+ 3 - 0
synapse/handlers/oidc_handler.py

@@ -149,6 +149,9 @@ class OidcHandler:
         Args:
             request: the incoming request from the browser.
         """
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
+
         # The provider might redirect with an error.
         # In that case, just display it as-is.
         if b"error" in request.args:

+ 1 - 1
synapse/logging/opentracing.py

@@ -262,7 +262,7 @@ logger = logging.getLogger(__name__)
 # Block everything by default
 # A regex which matches the server_names to expose traces for.
 # None means 'block everything'.
-_homeserver_whitelist = None
+_homeserver_whitelist = None  # type: Optional[re.Pattern[str]]
 
 # Util methods
 

+ 1 - 1
synapse/replication/tcp/protocol.py

@@ -104,7 +104,7 @@ tcp_outbound_commands_counter = Counter(
 
 # A list of all connected protocols. This allows us to send metrics about the
 # connections.
-connected_connections = []
+connected_connections = []  # type: List[BaseReplicationStreamProtocol]
 
 
 logger = logging.getLogger(__name__)

+ 3 - 0
synapse/rest/admin/rooms.py

@@ -390,6 +390,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
     async def on_POST(
         self, request: SynapseRequest, room_identifier: str
     ) -> Tuple[int, JsonDict]:
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
+
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 

+ 3 - 0
synapse/rest/admin/users.py

@@ -833,6 +833,9 @@ class UserMediaRestServlet(RestServlet):
     async def on_GET(
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
+
         await assert_requester_is_admin(self.auth, request)
 
         if not self.is_mine(UserID.from_string(user_id)):

+ 3 - 0
synapse/rest/client/v2_alpha/sync.py

@@ -91,6 +91,9 @@ class SyncRestServlet(RestServlet):
         self._event_serializer = hs.get_event_client_serializer()
 
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
+
         if b"from" in request.args:
             # /events used to use 'from', but /sync uses 'since'.
             # Lets be helpful and whine if we see a 'from'.

+ 2 - 0
synapse/rest/media/v1/preview_url_resource.py

@@ -187,6 +187,8 @@ class PreviewUrlResource(DirectServeJsonResource):
         respond_with_json(request, 200, {}, send_cors=True)
 
     async def _async_render_GET(self, request: SynapseRequest) -> None:
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
 
         # XXX: if get_user_by_req fails, what should we do in an async render?
         requester = await self.auth.get_user_by_req(request)

+ 3 - 0
synapse/rest/synapse/client/pick_username.py

@@ -104,6 +104,9 @@ class AccountDetailsResource(DirectServeHtmlResource):
         respond_with_html(request, 200, html)
 
     async def _async_render_POST(self, request: SynapseRequest):
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
+
         try:
             session_id = get_username_mapping_session_cookie_from_request(request)
         except SynapseError as e:

+ 2 - 2
synapse/util/caches/__init__.py

@@ -25,8 +25,8 @@ from synapse.config.cache import add_resizable_cache
 
 logger = logging.getLogger(__name__)
 
-caches_by_name = {}
-collectors_by_name = {}  # type: Dict
+caches_by_name = {}  # type: Dict[str, Sized]
+collectors_by_name = {}  # type: Dict[str, CacheMetric]
 
 cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
 cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])

+ 1 - 0
tests/replication/tcp/streams/test_typing.py

@@ -69,6 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         # The from token should be the token from the last RDATA we got.
+        assert request.args is not None
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
         self.test_handler.on_rdata.assert_called_once()

+ 2 - 2
tests/replication/test_multi_media_repo.py

@@ -15,7 +15,7 @@
 import logging
 import os
 from binascii import unhexlify
-from typing import Tuple
+from typing import Optional, Tuple
 
 from twisted.internet.protocol import Factory
 from twisted.protocols.tls import TLSMemoryBIOFactory
@@ -32,7 +32,7 @@ from tests.server import FakeChannel, FakeSite, FakeTransport, make_request
 
 logger = logging.getLogger(__name__)
 
-test_server_connection_factory = None
+test_server_connection_factory = None  # type: Optional[TestServerTLSConnectionFactory]
 
 
 class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):

+ 20 - 8
tests/server.py

@@ -2,7 +2,7 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
+from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -13,8 +13,11 @@ from twisted.internet._resolver import SimpleResolverComplexifier
 from twisted.internet.defer import Deferred, fail, succeed
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import (
+    IHostnameResolver,
+    IProtocol,
+    IPullProducer,
+    IPushProducer,
     IReactorPluggableNameResolver,
-    IReactorTCP,
     IResolverSimple,
     ITransport,
 )
@@ -45,11 +48,11 @@ class FakeChannel:
     wire).
     """
 
-    site = attr.ib(type=Site)
+    site = attr.ib(type=Union[Site, "FakeSite"])
     _reactor = attr.ib()
     result = attr.ib(type=dict, default=attr.Factory(dict))
     _ip = attr.ib(type=str, default="127.0.0.1")
-    _producer = None
+    _producer = None  # type: Optional[Union[IPullProducer, IPushProducer]]
 
     @property
     def json_body(self):
@@ -159,7 +162,11 @@ class FakeChannel:
 
         Any cookines found are added to the given dict
         """
-        for h in self.headers.getRawHeaders("Set-Cookie"):
+        headers = self.headers.getRawHeaders("Set-Cookie")
+        if not headers:
+            return
+
+        for h in headers:
             parts = h.split(";")
             k, v = parts[0].split("=", maxsplit=1)
             cookies[k] = v
@@ -311,8 +318,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
 
         self._tcp_callbacks = {}
         self._udp = []
-        lookups = self.lookups = {}
-        self._thread_callbacks = deque()  # type: Deque[Callable[[], None]]()
+        lookups = self.lookups = {}  # type: Dict[str, str]
+        self._thread_callbacks = deque()  # type: Deque[Callable[[], None]]
 
         @implementer(IResolverSimple)
         class FakeResolver:
@@ -324,6 +331,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
         self.nameResolver = SimpleResolverComplexifier(FakeResolver())
         super().__init__()
 
+    def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
+        raise NotImplementedError()
+
     def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
         p = udp.Port(port, protocol, interface, maxPacketSize, self)
         p.startListening()
@@ -621,7 +631,9 @@ class FakeTransport:
             self.disconnected = True
 
 
-def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
+def connect_client(
+    reactor: ThreadedMemoryReactorClock, client_id: int
+) -> Tuple[IProtocol, AccumulatingProtocol]:
     """
     Connect a client to a fake TCP transport.