123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443 |
- /*
- * Copyright 2021-2022 The OpenSSL Project Authors. All Rights Reserved.
- *
- * Licensed under the Apache License 2.0 (the "License"). You may not use
- * this file except in compliance with the License. You can obtain a copy
- * in the file LICENSE in the source distribution or at
- * https://www.openssl.org/source/license.html
- */
- #ifndef OSSL_INTERNAL_SAFE_MATH_H
- # define OSSL_INTERNAL_SAFE_MATH_H
- # pragma once
- # include <openssl/e_os2.h> /* For 'ossl_inline' */
- # ifndef OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING
- # ifdef __has_builtin
- # define has(func) __has_builtin(func)
- # elif __GNUC__ > 5
- # define has(func) 1
- # endif
- # endif /* OPENSSL_NO_BUILTIN_OVERFLOW_CHECKING */
- # ifndef has
- # define has(func) 0
- # endif
- /*
- * Safe addition helpers
- */
- # if has(__builtin_add_overflow)
- # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
- static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- type r; \
- \
- if (!__builtin_add_overflow(a, b, &r)) \
- return r; \
- *err |= 1; \
- return a < 0 ? min : max; \
- }
- # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
- static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- type r; \
- \
- if (!__builtin_add_overflow(a, b, &r)) \
- return r; \
- *err |= 1; \
- return a + b; \
- }
- # else /* has(__builtin_add_overflow) */
- # define OSSL_SAFE_MATH_ADDS(type_name, type, min, max) \
- static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if ((a < 0) ^ (b < 0) \
- || (a > 0 && b <= max - a) \
- || (a < 0 && b >= min - a) \
- || a == 0) \
- return a + b; \
- *err |= 1; \
- return a < 0 ? min : max; \
- }
- # define OSSL_SAFE_MATH_ADDU(type_name, type, max) \
- static ossl_inline ossl_unused type safe_add_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if (b > max - a) \
- *err |= 1; \
- return a + b; \
- }
- # endif /* has(__builtin_add_overflow) */
- /*
- * Safe subtraction helpers
- */
- # if has(__builtin_sub_overflow)
- # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
- static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- type r; \
- \
- if (!__builtin_sub_overflow(a, b, &r)) \
- return r; \
- *err |= 1; \
- return a < 0 ? min : max; \
- }
- # else /* has(__builtin_sub_overflow) */
- # define OSSL_SAFE_MATH_SUBS(type_name, type, min, max) \
- static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if (!((a < 0) ^ (b < 0)) \
- || (b > 0 && a >= min + b) \
- || (b < 0 && a <= max + b) \
- || b == 0) \
- return a - b; \
- *err |= 1; \
- return a < 0 ? min : max; \
- }
- # endif /* has(__builtin_sub_overflow) */
- # define OSSL_SAFE_MATH_SUBU(type_name, type) \
- static ossl_inline ossl_unused type safe_sub_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if (b > a) \
- *err |= 1; \
- return a - b; \
- }
- /*
- * Safe multiplication helpers
- */
- # if has(__builtin_mul_overflow)
- # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
- static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- type r; \
- \
- if (!__builtin_mul_overflow(a, b, &r)) \
- return r; \
- *err |= 1; \
- return (a < 0) ^ (b < 0) ? min : max; \
- }
- # define OSSL_SAFE_MATH_MULU(type_name, type, max) \
- static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- type r; \
- \
- if (!__builtin_mul_overflow(a, b, &r)) \
- return r; \
- *err |= 1; \
- return a * b; \
- }
- # else /* has(__builtin_mul_overflow) */
- # define OSSL_SAFE_MATH_MULS(type_name, type, min, max) \
- static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if (a == 0 || b == 0) \
- return 0; \
- if (a == 1) \
- return b; \
- if (b == 1) \
- return a; \
- if (a != min && b != min) { \
- const type x = a < 0 ? -a : a; \
- const type y = b < 0 ? -b : b; \
- \
- if (x <= max / y) \
- return a * b; \
- } \
- *err |= 1; \
- return (a < 0) ^ (b < 0) ? min : max; \
- }
- # define OSSL_SAFE_MATH_MULU(type_name, type, max) \
- static ossl_inline ossl_unused type safe_mul_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if (b != 0 && a > max / b) \
- *err |= 1; \
- return a * b; \
- }
- # endif /* has(__builtin_mul_overflow) */
- /*
- * Safe division helpers
- */
- # define OSSL_SAFE_MATH_DIVS(type_name, type, min, max) \
- static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if (b == 0) { \
- *err |= 1; \
- return a < 0 ? min : max; \
- } \
- if (b == -1 && a == min) { \
- *err |= 1; \
- return max; \
- } \
- return a / b; \
- }
- # define OSSL_SAFE_MATH_DIVU(type_name, type, max) \
- static ossl_inline ossl_unused type safe_div_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if (b != 0) \
- return a / b; \
- *err |= 1; \
- return max; \
- }
- /*
- * Safe modulus helpers
- */
- # define OSSL_SAFE_MATH_MODS(type_name, type, min, max) \
- static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if (b == 0) { \
- *err |= 1; \
- return 0; \
- } \
- if (b == -1 && a == min) { \
- *err |= 1; \
- return max; \
- } \
- return a % b; \
- }
- # define OSSL_SAFE_MATH_MODU(type_name, type) \
- static ossl_inline ossl_unused type safe_mod_ ## type_name(type a, \
- type b, \
- int *err) \
- { \
- if (b != 0) \
- return a % b; \
- *err |= 1; \
- return 0; \
- }
- /*
- * Safe negation helpers
- */
- # define OSSL_SAFE_MATH_NEGS(type_name, type, min) \
- static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
- int *err) \
- { \
- if (a != min) \
- return -a; \
- *err |= 1; \
- return min; \
- }
- # define OSSL_SAFE_MATH_NEGU(type_name, type) \
- static ossl_inline ossl_unused type safe_neg_ ## type_name(type a, \
- int *err) \
- { \
- if (a == 0) \
- return a; \
- *err |= 1; \
- return 1 + ~a; \
- }
- /*
- * Safe absolute value helpers
- */
- # define OSSL_SAFE_MATH_ABSS(type_name, type, min) \
- static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
- int *err) \
- { \
- if (a != min) \
- return a < 0 ? -a : a; \
- *err |= 1; \
- return min; \
- }
- # define OSSL_SAFE_MATH_ABSU(type_name, type) \
- static ossl_inline ossl_unused type safe_abs_ ## type_name(type a, \
- int *err) \
- { \
- return a; \
- }
- /*
- * Safe fused multiply divide helpers
- *
- * These are a bit obscure:
- * . They begin by checking the denominator for zero and getting rid of this
- * corner case.
- *
- * . Second is an attempt to do the multiplication directly, if it doesn't
- * overflow, the quotient is returned (for signed values there is a
- * potential problem here which isn't present for unsigned).
- *
- * . Finally, the multiplication/division is transformed so that the larger
- * of the numerators is divided first. This requires a remainder
- * correction:
- *
- * a b / c = (a / c) b + (a mod c) b / c, where a > b
- *
- * The individual operations need to be overflow checked (again signed
- * being more problematic).
- *
- * The algorithm used is not perfect but it should be "good enough".
- */
- # define OSSL_SAFE_MATH_MULDIVS(type_name, type, max) \
- static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
- type b, \
- type c, \
- int *err) \
- { \
- int e2 = 0; \
- type q, r, x, y; \
- \
- if (c == 0) { \
- *err |= 1; \
- return a == 0 || b == 0 ? 0 : max; \
- } \
- x = safe_mul_ ## type_name(a, b, &e2); \
- if (!e2) \
- return safe_div_ ## type_name(x, c, err); \
- if (b > a) { \
- x = b; \
- b = a; \
- a = x; \
- } \
- q = safe_div_ ## type_name(a, c, err); \
- r = safe_mod_ ## type_name(a, c, err); \
- x = safe_mul_ ## type_name(r, b, err); \
- y = safe_mul_ ## type_name(q, b, err); \
- q = safe_div_ ## type_name(x, c, err); \
- return safe_add_ ## type_name(y, q, err); \
- }
- # define OSSL_SAFE_MATH_MULDIVU(type_name, type, max) \
- static ossl_inline ossl_unused type safe_muldiv_ ## type_name(type a, \
- type b, \
- type c, \
- int *err) \
- { \
- int e2 = 0; \
- type x, y; \
- \
- if (c == 0) { \
- *err |= 1; \
- return a == 0 || b == 0 ? 0 : max; \
- } \
- x = safe_mul_ ## type_name(a, b, &e2); \
- if (!e2) \
- return x / c; \
- if (b > a) { \
- x = b; \
- b = a; \
- a = x; \
- } \
- x = safe_mul_ ## type_name(a % c, b, err); \
- y = safe_mul_ ## type_name(a / c, b, err); \
- return safe_add_ ## type_name(y, x / c, err); \
- }
- /*
- * Calculate a / b rounding up:
- * i.e. a / b + (a % b != 0)
- * Which is usually (less safely) converted to (a + b - 1) / b
- * If you *know* that b != 0, then it's safe to ignore err.
- */
- #define OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, max) \
- static ossl_inline ossl_unused type safe_div_round_up_ ## type_name \
- (type a, type b, int *errp) \
- { \
- type x; \
- int *err, err_local = 0; \
- \
- /* Allow errors to be ignored by callers */ \
- err = errp != NULL ? errp : &err_local; \
- /* Fast path, both positive */ \
- if (b > 0 && a > 0) { \
- /* Faster path: no overflow concerns */ \
- if (a < max - b) \
- return (a + b - 1) / b; \
- return a / b + (a % b != 0); \
- } \
- if (b == 0) { \
- *err |= 1; \
- return a == 0 ? 0 : max; \
- } \
- if (a == 0) \
- return 0; \
- /* Rather slow path because there are negatives involved */ \
- x = safe_mod_ ## type_name(a, b, err); \
- return safe_add_ ## type_name(safe_div_ ## type_name(a, b, err), \
- x != 0, err); \
- }
- /* Calculate ranges of types */
- # define OSSL_SAFE_MATH_MINS(type) ((type)1 << (sizeof(type) * 8 - 1))
- # define OSSL_SAFE_MATH_MAXS(type) (~OSSL_SAFE_MATH_MINS(type))
- # define OSSL_SAFE_MATH_MAXU(type) (~(type)0)
- /*
- * Wrapper macros to create all the functions of a given type
- */
- # define OSSL_SAFE_MATH_SIGNED(type_name, type) \
- OSSL_SAFE_MATH_ADDS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
- OSSL_SAFE_MATH_MAXS(type)) \
- OSSL_SAFE_MATH_SUBS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
- OSSL_SAFE_MATH_MAXS(type)) \
- OSSL_SAFE_MATH_MULS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
- OSSL_SAFE_MATH_MAXS(type)) \
- OSSL_SAFE_MATH_DIVS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
- OSSL_SAFE_MATH_MAXS(type)) \
- OSSL_SAFE_MATH_MODS(type_name, type, OSSL_SAFE_MATH_MINS(type), \
- OSSL_SAFE_MATH_MAXS(type)) \
- OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
- OSSL_SAFE_MATH_MAXS(type)) \
- OSSL_SAFE_MATH_MULDIVS(type_name, type, OSSL_SAFE_MATH_MAXS(type)) \
- OSSL_SAFE_MATH_NEGS(type_name, type, OSSL_SAFE_MATH_MINS(type)) \
- OSSL_SAFE_MATH_ABSS(type_name, type, OSSL_SAFE_MATH_MINS(type))
- # define OSSL_SAFE_MATH_UNSIGNED(type_name, type) \
- OSSL_SAFE_MATH_ADDU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
- OSSL_SAFE_MATH_SUBU(type_name, type) \
- OSSL_SAFE_MATH_MULU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
- OSSL_SAFE_MATH_DIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
- OSSL_SAFE_MATH_MODU(type_name, type) \
- OSSL_SAFE_MATH_DIV_ROUND_UP(type_name, type, \
- OSSL_SAFE_MATH_MAXU(type)) \
- OSSL_SAFE_MATH_MULDIVU(type_name, type, OSSL_SAFE_MATH_MAXU(type)) \
- OSSL_SAFE_MATH_NEGU(type_name, type) \
- OSSL_SAFE_MATH_ABSU(type_name, type)
- #endif /* OSSL_INTERNAL_SAFE_MATH_H */
|