mpvecdigmuladd.c 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. #include "os.h"
  2. #include <mp.h>
  3. #include "dat.h"
  4. #define LO(x) ((x) & ((1<<(Dbits/2))-1))
  5. #define HI(x) ((x) >> (Dbits/2))
  6. static void
  7. mpdigmul(mpdigit a, mpdigit b, mpdigit *p)
  8. {
  9. mpdigit x, ah, al, bh, bl, p1, p2, p3, p4;
  10. int carry;
  11. // half digits
  12. ah = HI(a);
  13. al = LO(a);
  14. bh = HI(b);
  15. bl = LO(b);
  16. // partial products
  17. p1 = ah*bl;
  18. p2 = bh*al;
  19. p3 = bl*al;
  20. p4 = ah*bh;
  21. // p = ((p1+p2)<<(Dbits/2)) + (p4<<Dbits) + p3
  22. carry = 0;
  23. x = p1<<(Dbits/2);
  24. p3 += x;
  25. if(p3 < x)
  26. carry++;
  27. x = p2<<(Dbits/2);
  28. p3 += x;
  29. if(p3 < x)
  30. carry++;
  31. p4 += carry + HI(p1) + HI(p2); // can't carry out of the high digit
  32. p[0] = p3;
  33. p[1] = p4;
  34. }
  35. // prereq: p must have room for n+1 digits
  36. void
  37. mpvecdigmuladd(mpdigit *b, int n, mpdigit m, mpdigit *p)
  38. {
  39. int i;
  40. mpdigit carry, x, y, part[2];
  41. carry = 0;
  42. part[1] = 0;
  43. for(i = 0; i < n; i++){
  44. x = part[1] + carry;
  45. if(x < carry)
  46. carry = 1;
  47. else
  48. carry = 0;
  49. y = *p;
  50. mpdigmul(*b++, m, part);
  51. x += part[0];
  52. if(x < part[0])
  53. carry++;
  54. x += y;
  55. if(x < y)
  56. carry++;
  57. *p++ = x;
  58. }
  59. *p = part[1] + carry;
  60. }
  61. // prereq: p must have room for n+1 digits
  62. int
  63. mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p)
  64. {
  65. int i;
  66. mpdigit x, y, part[2], borrow;
  67. borrow = 0;
  68. part[1] = 0;
  69. for(i = 0; i < n; i++){
  70. x = *p;
  71. y = x - borrow;
  72. if(y > x)
  73. borrow = 1;
  74. else
  75. borrow = 0;
  76. x = part[1];
  77. mpdigmul(*b++, m, part);
  78. x += part[0];
  79. if(x < part[0])
  80. borrow++;
  81. x = y - x;
  82. if(x > y)
  83. borrow++;
  84. *p++ = x;
  85. }
  86. x = *p;
  87. y = x - borrow - part[1];
  88. *p = y;
  89. if(y > x)
  90. return -1;
  91. else
  92. return 1;
  93. }