📄 lzz_px.c
字号:
KarSqr_FP(c, a, hsa, stk); KarSub(T2, c, hsa2 - 1); clear(c[hsa2 - 1]); KarAdd(c+hsa, T2, hsa2-1);}void PlainSqr(zz_pX& c, const zz_pX& a){ if (IsZero(a)) { clear(c); return; } vec_zz_p mem; const zz_p *ap; zz_p *cp; long sa = a.rep.length(); if (&a == &c) { mem = a.rep; ap = mem.elts(); } else ap = a.rep.elts(); c.rep.SetLength(2*sa-1); cp = c.rep.elts(); long p = zz_p::modulus(); long use_FP = ((p < NTL_SP_BOUND/KARSX) && (double(p)*double(p) < NTL_FDOUBLE_PRECISION/KARSX)); if (sa < KARSX) { if (use_FP) { a_buf.SetLength(sa); PlainSqr_FP(cp, ap, sa); } else PlainSqr(cp, ap, sa); } else { /* karatsuba */ long n, hn, sp; n = sa; sp = 0; do { hn = (n+1) >> 1; sp += hn+hn+hn - 1; n = hn; } while (n >= KARSX); vec_zz_p stk; stk.SetLength(sp); if (use_FP) { a_buf.SetLength(sa); KarSqr_FP(cp, ap, sa, stk.elts()); } else KarSqr(cp, ap, sa, stk.elts()); } c.normalize();}void PlainDivRem(zz_pX& q, zz_pX& r, const zz_pX& a, const zz_pX& b){ long da, db, dq, i, j, LCIsOne; const zz_p *bp; zz_p *qp; zz_p *xp; zz_p LCInv, t; zz_p s; da = deg(a); db = deg(b); if (db < 0) Error("zz_pX: division by zero"); if (da < db) { r = a; clear(q); return; } zz_pX lb; if (&q == &b) { lb = b; bp = lb.rep.elts(); } else bp = b.rep.elts(); if (IsOne(bp[db])) LCIsOne = 1; else { LCIsOne = 0; inv(LCInv, bp[db]); } vec_zz_p x; if (&r == &a) xp = r.rep.elts(); else { x = a.rep; xp = x.elts(); } dq = da - db; q.rep.SetLength(dq+1); qp = q.rep.elts(); long p = zz_p::modulus(); double pinv = zz_p::ModulusInverse(); for (i = dq; i >= 0; i--) { t = xp[i+db]; if (!LCIsOne) mul(t, t, LCInv); qp[i] = t; negate(t, t); long T = rep(t); double Tpinv = ((double) T)*pinv; for (j = db-1; j >= 0; j--) { long S = MulMod2(rep(bp[j]), T, p, Tpinv); S = AddMod(S, rep(xp[i+j]), p); xp[i+j].LoopHole() = S; } } r.rep.SetLength(db); if (&r != &a) { for (i = 0; i < db; i++) r.rep[i] = xp[i]; } r.normalize();}void PlainDiv(zz_pX& q, const zz_pX& a, const zz_pX& b){ long da, db, dq, i, j, LCIsOne; const zz_p *bp; zz_p *qp; zz_p *xp; zz_p LCInv, t; zz_p s; da = deg(a); db = deg(b); if (db < 0) Error("zz_pX: division by zero"); if (da < db) { clear(q); return; } zz_pX lb; if (&q == &b) { lb = b; bp = lb.rep.elts(); } else bp = b.rep.elts(); if (IsOne(bp[db])) LCIsOne = 1; else { LCIsOne = 0; inv(LCInv, bp[db]); } vec_zz_p x; x.SetLength(da+1-db); for (i = db; i <= da; i++) x[i-db] = a.rep[i]; xp = x.elts(); dq = da - db; q.rep.SetLength(dq+1); qp = q.rep.elts(); long p = zz_p::modulus(); double pinv = zz_p::ModulusInverse(); for (i = dq; i >= 0; i--) { t = xp[i]; if (!LCIsOne) mul(t, t, LCInv); qp[i] = t; negate(t, t); long T = rep(t); double Tpinv = ((double) T)*pinv; long lastj = max(0, db-i); for (j = db-1; j >= lastj; j--) { long S = MulMod2(rep(bp[j]), T, p, Tpinv); S = AddMod(S, rep(xp[i+j-db]), p); xp[i+j-db].LoopHole() = S; } }}void PlainRem(zz_pX& r, const zz_pX& a, const zz_pX& b){ long da, db, dq, i, j, LCIsOne; const zz_p *bp; zz_p *xp; zz_p LCInv, t; zz_p s; da = deg(a); db = deg(b); if (db < 0) Error("zz_pX: division by zero"); if (da < db) { r = a; return; } bp = b.rep.elts(); if (IsOne(bp[db])) LCIsOne = 1; else { LCIsOne = 0; inv(LCInv, bp[db]); } vec_zz_p x; if (&r == &a) xp = r.rep.elts(); else { x = a.rep; xp = x.elts(); } dq = da - db; long p = zz_p::modulus(); double pinv = zz_p::ModulusInverse(); for (i = dq; i >= 0; i--) { t = xp[i+db]; if (!LCIsOne) mul(t, t, LCInv); negate(t, t); long T = rep(t); double Tpinv = ((double) T)*pinv; for (j = db-1; j >= 0; j--) { long S = MulMod2(rep(bp[j]), T, p, Tpinv); S = AddMod(S, rep(xp[i+j]), p); xp[i+j].LoopHole() = S; } } r.rep.SetLength(db); if (&r != &a) { for (i = 0; i < db; i++) r.rep[i] = xp[i]; } r.normalize();}void mul(zz_pX& x, const zz_pX& a, zz_p b){ if (IsZero(b)) { clear(x); return; } if (IsOne(b)) { x = a; return; } long i, da; const zz_p *ap; zz_p* xp; long t; t = rep(b); long p = zz_p::modulus(); double pinv = zz_p::ModulusInverse(); double bpinv = t*pinv; da = deg(a); x.rep.SetLength(da+1); ap = a.rep.elts(); xp = x.rep.elts(); for (i = 0; i <= da; i++) xp[i].LoopHole() = MulMod2(rep(ap[i]), t, p, bpinv); x.normalize();}void PlainGCD(zz_pX& x, const zz_pX& a, const zz_pX& b){ zz_p t; if (IsZero(b)) x = a; else if (IsZero(a)) x = b; else { long n = max(deg(a),deg(b)) + 1; zz_pX u(INIT_SIZE, n), v(INIT_SIZE, n); u = a; v = b; do { PlainRem(u, u, v); swap(u, v); } while (!IsZero(v)); x = u; } if (IsZero(x)) return; if (IsOne(LeadCoeff(x))) return; /* make gcd monic */ inv(t, LeadCoeff(x)); mul(x, x, t); } void PlainXGCD(zz_pX& d, zz_pX& s, zz_pX& t, const zz_pX& a, const zz_pX& b){ zz_p z; if (IsZero(b)) { set(s); clear(t); d = a; } else if (IsZero(a)) { clear(s); set(t); d = b; } else { long e = max(deg(a), deg(b)) + 1; zz_pX temp(INIT_SIZE, e), u(INIT_SIZE, e), v(INIT_SIZE, e), u0(INIT_SIZE, e), v0(INIT_SIZE, e), u1(INIT_SIZE, e), v1(INIT_SIZE, e), u2(INIT_SIZE, e), v2(INIT_SIZE, e), q(INIT_SIZE, e); set(u1); clear(v1); clear(u2); set(v2); u = a; v = b; do { DivRem(q, u, u, v); swap(u, v); u0 = u2; v0 = v2; mul(temp, q, u2); sub(u2, u1, temp); mul(temp, q, v2); sub(v2, v1, temp); u1 = u0; v1 = v0; } while (!IsZero(v)); d = u; s = u1; t = v1; } if (IsZero(d)) return; if (IsOne(LeadCoeff(d))) return; /* make gcd monic */ inv(z, LeadCoeff(d)); mul(d, d, z); mul(s, s, z); mul(t, t, z);}void MulMod(zz_pX& x, const zz_pX& a, const zz_pX& b, const zz_pX& f){ if (deg(a) >= deg(f) || deg(b) >= deg(f) || deg(f) == 0) Error("MulMod: bad args"); zz_pX t; mul(t, a, b); rem(x, t, f);}void SqrMod(zz_pX& x, const zz_pX& a, const zz_pX& f){ if (deg(a) >= deg(f) || deg(f) == 0) Error("SqrMod: bad args"); zz_pX t; sqr(t, a); rem(x, t, f);}void InvMod(zz_pX& x, const zz_pX& a, const zz_pX& f){ if (deg(a) >= deg(f) || deg(f) == 0) Error("InvMod: bad args"); zz_pX d, t; XGCD(d, x, t, a, f); if (!IsOne(d)) Error("zz_pX InvMod: can't compute multiplicative inverse");}long InvModStatus(zz_pX& x, const zz_pX& a, const zz_pX& f){ if (deg(a) >= deg(f) || deg(f) == 0) Error("InvModStatus: bad args"); zz_pX d, t; XGCD(d, x, t, a, f); if (!IsOne(d)) { x = d; return 1; } else return 0;}staticvoid MulByXModAux(zz_pX& h, const zz_pX& a, const zz_pX& f){ long i, n, m; zz_p* hh; const zz_p *aa, *ff; zz_p t, z; n = deg(f); m = deg(a); if (m >= n || n == 0) Error("MulByXMod: bad args"); if (m < 0) { clear(h); return; } if (m < n-1) { h.rep.SetLength(m+2); hh = h.rep.elts(); aa = a.rep.elts(); for (i = m+1; i >= 1; i--) hh[i] = aa[i-1]; clear(hh[0]); } else { h.rep.SetLength(n); hh = h.rep.elts(); aa = a.rep.elts(); ff = f.rep.elts(); negate(z, aa[n-1]); if (!IsOne(ff[n])) div(z, z, ff[n]); for (i = n-1; i >= 1; i--) { mul(t, z, ff[i]); add(hh[i], aa[i-1], t); } mul(hh[0], z, ff[0]); h.normalize(); }}void MulByXMod(zz_pX& h, const zz_pX& a, const zz_pX& f){ if (&h == &f) { zz_pX hh; MulByXModAux(hh, a, f); h = hh; } else MulByXModAux(h, a, f);}void random(zz_pX& x, long n){ long i; x.rep.SetLength(n); for (i = 0; i < n; i++) random(x.rep[i]); x.normalize();}void fftRep::SetSize(long NewK){ if (NewK < -1 || NewK >= NTL_BITS_PER_LONG-1) Error("bad arg to fftRep::SetSize()"); if (NewK <= MaxK) { k = NewK; return; } if (NumPrimes != zz_pInfo->NumPrimes) Error("fftRep: inconsistent use"); long i, n; if (MaxK != -1) for (i = 0; i < zz_pInfo->NumPrimes; i++) free(tbl[i]); n = 1L << NewK; for (i = 0; i < zz_pInfo->NumPrimes; i++) { if ( !(tbl[i] = (long *) malloc(n * (sizeof (long)))) ) Error("out of space in fftRep::SetSize()"); } k = MaxK = NewK;}fftRep::fftRep(const fftRep& R){ k = MaxK = R.k; NumPrimes = R.NumPrimes; if (k < 0) return; long i, j, n; n = 1L << k; for (i = 0; i < NumPrimes; i++) { if ( !(tbl[i] = (long *) malloc(n * (sizeof (long)))) ) Error("out of space in fftRep"); for (j = 0; j < n; j++) tbl[i][j] = R.tbl[i][j]; }}fftRep& fftRep::operator=(const fftRep& R){ if (this == &R) return *this; if (NumPrimes != R.NumPrimes) Error("fftRep: inconsistent use"); if (R.k < 0) { k = -1; return *this; } if (R.k > MaxK) { long i, n; if (MaxK != -1) { for (i = 0; i < NumPrimes; i++) free(tbl[i]); } n = 1L << R.k; for (i = 0; i < NumPrimes; i++) { if ( !(tbl[i] = (long *) malloc(n * (sizeof (long)))) ) Error("out of space in fftRep"); } k = MaxK = R.k; } else { k = R.k; } long i, j, n; n = 1L << k; for (i = 0; i < NumPrimes; i++) for (j = 0; j < n; j++) tbl[i][j] = R.tbl[i][j]; return *this;}fftRep::~fftRep(){ if (MaxK == -1) return; for (long i = 0; i < NumPrimes; i++) free(tbl[i]);}static vec_long FFTBuf;void FromModularRep(zz_p& x, long *a){ long n = zz_pInfo->NumPrimes; long p = zz_pInfo->p; double pinv = zz_pInfo->pinv; long q, s, t; long i; double y;#if QUICK_CRT y = 0; for (i = 0; i < n; i++) y = y + ((double) a[i])*zz_pInfo->x[i]; y = y - long(y*pinv)*p; y = y + 0.5; while (y >= p) y -= p; while (y < 0) y += p; q = long(y);#else long Q, r; double qq; y = 0; qq = 0; for (i = 0; i < n; i++) { r = MulDivRem(Q, a[i], zz_pInfo->u[i], FFTPrime[i], zz_pInfo->x[i]); qq = qq + Q; y = y + r*FFTPrimeInv[i]; } y = qq + long(y + 0.5); y = y - long(y*pinv)*p; while (y >= p) y -= p; while (y < 0) y += p; q = long(y); #endif t = 0; for (i = 0; i < n; i++) { s = MulMod(a[i], zz_pInfo->CoeffModP[i], p, pinv); t = AddMod(t, s, p); } s = MulMod(q, zz_pInfo->MinusMModP, p, pinv); t = AddMod(t, s, p); x.LoopHole() = t;}void TofftRep(fftRep& y, const zz_pX& x, long k, long lo, long hi)// computes an n = 2^k point convolution.// if deg(x) >= 2^k, then x is first reduced modulo X^n-1.{ long n, i, j, m, j1; vec_long& s = FFTBuf;; zz_p accum; long NumPrimes = zz_pInfo->NumPrimes; if (k > zz_pInfo->MaxRoot) Error("Polynomial too big for FFT"); if (lo < 0) Error("bad arg to TofftRep"); hi = min(hi, deg(x)); y.SetSize(k); n = 1L << k; m = max(hi-lo + 1, 0); const zz_p *xx = x.rep.elts(); long index = zz_pInfo->index; if (index >= 0) { for (j = 0; j < n; j++) { if (j >= m) { y.tbl[0][j] = 0; } else { accum = xx[j+lo]; for (j1 = j + n; j1 < m; j1 += n) add(accum, accum, xx[j1+lo]); y.tbl[0][j] = rep(accum); } } } else { for (j = 0; j < n; j++) { if (j >= m) { for (i = 0; i < NumPrimes; i++) y.tbl[i][j] = 0; } else { accum = xx[j+lo]; for (j1 = j + n; j1 < m; j1 += n) add(accum, accum, xx[j1+lo]); for (i = 0; i < NumPrimes; i++) { long q = FFTPrime[i]; long t = rep(accum); if (t >= q) t -= q; y.tbl[i][j] = t; } } } } s.SetLength(n); long *sp = s.elts(); if (index >= 0) { long *Root = &RootTable[index][0]; long *yp = &y.tbl[0][0]; FFT(sp, yp, y.k, FFTPrime[index], Root); for (j = 0; j < n; j++) yp[j] = sp[j]; } else { for (i = 0; i < zz_pInfo->NumPrimes; i++) { long *Root = &RootTable[i][0]; long *yp = &y.tbl[i][0]; FFT(sp, yp, y.k, FFTPrime[i], Root); for (j = 0; j < n; j++) yp[j] = sp[j]; } }}void RevTofftRep(fftRep& y, const vec_zz_p& x, long k, long lo, long hi, long offset)// computes an n = 2^k point convolution of X^offset*x[lo..hi] mod X^n-1// using "inverted" evaluation points.{ long n, i, j, m, j1; vec_long& s = FFTBuf; zz_p accum; long NumPrimes = zz_pInfo->NumPrimes; if (k > zz_pInfo->MaxRoot) Error("Polynomial too big for FFT"); if (lo < 0) Error("bad arg to TofftRep"); hi = min(hi, x.length()-1); y.SetSize(k); n = 1L << k; m = max(hi-lo + 1, 0); const zz_p *xx = x.elts(); long index = zz_pInfo->index; offset = offset & (n-1); if (index >= 0) { for (j = 0; j < n; j++) { if (j >= m) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -