Browse Source

Make `sydent.terms` pass `mypy --strict` (#428)

Shout out to getForClient, which a) mixes presentation with data and b)
was a massive PITA to type hint. It's very very stringly typed.

* Explicitly check for a string master_version
David Robertson 2 years ago
parent
commit
5de734962a
3 changed files with 48 additions and 26 deletions
  1. 1 0
      changelog.d/428.misc
  2. 1 0
      pyproject.toml
  3. 46 26
      sydent/terms/terms.py

+ 1 - 0
changelog.d/428.misc

@@ -0,0 +1 @@
+Make `sydent.terms` pass `mypy --strict`.

+ 1 - 0
pyproject.toml

@@ -51,6 +51,7 @@ files = [
     #     find sydent tests -type d -not -name __pycache__ -exec bash -c "mypy --strict '{}' > /dev/null"  \; -print
     #     find sydent tests -type d -not -name __pycache__ -exec bash -c "mypy --strict '{}' > /dev/null"  \; -print
     "sydent/config",
     "sydent/config",
     "sydent/db",
     "sydent/db",
+    "sydent/terms",
     "sydent/threepid",
     "sydent/threepid",
     "sydent/users",
     "sydent/users",
     "sydent/util",
     "sydent/util",

+ 46 - 26
sydent/terms/terms.py

@@ -13,9 +13,10 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import logging
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Union
 
 
 import yaml
 import yaml
+from typing_extensions import TypedDict
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from sydent.sydent import Sydent
     from sydent.sydent import Sydent
@@ -23,8 +24,26 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
+class TermConfig(TypedDict):
+    master_version: str
+    docs: Mapping[str, "Policy"]
+
+
+class Policy(TypedDict):
+    version: str
+    langs: Mapping[str, "LocalisedPolicy"]
+
+
+class LocalisedPolicy(TypedDict):
+    name: str
+    url: str
+
+
+VersionOrLang = Union[str, LocalisedPolicy]
+
+
 class Terms:
 class Terms:
-    def __init__(self, yamlObj: Optional[Dict[str, Any]]) -> None:
+    def __init__(self, yamlObj: Optional[TermConfig]) -> None:
         """
         """
         :param yamlObj: The parsed YAML.
         :param yamlObj: The parsed YAML.
         """
         """
@@ -35,20 +54,19 @@ class Terms:
         :return: The global (master) version of the terms, or None if there
         :return: The global (master) version of the terms, or None if there
             are no terms of service for this server.
             are no terms of service for this server.
         """
         """
-        version = None if self._rawTerms is None else self._rawTerms["master_version"]
-
-        # Ensure we're dealing with unicode.
-        if version and isinstance(version, bytes):
-            version = version.decode("UTF-8")
-
-        return version
-
-    def getForClient(self) -> Dict[str, dict]:
+        if self._rawTerms is None:
+            return None
+        return self._rawTerms["master_version"]
+
+    def getForClient(self) -> Dict[str, Dict[str, Dict[str, VersionOrLang]]]:
+        # Examples:
+        # "policy" -> "terms_of_service", "version" -> "1.2.3"
+        # "policy" -> "terms_of_service", "en" -> LocalisedPolicy
         """
         """
         :return: A dict which value for the "policies" key is a dict which contains the
         :return: A dict which value for the "policies" key is a dict which contains the
             "docs" part of the terms' YAML. That nested dict is empty if no terms.
             "docs" part of the terms' YAML. That nested dict is empty if no terms.
         """
         """
-        policies = {}
+        policies: Dict[str, Dict[str, VersionOrLang]] = {}
         if self._rawTerms is not None:
         if self._rawTerms is not None:
             for docName, doc in self._rawTerms["docs"].items():
             for docName, doc in self._rawTerms["docs"].items():
                 policies[docName] = {
                 policies[docName] = {
@@ -66,11 +84,6 @@ class Terms:
             for docName, doc in self._rawTerms["docs"].items():
             for docName, doc in self._rawTerms["docs"].items():
                 for langName, lang in doc["langs"].items():
                 for langName, lang in doc["langs"].items():
                     url = lang["url"]
                     url = lang["url"]
-
-                    # Ensure we're dealing with unicode.
-                    if url and isinstance(url, bytes):
-                        url = url.decode("UTF-8")
-
                     urls.add(url)
                     urls.add(url)
         return urls
         return urls
 
 
@@ -87,36 +100,43 @@ class Terms:
         agreed = set()
         agreed = set()
         urlset = set(urls)
         urlset = set(urls)
 
 
-        if self._rawTerms is not None:
+        if self._rawTerms is None:
+            if urls:
+                raise ValueError("No configured terms, but user accepted some terms")
+            else:
+                return True
+
+        else:
             for docName, doc in self._rawTerms["docs"].items():
             for docName, doc in self._rawTerms["docs"].items():
                 for lang in doc["langs"].values():
                 for lang in doc["langs"].values():
                     if lang["url"] in urlset:
                     if lang["url"] in urlset:
                         agreed.add(docName)
                         agreed.add(docName)
                         break
                         break
 
 
-        required = set(self._rawTerms["docs"].keys())
-        return agreed == required
+            required = set(self._rawTerms["docs"].keys())
+            return agreed == required
 
 
 
 
 def get_terms(sydent: "Sydent") -> Optional[Terms]:
 def get_terms(sydent: "Sydent") -> Optional[Terms]:
-    """Read and parse terms as specified in the config.
-
-    :returns Terms
-    """
+    """Read and parse terms as specified in the config."""
     # TODO - move some of this to parse_config
     # TODO - move some of this to parse_config
 
 
     termsPath = sydent.config.general.terms_path
     termsPath = sydent.config.general.terms_path
 
 
     try:
     try:
-        termsYaml = None
-
         if termsPath == "":
         if termsPath == "":
             return Terms(None)
             return Terms(None)
 
 
         with open(termsPath) as fp:
         with open(termsPath) as fp:
             termsYaml = yaml.safe_load(fp)
             termsYaml = yaml.safe_load(fp)
+
+        # TODO use something like jsonschema instead of this handwritten code.
         if "master_version" not in termsYaml:
         if "master_version" not in termsYaml:
             raise Exception("No master version")
             raise Exception("No master version")
+        elif not isinstance(termsYaml["master_version"], str):
+            raise TypeError(
+                f"master_version should be a string, not {termsYaml['master_version']!r}"
+            )
         if "docs" not in termsYaml:
         if "docs" not in termsYaml:
             raise Exception("No 'docs' key in terms")
             raise Exception("No 'docs' key in terms")
         for docName, doc in termsYaml["docs"].items():
         for docName, doc in termsYaml["docs"].items():