📄 integer.cpp
字号:
const word *x = (word *)temp+7*4;
const __m64 *mx = (__m64 *)x;
const word *y = (word *)temp+7*4*2;
const __m64 *my = (__m64 *)y;
const word *z = (word *)temp+7*4*3;
const __m64 *mz = (__m64 *)z;
P4_Mul(temp, (__m128i *)A, (__m128i *)B);
P4_Mul(temp+7, (__m128i *)A+1, (__m128i *)B);
P4_Mul(temp+14, (__m128i *)A, (__m128i *)B+1);
P4_Mul(temp+21, (__m128i *)A+1, (__m128i *)B+1);
C[0] = w[0];
__m64 s1, s2, s3, s4;
__m64 w1 = _mm_cvtsi32_si64(w[1]);
__m64 w4 = mw[2];
__m64 w6 = mw[3];
__m64 w8 = mw[4];
__m64 w10 = mw[5];
__m64 w12 = mw[6];
__m64 w14 = mw[7];
__m64 w16 = mw[8];
__m64 w18 = mw[9];
__m64 w20 = mw[10];
__m64 w22 = mw[11];
__m64 w26 = _mm_cvtsi32_si64(w[26]);
__m64 w27 = _mm_cvtsi32_si64(w[27]);
__m64 x0 = _mm_cvtsi32_si64(x[0]);
__m64 x1 = _mm_cvtsi32_si64(x[1]);
__m64 x4 = mx[2];
__m64 x6 = mx[3];
__m64 x8 = mx[4];
__m64 x10 = mx[5];
__m64 x12 = mx[6];
__m64 x14 = mx[7];
__m64 x16 = mx[8];
__m64 x18 = mx[9];
__m64 x20 = mx[10];
__m64 x22 = mx[11];
__m64 x26 = _mm_cvtsi32_si64(x[26]);
__m64 x27 = _mm_cvtsi32_si64(x[27]);
__m64 y0 = _mm_cvtsi32_si64(y[0]);
__m64 y1 = _mm_cvtsi32_si64(y[1]);
__m64 y4 = my[2];
__m64 y6 = my[3];
__m64 y8 = my[4];
__m64 y10 = my[5];
__m64 y12 = my[6];
__m64 y14 = my[7];
__m64 y16 = my[8];
__m64 y18 = my[9];
__m64 y20 = my[10];
__m64 y22 = my[11];
__m64 y26 = _mm_cvtsi32_si64(y[26]);
__m64 y27 = _mm_cvtsi32_si64(y[27]);
__m64 z0 = _mm_cvtsi32_si64(z[0]);
__m64 z1 = _mm_cvtsi32_si64(z[1]);
__m64 z4 = mz[2];
__m64 z6 = mz[3];
__m64 z8 = mz[4];
__m64 z10 = mz[5];
__m64 z12 = mz[6];
__m64 z14 = mz[7];
__m64 z16 = mz[8];
__m64 z18 = mz[9];
__m64 z20 = mz[10];
__m64 z22 = mz[11];
__m64 z26 = _mm_cvtsi32_si64(z[26]);
s1 = _mm_add_si64(w1, w4);
C[1] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s2 = _mm_add_si64(w6, w8);
s1 = _mm_add_si64(s1, s2);
C[2] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s2 = _mm_add_si64(w10, w12);
s1 = _mm_add_si64(s1, s2);
C[3] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x0, y0);
s2 = _mm_add_si64(w14, w16);
s1 = _mm_add_si64(s1, s3);
s1 = _mm_add_si64(s1, s2);
C[4] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x1, y1);
s4 = _mm_add_si64(x4, y4);
s1 = _mm_add_si64(s1, w18);
s3 = _mm_add_si64(s3, s4);
s1 = _mm_add_si64(s1, w20);
s1 = _mm_add_si64(s1, s3);
C[5] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x6, y6);
s4 = _mm_add_si64(x8, y8);
s1 = _mm_add_si64(s1, w22);
s3 = _mm_add_si64(s3, s4);
s1 = _mm_add_si64(s1, w26);
s1 = _mm_add_si64(s1, s3);
C[6] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x10, y10);
s4 = _mm_add_si64(x12, y12);
s1 = _mm_add_si64(s1, w27);
s3 = _mm_add_si64(s3, s4);
s1 = _mm_add_si64(s1, s3);
C[7] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x14, y14);
s4 = _mm_add_si64(x16, y16);
s1 = _mm_add_si64(s1, z0);
s3 = _mm_add_si64(s3, s4);
s1 = _mm_add_si64(s1, s3);
C[8] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x18, y18);
s4 = _mm_add_si64(x20, y20);
s1 = _mm_add_si64(s1, z1);
s3 = _mm_add_si64(s3, s4);
s1 = _mm_add_si64(s1, z4);
s1 = _mm_add_si64(s1, s3);
C[9] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x22, y22);
s4 = _mm_add_si64(x26, y26);
s1 = _mm_add_si64(s1, z6);
s3 = _mm_add_si64(s3, s4);
s1 = _mm_add_si64(s1, z8);
s1 = _mm_add_si64(s1, s3);
C[10] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x27, y27);
s1 = _mm_add_si64(s1, z10);
s1 = _mm_add_si64(s1, z12);
s1 = _mm_add_si64(s1, s3);
C[11] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(z14, z16);
s1 = _mm_add_si64(s1, s3);
C[12] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(z18, z20);
s1 = _mm_add_si64(s1, s3);
C[13] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(z22, z26);
s1 = _mm_add_si64(s1, s3);
C[14] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
C[15] = z[27] + _mm_cvtsi64_si32(s1);
_mm_empty();
}
void P4Optimized::Multiply8Bottom(word *C, const word *A, const word *B)
{
__m128i temp[21];
const word *w = (word *)temp;
const __m64 *mw = (__m64 *)w;
const word *x = (word *)temp+7*4;
const __m64 *mx = (__m64 *)x;
const word *y = (word *)temp+7*4*2;
const __m64 *my = (__m64 *)y;
P4_Mul(temp, (__m128i *)A, (__m128i *)B);
P4_Mul(temp+7, (__m128i *)A+1, (__m128i *)B);
P4_Mul(temp+14, (__m128i *)A, (__m128i *)B+1);
C[0] = w[0];
__m64 s1, s2, s3, s4;
__m64 w1 = _mm_cvtsi32_si64(w[1]);
__m64 w4 = mw[2];
__m64 w6 = mw[3];
__m64 w8 = mw[4];
__m64 w10 = mw[5];
__m64 w12 = mw[6];
__m64 w14 = mw[7];
__m64 w16 = mw[8];
__m64 w18 = mw[9];
__m64 w20 = mw[10];
__m64 w22 = mw[11];
__m64 w26 = _mm_cvtsi32_si64(w[26]);
__m64 x0 = _mm_cvtsi32_si64(x[0]);
__m64 x1 = _mm_cvtsi32_si64(x[1]);
__m64 x4 = mx[2];
__m64 x6 = mx[3];
__m64 x8 = mx[4];
__m64 y0 = _mm_cvtsi32_si64(y[0]);
__m64 y1 = _mm_cvtsi32_si64(y[1]);
__m64 y4 = my[2];
__m64 y6 = my[3];
__m64 y8 = my[4];
s1 = _mm_add_si64(w1, w4);
C[1] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s2 = _mm_add_si64(w6, w8);
s1 = _mm_add_si64(s1, s2);
C[2] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s2 = _mm_add_si64(w10, w12);
s1 = _mm_add_si64(s1, s2);
C[3] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x0, y0);
s2 = _mm_add_si64(w14, w16);
s1 = _mm_add_si64(s1, s3);
s1 = _mm_add_si64(s1, s2);
C[4] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x1, y1);
s4 = _mm_add_si64(x4, y4);
s1 = _mm_add_si64(s1, w18);
s3 = _mm_add_si64(s3, s4);
s1 = _mm_add_si64(s1, w20);
s1 = _mm_add_si64(s1, s3);
C[5] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
s3 = _mm_add_si64(x6, y6);
s4 = _mm_add_si64(x8, y8);
s1 = _mm_add_si64(s1, w22);
s3 = _mm_add_si64(s3, s4);
s1 = _mm_add_si64(s1, w26);
s1 = _mm_add_si64(s1, s3);
C[6] = _mm_cvtsi64_si32(s1);
s1 = _mm_srli_si64(s1, 32);
C[7] = _mm_cvtsi64_si32(s1) + w[27] + x[10] + y[10] + x[12] + y[12];
_mm_empty();
}
#endif // #ifdef SSE2_INTRINSICS_AVAILABLE
// ********************************************************
#define A0 A
#define A1 (A+N2)
#define B0 B
#define B1 (B+N2)
#define T0 T
#define T1 (T+N2)
#define T2 (T+N)
#define T3 (T+N+N2)
#define R0 R
#define R1 (R+N2)
#define R2 (R+N)
#define R3 (R+N+N2)
// R[2*N] - result = A*B
// T[2*N] - temporary work space
// A[N] --- multiplier
// B[N] --- multiplicant
void RecursiveMultiply(word *R, word *T, const word *A, const word *B, unsigned int N)
{
assert(N>=2 && N%2==0);
if (LowLevel::MultiplyRecursionLimit() >= 8 && N==8)
LowLevel::Multiply8(R, A, B);
else if (LowLevel::MultiplyRecursionLimit() >= 4 && N==4)
LowLevel::Multiply4(R, A, B);
else if (N==2)
LowLevel::Multiply2(R, A, B);
else
{
const unsigned int N2 = N/2;
int carry;
int aComp = Compare(A0, A1, N2);
int bComp = Compare(B0, B1, N2);
switch (2*aComp + aComp + bComp)
{
case -4:
LowLevel::Subtract(R0, A1, A0, N2);
LowLevel::Subtract(R1, B0, B1, N2);
RecursiveMultiply(T0, T2, R0, R1, N2);
LowLevel::Subtract(T1, T1, R0, N2);
carry = -1;
break;
case -2:
LowLevel::Subtract(R0, A1, A0, N2);
LowLevel::Subtract(R1, B0, B1, N2);
RecursiveMultiply(T0, T2, R0, R1, N2);
carry = 0;
break;
case 2:
LowLevel::Subtract(R0, A0, A1, N2);
LowLevel::Subtract(R1, B1, B0, N2);
RecursiveMultiply(T0, T2, R0, R1, N2);
carry = 0;
break;
case 4:
LowLevel::Subtract(R0, A1, A0, N2);
LowLevel::Subtract(R1, B0, B1, N2);
RecursiveMultiply(T0, T2, R0, R1, N2);
LowLevel::Subtract(T1, T1, R1, N2);
carry = -1;
break;
default:
SetWords(T0, 0, N);
carry = 0;
}
RecursiveMultiply(R0, T2, A0, B0, N2);
RecursiveMultiply(R2, T2, A1, B1, N2);
// now T[01] holds (A1-A0)*(B0-B1), R[01] holds A0*B0, R[23] holds A1*B1
carry += LowLevel::Add(T0, T0, R0, N);
carry += LowLevel::Add(T0, T0, R2, N);
carry += LowLevel::Add(R1, R1, T0, N);
assert (carry >= 0 && carry <= 2);
Increment(R3, N2, carry);
}
}
// R[2*N] - result = A*A
// T[2*N] - temporary work space
// A[N] --- number to be squared
void RecursiveSquare(word *R, word *T, const word *A, unsigned int N)
{
assert(N && N%2==0);
if (LowLevel::SquareRecursionLimit() >= 8 && N==8)
LowLevel::Square8(R, A);
if (LowLevel::SquareRecursionLimit() >= 4 && N==4)
LowLevel::Square4(R, A);
else if (N==2)
LowLevel::Square2(R, A);
else
{
const unsigned int N2 = N/2;
RecursiveSquare(R0, T2, A0, N2);
RecursiveSquare(R2, T2, A1, N2);
RecursiveMultiply(T0, T2, A0, A1, N2);
word carry = LowLevel::Add(R1, R1, T0, N);
carry += LowLevel::Add(R1, R1, T0, N);
Increment(R3, N2, carry);
}
}
// R[N] - bottom half of A*B
// T[N] - temporary work space
// A[N] - multiplier
// B[N] - multiplicant
void RecursiveMultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N)
{
assert(N>=2 && N%2==0);
if (LowLevel::MultiplyBottomRecursionLimit() >= 8 && N==8)
LowLevel::Multiply8Bottom(R, A, B);
else if (LowLevel::MultiplyBottomRecursionLimit() >= 4 && N==4)
LowLevel::Multiply4Bottom(R, A, B);
else if (N==2)
LowLevel::Multiply2Bottom(R, A, B);
else
{
const unsigned int N2 = N/2;
RecursiveMultiply(R, T, A0, B0, N2);
RecursiveMultiplyBottom(T0, T1, A1, B0, N2);
LowLevel::Add(R1, R1, T0, N2);
RecursiveMultiplyBottom(T0, T1, A0, B1, N2);
LowLevel::Add(R1, R1, T0, N2);
}
}
// R[N] --- upper half of A*B
// T[2*N] - temporary work space
// L[N] --- lower half of A*B
// A[N] --- multiplier
// B[N] --- multiplicant
void RecursiveMultiplyTop(word *R, word *T, const word *L, const word *A, const word *B, unsigned int N)
{
assert(N>=2 && N%2==0);
if (N==4)
{
LowLevel::Multiply4(T, A, B);
memcpy(R, T+4, 4*WORD_SIZE);
}
else if (N==2)
{
LowLevel::Multiply2(T, A, B);
memcpy(R, T+2, 2*WORD_SIZE);
}
else
{
const unsigned int N2 = N/2;
int carry;
int aComp = Compare(A0, A1, N2);
int bComp = Compare(B0, B1, N2);
switch (2*aComp + aComp + bComp)
{
case -4:
LowLevel::Subtract(R0, A1, A0, N2);
LowLevel::Subtract(R1, B0, B1, N2);
RecursiveMultiply(T0, T2, R0, R1, N2);
LowLevel::Subtract(T1, T1, R0, N2);
carry = -1;
break;
case -2:
LowLevel::Subtract(R0, A1, A0, N2);
LowLevel::Subtract(R1, B0, B1, N2);
RecursiveMultiply(T0, T2, R0, R1, N2);
carry = 0;
break;
case 2:
LowLevel::Subtract(R0, A0, A1, N2);
LowLevel::Subtract(R1, B1, B0, N2);
RecursiveMultiply(T0, T2, R0, R1, N2);
carry = 0;
break;
case 4:
LowLevel::Subtract(R0, A1, A0, N2);
LowLevel::Subtract(R1, B0, B1, N2);
RecursiveMultiply(T0, T2, R0, R1, N2);
LowLevel::Subtract(T1, T1, R1, N2);
carry = -1;
break;
default:
SetWords(T0, 0, N);
carry = 0;
}
RecursiveMultiply(T2, R0, A1, B1, N2);
// now T[01] holds (A1-A0)*(B0-B1), T[23] holds A1*B1
word c2 = LowLevel::Subtract(R0, L+N2, L, N2);
c2 += LowLevel::Subtract(R0, R0, T0, N2);
word t = (Compare(R0, T2, N2) == -1);
carry += t;
carry += Increment(R0, N2, c2+t);
carry += LowLevel::Add(R0, R0, T1, N2);
carry += LowLevel::Add(R0, R0, T3, N2);
assert (carry >= 0 && carry <= 2);
CopyWords(R1, T3, N2);
Increment(R1, N2, carry);
}
}
inline word Add(word *C, const word *A, const word *B, unsigned int N)
{
return LowLevel::Add(C, A, B, N);
}
inline word Subtract(word *C, const word *A, const word *B, unsigned int N)
{
return LowLevel::Subtract(C, A, B, N);
}
inline void Multiply(word *R, word *T, const word *A, const word *B, unsigned int N)
{
RecursiveMultiply(R, T, A, B, N);
}
inline void Square(word *R, word *T, const word *A, unsigned int N)
{
RecursiveSquare(R, T, A, N);
}
inline void MultiplyBottom(word *R, word *T, const word *A, const word *B, unsigned int N)
{
RecursiveMultiplyBottom(R, T, A, B, N);
}
inline void MultiplyTop(word *R, word *T, const word *L, const word *A, const word *B, unsigned int N)
{
RecursiveMultiplyTop(R, T, L, A, B, N);
}
static word LinearMultiply(word *C, const word *A, word B, unsigned int N)
{
word carry=0;
for(unsigned i=0; i<N; i++)
{
DWord p = DWord::MultiplyAndAdd(A[i], B, carry);
C[i] = p.GetLowHalf();
carry = p.GetHighHalf();
}
return carry;
}
// R[NA+NB] - result = A*B
// T[NA+NB] - temporary work space
// A[NA] ---- multiplier
// B[NB] ---- multiplicant
void AsymmetricMultiply(word *R, word *T, const word *A, unsigned int NA, const word *B, unsigned int NB)
{
if (NA == NB)
{
if (A == B)
Square(R, T, A, NA);
else
Multiply(R, T, A, B, NA);
return;
}
if (NA > NB)
{
std::swap(A, B);
std::swap(NA, NB);
}
assert(NB % NA == 0);
assert((NB/NA)%2 == 0); // NB is an even multiple of NA
if (NA==2 && !A[1])
{
switch (A[0])
{
case 0:
SetWords(R, 0, NB+2);
return;
case 1:
CopyWords(R, B, NB);
R[NB] = R[NB+1] = 0;
return;
default:
R[NB] = LinearMultiply(R, B, A[0], NB);
R[NB+1] = 0;
return;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -