📄 integer.cpp
字号:
{
assert(N>=2 && N%2==0);
if (P::MultiplyRecursionLimit() >= 8 && N==8)
P::Multiply8(R, A, B);
else if (P::MultiplyRecursionLimit() >= 4 && N==4)
P::Multiply4(R, A, B);
else if (N==2)
P::Multiply2(R, A, B);
else
DoRecursiveMultiply<P>(R, T, A, B, N, NULL); // VC60 workaround: needs this NULL
}
template <class P>
void DoRecursiveMultiply(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy)
{
const unsigned int N2 = N/2;
int carry;
int aComp = Compare(A0, A1, N2);
int bComp = Compare(B0, B1, N2);
switch (2*aComp + aComp + bComp)
{
case -4:
P::Subtract(R0, A1, A0, N2);
P::Subtract(R1, B0, B1, N2);
RecursiveMultiply<P>(T0, T2, R0, R1, N2);
P::Subtract(T1, T1, R0, N2);
carry = -1;
break;
case -2:
P::Subtract(R0, A1, A0, N2);
P::Subtract(R1, B0, B1, N2);
RecursiveMultiply<P>(T0, T2, R0, R1, N2);
carry = 0;
break;
case 2:
P::Subtract(R0, A0, A1, N2);
P::Subtract(R1, B1, B0, N2);
RecursiveMultiply<P>(T0, T2, R0, R1, N2);
carry = 0;
break;
case 4:
P::Subtract(R0, A1, A0, N2);
P::Subtract(R1, B0, B1, N2);
RecursiveMultiply<P>(T0, T2, R0, R1, N2);
P::Subtract(T1, T1, R1, N2);
carry = -1;
break;
default:
SetWords(T0, 0, N);
carry = 0;
}
RecursiveMultiply<P>(R0, T2, A0, B0, N2);
RecursiveMultiply<P>(R2, T2, A1, B1, N2);
// now T[01] holds (A1-A0)*(B0-B1), R[01] holds A0*B0, R[23] holds A1*B1
carry += P::Add(T0, T0, R0, N);
carry += P::Add(T0, T0, R2, N);
carry += P::Add(R1, R1, T0, N);
assert (carry >= 0 && carry <= 2);
Increment(R3, N2, carry);
}
// R[2*N] - result = A*A
// T[2*N] - temporary work space
// A[N] --- number to be squared
template <class P>
void DoRecursiveSquare(word *R, word *T, const word *A, unsigned int N, const P *dummy=NULL);
template <class P>
inline void RecursiveSquare(word *R, word *T, const word *A, unsigned int N, const P *dummy=NULL)
{
assert(N && N%2==0);
if (P::SquareRecursionLimit() >= 8 && N==8)
P::Square8(R, A);
if (P::SquareRecursionLimit() >= 4 && N==4)
P::Square4(R, A);
else if (N==2)
P::Square2(R, A);
else
DoRecursiveSquare<P>(R, T, A, N, NULL); // VC60 workaround: needs this NULL
}
template <class P>
void DoRecursiveSquare(word *R, word *T, const word *A, unsigned int N, const P *dummy)
{
const unsigned int N2 = N/2;
RecursiveSquare<P>(R0, T2, A0, N2);
RecursiveSquare<P>(R2, T2, A1, N2);
RecursiveMultiply<P>(T0, T2, A0, A1, N2);
word carry = P::Add(R1, R1, T0, N);
carry += P::Add(R1, R1, T0, N);
Increment(R3, N2, carry);
}
// R[N] - bottom half of A*B
// T[N] - temporary work space
// A[N] - multiplier
// B[N] - multiplicant
template <class P>
void DoRecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL);
template <class P>
inline void RecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy=NULL)
{
assert(N>=2 && N%2==0);
if (P::MultiplyBottomRecursionLimit() >= 8 && N==8)
P::Multiply8Bottom(R, A, B);
else if (P::MultiplyBottomRecursionLimit() >= 4 && N==4)
P::Multiply4Bottom(R, A, B);
else if (N==2)
P::Multiply2Bottom(R, A, B);
else
DoRecursiveMultiplyBottom<P>(R, T, A, B, N, NULL);
}
template <class P>
void DoRecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N, const P *dummy)
{
const unsigned int N2 = N/2;
RecursiveMultiply<P>(R, T, A0, B0, N2);
RecursiveMultiplyBottom<P>(T0, T1, A1, B0, N2);
P::Add(R1, R1, T0, N2);
RecursiveMultiplyBottom<P>(T0, T1, A0, B1, N2);
P::Add(R1, R1, T0, N2);
}
// R[N] --- upper half of A*B
// T[2*N] - temporary work space
// L[N] --- lower half of A*B
// A[N] --- multiplier
// B[N] --- multiplicant
template <class P>
void RecursiveMultiplyTop(word *R, word *T, const word *L, const word *A, const word *B, unsigned int N, const P *dummy=NULL)
{
assert(N>=2 && N%2==0);
if (N==4)
{
P::Multiply4(T, A, B);
((dword *)R)[0] = ((dword *)T)[2];
((dword *)R)[1] = ((dword *)T)[3];
}
else if (N==2)
{
P::Multiply2(T, A, B);
((dword *)R)[0] = ((dword *)T)[1];
}
else
{
const unsigned int N2 = N/2;
int carry;
int aComp = Compare(A0, A1, N2);
int bComp = Compare(B0, B1, N2);
switch (2*aComp + aComp + bComp)
{
case -4:
P::Subtract(R0, A1, A0, N2);
P::Subtract(R1, B0, B1, N2);
RecursiveMultiply<P>(T0, T2, R0, R1, N2);
P::Subtract(T1, T1, R0, N2);
carry = -1;
break;
case -2:
P::Subtract(R0, A1, A0, N2);
P::Subtract(R1, B0, B1, N2);
RecursiveMultiply<P>(T0, T2, R0, R1, N2);
carry = 0;
break;
case 2:
P::Subtract(R0, A0, A1, N2);
P::Subtract(R1, B1, B0, N2);
RecursiveMultiply<P>(T0, T2, R0, R1, N2);
carry = 0;
break;
case 4:
P::Subtract(R0, A1, A0, N2);
P::Subtract(R1, B0, B1, N2);
RecursiveMultiply<P>(T0, T2, R0, R1, N2);
P::Subtract(T1, T1, R1, N2);
carry = -1;
break;
default:
SetWords(T0, 0, N);
carry = 0;
}
RecursiveMultiply<P>(T2, R0, A1, B1, N2);
// now T[01] holds (A1-A0)*(B0-B1), T[23] holds A1*B1
word c2 = P::Subtract(R0, L+N2, L, N2);
c2 += P::Subtract(R0, R0, T0, N2);
word t = (Compare(R0, T2, N2) == -1);
carry += t;
carry += Increment(R0, N2, c2+t);
carry += P::Add(R0, R0, T1, N2);
carry += P::Add(R0, R0, T3, N2);
assert (carry >= 0 && carry <= 2);
CopyWords(R1, T3, N2);
Increment(R1, N2, carry);
}
}
inline word Add(word *C, const word *A, const word *B, unsigned int N)
{
return LowLevel::Add(C, A, B, N);
}
inline word Subtract(word *C, const word *A, const word *B, unsigned int N)
{
return LowLevel::Subtract(C, A, B, N);
}
inline void Multiply(word *R, word *T, const word *A, const word *B, unsigned int N)
{
#ifdef SSE2_INTRINSICS_AVAILABLE
if (HasSSE2())
RecursiveMultiply<P4Optimized>(R, T, A, B, N);
else
#endif
RecursiveMultiply<LowLevel>(R, T, A, B, N);
}
inline void Square(word *R, word *T, const word *A, unsigned int N)
{
#ifdef SSE2_INTRINSICS_AVAILABLE
if (HasSSE2())
RecursiveSquare<P4Optimized>(R, T, A, N);
else
#endif
RecursiveSquare<LowLevel>(R, T, A, N);
}
inline void MultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N)
{
#ifdef SSE2_INTRINSICS_AVAILABLE
if (HasSSE2())
RecursiveMultiplyBottom<P4Optimized>(R, T, A, B, N);
else
#endif
RecursiveMultiplyBottom<LowLevel>(R, T, A, B, N);
}
inline void MultiplyTop(word *R, word *T, const word *L, const word *A, const word *B, unsigned int N)
{
#ifdef SSE2_INTRINSICS_AVAILABLE
if (HasSSE2())
RecursiveMultiplyTop<P4Optimized>(R, T, L, A, B, N);
else
#endif
RecursiveMultiplyTop<LowLevel>(R, T, L, A, B, N);
}
// R[NA+NB] - result = A*B
// T[NA+NB] - temporary work space
// A[NA] ---- multiplier
// B[NB] ---- multiplicant
void AsymmetricMultiply(word *R, word *T, const word *A, unsigned int NA, const word *B, unsigned int NB)
{
if (NA == NB)
{
if (A == B)
Square(R, T, A, NA);
else
Multiply(R, T, A, B, NA);
return;
}
if (NA > NB)
{
std::swap(A, B);
std::swap(NA, NB);
}
assert(NB % NA == 0);
assert((NB/NA)%2 == 0); // NB is an even multiple of NA
if (NA==2 && !A[1])
{
switch (A[0])
{
case 0:
SetWords(R, 0, NB+2);
return;
case 1:
CopyWords(R, B, NB);
R[NB] = R[NB+1] = 0;
return;
default:
R[NB] = LinearMultiply(R, B, A[0], NB);
R[NB+1] = 0;
return;
}
}
Multiply(R, T, A, B, NA);
CopyWords(T+2*NA, R+NA, NA);
unsigned i;
for (i=2*NA; i<NB; i+=2*NA)
Multiply(T+NA+i, T, A, B+i, NA);
for (i=NA; i<NB; i+=2*NA)
Multiply(R+i, T, A, B+i, NA);
if (Add(R+NA, R+NA, T+2*NA, NB-NA))
Increment(R+NB, NA);
}
// R[N] ----- result = A inverse mod 2**(WORD_BITS*N)
// T[3*N/2] - temporary work space
// A[N] ----- an odd number as input
void RecursiveInverseModPower2(word *R, word *T, const word *A, unsigned int N)
{
if (N==2)
AtomicInverseModPower2(R, A[0], A[1]);
else
{
const unsigned int N2 = N/2;
RecursiveInverseModPower2(R0, T0, A0, N2);
T0[0] = 1;
SetWords(T0+1, 0, N2-1);
MultiplyTop(R1, T1, T0, R0, A0, N2);
MultiplyBottom(T0, T1, R0, A1, N2);
Add(T0, R1, T0, N2);
TwosComplement(T0, N2);
MultiplyBottom(R1, T1, R0, T0, N2);
}
}
// R[N] --- result = X/(2**(WORD_BITS*N)) mod M
// T[3*N] - temporary work space
// X[2*N] - number to be reduced
// M[N] --- modulus
// U[N] --- multiplicative inverse of M mod 2**(WORD_BITS*N)
void MontgomeryReduce(word *R, word *T, const word *X, const word *M, const word *U, unsigned int N)
{
MultiplyBottom(R, T, X, U, N);
MultiplyTop(T, T+N, X, R, M, N);
if (Subtract(R, X+N, T, N))
{
word carry = Add(R, R, M, N);
assert(carry);
}
}
// R[N] --- result = X/(2**(WORD_BITS*N/2)) mod M
// T[2*N] - temporary work space
// X[2*N] - number to be reduced
// M[N] --- modulus
// U[N/2] - multiplicative inverse of M mod 2**(WORD_BITS*N/2)
// V[N] --- 2**(WORD_BITS*3*N/2) mod M
void HalfMontgomeryReduce(word *R, word *T, const word *X, const word *M, const word *U, const word *V, unsigned int N)
{
assert(N%2==0 && N>=4);
#define M0 M
#define M1 (M+N2)
#define V0 V
#define V1 (V+N2)
#define X0 X
#define X1 (X+N2)
#define X2 (X+N)
#define X3 (X+N+N2)
const unsigned int N2 = N/2;
Multiply(T0, T2, V0, X3, N2);
int c2 = Add(T0, T0, X0, N);
MultiplyBottom(T3, T2, T0, U, N2);
MultiplyTop(T2, R, T0, T3, M0, N2);
c2 -= Subtract(T2, T1, T2, N2);
Multiply(T0, R, T3, M1, N2);
c2 -= Subtract(T0, T2, T0, N2);
int c3 = -(int)Subtract(T1, X2, T1, N2);
Multiply(R0, T2, V1, X3, N2);
c3 += Add(R, R, T, N);
if (c2>0)
c3 += Increment(R1, N2);
else if (c2<0)
c3 -= Decrement(R1, N2, -c2);
assert(c3>=-1 && c3<=1);
if (c3>0)
Subtract(R, R, M, N);
else if (c3<0)
Add(R, R, M, N);
#undef M0
#undef M1
#undef V0
#undef V1
#undef X0
#undef X1
#undef X2
#undef X3
}
#undef A0
#undef A1
#undef B0
#undef B1
#undef T0
#undef T1
#undef T2
#undef T3
#undef R0
#undef R1
#undef R2
#undef R3
// do a 3 word by 2 word divide, returns quotient and leaves remainder in A
static word SubatomicDivide(word *A, word B0, word B1)
{
// assert {A[2],A[1]} < {B1,B0}, so quotient can fit in a word
assert(A[2] < B1 || (A[2]==B1 && A[1] < B0));
dword p, u;
word Q;
// estimate the quotient: do a 2 word by 1 word divide
if (B1+1 == 0)
Q = A[2];
else
Q = word(MAKE_DWORD(A[1], A[2]) / (B1+1));
// now subtract Q*B from A
p = (dword) B0*Q;
u = (dword) A[0] - LOW_WORD(p);
A[0] = LOW_WORD(u);
u = (dword) A[1] - HIGH_WORD(p) - (word)(0-HIGH_WORD(u)) - (dword)B1*Q;
A[1] = LOW_WORD(u);
A[2] += HIGH_WORD(u);
// Q <= actual quotient, so fix it
while (A[2] || A[1] > B1 || (A[1]==B1 && A[0]>=B0))
{
u = (dword) A[0] - B0;
A[0] = LOW_WORD(u);
u = (dword) A[1] - B1 - (word)(0-HIGH_WORD(u));
A[1] = LOW_WORD(u);
A[2] += HIGH_WORD(u);
Q++;
assert(Q); // shouldn't overflow
}
return Q;
}
// do a 4 word by 2 word divide, returns 2 word quotient in Q0 and Q1
static inline void AtomicDivide(word *Q, const word *A, const word *B)
{
if (!B[0] && !B[1]) // if divisor is 0, we assume divisor==2**(2*WORD_BITS)
{
Q[0] = A[2];
Q[1] = A[3];
}
else
{
word T[4];
T[0] = A[0]; T[1] = A[1]; T[2] = A[2]; T[3] = A[3];
Q[1] = SubatomicDivide(T+1, B[0], B[1]);
Q[0] = SubatomicDivide(T, B[0], B[1]);
#ifndef NDEBUG
// multiply quotient and divisor and add remainder, make sure it equals dividend
assert(!T[2] && !T[3] && (T[1] < B[1] || (T[1]==B[1] && T[0]<B[0])));
word P[4];
LowLevel::Multiply2(P, Q, B);
Add(P, P, T, 4);
assert(memcmp(P, A, 4*WORD_SIZE)==0);
#endif
}
}
// for use by Divide(), corrects the underestimated quotient {Q1,Q0}
static void CorrectQuotientEstimate(word *R, word *T, word *Q, const word *B, unsigned int N)
{
assert(N && N%2==0);
if (Q[1])
{
T[N] = T[N+1] = 0;
unsigned i;
for (i=0; i<N; i+=4)
LowLevel::Multiply2(T+i, Q, B+i);
for (i=2; i<N; i+=4)
if (LowLevel::Multiply2Add(T+i, Q, B+i))
T[i+5] += (++T[i+4]==0);
}
else
{
T[N] = LinearMultiply(T, B, Q[0], N);
T[N+1] = 0;
}
word borrow = Subtract(R, R, T, N+2);
assert(!borrow && !R[N+1]);
while (R[N] || Compare(R, B, N) >= 0)
{
R[N] -= Subtract(R, R, B, N);
Q[1] += (++Q[0]==0);
assert(Q[0] || Q[1]); // no overflow
}
}
// R[NB] -------- remainder = A%B
// Q[NA-NB+2] --- quotient = A/B
// T[NA+2*NB+4] - temp work space
// A[NA] -------- dividend
// B[NB] -------- divisor
void Divide(word *R, word *Q, word *T, const word *A, unsigned int NA, const word *B, unsigned int NB)
{
assert(NA && NB && NA%2==0 && NB%2==0);
assert(B[NB-1] || B[NB-2]);
assert(NB <= NA);
// set up temporary work space
word *const TA=T;
word *const TB=T+NA+2;
word *const TP=T+NA+2+NB;
// copy B into TB and normalize it so that TB has highest bit set to 1
unsigned shiftWords = (B[NB-1]==0);
TB[0] = TB[NB-1] = 0;
CopyWords(TB+shiftWords, B, NB-shiftWords);
unsigned shiftBits = WORD_BITS - BitPrecision(TB[NB-1]);
assert(shiftBits < WORD_BITS);
ShiftWordsLeftByBits(TB, NB, shiftBits);
// copy A into TA and normalize it
TA[0] = TA[NA] = TA[NA+1] = 0;
CopyWords(TA+shiftWords, A, NA);
ShiftWordsLeftByBits(TA, NA+2, shiftBits);
if (TA[NA+1]==0 && TA[NA] <= 1)
{
Q[NA-NB+1] = Q[NA-NB] = 0;
while (TA[NA] || Compare(TA+NA-NB, TB, NB) >= 0)
{
TA[NA] -= Subtract(TA+NA-NB, TA+NA-NB, TB, NB);
++Q[NA-NB];
}
}
else
{
NA+=2;
assert(Compare(TA+NA-NB, TB, NB) < 0);
}
word BT[2];
BT[0] = TB[NB-2] + 1;
BT[1] = TB[NB-1] + (BT[0]==0);
// start reducing TA mod TB, 2 words at a time
for (unsigned i=NA-2; i>=NB; i-=2)
{
AtomicDivide(Q+i-NB, TA+i-2, BT);
CorrectQuotientEstimate(TA+i-NB, TP, Q+i-NB, TB, NB);
}
// copy TA into R, and denormalize it
CopyWords(R, TA+shiftWords, NB);
ShiftWordsRightByBits(R, NB, shiftBits);
}
static inline unsigned int EvenWordCount(const word *X, unsigned int N)
{
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -