certs.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. #***************************************************************************
  4. # _ _ ____ _
  5. # Project ___| | | | _ \| |
  6. # / __| | | | |_) | |
  7. # | (__| |_| | _ <| |___
  8. # \___|\___/|_| \_\_____|
  9. #
  10. # Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
  11. #
  12. # This software is licensed as described in the file COPYING, which
  13. # you should have received as part of this distribution. The terms
  14. # are also available at https://curl.se/docs/copyright.html.
  15. #
  16. # You may opt to use, copy, modify, merge, publish, distribute and/or sell
  17. # copies of the Software, and permit persons to whom the Software is
  18. # furnished to do so, under the terms of the COPYING file.
  19. #
  20. # This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
  21. # KIND, either express or implied.
  22. #
  23. # SPDX-License-Identifier: curl
  24. #
  25. ###########################################################################
  26. #
  27. import ipaddress
  28. import os
  29. import re
  30. from datetime import timedelta, datetime, timezone
  31. from typing import List, Any, Optional
  32. from cryptography import x509
  33. from cryptography.hazmat.backends import default_backend
  34. from cryptography.hazmat.primitives import hashes
  35. from cryptography.hazmat.primitives.asymmetric import ec, rsa
  36. from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
  37. from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
  38. from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key
  39. from cryptography.x509 import ExtendedKeyUsageOID, NameOID
  40. EC_SUPPORTED = {}
  41. EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [
  42. ec.SECP192R1,
  43. ec.SECP224R1,
  44. ec.SECP256R1,
  45. ec.SECP384R1,
  46. ]])
  47. def _private_key(key_type):
  48. if isinstance(key_type, str):
  49. key_type = key_type.upper()
  50. m = re.match(r'^(RSA)?(\d+)$', key_type)
  51. if m:
  52. key_type = int(m.group(2))
  53. if isinstance(key_type, int):
  54. return rsa.generate_private_key(
  55. public_exponent=65537,
  56. key_size=key_type,
  57. backend=default_backend()
  58. )
  59. if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED:
  60. key_type = EC_SUPPORTED[key_type]
  61. return ec.generate_private_key(
  62. curve=key_type,
  63. backend=default_backend()
  64. )
  65. class CertificateSpec:
  66. def __init__(self, name: Optional[str] = None,
  67. domains: Optional[List[str]] = None,
  68. email: Optional[str] = None,
  69. key_type: Optional[str] = None,
  70. single_file: bool = False,
  71. valid_from: timedelta = timedelta(days=-1),
  72. valid_to: timedelta = timedelta(days=89),
  73. client: bool = False,
  74. check_valid: bool = True,
  75. sub_specs: Optional[List['CertificateSpec']] = None):
  76. self._name = name
  77. self.domains = domains
  78. self.client = client
  79. self.email = email
  80. self.key_type = key_type
  81. self.single_file = single_file
  82. self.valid_from = valid_from
  83. self.valid_to = valid_to
  84. self.sub_specs = sub_specs
  85. self.check_valid = check_valid
  86. @property
  87. def name(self) -> Optional[str]:
  88. if self._name:
  89. return self._name
  90. elif self.domains:
  91. return self.domains[0]
  92. return None
  93. @property
  94. def type(self) -> Optional[str]:
  95. if self.domains and len(self.domains):
  96. return "server"
  97. elif self.client:
  98. return "client"
  99. elif self.name:
  100. return "ca"
  101. return None
  102. class Credentials:
  103. def __init__(self,
  104. name: str,
  105. cert: Any,
  106. pkey: Any,
  107. issuer: Optional['Credentials'] = None):
  108. self._name = name
  109. self._cert = cert
  110. self._pkey = pkey
  111. self._issuer = issuer
  112. self._cert_file = None
  113. self._pkey_file = None
  114. self._store = None
  115. @property
  116. def name(self) -> str:
  117. return self._name
  118. @property
  119. def subject(self) -> x509.Name:
  120. return self._cert.subject
  121. @property
  122. def key_type(self):
  123. if isinstance(self._pkey, RSAPrivateKey):
  124. return f"rsa{self._pkey.key_size}"
  125. elif isinstance(self._pkey, EllipticCurvePrivateKey):
  126. return f"{self._pkey.curve.name}"
  127. else:
  128. raise Exception(f"unknown key type: {self._pkey}")
  129. @property
  130. def private_key(self) -> Any:
  131. return self._pkey
  132. @property
  133. def certificate(self) -> Any:
  134. return self._cert
  135. @property
  136. def cert_pem(self) -> bytes:
  137. return self._cert.public_bytes(Encoding.PEM)
  138. @property
  139. def pkey_pem(self) -> bytes:
  140. return self._pkey.private_bytes(
  141. Encoding.PEM,
  142. PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8,
  143. NoEncryption())
  144. @property
  145. def issuer(self) -> Optional['Credentials']:
  146. return self._issuer
  147. def set_store(self, store: 'CertStore'):
  148. self._store = store
  149. def set_files(self, cert_file: str, pkey_file: Optional[str] = None,
  150. combined_file: Optional[str] = None):
  151. self._cert_file = cert_file
  152. self._pkey_file = pkey_file
  153. self._combined_file = combined_file
  154. @property
  155. def cert_file(self) -> str:
  156. return self._cert_file
  157. @property
  158. def pkey_file(self) -> Optional[str]:
  159. return self._pkey_file
  160. @property
  161. def combined_file(self) -> Optional[str]:
  162. return self._combined_file
  163. def get_first(self, name) -> Optional['Credentials']:
  164. creds = self._store.get_credentials_for_name(name) if self._store else []
  165. return creds[0] if len(creds) else None
  166. def get_credentials_for_name(self, name) -> List['Credentials']:
  167. return self._store.get_credentials_for_name(name) if self._store else []
  168. def issue_certs(self, specs: List[CertificateSpec],
  169. chain: Optional[List['Credentials']] = None) -> List['Credentials']:
  170. return [self.issue_cert(spec=spec, chain=chain) for spec in specs]
  171. def issue_cert(self, spec: CertificateSpec,
  172. chain: Optional[List['Credentials']] = None) -> 'Credentials':
  173. key_type = spec.key_type if spec.key_type else self.key_type
  174. creds = None
  175. if self._store:
  176. creds = self._store.load_credentials(
  177. name=spec.name, key_type=key_type, single_file=spec.single_file,
  178. issuer=self, check_valid=spec.check_valid)
  179. if creds is None:
  180. creds = TestCA.create_credentials(spec=spec, issuer=self, key_type=key_type,
  181. valid_from=spec.valid_from, valid_to=spec.valid_to)
  182. if self._store:
  183. self._store.save(creds, single_file=spec.single_file)
  184. if spec.type == "ca":
  185. self._store.save_chain(creds, "ca", with_root=True)
  186. if spec.sub_specs:
  187. if self._store:
  188. sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name))
  189. creds.set_store(sub_store)
  190. subchain = chain.copy() if chain else []
  191. subchain.append(self)
  192. creds.issue_certs(spec.sub_specs, chain=subchain)
  193. return creds
  194. class CertStore:
  195. def __init__(self, fpath: str):
  196. self._store_dir = fpath
  197. if not os.path.exists(self._store_dir):
  198. os.makedirs(self._store_dir)
  199. self._creds_by_name = {}
  200. @property
  201. def path(self) -> str:
  202. return self._store_dir
  203. def save(self, creds: Credentials, name: Optional[str] = None,
  204. chain: Optional[List[Credentials]] = None,
  205. single_file: bool = False) -> None:
  206. name = name if name is not None else creds.name
  207. cert_file = self.get_cert_file(name=name, key_type=creds.key_type)
  208. pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type)
  209. comb_file = self.get_combined_file(name=name, key_type=creds.key_type)
  210. if single_file:
  211. pkey_file = None
  212. with open(cert_file, "wb") as fd:
  213. fd.write(creds.cert_pem)
  214. if chain:
  215. for c in chain:
  216. fd.write(c.cert_pem)
  217. if pkey_file is None:
  218. fd.write(creds.pkey_pem)
  219. if pkey_file is not None:
  220. with open(pkey_file, "wb") as fd:
  221. fd.write(creds.pkey_pem)
  222. with open(comb_file, "wb") as fd:
  223. fd.write(creds.cert_pem)
  224. if chain:
  225. for c in chain:
  226. fd.write(c.cert_pem)
  227. fd.write(creds.pkey_pem)
  228. creds.set_files(cert_file, pkey_file, comb_file)
  229. self._add_credentials(name, creds)
  230. def save_chain(self, creds: Credentials, infix: str, with_root=False):
  231. name = creds.name
  232. chain = [creds]
  233. while creds.issuer is not None:
  234. creds = creds.issuer
  235. chain.append(creds)
  236. if not with_root and len(chain) > 1:
  237. chain = chain[:-1]
  238. chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem')
  239. with open(chain_file, "wb") as fd:
  240. for c in chain:
  241. fd.write(c.cert_pem)
  242. def _add_credentials(self, name: str, creds: Credentials):
  243. if name not in self._creds_by_name:
  244. self._creds_by_name[name] = []
  245. self._creds_by_name[name].append(creds)
  246. def get_credentials_for_name(self, name) -> List[Credentials]:
  247. return self._creds_by_name[name] if name in self._creds_by_name else []
  248. def get_cert_file(self, name: str, key_type=None) -> str:
  249. key_infix = ".{0}".format(key_type) if key_type is not None else ""
  250. return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem')
  251. def get_pkey_file(self, name: str, key_type=None) -> str:
  252. key_infix = ".{0}".format(key_type) if key_type is not None else ""
  253. return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem')
  254. def get_combined_file(self, name: str, key_type=None) -> str:
  255. return os.path.join(self._store_dir, f'{name}.pem')
  256. def load_pem_cert(self, fpath: str) -> x509.Certificate:
  257. with open(fpath) as fd:
  258. return x509.load_pem_x509_certificate("".join(fd.readlines()).encode())
  259. def load_pem_pkey(self, fpath: str):
  260. with open(fpath) as fd:
  261. return load_pem_private_key("".join(fd.readlines()).encode(), password=None)
  262. def load_credentials(self, name: str, key_type=None,
  263. single_file: bool = False,
  264. issuer: Optional[Credentials] = None,
  265. check_valid: bool = False):
  266. cert_file = self.get_cert_file(name=name, key_type=key_type)
  267. pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type)
  268. comb_file = self.get_combined_file(name=name, key_type=key_type)
  269. if os.path.isfile(cert_file) and os.path.isfile(pkey_file):
  270. cert = self.load_pem_cert(cert_file)
  271. pkey = self.load_pem_pkey(pkey_file)
  272. try:
  273. now = datetime.now(tz=timezone.utc)
  274. if check_valid and \
  275. ((cert.not_valid_after_utc < now) or
  276. (cert.not_valid_before_utc > now)):
  277. return None
  278. except AttributeError: # older python
  279. now = datetime.now()
  280. if check_valid and \
  281. ((cert.not_valid_after < now) or
  282. (cert.not_valid_before > now)):
  283. return None
  284. creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
  285. creds.set_store(self)
  286. creds.set_files(cert_file, pkey_file, comb_file)
  287. self._add_credentials(name, creds)
  288. return creds
  289. return None
  290. class TestCA:
  291. @classmethod
  292. def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials:
  293. store = CertStore(fpath=store_dir)
  294. creds = store.load_credentials(name="ca", key_type=key_type, issuer=None)
  295. if creds is None:
  296. creds = TestCA._make_ca_credentials(name=name, key_type=key_type)
  297. store.save(creds, name="ca")
  298. creds.set_store(store)
  299. return creds
  300. @staticmethod
  301. def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any,
  302. valid_from: timedelta = timedelta(days=-1),
  303. valid_to: timedelta = timedelta(days=89),
  304. ) -> Credentials:
  305. """Create a certificate signed by this CA for the given domains.
  306. :returns: the certificate and private key PEM file paths
  307. """
  308. if spec.domains and len(spec.domains):
  309. creds = TestCA._make_server_credentials(name=spec.name, domains=spec.domains,
  310. issuer=issuer, valid_from=valid_from,
  311. valid_to=valid_to, key_type=key_type)
  312. elif spec.client:
  313. creds = TestCA._make_client_credentials(name=spec.name, issuer=issuer,
  314. email=spec.email, valid_from=valid_from,
  315. valid_to=valid_to, key_type=key_type)
  316. elif spec.name:
  317. creds = TestCA._make_ca_credentials(name=spec.name, issuer=issuer,
  318. valid_from=valid_from, valid_to=valid_to,
  319. key_type=key_type)
  320. else:
  321. raise Exception(f"unrecognized certificate specification: {spec}")
  322. return creds
  323. @staticmethod
  324. def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name:
  325. name_pieces = []
  326. if org_name:
  327. oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME
  328. name_pieces.append(x509.NameAttribute(oid, org_name))
  329. elif common_name:
  330. name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name))
  331. if parent:
  332. name_pieces.extend([rdn for rdn in parent])
  333. return x509.Name(name_pieces)
  334. @staticmethod
  335. def _make_csr(
  336. subject: x509.Name,
  337. pkey: Any,
  338. issuer_subject: Optional[Credentials],
  339. valid_from_delta: timedelta = None,
  340. valid_until_delta: timedelta = None
  341. ):
  342. pubkey = pkey.public_key()
  343. issuer_subject = issuer_subject if issuer_subject is not None else subject
  344. valid_from = datetime.now()
  345. if valid_until_delta is not None:
  346. valid_from += valid_from_delta
  347. valid_until = datetime.now()
  348. if valid_until_delta is not None:
  349. valid_until += valid_until_delta
  350. return (
  351. x509.CertificateBuilder()
  352. .subject_name(subject)
  353. .issuer_name(issuer_subject)
  354. .public_key(pubkey)
  355. .not_valid_before(valid_from)
  356. .not_valid_after(valid_until)
  357. .serial_number(x509.random_serial_number())
  358. .add_extension(
  359. x509.SubjectKeyIdentifier.from_public_key(pubkey),
  360. critical=False,
  361. )
  362. )
  363. @staticmethod
  364. def _add_ca_usages(csr: Any) -> Any:
  365. return csr.add_extension(
  366. x509.BasicConstraints(ca=True, path_length=9),
  367. critical=True,
  368. ).add_extension(
  369. x509.KeyUsage(
  370. digital_signature=True,
  371. content_commitment=False,
  372. key_encipherment=False,
  373. data_encipherment=False,
  374. key_agreement=False,
  375. key_cert_sign=True,
  376. crl_sign=True,
  377. encipher_only=False,
  378. decipher_only=False),
  379. critical=True
  380. ).add_extension(
  381. x509.ExtendedKeyUsage([
  382. ExtendedKeyUsageOID.CLIENT_AUTH,
  383. ExtendedKeyUsageOID.SERVER_AUTH,
  384. ExtendedKeyUsageOID.CODE_SIGNING,
  385. ]),
  386. critical=True
  387. )
  388. @staticmethod
  389. def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any:
  390. names = []
  391. for name in domains:
  392. try:
  393. names.append(x509.IPAddress(ipaddress.ip_address(name)))
  394. except:
  395. names.append(x509.DNSName(name))
  396. return csr.add_extension(
  397. x509.BasicConstraints(ca=False, path_length=None),
  398. critical=True,
  399. ).add_extension(
  400. x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
  401. issuer.certificate.extensions.get_extension_for_class(
  402. x509.SubjectKeyIdentifier).value),
  403. critical=False
  404. ).add_extension(
  405. x509.SubjectAlternativeName(names), critical=True,
  406. ).add_extension(
  407. x509.ExtendedKeyUsage([
  408. ExtendedKeyUsageOID.SERVER_AUTH,
  409. ]),
  410. critical=False
  411. )
  412. @staticmethod
  413. def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any:
  414. cert = csr.add_extension(
  415. x509.BasicConstraints(ca=False, path_length=None),
  416. critical=True,
  417. ).add_extension(
  418. x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
  419. issuer.certificate.extensions.get_extension_for_class(
  420. x509.SubjectKeyIdentifier).value),
  421. critical=False
  422. )
  423. if rfc82name:
  424. cert.add_extension(
  425. x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]),
  426. critical=True,
  427. )
  428. cert.add_extension(
  429. x509.ExtendedKeyUsage([
  430. ExtendedKeyUsageOID.CLIENT_AUTH,
  431. ]),
  432. critical=True
  433. )
  434. return cert
  435. @staticmethod
  436. def _make_ca_credentials(name, key_type: Any,
  437. issuer: Credentials = None,
  438. valid_from: timedelta = timedelta(days=-1),
  439. valid_to: timedelta = timedelta(days=89),
  440. ) -> Credentials:
  441. pkey = _private_key(key_type=key_type)
  442. if issuer is not None:
  443. issuer_subject = issuer.certificate.subject
  444. issuer_key = issuer.private_key
  445. else:
  446. issuer_subject = None
  447. issuer_key = pkey
  448. subject = TestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None)
  449. csr = TestCA._make_csr(subject=subject,
  450. issuer_subject=issuer_subject, pkey=pkey,
  451. valid_from_delta=valid_from, valid_until_delta=valid_to)
  452. csr = TestCA._add_ca_usages(csr)
  453. cert = csr.sign(private_key=issuer_key,
  454. algorithm=hashes.SHA256(),
  455. backend=default_backend())
  456. return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
  457. @staticmethod
  458. def _make_server_credentials(name: str, domains: List[str], issuer: Credentials,
  459. key_type: Any,
  460. valid_from: timedelta = timedelta(days=-1),
  461. valid_to: timedelta = timedelta(days=89),
  462. ) -> Credentials:
  463. name = name
  464. pkey = _private_key(key_type=key_type)
  465. subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
  466. csr = TestCA._make_csr(subject=subject,
  467. issuer_subject=issuer.certificate.subject, pkey=pkey,
  468. valid_from_delta=valid_from, valid_until_delta=valid_to)
  469. csr = TestCA._add_leaf_usages(csr, domains=domains, issuer=issuer)
  470. cert = csr.sign(private_key=issuer.private_key,
  471. algorithm=hashes.SHA256(),
  472. backend=default_backend())
  473. return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
  474. @staticmethod
  475. def _make_client_credentials(name: str,
  476. issuer: Credentials, email: Optional[str],
  477. key_type: Any,
  478. valid_from: timedelta = timedelta(days=-1),
  479. valid_to: timedelta = timedelta(days=89),
  480. ) -> Credentials:
  481. pkey = _private_key(key_type=key_type)
  482. subject = TestCA._make_x509_name(common_name=name, parent=issuer.subject)
  483. csr = TestCA._make_csr(subject=subject,
  484. issuer_subject=issuer.certificate.subject, pkey=pkey,
  485. valid_from_delta=valid_from, valid_until_delta=valid_to)
  486. csr = TestCA._add_client_usages(csr, issuer=issuer, rfc82name=email)
  487. cert = csr.sign(private_key=issuer.private_key,
  488. algorithm=hashes.SHA256(),
  489. backend=default_backend())
  490. return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)