1
0

check_pydantic_models.py 14 KB


  1. #! /usr/bin/env python
  2. # Copyright 2022 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. A script which enforces that Synapse always uses strict types when defining a Pydantic
  17. model.
  18. Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. See
  19. https://github.com/pydantic/pydantic/issues/1098
  20. https://pydantic-docs.helpmanual.io/blog/pydantic-v2/#strict-mode
  21. until then, this script is a best effort to stop us from introducing type coersion bugs
  22. (like the infamous stringy power levels fixed in room version 10).
  23. """
  24. import argparse
  25. import contextlib
  26. import functools
  27. import importlib
  28. import logging
  29. import os
  30. import pkgutil
  31. import sys
  32. import textwrap
  33. import traceback
  34. import unittest.mock
  35. from contextlib import contextmanager
  36. from typing import Any, Callable, Dict, Generator, List, Set, Type, TypeVar
  37. from parameterized import parameterized
  38. from pydantic import BaseModel as PydanticBaseModel, conbytes, confloat, conint, constr
  39. from pydantic.typing import get_args
  40. from typing_extensions import ParamSpec
  41. logger = logging.getLogger(__name__)
  42. CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [
  43. constr,
  44. conbytes,
  45. conint,
  46. confloat,
  47. ]
  48. TYPES_THAT_PYDANTIC_WILL_COERCE_TO = [
  49. str,
  50. bytes,
  51. int,
  52. float,
  53. bool,
  54. ]
  55. P = ParamSpec("P")
  56. R = TypeVar("R")
  57. class ModelCheckerException(Exception):
  58. """Dummy exception. Allows us to detect unwanted types during a module import."""
  59. class MissingStrictInConstrainedTypeException(ModelCheckerException):
  60. factory_name: str
  61. def __init__(self, factory_name: str):
  62. self.factory_name = factory_name
  63. class FieldHasUnwantedTypeException(ModelCheckerException):
  64. message: str
  65. def __init__(self, message: str):
  66. self.message = message
  67. def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]:
  68. """We patch `constr` and friends with wrappers that enforce strict=True."""
  69. @functools.wraps(factory)
  70. def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
  71. if "strict" not in kwargs:
  72. raise MissingStrictInConstrainedTypeException(factory.__name__)
  73. if not kwargs["strict"]:
  74. raise MissingStrictInConstrainedTypeException(factory.__name__)
  75. return factory(*args, **kwargs)
  76. return wrapper
  77. def field_type_unwanted(type_: Any) -> bool:
  78. """Very rough attempt to detect if a type is unwanted as a Pydantic annotation.
  79. At present, we exclude types which will coerce, or any generic type involving types
  80. which will coerce."""
  81. logger.debug("Is %s unwanted?")
  82. if type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO:
  83. logger.debug("yes")
  84. return True
  85. logger.debug("Maybe. Subargs are %s", get_args(type_))
  86. rv = any(field_type_unwanted(t) for t in get_args(type_))
  87. logger.debug("Conclusion: %s %s unwanted", type_, "is" if rv else "is not")
  88. return rv
  89. class PatchedBaseModel(PydanticBaseModel):
  90. """A patched version of BaseModel that inspects fields after models are defined.
  91. We complain loudly if we see an unwanted type.
  92. Beware: ModelField.type_ is presumably private; this is likely to be very brittle.
  93. """
  94. @classmethod
  95. def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object):
  96. for field in cls.__fields__.values():
  97. # Note that field.type_ and field.outer_type are computed based on the
  98. # annotation type, see pydantic.fields.ModelField._type_analysis
  99. if field_type_unwanted(field.outer_type_):
  100. # TODO: this only reports the first bad field. Can we find all bad ones
  101. # and report them all?
  102. raise FieldHasUnwantedTypeException(
  103. f"{cls.__module__}.{cls.__qualname__} has field '{field.name}' "
  104. f"with unwanted type `{field.outer_type_}`"
  105. )
  106. @contextmanager
  107. def monkeypatch_pydantic() -> Generator[None, None, None]:
  108. """Patch pydantic with our snooping versions of BaseModel and the con* functions.
  109. If the snooping functions see something they don't like, they'll raise a
  110. ModelCheckingException instance.
  111. """
  112. with contextlib.ExitStack() as patches:
  113. # Most Synapse code ought to import the patched objects directly from
  114. # `pydantic`. But we also patch their containing modules `pydantic.main` and
  115. # `pydantic.types` for completeness.
  116. patch_basemodel1 = unittest.mock.patch(
  117. "pydantic.BaseModel", new=PatchedBaseModel
  118. )
  119. patch_basemodel2 = unittest.mock.patch(
  120. "pydantic.main.BaseModel", new=PatchedBaseModel
  121. )
  122. patches.enter_context(patch_basemodel1)
  123. patches.enter_context(patch_basemodel2)
  124. for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG:
  125. wrapper: Callable = make_wrapper(factory)
  126. patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper)
  127. patch2 = unittest.mock.patch(
  128. f"pydantic.types.{factory.__name__}", new=wrapper
  129. )
  130. patches.enter_context(patch1)
  131. patches.enter_context(patch2)
  132. yield
  133. def format_model_checker_exception(e: ModelCheckerException) -> str:
  134. """Work out which line of code caused e. Format the line in a human-friendly way."""
  135. # TODO. FieldHasUnwantedTypeException gives better error messages. Can we ditch the
  136. # patches of constr() etc, and instead inspect fields to look for ConstrainedStr
  137. # with strict=False? There is some difficulty with the inheritance hierarchy
  138. # because StrictStr < ConstrainedStr < str.
  139. if isinstance(e, FieldHasUnwantedTypeException):
  140. return e.message
  141. elif isinstance(e, MissingStrictInConstrainedTypeException):
  142. frame_summary = traceback.extract_tb(e.__traceback__)[-2]
  143. return (
  144. f"Missing `strict=True` from {e.factory_name}() call \n"
  145. + traceback.format_list([frame_summary])[0].lstrip()
  146. )
  147. else:
  148. raise ValueError(f"Unknown exception {e}") from e
  149. def lint() -> int:
  150. """Try to import all of Synapse and see if we spot any Pydantic type coercions.
  151. Print any problems, then return a status code suitable for sys.exit."""
  152. failures = do_lint()
  153. if failures:
  154. print(f"Found {len(failures)} problem(s)")
  155. for failure in sorted(failures):
  156. print(failure)
  157. return os.EX_DATAERR if failures else os.EX_OK
  158. def do_lint() -> Set[str]:
  159. """Try to import all of Synapse and see if we spot any Pydantic type coercions."""
  160. failures = set()
  161. with monkeypatch_pydantic():
  162. logger.debug("Importing synapse")
  163. try:
  164. # TODO: make "synapse" an argument so we can target this script at
  165. # a subpackage
  166. module = importlib.import_module("synapse")
  167. except ModelCheckerException as e:
  168. logger.warning("Bad annotation found when importing synapse")
  169. failures.add(format_model_checker_exception(e))
  170. return failures
  171. try:
  172. logger.debug("Fetching subpackages")
  173. module_infos = list(
  174. pkgutil.walk_packages(module.__path__, f"{module.__name__}.")
  175. )
  176. except ModelCheckerException as e:
  177. logger.warning("Bad annotation found when looking for modules to import")
  178. failures.add(format_model_checker_exception(e))
  179. return failures
  180. for module_info in module_infos:
  181. logger.debug("Importing %s", module_info.name)
  182. try:
  183. importlib.import_module(module_info.name)
  184. except ModelCheckerException as e:
  185. logger.warning(
  186. f"Bad annotation found when importing {module_info.name}"
  187. )
  188. failures.add(format_model_checker_exception(e))
  189. return failures
  190. def run_test_snippet(source: str) -> None:
  191. """Exec a snippet of source code in an isolated environment."""
  192. # To emulate `source` being called at the top level of the module,
  193. # the globals and locals we provide apparently have to be the same mapping.
  194. #
  195. # > Remember that at the module level, globals and locals are the same dictionary.
  196. # > If exec gets two separate objects as globals and locals, the code will be
  197. # > executed as if it were embedded in a class definition.
  198. globals_: Dict[str, object]
  199. locals_: Dict[str, object]
  200. globals_ = locals_ = {}
  201. exec(textwrap.dedent(source), globals_, locals_)
  202. class TestConstrainedTypesPatch(unittest.TestCase):
  203. def test_expression_without_strict_raises(self) -> None:
  204. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  205. run_test_snippet(
  206. """
  207. from pydantic import constr
  208. constr()
  209. """
  210. )
  211. def test_called_as_module_attribute_raises(self) -> None:
  212. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  213. run_test_snippet(
  214. """
  215. import pydantic
  216. pydantic.constr()
  217. """
  218. )
  219. def test_wildcard_import_raises(self) -> None:
  220. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  221. run_test_snippet(
  222. """
  223. from pydantic import *
  224. constr()
  225. """
  226. )
  227. def test_alternative_import_raises(self) -> None:
  228. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  229. run_test_snippet(
  230. """
  231. from pydantic.types import constr
  232. constr()
  233. """
  234. )
  235. def test_alternative_import_attribute_raises(self) -> None:
  236. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  237. run_test_snippet(
  238. """
  239. import pydantic.types
  240. pydantic.types.constr()
  241. """
  242. )
  243. def test_kwarg_but_no_strict_raises(self) -> None:
  244. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  245. run_test_snippet(
  246. """
  247. from pydantic import constr
  248. constr(min_length=10)
  249. """
  250. )
  251. def test_kwarg_strict_False_raises(self) -> None:
  252. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  253. run_test_snippet(
  254. """
  255. from pydantic import constr
  256. constr(strict=False)
  257. """
  258. )
  259. def test_kwarg_strict_True_doesnt_raise(self) -> None:
  260. with monkeypatch_pydantic():
  261. run_test_snippet(
  262. """
  263. from pydantic import constr
  264. constr(strict=True)
  265. """
  266. )
  267. def test_annotation_without_strict_raises(self) -> None:
  268. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  269. run_test_snippet(
  270. """
  271. from pydantic import constr
  272. x: constr()
  273. """
  274. )
  275. def test_field_annotation_without_strict_raises(self) -> None:
  276. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  277. run_test_snippet(
  278. """
  279. from pydantic import BaseModel, conint
  280. class C:
  281. x: conint()
  282. """
  283. )
  284. class TestFieldTypeInspection(unittest.TestCase):
  285. @parameterized.expand(
  286. [
  287. ("str",),
  288. ("bytes"),
  289. ("int",),
  290. ("float",),
  291. ("bool"),
  292. ("Optional[str]",),
  293. ("Union[None, str]",),
  294. ("List[str]",),
  295. ("List[List[str]]",),
  296. ("Dict[StrictStr, str]",),
  297. ("Dict[str, StrictStr]",),
  298. ("TypedDict('D', x=int)",),
  299. ]
  300. )
  301. def test_field_holding_unwanted_type_raises(self, annotation: str) -> None:
  302. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  303. run_test_snippet(
  304. f"""
  305. from typing import *
  306. from pydantic import *
  307. class C(BaseModel):
  308. f: {annotation}
  309. """
  310. )
  311. @parameterized.expand(
  312. [
  313. ("StrictStr",),
  314. ("StrictBytes"),
  315. ("StrictInt",),
  316. ("StrictFloat",),
  317. ("StrictBool"),
  318. ("constr(strict=True, min_length=10)",),
  319. ("Optional[StrictStr]",),
  320. ("Union[None, StrictStr]",),
  321. ("List[StrictStr]",),
  322. ("List[List[StrictStr]]",),
  323. ("Dict[StrictStr, StrictStr]",),
  324. ("TypedDict('D', x=StrictInt)",),
  325. ]
  326. )
  327. def test_field_holding_accepted_type_doesnt_raise(self, annotation: str) -> None:
  328. with monkeypatch_pydantic():
  329. run_test_snippet(
  330. f"""
  331. from typing import *
  332. from pydantic import *
  333. class C(BaseModel):
  334. f: {annotation}
  335. """
  336. )
  337. def test_field_holding_str_raises_with_alternative_import(self) -> None:
  338. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  339. run_test_snippet(
  340. """
  341. from pydantic.main import BaseModel
  342. class C(BaseModel):
  343. f: str
  344. """
  345. )
  346. parser = argparse.ArgumentParser()
  347. parser.add_argument("mode", choices=["lint", "test"], default="lint", nargs="?")
  348. parser.add_argument("-v", "--verbose", action="store_true")
  349. if __name__ == "__main__":
  350. args = parser.parse_args(sys.argv[1:])
  351. logging.basicConfig(
  352. format="%(asctime)s %(name)s:%(lineno)d %(levelname)s %(message)s",
  353. level=logging.DEBUG if args.verbose else logging.INFO,
  354. )
  355. # suppress logs we don't care about
  356. logging.getLogger("xmlschema").setLevel(logging.WARNING)
  357. if args.mode == "lint":
  358. sys.exit(lint())
  359. elif args.mode == "test":
  360. unittest.main(argv=sys.argv[:1])