123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- #include "os.h"
- #include <mp.h>
- #include "dat.h"
- //
- // from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260
- //
- // mpvecmul is an assembly language routine that performs the inner
- // loop.
- //
- // the karatsuba trade off is set empiricly by measuring the algs on
- // a 400 MHz Pentium II.
- //
- // karatsuba like (see knuth pg 258)
- // prereq: p is already zeroed
- static void
- mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
- {
- mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
- int u0len, u1len, v0len, v1len, reslen;
- int sign, n;
- // divide each piece in half
- n = alen/2;
- if(alen&1)
- n++;
- u0len = n;
- u1len = alen-n;
- if(blen > n){
- v0len = n;
- v1len = blen-n;
- } else {
- v0len = blen;
- v1len = 0;
- }
- u0 = a;
- u1 = a + u0len;
- v0 = b;
- v1 = b + v0len;
- // room for the partial products
- t = mallocz(Dbytes*5*(2*n+1), 1);
- if(t == nil)
- sysfatal("mpkaratsuba: %r");
- u0v0 = t;
- u1v1 = t + (2*n+1);
- diffprod = t + 2*(2*n+1);
- res = t + 3*(2*n+1);
- reslen = 4*n+1;
- // t[0] = (u1-u0)
- sign = 1;
- if(mpveccmp(u1, u1len, u0, u0len) < 0){
- sign = -1;
- mpvecsub(u0, u0len, u1, u1len, u0v0);
- } else
- mpvecsub(u1, u1len, u0, u1len, u0v0);
- // t[1] = (v0-v1)
- if(mpveccmp(v0, v0len, v1, v1len) < 0){
- sign *= -1;
- mpvecsub(v1, v1len, v0, v1len, u1v1);
- } else
- mpvecsub(v0, v0len, v1, v1len, u1v1);
- // t[4:5] = (u1-u0)*(v0-v1)
- mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);
- // t[0:1] = u1*v1
- memset(t, 0, 2*(2*n+1)*Dbytes);
- if(v1len > 0)
- mpvecmul(u1, u1len, v1, v1len, u1v1);
- // t[2:3] = u0v0
- mpvecmul(u0, u0len, v0, v0len, u0v0);
- // res = u0*v0<<n + u0*v0
- mpvecadd(res, reslen, u0v0, u0len+v0len, res);
- mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);
- // res += u1*v1<<n + u1*v1<<2*n
- if(v1len > 0){
- mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
- mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
- }
- // res += (u1-u0)*(v0-v1)<<n
- if(sign < 0)
- mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
- else
- mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
- memmove(p, res, (alen+blen)*Dbytes);
- free(t);
- }
- #define KARATSUBAMIN 32
- void
- mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
- {
- int i;
- mpdigit d;
- mpdigit *t;
- // both mpvecdigmuladd and karatsuba are fastest when a is the longer vector
- if(alen < blen){
- i = alen;
- alen = blen;
- blen = i;
- t = a;
- a = b;
- b = t;
- }
- if(blen == 0){
- memset(p, 0, Dbytes*(alen+blen));
- return;
- }
- if(alen >= KARATSUBAMIN && blen > 1){
- // O(n^1.585)
- mpkaratsuba(a, alen, b, blen, p);
- } else {
- // O(n^2)
- for(i = 0; i < blen; i++){
- d = b[i];
- if(d != 0)
- mpvecdigmuladd(a, alen, d, &p[i]);
- }
- }
- }
- void
- mpmul(mpint *b1, mpint *b2, mpint *prod)
- {
- mpint *oprod;
- oprod = nil;
- if(prod == b1 || prod == b2){
- oprod = prod;
- prod = mpnew(0);
- }
- prod->top = 0;
- mpbits(prod, (b1->top+b2->top+1)*Dbits);
- mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
- prod->top = b1->top+b2->top+1;
- prod->sign = b1->sign*b2->sign;
- mpnorm(prod);
- if(oprod != nil){
- mpassign(prod, oprod);
- mpfree(prod);
- }
- }
|