check_pydantic_models.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. #! /usr/bin/env python
  2. # Copyright 2022 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. A script which enforces that Synapse always uses strict types when defining a Pydantic
  17. model.
  18. Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. See
  19. https://github.com/pydantic/pydantic/issues/1098
  20. https://pydantic-docs.helpmanual.io/blog/pydantic-v2/#strict-mode
  21. until then, this script is a best effort to stop us from introducing type coersion bugs
  22. (like the infamous stringy power levels fixed in room version 10).
  23. """
  24. import argparse
  25. import contextlib
  26. import functools
  27. import importlib
  28. import logging
  29. import os
  30. import pkgutil
  31. import sys
  32. import textwrap
  33. import traceback
  34. import unittest.mock
  35. from contextlib import contextmanager
  36. from typing import (
  37. TYPE_CHECKING,
  38. Any,
  39. Callable,
  40. Dict,
  41. Generator,
  42. List,
  43. Set,
  44. Type,
  45. TypeVar,
  46. )
  47. from parameterized import parameterized
  48. from synapse._pydantic_compat import HAS_PYDANTIC_V2
  49. if TYPE_CHECKING or HAS_PYDANTIC_V2:
  50. from pydantic.v1 import (
  51. BaseModel as PydanticBaseModel,
  52. conbytes,
  53. confloat,
  54. conint,
  55. constr,
  56. )
  57. from pydantic.v1.typing import get_args
  58. else:
  59. from pydantic import (
  60. BaseModel as PydanticBaseModel,
  61. conbytes,
  62. confloat,
  63. conint,
  64. constr,
  65. )
  66. from pydantic.typing import get_args
  67. from typing_extensions import ParamSpec
  68. logger = logging.getLogger(__name__)
  69. CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [
  70. constr,
  71. conbytes,
  72. conint,
  73. confloat,
  74. ]
  75. TYPES_THAT_PYDANTIC_WILL_COERCE_TO = [
  76. str,
  77. bytes,
  78. int,
  79. float,
  80. bool,
  81. ]
  82. P = ParamSpec("P")
  83. R = TypeVar("R")
  84. class ModelCheckerException(Exception):
  85. """Dummy exception. Allows us to detect unwanted types during a module import."""
  86. class MissingStrictInConstrainedTypeException(ModelCheckerException):
  87. factory_name: str
  88. def __init__(self, factory_name: str):
  89. self.factory_name = factory_name
  90. class FieldHasUnwantedTypeException(ModelCheckerException):
  91. message: str
  92. def __init__(self, message: str):
  93. self.message = message
  94. def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]:
  95. """We patch `constr` and friends with wrappers that enforce strict=True."""
  96. @functools.wraps(factory)
  97. def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
  98. if "strict" not in kwargs:
  99. raise MissingStrictInConstrainedTypeException(factory.__name__)
  100. if not kwargs["strict"]:
  101. raise MissingStrictInConstrainedTypeException(factory.__name__)
  102. return factory(*args, **kwargs)
  103. return wrapper
  104. def field_type_unwanted(type_: Any) -> bool:
  105. """Very rough attempt to detect if a type is unwanted as a Pydantic annotation.
  106. At present, we exclude types which will coerce, or any generic type involving types
  107. which will coerce."""
  108. logger.debug("Is %s unwanted?")
  109. if type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO:
  110. logger.debug("yes")
  111. return True
  112. logger.debug("Maybe. Subargs are %s", get_args(type_))
  113. rv = any(field_type_unwanted(t) for t in get_args(type_))
  114. logger.debug("Conclusion: %s %s unwanted", type_, "is" if rv else "is not")
  115. return rv
  116. class PatchedBaseModel(PydanticBaseModel):
  117. """A patched version of BaseModel that inspects fields after models are defined.
  118. We complain loudly if we see an unwanted type.
  119. Beware: ModelField.type_ is presumably private; this is likely to be very brittle.
  120. """
  121. @classmethod
  122. def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object):
  123. for field in cls.__fields__.values():
  124. # Note that field.type_ and field.outer_type are computed based on the
  125. # annotation type, see pydantic.fields.ModelField._type_analysis
  126. if field_type_unwanted(field.outer_type_):
  127. # TODO: this only reports the first bad field. Can we find all bad ones
  128. # and report them all?
  129. raise FieldHasUnwantedTypeException(
  130. f"{cls.__module__}.{cls.__qualname__} has field '{field.name}' "
  131. f"with unwanted type `{field.outer_type_}`"
  132. )
  133. @contextmanager
  134. def monkeypatch_pydantic() -> Generator[None, None, None]:
  135. """Patch pydantic with our snooping versions of BaseModel and the con* functions.
  136. If the snooping functions see something they don't like, they'll raise a
  137. ModelCheckingException instance.
  138. """
  139. with contextlib.ExitStack() as patches:
  140. # Most Synapse code ought to import the patched objects directly from
  141. # `pydantic`. But we also patch their containing modules `pydantic.main` and
  142. # `pydantic.types` for completeness.
  143. patch_basemodel1 = unittest.mock.patch(
  144. "pydantic.BaseModel", new=PatchedBaseModel
  145. )
  146. patch_basemodel2 = unittest.mock.patch(
  147. "pydantic.main.BaseModel", new=PatchedBaseModel
  148. )
  149. patches.enter_context(patch_basemodel1)
  150. patches.enter_context(patch_basemodel2)
  151. for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG:
  152. wrapper: Callable = make_wrapper(factory)
  153. patch1 = unittest.mock.patch(f"pydantic.{factory.__name__}", new=wrapper)
  154. patch2 = unittest.mock.patch(
  155. f"pydantic.types.{factory.__name__}", new=wrapper
  156. )
  157. patches.enter_context(patch1)
  158. patches.enter_context(patch2)
  159. yield
  160. def format_model_checker_exception(e: ModelCheckerException) -> str:
  161. """Work out which line of code caused e. Format the line in a human-friendly way."""
  162. # TODO. FieldHasUnwantedTypeException gives better error messages. Can we ditch the
  163. # patches of constr() etc, and instead inspect fields to look for ConstrainedStr
  164. # with strict=False? There is some difficulty with the inheritance hierarchy
  165. # because StrictStr < ConstrainedStr < str.
  166. if isinstance(e, FieldHasUnwantedTypeException):
  167. return e.message
  168. elif isinstance(e, MissingStrictInConstrainedTypeException):
  169. frame_summary = traceback.extract_tb(e.__traceback__)[-2]
  170. return (
  171. f"Missing `strict=True` from {e.factory_name}() call \n"
  172. + traceback.format_list([frame_summary])[0].lstrip()
  173. )
  174. else:
  175. raise ValueError(f"Unknown exception {e}") from e
  176. def lint() -> int:
  177. """Try to import all of Synapse and see if we spot any Pydantic type coercions.
  178. Print any problems, then return a status code suitable for sys.exit."""
  179. failures = do_lint()
  180. if failures:
  181. print(f"Found {len(failures)} problem(s)")
  182. for failure in sorted(failures):
  183. print(failure)
  184. return os.EX_DATAERR if failures else os.EX_OK
  185. def do_lint() -> Set[str]:
  186. """Try to import all of Synapse and see if we spot any Pydantic type coercions."""
  187. failures = set()
  188. with monkeypatch_pydantic():
  189. logger.debug("Importing synapse")
  190. try:
  191. # TODO: make "synapse" an argument so we can target this script at
  192. # a subpackage
  193. module = importlib.import_module("synapse")
  194. except ModelCheckerException as e:
  195. logger.warning("Bad annotation found when importing synapse")
  196. failures.add(format_model_checker_exception(e))
  197. return failures
  198. try:
  199. logger.debug("Fetching subpackages")
  200. module_infos = list(
  201. pkgutil.walk_packages(module.__path__, f"{module.__name__}.")
  202. )
  203. except ModelCheckerException as e:
  204. logger.warning("Bad annotation found when looking for modules to import")
  205. failures.add(format_model_checker_exception(e))
  206. return failures
  207. for module_info in module_infos:
  208. logger.debug("Importing %s", module_info.name)
  209. try:
  210. importlib.import_module(module_info.name)
  211. except ModelCheckerException as e:
  212. logger.warning(
  213. f"Bad annotation found when importing {module_info.name}"
  214. )
  215. failures.add(format_model_checker_exception(e))
  216. return failures
  217. def run_test_snippet(source: str) -> None:
  218. """Exec a snippet of source code in an isolated environment."""
  219. # To emulate `source` being called at the top level of the module,
  220. # the globals and locals we provide apparently have to be the same mapping.
  221. #
  222. # > Remember that at the module level, globals and locals are the same dictionary.
  223. # > If exec gets two separate objects as globals and locals, the code will be
  224. # > executed as if it were embedded in a class definition.
  225. globals_: Dict[str, object]
  226. locals_: Dict[str, object]
  227. globals_ = locals_ = {}
  228. exec(textwrap.dedent(source), globals_, locals_)
  229. class TestConstrainedTypesPatch(unittest.TestCase):
  230. def test_expression_without_strict_raises(self) -> None:
  231. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  232. run_test_snippet(
  233. """
  234. try:
  235. from pydantic.v1 import constr
  236. except ImportError:
  237. from pydantic import constr
  238. constr()
  239. """
  240. )
  241. def test_called_as_module_attribute_raises(self) -> None:
  242. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  243. run_test_snippet(
  244. """
  245. import pydantic
  246. pydantic.constr()
  247. """
  248. )
  249. def test_wildcard_import_raises(self) -> None:
  250. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  251. run_test_snippet(
  252. """
  253. try:
  254. from pydantic.v1 import *
  255. except ImportError:
  256. from pydantic import *
  257. constr()
  258. """
  259. )
  260. def test_alternative_import_raises(self) -> None:
  261. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  262. run_test_snippet(
  263. """
  264. try:
  265. from pydantic.v1.types import constr
  266. except ImportError:
  267. from pydantic.types import constr
  268. constr()
  269. """
  270. )
  271. def test_alternative_import_attribute_raises(self) -> None:
  272. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  273. run_test_snippet(
  274. """
  275. try:
  276. from pydantic.v1 import types as pydantic_types
  277. except ImportError:
  278. from pydantic import types as pydantic_types
  279. pydantic_types.constr()
  280. """
  281. )
  282. def test_kwarg_but_no_strict_raises(self) -> None:
  283. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  284. run_test_snippet(
  285. """
  286. try:
  287. from pydantic.v1 import constr
  288. except ImportError:
  289. from pydantic import constr
  290. constr(min_length=10)
  291. """
  292. )
  293. def test_kwarg_strict_False_raises(self) -> None:
  294. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  295. run_test_snippet(
  296. """
  297. try:
  298. from pydantic.v1 import constr
  299. except ImportError:
  300. from pydantic import constr
  301. constr(strict=False)
  302. """
  303. )
  304. def test_kwarg_strict_True_doesnt_raise(self) -> None:
  305. with monkeypatch_pydantic():
  306. run_test_snippet(
  307. """
  308. try:
  309. from pydantic.v1 import constr
  310. except ImportError:
  311. from pydantic import constr
  312. constr(strict=True)
  313. """
  314. )
  315. def test_annotation_without_strict_raises(self) -> None:
  316. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  317. run_test_snippet(
  318. """
  319. try:
  320. from pydantic.v1 import constr
  321. except ImportError:
  322. from pydantic import constr
  323. x: constr()
  324. """
  325. )
  326. def test_field_annotation_without_strict_raises(self) -> None:
  327. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  328. run_test_snippet(
  329. """
  330. try:
  331. from pydantic.v1 import BaseModel, conint
  332. except ImportError:
  333. from pydantic import BaseModel, conint
  334. class C:
  335. x: conint()
  336. """
  337. )
  338. class TestFieldTypeInspection(unittest.TestCase):
  339. @parameterized.expand(
  340. [
  341. ("str",),
  342. ("bytes"),
  343. ("int",),
  344. ("float",),
  345. ("bool"),
  346. ("Optional[str]",),
  347. ("Union[None, str]",),
  348. ("List[str]",),
  349. ("List[List[str]]",),
  350. ("Dict[StrictStr, str]",),
  351. ("Dict[str, StrictStr]",),
  352. ("TypedDict('D', x=int)",),
  353. ]
  354. )
  355. def test_field_holding_unwanted_type_raises(self, annotation: str) -> None:
  356. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  357. run_test_snippet(
  358. f"""
  359. from typing import *
  360. try:
  361. from pydantic.v1 import *
  362. except ImportError:
  363. from pydantic import *
  364. class C(BaseModel):
  365. f: {annotation}
  366. """
  367. )
  368. @parameterized.expand(
  369. [
  370. ("StrictStr",),
  371. ("StrictBytes"),
  372. ("StrictInt",),
  373. ("StrictFloat",),
  374. ("StrictBool"),
  375. ("constr(strict=True, min_length=10)",),
  376. ("Optional[StrictStr]",),
  377. ("Union[None, StrictStr]",),
  378. ("List[StrictStr]",),
  379. ("List[List[StrictStr]]",),
  380. ("Dict[StrictStr, StrictStr]",),
  381. ("TypedDict('D', x=StrictInt)",),
  382. ]
  383. )
  384. def test_field_holding_accepted_type_doesnt_raise(self, annotation: str) -> None:
  385. with monkeypatch_pydantic():
  386. run_test_snippet(
  387. f"""
  388. from typing import *
  389. try:
  390. from pydantic.v1 import *
  391. except ImportError:
  392. from pydantic import *
  393. class C(BaseModel):
  394. f: {annotation}
  395. """
  396. )
  397. def test_field_holding_str_raises_with_alternative_import(self) -> None:
  398. with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
  399. run_test_snippet(
  400. """
  401. try:
  402. from pydantic.v1.main import BaseModel
  403. except ImportError:
  404. from pydantic.main import BaseModel
  405. class C(BaseModel):
  406. f: str
  407. """
  408. )
  409. parser = argparse.ArgumentParser()
  410. parser.add_argument("mode", choices=["lint", "test"], default="lint", nargs="?")
  411. parser.add_argument("-v", "--verbose", action="store_true")
  412. if __name__ == "__main__":
  413. args = parser.parse_args(sys.argv[1:])
  414. logging.basicConfig(
  415. format="%(asctime)s %(name)s:%(lineno)d %(levelname)s %(message)s",
  416. level=logging.DEBUG if args.verbose else logging.INFO,
  417. )
  418. # suppress logs we don't care about
  419. logging.getLogger("xmlschema").setLevel(logging.WARNING)
  420. if args.mode == "lint":
  421. sys.exit(lint())
  422. elif args.mode == "test":
  423. unittest.main(argv=sys.argv[:1])