ソースを参照

Merge pull request #6370 from SparkiDev/sp_int_copy_change

SP int: _sp_copy don't check a == b, change calls to _sp_copy
JacobBarthelmeh 1 年間 前
コミット
f05543c9e5
1 ファイル変更50 行追加47 行削除
  1. 50 47
      wolfcrypt/src/sp_int.c

+ 50 - 47
wolfcrypt/src/sp_int.c

@@ -5059,22 +5059,19 @@ void sp_forcezero(sp_int* a)
  */
 static void _sp_copy(const sp_int* a, sp_int* r)
 {
-    /* Only copy if different pointers. */
-    if (a != r) {
-        /* Copy words across. */
-        if (a->used == 0) {
-            r->dp[0] = 0;
-        }
-        else {
-            XMEMCPY(r->dp, a->dp, a->used * SP_WORD_SIZEOF);
-        }
-        /* Set number of used words in result. */
-        r->used = a->used;
-    #ifdef WOLFSSL_SP_INT_NEGATIVE
-        /* Set sign of result. */
-        r->sign = a->sign;
-    #endif
+    /* Copy words across. */
+    if (a->used == 0) {
+        r->dp[0] = 0;
+    }
+    else {
+        XMEMCPY(r->dp, a->dp, a->used * SP_WORD_SIZEOF);
     }
+    /* Set number of used words in result. */
+    r->used = a->used;
+#ifdef WOLFSSL_SP_INT_NEGATIVE
+    /* Set sign of result. */
+    r->sign = a->sign;
+#endif
 }
 
 /* Copy value of multi-precision number a into r.
@@ -5092,12 +5089,15 @@ int sp_copy(const sp_int* a, sp_int* r)
     if ((a == NULL) || (r == NULL)) {
         err = MP_VAL;
     }
-    /* Validated space in result. */
-    if ((err == MP_OKAY) && (a->used > r->size)) {
-        err = MP_VAL;
-    }
-    if (err == MP_OKAY) {
-        _sp_copy(a, r);
+    /* Only copy if different pointers. */
+    if (a != r) {
+        /* Validated space in result. */
+        if ((err == MP_OKAY) && (a->used > r->size)) {
+            err = MP_VAL;
+        }
+        if (err == MP_OKAY) {
+            _sp_copy(a, r);
+        }
     }
 
     return err;
@@ -8374,7 +8374,7 @@ static int _sp_div(const sp_int* a, const sp_int* d, sp_int* r, sp_int* rem,
     ret = _sp_cmp_abs(a, d);
     if (ret == MP_LT) {
         /* a = 0 * d + a */
-        if (rem != NULL) {
+        if ((rem != NULL) && (a != rem)) {
             _sp_copy(a, rem);
         }
         if (r != NULL) {
@@ -8622,7 +8622,7 @@ static int _sp_mod(const sp_int* a, const sp_int* m, sp_int* r)
             err = sp_add(t, m, r);
         }
         else {
-            err = sp_copy(t, r);
+            _sp_copy(t, r);
         }
     }
     FREE_SP_INT(t, NULL);
@@ -11818,7 +11818,9 @@ static int _sp_invmod_bin(const sp_int* a, const sp_int* m, sp_int* u,
 
     /* 1. u = m, v = a, b = 0, c = 1 */
     _sp_copy(m, u);
-    _sp_copy(a, v);
+    if (a != v) {
+        _sp_copy(a, v);
+    }
     _sp_zero(b);
     _sp_set(c, 1);
 
@@ -11920,7 +11922,9 @@ static int _sp_invmod_div(const sp_int* a, const sp_int* m, sp_int* x,
         mp_init(d);
 
         /* 1. x = m, y = a, b = 1, c = 0 */
-        _sp_copy(a, y);
+        if (a != y) {
+            _sp_copy(a, y);
+        }
         _sp_copy(m, x);
         _sp_set(b, 1);
         _sp_zero(c);
@@ -12128,7 +12132,7 @@ static int _sp_invmod(const sp_int* a, const sp_int* m, sp_int* r)
             }
         }
         else if (err == MP_OKAY) {
-            err = sp_copy(c, r);
+            _sp_copy(c, r);
         }
     }
 
@@ -12290,7 +12294,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r,
         /* 1. pre[0] = 2^0 * a mod m
          *    Start with 1.a = a.
          */
-        err = sp_copy(a, pre[0]);
+        _sp_copy(a, pre[0]);
         /* 2. For i in 2..CT_INV_MOD_PRE_CNT
          *    For rest of entries in table.
          */
@@ -12325,7 +12329,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r,
             }
         }
         /* 3. Set tmp to product of leading bits. */
-        err = sp_copy(pre[j-1], t);
+        _sp_copy(pre[j-1], t);
 
         /* 4. s = 0 */
         s = 0;
@@ -12402,7 +12406,7 @@ static int _sp_invmod_mont_ct(const sp_int* a, const sp_int* m, sp_int* r,
         }
         /* 9. Else r = t */
         else {
-            err = sp_copy(t, r);
+            _sp_copy(t, r);
         }
     }
 
@@ -12535,7 +12539,7 @@ static int _sp_exptmod_ex(const sp_int* b, const sp_int* e, int bits,
         }
         else {
             /* Copy base into working variable. */
-            err = sp_copy(b, t[0]);
+            _sp_copy(b, t[0]);
         }
     }
 
@@ -12543,7 +12547,7 @@ static int _sp_exptmod_ex(const sp_int* b, const sp_int* e, int bits,
         /* 3. t[1] = t[0]
          *    Set real working value to base.
          */
-        err = sp_copy(t[0], t[1]);
+        _sp_copy(t[0], t[1]);
 
         /* 4. For i in (bits-1)...0 */
         for (i = bits - 1; (err == MP_OKAY) && (i >= 0); i--) {
@@ -12591,7 +12595,7 @@ static int _sp_exptmod_ex(const sp_int* b, const sp_int* e, int bits,
     }
     if ((!done) && (err == MP_OKAY)) {
         /* 5. r = t[1] */
-        err = sp_copy(t[1], r);
+        _sp_copy(t[1], r);
     }
 
     FREE_SP_INT_ARRAY(t, NULL);
@@ -12661,7 +12665,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
         }
         else {
             /* Copy base into working variable. */
-            err = sp_copy(b, t[0]);
+            _sp_copy(b, t[0]);
         }
     }
 
@@ -12732,7 +12736,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
     }
     if ((!done) && (err == MP_OKAY)) {
         /* 8. r = t[1] */
-        err = sp_copy(t[1], r);
+        _sp_copy(t[1], r);
     }
 
     FREE_SP_INT_ARRAY(t, NULL);
@@ -12842,7 +12846,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
         }
         else {
             /* Copy base into entry of table to contain b^1. */
-            err = sp_copy(b, t[1]);
+            _sp_copy(b, t[1]);
         }
     }
 
@@ -12954,7 +12958,7 @@ static int _sp_exptmod_mont_ex(const sp_int* b, const sp_int* e, int bits,
     }
     if ((!done) && (err == MP_OKAY)) {
         /* 8. r = tr */
-        err = sp_copy(tr, r);
+        _sp_copy(tr, r);
     }
 
     FREE_SP_INT_ARRAY(t, NULL);
@@ -13188,7 +13192,7 @@ static int _sp_exptmod_base_2(const sp_int* e, int digits, const sp_int* m,
     }
     if (err == MP_OKAY) {
         /* 8. r = tr */
-        err = sp_copy(tr, r);
+        _sp_copy(tr, r);
     }
 
 #if 0
@@ -13538,7 +13542,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
         }
         else {
             /* Copy base into Montogmery base variable. */
-            err = sp_copy(b, bm);
+            _sp_copy(b, bm);
         }
     }
 
@@ -13556,7 +13560,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
         }
         if (err == MP_OKAY) {
             /* Copy Montgomery form of base into first element of table. */
-            err = sp_copy(bm, t[0]);
+            _sp_copy(bm, t[0]);
         }
         /* Calculate b^(2^(winBits-1)) */
         for (i = 1; (i < winBits) && (err == MP_OKAY); i++) {
@@ -13605,7 +13609,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
                     n <<= winBits;
                     c -= winBits;
                 }
-                err = sp_copy(t[y], tr);
+                _sp_copy(t[y], tr);
             }
             else {
                 /* 1 in Montgomery form. */
@@ -13729,7 +13733,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
     }
     if ((!done) && (err == MP_OKAY)) {
         /* Copy temporary result into parameter. */
-        err = sp_copy(tr, r);
+        _sp_copy(tr, r);
     }
 
 #ifndef WOLFSSL_SP_NO_MALLOC
@@ -13792,7 +13796,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
         }
         else {
             /* Copy base into temp. */
-            err = sp_copy(b, t[0]);
+            _sp_copy(b, t[0]);
         }
     }
 
@@ -13838,7 +13842,7 @@ static int _sp_exptmod_nct(const sp_int* b, const sp_int* e, const sp_int* m,
     }
     if ((!done) && (err == MP_OKAY)) {
         /* Copy temporary result into parameter. */
-        err = sp_copy(t[0], r);
+        _sp_copy(t[0], r);
     }
 
     FREE_SP_INT_ARRAY(t, NULL);
@@ -17817,10 +17821,9 @@ int sp_todecimal(const sp_int* a, char* str)
 
         ALLOC_SP_INT_SIZE(t, a->used + 1, err, NULL);
         if (err == MP_OKAY) {
-            err = sp_copy(a, t);
+            _sp_copy(a, t);
         }
         if (err == MP_OKAY) {
-
         #ifdef WOLFSSL_SP_INT_NEGATIVE
             if (a->sign == MP_NEG) {
                 /* Add negative sign character. */
@@ -17969,7 +17972,7 @@ int sp_radix_size(const sp_int* a, int radix, int* size)
             ALLOC_SP_INT(t, a->used, err, NULL);
             if (err == MP_OKAY) {
                 t->size = a->used;
-                err = sp_copy(a, t);
+                _sp_copy(a, t);
             }
 
             if (err == MP_OKAY) {
@@ -18786,7 +18789,7 @@ static WC_INLINE int _sp_gcd(const sp_int* a, const sp_int* b, sp_int* r)
     }
     if (err == MP_OKAY) {
         /* 5. r = u */
-        err = sp_copy(u, r);
+        _sp_copy(u, r);
     }
 
     FREE_SP_INT_ARRAY(d, NULL);