📄 lzz_px.cpp
字号:
if (x.k != y.k) Error("FFT rep mismatch");
k = x.k;
n = 1L << k;
z.SetSize(k);
long index = zz_pInfo->index;
if (index >= 0) {
long *zp = &z.tbl[0][0];
const long *xp = &x.tbl[0][0];
const long *yp = &y.tbl[0][0];
long q = FFTPrime[index];
double qinv = FFTPrimeInv[index];
for (j = 0; j < n; j++)
zp[j] = MulMod(xp[j], yp[j], q, qinv);
}
else {
for (i = 0; i < zz_pInfo->NumPrimes; i++) {
long *zp = &z.tbl[i][0];
const long *xp = &x.tbl[i][0];
const long *yp = &y.tbl[i][0];
long q = FFTPrime[i];
double qinv = FFTPrimeInv[i];
for (j = 0; j < n; j++)
zp[j] = MulMod(xp[j], yp[j], q, qinv);
}
}
}
void sub(fftRep& z, const fftRep& x, const fftRep& y)
{
long k, n, i, j;
if (x.k != y.k) Error("FFT rep mismatch");
k = x.k;
n = 1L << k;
z.SetSize(k);
long index = zz_pInfo->index;
if (index >= 0) {
long *zp = &z.tbl[0][0];
const long *xp = &x.tbl[0][0];
const long *yp = &y.tbl[0][0];
long q = FFTPrime[index];
for (j = 0; j < n; j++)
zp[j] = SubMod(xp[j], yp[j], q);
}
else {
for (i = 0; i < zz_pInfo->NumPrimes; i++) {
long *zp = &z.tbl[i][0];
const long *xp = &x.tbl[i][0];
const long *yp = &y.tbl[i][0];
long q = FFTPrime[i];
for (j = 0; j < n; j++)
zp[j] = SubMod(xp[j], yp[j], q);
}
}
}
void add(fftRep& z, const fftRep& x, const fftRep& y)
{
long k, n, i, j;
if (x.k != y.k) Error("FFT rep mismatch");
k = x.k;
n = 1L << k;
z.SetSize(k);
long index = zz_pInfo->index;
if (index >= 0) {
long *zp = &z.tbl[0][0];
const long *xp = &x.tbl[0][0];
const long *yp = &y.tbl[0][0];
long q = FFTPrime[index];
for (j = 0; j < n; j++)
zp[j] = AddMod(xp[j], yp[j], q);
}
else {
for (i = 0; i < zz_pInfo->NumPrimes; i++) {
long *zp = &z.tbl[i][0];
const long *xp = &x.tbl[i][0];
const long *yp = &y.tbl[i][0];
long q = FFTPrime[i];
for (j = 0; j < n; j++)
zp[j] = AddMod(xp[j], yp[j], q);
}
}
}
void reduce(fftRep& x, const fftRep& a, long k)
// reduces a 2^l point FFT-rep to a 2^k point FFT-rep
// input may alias output
{
long i, j, l, n;
long* xp;
const long* ap;
l = a.k;
n = 1L << k;
if (l < k) Error("reduce: bad operands");
x.SetSize(k);
for (i = 0; i < zz_pInfo->NumPrimes; i++) {
ap = &a.tbl[i][0];
xp = &x.tbl[i][0];
for (j = 0; j < n; j++)
xp[j] = ap[j << (l-k)];
}
}
void AddExpand(fftRep& x, const fftRep& a)
// x = x + (an "expanded" version of a)
{
long i, j, l, k, n;
l = x.k;
k = a.k;
n = 1L << k;
if (l < k) Error("AddExpand: bad args");
long index = zz_pInfo->index;
if (index >= 0) {
long q = FFTPrime[index];
const long *ap = &a.tbl[0][0];
long *xp = &x.tbl[0][0];
for (j = 0; j < n; j++) {
long j1 = j << (l-k);
xp[j1] = AddMod(xp[j1], ap[j], q);
}
}
else {
for (i = 0; i < zz_pInfo->NumPrimes; i++) {
long q = FFTPrime[i];
const long *ap = &a.tbl[i][0];
long *xp = &x.tbl[i][0];
for (j = 0; j < n; j++) {
long j1 = j << (l-k);
xp[j1] = AddMod(xp[j1], ap[j], q);
}
}
}
}
void FFTMul(zz_pX& x, const zz_pX& a, const zz_pX& b)
{
long k, d;
if (IsZero(a) || IsZero(b)) {
clear(x);
return;
}
d = deg(a) + deg(b);
k = NextPowerOfTwo(d+1);
fftRep R1(INIT_SIZE, k), R2(INIT_SIZE, k);
TofftRep(R1, a, k);
TofftRep(R2, b, k);
mul(R1, R1, R2);
FromfftRep(x, R1, 0, d);
}
void FFTSqr(zz_pX& x, const zz_pX& a)
{
long k, d;
if (IsZero(a)) {
clear(x);
return;
}
d = 2*deg(a);
k = NextPowerOfTwo(d+1);
fftRep R1(INIT_SIZE, k);
TofftRep(R1, a, k);
mul(R1, R1, R1);
FromfftRep(x, R1, 0, d);
}
void CopyReverse(zz_pX& x, const zz_pX& a, long lo, long hi)
// x[0..hi-lo] = reverse(a[lo..hi]), with zero fill
// input may not alias output
{
long i, j, n, m;
n = hi-lo+1;
m = a.rep.length();
x.rep.SetLength(n);
const zz_p* ap = a.rep.elts();
zz_p* xp = x.rep.elts();
for (i = 0; i < n; i++) {
j = hi-i;
if (j < 0 || j >= m)
clear(xp[i]);
else
xp[i] = ap[j];
}
x.normalize();
}
void copy(zz_pX& x, const zz_pX& a, long lo, long hi)
// x[0..hi-lo] = a[lo..hi], with zero fill
// input may not alias output
{
long i, j, n, m;
n = hi-lo+1;
m = a.rep.length();
x.rep.SetLength(n);
const zz_p* ap = a.rep.elts();
zz_p* xp = x.rep.elts();
for (i = 0; i < n; i++) {
j = lo + i;
if (j < 0 || j >= m)
clear(xp[i]);
else
xp[i] = ap[j];
}
x.normalize();
}
void rem21(zz_pX& x, const zz_pX& a, const zz_pXModulus& F)
{
long i, da, ds, n, kk;
da = deg(a);
n = F.n;
if (da > 2*n-2)
Error("bad args to rem(zz_pX,zz_pX,zz_pXModulus)");
if (da < n) {
x = a;
return;
}
if (!F.UseFFT || da - n <= NTL_zz_pX_MOD_CROSSOVER) {
PlainRem(x, a, F.f);
return;
}
fftRep R1(INIT_SIZE, F.l);
zz_pX P1(INIT_SIZE, n);
TofftRep(R1, a, F.l, n, 2*(n-1));
mul(R1, R1, F.HRep);
FromfftRep(P1, R1, n-2, 2*n-4);
TofftRep(R1, P1, F.k);
mul(R1, R1, F.FRep);
FromfftRep(P1, R1, 0, n-1);
ds = deg(P1);
kk = 1L << F.k;
x.rep.SetLength(n);
const zz_p* aa = a.rep.elts();
const zz_p* ss = P1.rep.elts();
zz_p* xx = x.rep.elts();
for (i = 0; i < n; i++) {
if (i <= ds)
sub(xx[i], aa[i], ss[i]);
else
xx[i] = aa[i];
if (i + kk <= da)
add(xx[i], xx[i], aa[i+kk]);
}
x.normalize();
}
void DivRem21(zz_pX& q, zz_pX& x, const zz_pX& a, const zz_pXModulus& F)
{
long i, da, ds, n, kk;
da = deg(a);
n = F.n;
if (da > 2*n-2)
Error("bad args to rem(zz_pX,zz_pX,zz_pXModulus)");
if (da < n) {
x = a;
clear(q);
return;
}
if (!F.UseFFT || da - n <= NTL_zz_pX_MOD_CROSSOVER) {
PlainDivRem(q, x, a, F.f);
return;
}
fftRep R1(INIT_SIZE, F.l);
zz_pX P1(INIT_SIZE, n), qq;
TofftRep(R1, a, F.l, n, 2*(n-1));
mul(R1, R1, F.HRep);
FromfftRep(P1, R1, n-2, 2*n-4);
qq = P1;
TofftRep(R1, P1, F.k);
mul(R1, R1, F.FRep);
FromfftRep(P1, R1, 0, n-1);
ds = deg(P1);
kk = 1L << F.k;
x.rep.SetLength(n);
const zz_p* aa = a.rep.elts();
const zz_p* ss = P1.rep.elts();
zz_p* xx = x.rep.elts();
for (i = 0; i < n; i++) {
if (i <= ds)
sub(xx[i], aa[i], ss[i]);
else
xx[i] = aa[i];
if (i + kk <= da)
add(xx[i], xx[i], aa[i+kk]);
}
x.normalize();
q = qq;
}
void div21(zz_pX& x, const zz_pX& a, const zz_pXModulus& F)
{
long da, n;
da = deg(a);
n = F.n;
if (da > 2*n-2)
Error("bad args to rem(zz_pX,zz_pX,zz_pXModulus)");
if (da < n) {
clear(x);
return;
}
if (!F.UseFFT || da - n <= NTL_zz_pX_MOD_CROSSOVER) {
PlainDiv(x, a, F.f);
return;
}
fftRep R1(INIT_SIZE, F.l);
zz_pX P1(INIT_SIZE, n);
TofftRep(R1, a, F.l, n, 2*(n-1));
mul(R1, R1, F.HRep);
FromfftRep(x, R1, n-2, 2*n-4);
}
void rem(zz_pX& x, const zz_pX& a, const zz_pXModulus& F)
{
long da = deg(a);
long n = F.n;
if (n < 0) Error("rem: uninitialized modulus");
if (da <= 2*n-2) {
rem21(x, a, F);
return;
}
else if (!F.UseFFT || da-n <= NTL_zz_pX_MOD_CROSSOVER) {
PlainRem(x, a, F.f);
return;
}
zz_pX buf(INIT_SIZE, 2*n-1);
long a_len = da+1;
while (a_len > 0) {
long old_buf_len = buf.rep.length();
long amt = min(2*n-1-old_buf_len, a_len);
buf.rep.SetLength(old_buf_len+amt);
long i;
for (i = old_buf_len+amt-1; i >= amt; i--)
buf.rep[i] = buf.rep[i-amt];
for (i = amt-1; i >= 0; i--)
buf.rep[i] = a.rep[a_len-amt+i];
buf.normalize();
rem21(buf, buf, F);
a_len -= amt;
}
x = buf;
}
void DivRem(zz_pX& q, zz_pX& r, const zz_pX& a, const zz_pXModulus& F)
{
long da = deg(a);
long n = F.n;
if (n < 0) Error("DivRem: uninitialized modulus");
if (da <= 2*n-2) {
DivRem21(q, r, a, F);
return;
}
else if (!F.UseFFT || da-n <= NTL_zz_pX_MOD_CROSSOVER) {
PlainDivRem(q, r, a, F.f);
return;
}
zz_pX buf(INIT_SIZE, 2*n-1);
zz_pX qbuf(INIT_SIZE, n-1);
zz_pX qq;
qq.rep.SetLength(da-n+1);
long a_len = da+1;
long q_hi = da-n+1;
while (a_len > 0) {
long old_buf_len = buf.rep.length();
long amt = min(2*n-1-old_buf_len, a_len);
buf.rep.SetLength(old_buf_len+amt);
long i;
for (i = old_buf_len+amt-1; i >= amt; i--)
buf.rep[i] = buf.rep[i-amt];
for (i = amt-1; i >= 0; i--)
buf.rep[i] = a.rep[a_len-amt+i];
buf.normalize();
DivRem21(qbuf, buf, buf, F);
long dl = qbuf.rep.length();
a_len = a_len - amt;
for(i = 0; i < dl; i++)
qq.rep[a_len+i] = qbuf.rep[i];
for(i = dl+a_len; i < q_hi; i++)
clear(qq.rep[i]);
q_hi = a_len;
}
r = buf;
qq.normalize();
q = qq;
}
void div(zz_pX& q, const zz_pX& a, const zz_pXModulus& F)
{
long da = deg(a);
long n = F.n;
if (n < 0) Error("div: uninitialized modulus");
if (da <= 2*n-2) {
div21(q, a, F);
return;
}
else if (!F.UseFFT || da-n <= NTL_zz_pX_MOD_CROSSOVER) {
PlainDiv(q, a, F.f);
return;
}
zz_pX buf(INIT_SIZE, 2*n-1);
zz_pX qbuf(INIT_SIZE, n-1);
zz_pX qq;
qq.rep.SetLength(da-n+1);
long a_len = da+1;
long q_hi = da-n+1;
while (a_len > 0) {
long old_buf_len = buf.rep.length();
long amt = min(2*n-1-old_buf_len, a_len);
buf.rep.SetLength(old_buf_len+amt);
long i;
for (i = old_buf_len+amt-1; i >= amt; i--)
buf.rep[i] = buf.rep[i-amt];
for (i = amt-1; i >= 0; i--)
buf.rep[i] = a.rep[a_len-amt+i];
buf.normalize();
a_len = a_len - amt;
if (a_len > 0)
DivRem21(qbuf, buf, buf, F);
else
div21(qbuf, buf, F);
long dl = qbuf.rep.length();
for(i = 0; i < dl; i++)
qq.rep[a_len+i] = qbuf.rep[i];
for(i = dl+a_len; i < q_hi; i++)
clear(qq.rep[i]);
q_hi = a_len;
}
qq.normalize();
q = qq;
}
void MulMod(zz_pX& x, const zz_pX& a, const zz_pX& b, const zz_pXModulus& F)
{
long da, db, d, n, k;
da = deg(a);
db = deg(b);
n = F.n;
if (n < 0) Error("MulMod: uninitialized modulus");
if (da >= n || db >= n)
Error("bad args to MulMod(zz_pX,zz_pX,zz_pX,zz_pXModulus)");
if (da < 0 || db < 0) {
clear(x);
return;
}
if (!F.UseFFT || da <= NTL_zz_pX_MUL_CROSSOVER || db <= NTL_zz_pX_MUL_CROSSOVER) {
zz_pX P1;
mul(P1, a, b);
rem(x, P1, F);
return;
}
d = da + db + 1;
k = NextPowerOfTwo(d);
k = max(k, F.k);
fftRep R1(INIT_SIZE, k), R2(INIT_SIZE, F.l);
zz_pX P1(INIT_SIZE, n);
TofftRep(R1, a, k);
TofftRep(R2, b, k);
mul(R1, R1, R2);
NDFromfftRep(P1, R1, n, d-1, R2); // save R1 for future use
TofftRep(R2, P1, F.l);
mul(R2, R2, F.HRep);
FromfftRep(P1, R2, n-2, 2*n-4);
TofftRep(R2, P1, F.k);
mul(R2, R2, F.FRep);
reduce(R1, R1, F.k);
sub(R1, R1, R2);
FromfftRep(x, R1, 0, n-1);
}
void SqrMod(zz_pX& x, const zz_pX& a, const zz_pXModulus& F)
{
long da, d, n, k;
da = deg(a);
n = F.n;
if (n < 0) Error("SqrMod: uninitialized modulus");
if (da >= n)
Error("bad args to SqrMod(zz_pX,zz_pX,zz_pXModulus)");
if (!F.UseFFT || da <= NTL_zz_pX_MUL_CROSSOVER) {
zz_pX P1;
sqr(P1, a);
rem(x, P1, F);
return;
}
d = 2*da + 1;
k = NextPowerOfTwo(d);
k = max(k, F.k);
fftRep R1(INIT_SIZE, k), R2(INIT_SIZE, F.l);
zz_pX P1(INIT_SIZE, n);
TofftRep(R1, a, k);
mul(R1, R1, R1);
NDFromfftRep(P1, R1, n, d-1, R2); // save R1 for future use
TofftRep(R2, P1, F.l);
mul(R2, R2, F.HRep);
FromfftRep(P1, R2, n-2, 2*n-4);
TofftRep(R2, P1, F.k);
mul(R2, R2, F.FRep);
reduce(R1, R1, F.k);
sub(R1, R1, R2);
FromfftRep(x, R1, 0, n-1);
}
void PlainInvTrunc(zz_pX& x, const zz_pX& a, long m)
/* x = (1/a) % X^m, input not output, constant term a is nonzero */
{
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -