bn_sqr.c 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /*
  2. * Copyright 1995-2017 The OpenSSL Project Authors. All Rights Reserved.
  3. *
  4. * Licensed under the OpenSSL license (the "License"). You may not use
  5. * this file except in compliance with the License. You can obtain a copy
  6. * in the file LICENSE in the source distribution or at
  7. * https://www.openssl.org/source/license.html
  8. */
  9. #include "internal/cryptlib.h"
  10. #include "bn_lcl.h"
  11. /* r must not be a */
  12. /*
  13. * I've just gone over this and it is now %20 faster on x86 - eay - 27 Jun 96
  14. */
  15. int BN_sqr(BIGNUM *r, const BIGNUM *a, BN_CTX *ctx)
  16. {
  17. int max, al;
  18. int ret = 0;
  19. BIGNUM *tmp, *rr;
  20. bn_check_top(a);
  21. al = a->top;
  22. if (al <= 0) {
  23. r->top = 0;
  24. r->neg = 0;
  25. return 1;
  26. }
  27. BN_CTX_start(ctx);
  28. rr = (a != r) ? r : BN_CTX_get(ctx);
  29. tmp = BN_CTX_get(ctx);
  30. if (rr == NULL || tmp == NULL)
  31. goto err;
  32. max = 2 * al; /* Non-zero (from above) */
  33. if (bn_wexpand(rr, max) == NULL)
  34. goto err;
  35. if (al == 4) {
  36. #ifndef BN_SQR_COMBA
  37. BN_ULONG t[8];
  38. bn_sqr_normal(rr->d, a->d, 4, t);
  39. #else
  40. bn_sqr_comba4(rr->d, a->d);
  41. #endif
  42. } else if (al == 8) {
  43. #ifndef BN_SQR_COMBA
  44. BN_ULONG t[16];
  45. bn_sqr_normal(rr->d, a->d, 8, t);
  46. #else
  47. bn_sqr_comba8(rr->d, a->d);
  48. #endif
  49. } else {
  50. #if defined(BN_RECURSION)
  51. if (al < BN_SQR_RECURSIVE_SIZE_NORMAL) {
  52. BN_ULONG t[BN_SQR_RECURSIVE_SIZE_NORMAL * 2];
  53. bn_sqr_normal(rr->d, a->d, al, t);
  54. } else {
  55. int j, k;
  56. j = BN_num_bits_word((BN_ULONG)al);
  57. j = 1 << (j - 1);
  58. k = j + j;
  59. if (al == j) {
  60. if (bn_wexpand(tmp, k * 2) == NULL)
  61. goto err;
  62. bn_sqr_recursive(rr->d, a->d, al, tmp->d);
  63. } else {
  64. if (bn_wexpand(tmp, max) == NULL)
  65. goto err;
  66. bn_sqr_normal(rr->d, a->d, al, tmp->d);
  67. }
  68. }
  69. #else
  70. if (bn_wexpand(tmp, max) == NULL)
  71. goto err;
  72. bn_sqr_normal(rr->d, a->d, al, tmp->d);
  73. #endif
  74. }
  75. rr->neg = 0;
  76. /*
  77. * If the most-significant half of the top word of 'a' is zero, then the
  78. * square of 'a' will max-1 words.
  79. */
  80. if (a->d[al - 1] == (a->d[al - 1] & BN_MASK2l))
  81. rr->top = max - 1;
  82. else
  83. rr->top = max;
  84. if (r != rr && BN_copy(r, rr) == NULL)
  85. goto err;
  86. ret = 1;
  87. err:
  88. bn_check_top(rr);
  89. bn_check_top(tmp);
  90. BN_CTX_end(ctx);
  91. return ret;
  92. }
  93. /* tmp must have 2*n words */
  94. void bn_sqr_normal(BN_ULONG *r, const BN_ULONG *a, int n, BN_ULONG *tmp)
  95. {
  96. int i, j, max;
  97. const BN_ULONG *ap;
  98. BN_ULONG *rp;
  99. max = n * 2;
  100. ap = a;
  101. rp = r;
  102. rp[0] = rp[max - 1] = 0;
  103. rp++;
  104. j = n;
  105. if (--j > 0) {
  106. ap++;
  107. rp[j] = bn_mul_words(rp, ap, j, ap[-1]);
  108. rp += 2;
  109. }
  110. for (i = n - 2; i > 0; i--) {
  111. j--;
  112. ap++;
  113. rp[j] = bn_mul_add_words(rp, ap, j, ap[-1]);
  114. rp += 2;
  115. }
  116. bn_add_words(r, r, r, max);
  117. /* There will not be a carry */
  118. bn_sqr_words(tmp, a, n);
  119. bn_add_words(r, r, tmp, max);
  120. }
  121. #ifdef BN_RECURSION
  122. /*-
  123. * r is 2*n words in size,
  124. * a and b are both n words in size. (There's not actually a 'b' here ...)
  125. * n must be a power of 2.
  126. * We multiply and return the result.
  127. * t must be 2*n words in size
  128. * We calculate
  129. * a[0]*b[0]
  130. * a[0]*b[0]+a[1]*b[1]+(a[0]-a[1])*(b[1]-b[0])
  131. * a[1]*b[1]
  132. */
  133. void bn_sqr_recursive(BN_ULONG *r, const BN_ULONG *a, int n2, BN_ULONG *t)
  134. {
  135. int n = n2 / 2;
  136. int zero, c1;
  137. BN_ULONG ln, lo, *p;
  138. if (n2 == 4) {
  139. # ifndef BN_SQR_COMBA
  140. bn_sqr_normal(r, a, 4, t);
  141. # else
  142. bn_sqr_comba4(r, a);
  143. # endif
  144. return;
  145. } else if (n2 == 8) {
  146. # ifndef BN_SQR_COMBA
  147. bn_sqr_normal(r, a, 8, t);
  148. # else
  149. bn_sqr_comba8(r, a);
  150. # endif
  151. return;
  152. }
  153. if (n2 < BN_SQR_RECURSIVE_SIZE_NORMAL) {
  154. bn_sqr_normal(r, a, n2, t);
  155. return;
  156. }
  157. /* r=(a[0]-a[1])*(a[1]-a[0]) */
  158. c1 = bn_cmp_words(a, &(a[n]), n);
  159. zero = 0;
  160. if (c1 > 0)
  161. bn_sub_words(t, a, &(a[n]), n);
  162. else if (c1 < 0)
  163. bn_sub_words(t, &(a[n]), a, n);
  164. else
  165. zero = 1;
  166. /* The result will always be negative unless it is zero */
  167. p = &(t[n2 * 2]);
  168. if (!zero)
  169. bn_sqr_recursive(&(t[n2]), t, n, p);
  170. else
  171. memset(&t[n2], 0, sizeof(*t) * n2);
  172. bn_sqr_recursive(r, a, n, p);
  173. bn_sqr_recursive(&(r[n2]), &(a[n]), n, p);
  174. /*-
  175. * t[32] holds (a[0]-a[1])*(a[1]-a[0]), it is negative or zero
  176. * r[10] holds (a[0]*b[0])
  177. * r[32] holds (b[1]*b[1])
  178. */
  179. c1 = (int)(bn_add_words(t, r, &(r[n2]), n2));
  180. /* t[32] is negative */
  181. c1 -= (int)(bn_sub_words(&(t[n2]), t, &(t[n2]), n2));
  182. /*-
  183. * t[32] holds (a[0]-a[1])*(a[1]-a[0])+(a[0]*a[0])+(a[1]*a[1])
  184. * r[10] holds (a[0]*a[0])
  185. * r[32] holds (a[1]*a[1])
  186. * c1 holds the carry bits
  187. */
  188. c1 += (int)(bn_add_words(&(r[n]), &(r[n]), &(t[n2]), n2));
  189. if (c1) {
  190. p = &(r[n + n2]);
  191. lo = *p;
  192. ln = (lo + c1) & BN_MASK2;
  193. *p = ln;
  194. /*
  195. * The overflow will stop before we over write words we should not
  196. * overwrite
  197. */
  198. if (ln < (BN_ULONG)c1) {
  199. do {
  200. p++;
  201. lo = *p;
  202. ln = (lo + 1) & BN_MASK2;
  203. *p = ln;
  204. } while (ln == 0);
  205. }
  206. }
  207. }
  208. #endif