Pārlūkot izejas kodu

Make `sydent.terms` pass `mypy --strict` (#428)

Shout out to getForClient, which a) mixes presentation with data and b)
was a massive PITA to type hint. It's very very stringly typed.

* Explicitly check for a string master_version
David Robertson 2 gadi atpakaļ
vecāks
revīzija
5de734962a
3 mainītis faili ar 48 papildinājumiem un 26 dzēšanām
  1. 1 0
      changelog.d/428.misc
  2. 1 0
      pyproject.toml
  3. 46 26
      sydent/terms/terms.py

+ 1 - 0
changelog.d/428.misc

@@ -0,0 +1 @@
+Make `sydent.terms` pass `mypy --strict`.

+ 1 - 0
pyproject.toml

@@ -51,6 +51,7 @@ files = [
     #     find sydent tests -type d -not -name __pycache__ -exec bash -c "mypy --strict '{}' > /dev/null"  \; -print
     "sydent/config",
     "sydent/db",
+    "sydent/terms",
     "sydent/threepid",
     "sydent/users",
     "sydent/util",

+ 46 - 26
sydent/terms/terms.py

@@ -13,9 +13,10 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Union
 
 import yaml
+from typing_extensions import TypedDict
 
 if TYPE_CHECKING:
     from sydent.sydent import Sydent
@@ -23,8 +24,26 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class TermConfig(TypedDict):
+    master_version: str
+    docs: Mapping[str, "Policy"]
+
+
+class Policy(TypedDict):
+    version: str
+    langs: Mapping[str, "LocalisedPolicy"]
+
+
+class LocalisedPolicy(TypedDict):
+    name: str
+    url: str
+
+
+VersionOrLang = Union[str, LocalisedPolicy]
+
+
 class Terms:
-    def __init__(self, yamlObj: Optional[Dict[str, Any]]) -> None:
+    def __init__(self, yamlObj: Optional[TermConfig]) -> None:
         """
         :param yamlObj: The parsed YAML.
         """
@@ -35,20 +54,19 @@ class Terms:
         :return: The global (master) version of the terms, or None if there
             are no terms of service for this server.
         """
-        version = None if self._rawTerms is None else self._rawTerms["master_version"]
-
-        # Ensure we're dealing with unicode.
-        if version and isinstance(version, bytes):
-            version = version.decode("UTF-8")
-
-        return version
-
-    def getForClient(self) -> Dict[str, dict]:
+        if self._rawTerms is None:
+            return None
+        return self._rawTerms["master_version"]
+
+    def getForClient(self) -> Dict[str, Dict[str, Dict[str, VersionOrLang]]]:
+        # Examples:
+        # "policy" -> "terms_of_service", "version" -> "1.2.3"
+        # "policy" -> "terms_of_service", "en" -> LocalisedPolicy
         """
         :return: A dict which value for the "policies" key is a dict which contains the
             "docs" part of the terms' YAML. That nested dict is empty if no terms.
         """
-        policies = {}
+        policies: Dict[str, Dict[str, VersionOrLang]] = {}
         if self._rawTerms is not None:
             for docName, doc in self._rawTerms["docs"].items():
                 policies[docName] = {
@@ -66,11 +84,6 @@ class Terms:
             for docName, doc in self._rawTerms["docs"].items():
                 for langName, lang in doc["langs"].items():
                     url = lang["url"]
-
-                    # Ensure we're dealing with unicode.
-                    if url and isinstance(url, bytes):
-                        url = url.decode("UTF-8")
-
                     urls.add(url)
         return urls
 
@@ -87,36 +100,43 @@ class Terms:
         agreed = set()
         urlset = set(urls)
 
-        if self._rawTerms is not None:
+        if self._rawTerms is None:
+            if urls:
+                raise ValueError("No configured terms, but user accepted some terms")
+            else:
+                return True
+
+        else:
             for docName, doc in self._rawTerms["docs"].items():
                 for lang in doc["langs"].values():
                     if lang["url"] in urlset:
                         agreed.add(docName)
                         break
 
-        required = set(self._rawTerms["docs"].keys())
-        return agreed == required
+            required = set(self._rawTerms["docs"].keys())
+            return agreed == required
 
 
 def get_terms(sydent: "Sydent") -> Optional[Terms]:
-    """Read and parse terms as specified in the config.
-
-    :returns Terms
-    """
+    """Read and parse terms as specified in the config."""
     # TODO - move some of this to parse_config
 
     termsPath = sydent.config.general.terms_path
 
     try:
-        termsYaml = None
-
         if termsPath == "":
             return Terms(None)
 
         with open(termsPath) as fp:
             termsYaml = yaml.safe_load(fp)
+
+        # TODO use something like jsonschema instead of this handwritten code.
         if "master_version" not in termsYaml:
             raise Exception("No master version")
+        elif not isinstance(termsYaml["master_version"], str):
+            raise TypeError(
+                f"master_version should be a string, not {termsYaml['master_version']!r}"
+            )
         if "docs" not in termsYaml:
             raise Exception("No 'docs' key in terms")
         for docName, doc in termsYaml["docs"].items():