safe_math.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. /*
  2. * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved.
  3. *
  4. * Licensed under the Apache License 2.0 (the "License"). You may not use
  5. * this file except in compliance with the License. You can obtain a copy
  6. * in the file LICENSE in the source distribution or at
  7. * https://www.openssl.org/source/license.html
  8. */
  9. #ifndef OSSL_INTERNAL_SAFE_MATH_H
  10. # define OSSL_INTERNAL_SAFE_MATH_H
  11. # pragma once
  12. # include <openssl/e_os2.h> /* For 'ossl_inline' */
  13. # ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING
  14. # ifdef __has_builtin
  15. # define has(func) __has_builtin(func)
  16. # elif __GNUC__ > 5
  17. # define has(func) 1
  18. # endif
  19. # endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */
  20. # ifndef has
  21. # define has(func) 0
  22. # endif
  23. /*
  24. * Safe addition helpers
  25. */
  26. # if has(__builtin_add_overflow)
  27. # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
  28. static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
  29. type b, \
  30. int *err) \
  31. { \
  32. type r; \
  33. \
  34. if (!__builtin_add_overflow(a, b, &r)) \
  35. return r; \
  36. *err |= 1; \
  37. return a < 0 ? min : max; \
  38. }
  39. # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
  40. static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
  41. type b, \
  42. int *err) \
  43. { \
  44. type r; \
  45. \
  46. if (!__builtin_add_overflow(a, b, &r)) \
  47. return r; \
  48. *err |= 1; \
  49. return a + b; \
  50. }
  51. # else /* has(__builtin_add_overflow) */
  52. # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
  53. static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
  54. type b, \
  55. int *err) \
  56. { \
  57. if ((a < 0) ^ (b < 0) \
  58. || (a > 0 && b <= max - a) \
  59. || (a < 0 && b >= min - a) \
  60. || a == 0) \
  61. return a + b; \
  62. *err |= 1; \
  63. return a < 0 ? min : max; \
  64. }
  65. # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
  66. static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
  67. type b, \
  68. int *err) \
  69. { \
  70. if (b > max - a) \
  71. *err |= 1; \
  72. return a + b; \
  73. }
  74. # endif /* has(__builtin_add_overflow) */
  75. /*
  76. * Safe subtraction helpers
  77. */
  78. # if has(__builtin_sub_overflow)
  79. # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
  80. static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
  81. type b, \
  82. int *err) \
  83. { \
  84. type r; \
  85. \
  86. if (!__builtin_sub_overflow(a, b, &r)) \
  87. return r; \
  88. *err |= 1; \
  89. return a < 0 ? min : max; \
  90. }
  91. # else /* has(__builtin_sub_overflow) */
  92. # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
  93. static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
  94. type b, \
  95. int *err) \
  96. { \
  97. if (!((a < 0) ^ (b < 0)) \
  98. || (b > 0 && a >= min + b) \
  99. || (b < 0 && a <= max + b) \
  100. || b == 0) \
  101. return a - b; \
  102. *err |= 1; \
  103. return a < 0 ? min : max; \
  104. }
  105. # endif /* has(__builtin_sub_overflow) */
  106. # define OSSL_SAFE_MATH_SUBU(type_name, type) \
  107. static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
  108. type b, \
  109. int *err) \
  110. { \
  111. if (b > a) \
  112. *err |= 1; \
  113. return a - b; \
  114. }
  115. /*
  116. * Safe multiplication helpers
  117. */
  118. # if has(__builtin_mul_overflow)
  119. # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
  120. static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
  121. type b, \
  122. int *err) \
  123. { \
  124. type r; \
  125. \
  126. if (!__builtin_mul_overflow(a, b, &r)) \
  127. return r; \
  128. *err |= 1; \
  129. return (a < 0) ^ (b < 0) ? min : max; \
  130. }
  131. # define OSSL_SAFE_MATH_MULU(type_name, type, max) \
  132. static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
  133. type b, \
  134. int *err) \
  135. { \
  136. type r; \
  137. \
  138. if (!__builtin_mul_overflow(a, b, &r)) \
  139. return r; \
  140. *err |= 1; \
  141. return a * b; \
  142. }
  143. # else /* has(__builtin_mul_overflow) */
  144. # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
  145. static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
  146. type b, \
  147. int *err) \
  148. { \
  149. if (a == 0 || b == 0) \
  150. return 0; \
  151. if (a == 1) \
  152. return b; \
  153. if (b == 1) \
  154. return a; \
  155. if (a != min && b != min) { \
  156. const type x = a < 0 ? -a : a; \
  157. const type y = b < 0 ? -b : b; \
  158. \
  159. if (x <= max / y) \
  160. return a * b; \
  161. } \
  162. *err |= 1; \
  163. return (a < 0) ^ (b < 0) ? min : max; \
  164. }
  165. # define OSSL_SAFE_MATH_MULU(type_name, type, max) \
  166. static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
  167. type b, \
  168. int *err) \
  169. { \
  170. if (b != 0 && a > max / b) \
  171. *err |= 1; \
  172. return a * b; \
  173. }
  174. # endif /* has(__builtin_mul_overflow) */
  175. /*
  176. * Safe division helpers
  177. */
  178. # define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \
  179. static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
  180. type b, \
  181. int *err) \
  182. { \
  183. if (b == 0) { \
  184. *err |= 1; \
  185. return a < 0 ? min : max; \
  186. } \
  187. if (b == -1 && a == min) { \
  188. *err |= 1; \
  189. return max; \
  190. } \
  191. return a / b; \
  192. }
  193. # define OSSL_SAFE_MATH_DIVU(type_name, type, max) \
  194. static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
  195. type b, \
  196. int *err) \
  197. { \
  198. if (b != 0) \
  199. return a / b; \
  200. *err |= 1; \
  201. return max; \
  202. }
  203. /*
  204. * Safe modulus helpers
  205. */
  206. # define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \
  207. static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
  208. type b, \
  209. int *err) \
  210. { \
  211. if (b == 0) { \
  212. *err |= 1; \
  213. return 0; \
  214. } \
  215. if (b == -1 && a == min) { \
  216. *err |= 1; \
  217. return max; \
  218. } \
  219. return a % b; \
  220. }
  221. # define OSSL_SAFE_MATH_MODU(type_name, type) \
  222. static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
  223. type b, \
  224. int *err) \
  225. { \
  226. if (b != 0) \
  227. return a % b; \
  228. *err |= 1; \
  229. return 0; \
  230. }
  231. /*
  232. * Safe negation helpers
  233. */
  234. # define OSSL_SAFE_MATH_NEGS(type_name, type, min) \
  235. static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
  236. int *err) \
  237. { \
  238. if (a != min) \
  239. return -a; \
  240. *err |= 1; \
  241. return min; \
  242. }
  243. # define OSSL_SAFE_MATH_NEGU(type_name, type) \
  244. static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
  245. int *err) \
  246. { \
  247. if (a == 0) \
  248. return a; \
  249. *err |= 1; \
  250. return 1 + ~a; \
  251. }
  252. /*
  253. * Safe absolute value helpers
  254. */
  255. # define OSSL_SAFE_MATH_ABSS(type_name, type, min) \
  256. static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
  257. int *err) \
  258. { \
  259. if (a != min) \
  260. return a < 0 ? -a : a; \
  261. *err |= 1; \
  262. return min; \
  263. }
  264. # define OSSL_SAFE_MATH_ABSU(type_name, type) \
  265. static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
  266. int *err) \
  267. { \
  268. return a; \
  269. }
  270. /*
  271. * Safe fused multiply divide helpers
  272. *
  273. * These are a bit obscure:
  274. * . They begin by checking the denominator for zero and getting rid of this
  275. * corner case.
  276. *
  277. * . Second is an attempt to do the multiplication directly, if it doesn't
  278. * overflow, the quotient is returned (for signed values there is a
  279. * potential problem here which isn't present for unsigned).
  280. *
  281. * . Finally, the multiplication/division is transformed so that the larger
  282. * of the numerators is divided first. This requires a remainder
  283. * correction:
  284. *
  285. * a b / c = (a / c) b + (a mod c) b / c, where a > b
  286. *
  287. * The individual operations need to be overflow checked (again signed
  288. * being more problematic).
  289. *
  290. * The algorithm used is not perfect but it should be "good enough".
  291. */
  292. # define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \
  293. static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
  294. type b, \
  295. type c, \
  296. int *err) \
  297. { \
  298. int e2 = 0; \
  299. type q, r, x, y; \
  300. \
  301. if (c == 0) { \
  302. *err |= 1; \
  303. return a == 0 || b == 0 ? 0 : max; \
  304. } \
  305. x = safe_mul_ ## type_name(a, b, &e2); \
  306. if (!e2) \
  307. return safe_div_ ## type_name(x, c, err); \
  308. if (b > a) { \
  309. x = b; \
  310. b = a; \
  311. a = x; \
  312. } \
  313. q = safe_div_ ## type_name(a, c, err); \
  314. r = safe_mod_ ## type_name(a, c, err); \
  315. x = safe_mul_ ## type_name(r, b, err); \
  316. y = safe_mul_ ## type_name(q, b, err); \
  317. q = safe_div_ ## type_name(x, c, err); \
  318. return safe_add_ ## type_name(y, q, err); \
  319. }
  320. # define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \
  321. static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
  322. type b, \
  323. type c, \
  324. int *err) \
  325. { \
  326. int e2 = 0; \
  327. type x, y; \
  328. \
  329. if (c == 0) { \
  330. *err |= 1; \
  331. return a == 0 || b == 0 ? 0 : max; \
  332. } \
  333. x = safe_mul_ ## type_name(a, b, &e2); \
  334. if (!e2) \
  335. return x / c; \
  336. if (b > a) { \
  337. x = b; \
  338. b = a; \
  339. a = x; \
  340. } \
  341. x = safe_mul_ ## type_name(a % c, b, err); \
  342. y = safe_mul_ ## type_name(a / c, b, err); \
  343. return safe_add_ ## type_name(y, x / c, err); \
  344. }
  345. /*
  346. * Calculate a / b rounding up:
  347. * i.e. a / b + (a % b != 0)
  348. * Which is usually (less safely) converted to (a + b - 1) / b
  349. * If you *know* that b != 0, then it's safe to ignore err.
  350. */
  351. #define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \
  352. static ossl_inline ossl_unused type safe_div_round_up_ ## type_name \
  353. (type a, type b, int *errp) \
  354. { \
  355. type x; \
  356. int *err, err_local = 0; \
  357. \
  358. /* Allow errors to be ignored by callers */ \
  359. err = errp != NULL ? errp : &err_local; \
  360. /* Fast path, both positive */ \
  361. if (b > 0 && a > 0) { \
  362. /* Faster path: no overflow concerns */ \
  363. if (a < max - b) \
  364. return (a + b - 1) / b; \
  365. return a / b + (a % b != 0); \
  366. } \
  367. if (b == 0) { \
  368. *err |= 1; \
  369. return a == 0 ? 0 : max; \
  370. } \
  371. if (a == 0) \
  372. return 0; \
  373. /* Rather slow path because there are negatives involved */ \
  374. x = safe_mod_ ## type_name(a, b, err); \
  375. return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err), \
  376. x != 0, err); \
  377. }
  378. /* Calculate ranges of types */
  379. # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
  380. # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
  381. # define OSSL_SAFE_MATH_MAXU(type) (~(type)0)
  382. /*
  383. * Wrapper macros to create all the functions of a given type
  384. */
  385. # define OSSL_SAFE_MATH_SIGNED(type_name, type) \
  386. OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  387. OSSL_SAFE_MATH_MAXS(type)) \
  388. OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  389. OSSL_SAFE_MATH_MAXS(type)) \
  390. OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  391. OSSL_SAFE_MATH_MAXS(type)) \
  392. OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  393. OSSL_SAFE_MATH_MAXS(type)) \
  394. OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
  395. OSSL_SAFE_MATH_MAXS(type)) \
  396. OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
  397. OSSL_SAFE_MATH_MAXS(type)) \
  398. OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \
  399. OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type)) \
  400. OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
  401. # define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \
  402. OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
  403. OSSL_SAFE_MATH_SUBU(type_name, type) \
  404. OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
  405. OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
  406. OSSL_SAFE_MATH_MODU(type_name, type) \
  407. OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
  408. OSSL_SAFE_MATH_MAXU(type)) \
  409. OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
  410. OSSL_SAFE_MATH_NEGU(type_name, type) \
  411. OSSL_SAFE_MATH_ABSU(type_name, type)
  412. #endif /* OSSL_INTERNAL_SAFE_MATH_H */