mpmul.c 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #include "os.h"
  2. #include <mp.h>
  3. #include "dat.h"
  4. //
  5. // from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260
  6. //
  7. // mpvecmul is an assembly language routine that performs the inner
  8. // loop.
  9. //
  10. // the karatsuba trade off is set empiricly by measuring the algs on
  11. // a 400 MHz Pentium II.
  12. //
  13. // karatsuba like (see knuth pg 258)
  14. // prereq: p is already zeroed
  15. static void
  16. mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
  17. {
  18. mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
  19. int u0len, u1len, v0len, v1len, reslen;
  20. int sign, n;
  21. // divide each piece in half
  22. n = alen/2;
  23. if(alen&1)
  24. n++;
  25. u0len = n;
  26. u1len = alen-n;
  27. if(blen > n){
  28. v0len = n;
  29. v1len = blen-n;
  30. } else {
  31. v0len = blen;
  32. v1len = 0;
  33. }
  34. u0 = a;
  35. u1 = a + u0len;
  36. v0 = b;
  37. v1 = b + v0len;
  38. // room for the partial products
  39. t = mallocz(Dbytes*5*(2*n+1), 1);
  40. if(t == nil)
  41. sysfatal("mpkaratsuba: %r");
  42. u0v0 = t;
  43. u1v1 = t + (2*n+1);
  44. diffprod = t + 2*(2*n+1);
  45. res = t + 3*(2*n+1);
  46. reslen = 4*n+1;
  47. // t[0] = (u1-u0)
  48. sign = 1;
  49. if(mpveccmp(u1, u1len, u0, u0len) < 0){
  50. sign = -1;
  51. mpvecsub(u0, u0len, u1, u1len, u0v0);
  52. } else
  53. mpvecsub(u1, u1len, u0, u1len, u0v0);
  54. // t[1] = (v0-v1)
  55. if(mpveccmp(v0, v0len, v1, v1len) < 0){
  56. sign *= -1;
  57. mpvecsub(v1, v1len, v0, v1len, u1v1);
  58. } else
  59. mpvecsub(v0, v0len, v1, v1len, u1v1);
  60. // t[4:5] = (u1-u0)*(v0-v1)
  61. mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);
  62. // t[0:1] = u1*v1
  63. memset(t, 0, 2*(2*n+1)*Dbytes);
  64. if(v1len > 0)
  65. mpvecmul(u1, u1len, v1, v1len, u1v1);
  66. // t[2:3] = u0v0
  67. mpvecmul(u0, u0len, v0, v0len, u0v0);
  68. // res = u0*v0<<n + u0*v0
  69. mpvecadd(res, reslen, u0v0, u0len+v0len, res);
  70. mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);
  71. // res += u1*v1<<n + u1*v1<<2*n
  72. if(v1len > 0){
  73. mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
  74. mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
  75. }
  76. // res += (u1-u0)*(v0-v1)<<n
  77. if(sign < 0)
  78. mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
  79. else
  80. mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
  81. memmove(p, res, (alen+blen)*Dbytes);
  82. free(t);
  83. }
  84. #define KARATSUBAMIN 32
  85. void
  86. mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
  87. {
  88. int i;
  89. mpdigit d;
  90. mpdigit *t;
  91. // both mpvecdigmuladd and karatsuba are fastest when a is the longer vector
  92. if(alen < blen){
  93. i = alen;
  94. alen = blen;
  95. blen = i;
  96. t = a;
  97. a = b;
  98. b = t;
  99. }
  100. if(blen == 0){
  101. memset(p, 0, Dbytes*(alen+blen));
  102. return;
  103. }
  104. if(alen >= KARATSUBAMIN && blen > 1){
  105. // O(n^1.585)
  106. mpkaratsuba(a, alen, b, blen, p);
  107. } else {
  108. // O(n^2)
  109. for(i = 0; i < blen; i++){
  110. d = b[i];
  111. if(d != 0)
  112. mpvecdigmuladd(a, alen, d, &p[i]);
  113. }
  114. }
  115. }
  116. void
  117. mpmul(mpint *b1, mpint *b2, mpint *prod)
  118. {
  119. mpint *oprod;
  120. oprod = nil;
  121. if(prod == b1 || prod == b2){
  122. oprod = prod;
  123. prod = mpnew(0);
  124. }
  125. prod->top = 0;
  126. mpbits(prod, (b1->top+b2->top+1)*Dbits);
  127. mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
  128. prod->top = b1->top+b2->top+1;
  129. prod->sign = b1->sign*b2->sign;
  130. mpnorm(prod);
  131. if(oprod != nil){
  132. mpassign(prod, oprod);
  133. mpfree(prod);
  134. }
  135. }