Selaa lähdekoodia

Make sydent.util pass `mypy --strict` (#418)

* Introduce a stubs directory and pull in new stub packages
* Allow newer versions of unpaddedbase64, so that we can use 2.1.0's type hints in CI
David Robertson 2 vuotta sitten
vanhempi
commit
4989ec00f4

+ 1 - 0
changelog.d/418.misc

@@ -0,0 +1 @@
+Add type hints to `sydent.util`.

+ 2 - 2
pyproject.toml

@@ -50,25 +50,25 @@ files = [
     # Find files that pass with
     #     find sydent tests -type d -not -name __pycache__ -exec bash -c "mypy --strict '{}' > /dev/null"  \; -print
     "sydent/users",
+    "sydent/util",
     # TODO the rest of CI checks these---mypy ought to too.
     # "tests",
     # "matrix_is_test",
     # "scripts",
     # "setup.py",
 ]
+mypy_path = "stubs"
 
 [[tool.mypy.overrides]]
 module = [
     "idna",
     "nacl.*",
     "netaddr",
-    "OpenSSL",
     "prometheus_client",
     "phonenumbers",
     "sentry_sdk",
     "signedjson.*",
     "sortedcontainers",
-    "unpaddedbase64",
 ]
 ignore_missing_imports = true
 

+ 3 - 1
setup.py

@@ -43,7 +43,7 @@ setup(
     install_requires=[
         "jinja2>=3.0.0",
         "signedjson==1.1.1",
-        "unpaddedbase64==1.1.0",
+        "unpaddedbase64>=1.1.0",
         "Twisted>=18.4.0",
         # twisted warns about about the absence of this
         "service_identity>=1.0.0",
@@ -61,6 +61,8 @@ setup(
             "isort==5.8.0",
             "mypy>=0.902",
             "mypy-zope>=0.3.1",
+            "types-Jinja2",
+            "types-PyOpenSSL",
             "types-PyYAML",
             "types-mock",
         ],

+ 0 - 0
stubs/twisted/__init__.pyi


+ 0 - 0
stubs/twisted/python/__init__.pyi


+ 11 - 0
stubs/twisted/python/log.pyi

@@ -0,0 +1,11 @@
+from typing import Optional, Union, Any
+
+from twisted.python.failure import Failure
+
+
+def err(
+    _stuff: Union[None, Exception, Failure] = None,
+    _why: Optional[str] = None,
+    **kw: Any,
+)-> None:
+    ...

+ 3 - 3
sydent/util/__init__.py

@@ -14,19 +14,19 @@
 
 import json
 import time
+from typing import NoReturn
 
 
-def time_msec():
+def time_msec() -> int:
     """
     Get the current time in milliseconds.
 
     :return: The current time in milliseconds.
-    :rtype: int
     """
     return int(time.time() * 1000)
 
 
-def _reject_invalid_json(val):
+def _reject_invalid_json(val: str) -> NoReturn:
     """Do not allow Infinity, -Infinity, or NaN values in JSON."""
     raise ValueError("Invalid JSON value: '%s'" % val)
 

+ 1 - 1
sydent/util/emailutils.py

@@ -93,7 +93,7 @@ def sendEmail(
         raise EmailAddressException()
 
     mailServer = sydent.config.email.smtp_server
-    mailPort = sydent.config.email.smtp_port
+    mailPort = int(sydent.config.email.smtp_port)
     mailUsername = sydent.config.email.smtp_username
     mailPassword = sydent.config.email.smtp_password
     mailTLSMode = sydent.config.email.tls_mode

+ 1 - 1
sydent/util/hash.py

@@ -14,7 +14,7 @@
 
 import hashlib
 
-import unpaddedbase64  # type: ignore
+import unpaddedbase64
 
 
 def sha256_and_url_safe_base64(input_text: str) -> str:

+ 1 - 1
sydent/util/ip_range.py

@@ -14,7 +14,7 @@
 import itertools
 from typing import Iterable, Optional
 
-from netaddr import AddrFormatError, IPNetwork, IPSet  # type: ignore
+from netaddr import AddrFormatError, IPNetwork, IPSet
 
 # IP ranges that are considered private / unroutable / don't make sense.
 DEFAULT_IP_RANGE_BLACKLIST = [

+ 1 - 1
sydent/util/stringutils.py

@@ -128,7 +128,7 @@ def is_valid_matrix_server_name(string: str) -> bool:
     return valid_ipv4_addr or valid_ipv6_literal or is_valid_hostname(host)
 
 
-def normalise_address(address, medium):
+def normalise_address(address: str, medium: str) -> str:
     if medium == "email":
         return address.casefold()
     else:

+ 42 - 32
sydent/util/ttlcache.py

@@ -12,52 +12,60 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import enum
 import logging
 import time
-from typing import Any, Tuple
+from typing import Callable, Dict, Generic, Tuple, TypeVar, Union
 
 import attr
-from sortedcontainers import SortedList  # type: ignore
+from sortedcontainers import SortedList
+from typing_extensions import Literal
 
 logger = logging.getLogger(__name__)
 
-SENTINEL = object()
 
+class Sentinel(enum.Enum):
+    token = enum.auto()
 
-class TTLCache:
+
+K = TypeVar("K")
+V = TypeVar("V")
+
+
+class TTLCache(Generic[K, V]):
     """A key/value cache implementation where each entry has its own TTL"""
 
-    def __init__(self, cache_name, timer=time.time):
-        # map from key to _CacheEntry
-        self._data = {}
+    def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
+        self._data: Dict[K, _CacheEntry[K, V]] = {}
 
         # the _CacheEntries, sorted by expiry time
-        self._expiry_list = SortedList()
+        self._expiry_list: SortedList[_CacheEntry] = SortedList()
 
         self._timer = timer
 
-    def set(self, key, value, ttl: float) -> None:
+    def set(self, key: K, value: V, ttl: float) -> None:
         """Add/update an entry in the cache
 
         :param key: Key for this entry.
 
         :param value: Value for this entry.
 
-        :param paramttl: TTL for this entry, in seconds.
-        :type paramttl: float
+        :param ttl: TTL for this entry, in seconds.
         """
         expiry = self._timer() + ttl
 
         self.expire()
-        e = self._data.pop(key, SENTINEL)
-        if e != SENTINEL:
+        e = self._data.pop(key, Sentinel.token)
+        if e != Sentinel.token:
             self._expiry_list.remove(e)
 
         entry = _CacheEntry(expiry_time=expiry, key=key, value=value)
         self._data[key] = entry
         self._expiry_list.add(entry)
 
-    def get(self, key, default=SENTINEL):
+    def get(
+        self, key: K, default: Union[V, Literal[Sentinel.token]] = Sentinel.token
+    ) -> V:
         """Get a value from the cache
 
         :param key: The key to look up.
@@ -67,14 +75,14 @@ class TTLCache:
         :returns a value from the cache, or the default.
         """
         self.expire()
-        e = self._data.get(key, SENTINEL)
-        if e == SENTINEL:
-            if default == SENTINEL:
+        e = self._data.get(key, Sentinel.token)
+        if e is Sentinel.token:
+            if default is Sentinel.token:
                 raise KeyError(key)
             return default
         return e.value
 
-    def get_with_expiry(self, key) -> Tuple[Any, float]:
+    def get_with_expiry(self, key: K) -> Tuple[V, float]:
         """Get a value, and its expiry time, from the cache
 
         :param key: key to look up
@@ -92,7 +100,9 @@ class TTLCache:
             raise
         return e.value, e.expiry_time
 
-    def pop(self, key, default=SENTINEL):
+    def pop(
+        self, key: K, default: Union[V, Literal[Sentinel.token]] = Sentinel.token
+    ) -> V:
         """Remove a value from the cache
 
         If key is in the cache, remove it and return its value, else return default.
@@ -105,28 +115,28 @@ class TTLCache:
         :returns a value from the cache, or the default
         """
         self.expire()
-        e = self._data.pop(key, SENTINEL)
-        if e == SENTINEL:
-            if default == SENTINEL:
+        e = self._data.pop(key, Sentinel.token)
+        if e is Sentinel.token:
+            if default == Sentinel.token:
                 raise KeyError(key)
             return default
         self._expiry_list.remove(e)
         return e.value
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: K) -> V:
         return self.get(key)
 
-    def __delitem__(self, key):
+    def __delitem__(self, key: K) -> None:
         self.pop(key)
 
-    def __contains__(self, key):
+    def __contains__(self, key: K) -> bool:
         return key in self._data
 
-    def __len__(self):
+    def __len__(self) -> int:
         self.expire()
         return len(self._data)
 
-    def expire(self):
+    def expire(self) -> None:
         """Run the expiry on the cache. Any entries whose expiry times are due will
         be removed
         """
@@ -139,11 +149,11 @@ class TTLCache:
             del self._expiry_list[0]
 
 
-@attr.s(frozen=True, slots=True)
-class _CacheEntry:
+@attr.s(frozen=True)
+class _CacheEntry(Generic[K, V]):
     """TTLCache entry"""
 
     # expiry_time is the first attribute, so that entries are sorted by expiry.
-    expiry_time = attr.ib()
-    key = attr.ib()
-    value = attr.ib()
+    expiry_time: float = attr.ib()
+    key: K = attr.ib()
+    value: V = attr.ib()