📄 lzz_px.c
字号:
#include <NTL/lzz_pX.h>#include <NTL/vec_double.h>#include <NTL/new.h>NTL_START_IMPLlong zz_pX_mod_crossover[5] = {45, 45, 90, 180, 180};long zz_pX_mul_crossover[5] = {90, 400, 600, 1500, 1500};long zz_pX_newton_crossover[5] = {150, 150, 300, 700, 700};long zz_pX_div_crossover[5] = {180, 180, 350, 750, 750};long zz_pX_halfgcd_crossover[5] = {90, 90, 180, 350, 350};long zz_pX_gcd_crossover[5] = {400, 400, 800, 1400, 1400};long zz_pX_bermass_crossover[5] = {400, 480, 900, 1600, 1600};long zz_pX_trace_crossover[5] = {200, 350, 450, 800, 800};#define QUICK_CRT (NTL_DOUBLE_PRECISION - NTL_SP_NBITS > 10)const zz_pX& zz_pX::zero(){ static zz_pX z; return z;}istream& operator>>(istream& s, zz_pX& x){ s >> x.rep; x.normalize(); return s;}ostream& operator<<(ostream& s, const zz_pX& a){ return s << a.rep;}void zz_pX::normalize(){ long n; const zz_p* p; n = rep.length(); if (n == 0) return; p = rep.elts() + (n-1); while (n > 0 && IsZero(*p)) { p--; n--; } rep.SetLength(n);}long IsZero(const zz_pX& a){ return a.rep.length() == 0;}long IsOne(const zz_pX& a){ return a.rep.length() == 1 && IsOne(a.rep[0]);}void GetCoeff(zz_p& x, const zz_pX& a, long i){ if (i < 0 || i > deg(a)) clear(x); else x = a.rep[i];}void SetCoeff(zz_pX& x, long i, zz_p a){ long j, m; if (i < 0) Error("SetCoeff: negative index"); if (i >= (1L << (NTL_BITS_PER_LONG-4))) Error("overflow in SetCoeff"); m = deg(x); if (i > m) { x.rep.SetLength(i+1); for (j = m+1; j < i; j++) clear(x.rep[j]); } x.rep[i] = a; x.normalize();}void SetCoeff(zz_pX& x, long i, long a){ if (a == 1) SetCoeff(x, i); else SetCoeff(x, i, to_zz_p(a));}void SetCoeff(zz_pX& x, long i){ long j, m; if (i < 0) Error("coefficient index out of range"); if (i >= (1L << (NTL_BITS_PER_LONG-4))) Error("overflow in SetCoeff"); m = deg(x); if (i > m) { x.rep.SetLength(i+1); for (j = m+1; j < i; j++) clear(x.rep[j]); } set(x.rep[i]); x.normalize();}void SetX(zz_pX& x){ clear(x); SetCoeff(x, 1);}long IsX(const zz_pX& a){ return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));} zz_p coeff(const zz_pX& a, long i){ if (i < 0 || i > deg(a)) return zz_p::zero(); else return a.rep[i];}zz_p LeadCoeff(const zz_pX& a){ if (IsZero(a)) return zz_p::zero(); else return a.rep[deg(a)];}zz_p ConstTerm(const zz_pX& a){ if (IsZero(a)) return zz_p::zero(); else return a.rep[0];}void conv(zz_pX& x, zz_p a){ if (IsZero(a)) x.rep.SetLength(0); else { x.rep.SetLength(1); x.rep[0] = a; }}void conv(zz_pX& x, long a){ if (a == 0) { x.rep.SetLength(0); return; } zz_p t; conv(t, a); conv(x, t);}void conv(zz_pX& x, const ZZ& a){ if (a == 0) { x.rep.SetLength(0); return; } zz_p t; conv(t, a); conv(x, t);}void conv(zz_pX& x, const vec_zz_p& a){ x.rep = a; x.normalize();}void add(zz_pX& x, const zz_pX& a, const zz_pX& b){ long da = deg(a); long db = deg(b); long minab = min(da, db); long maxab = max(da, db); x.rep.SetLength(maxab+1); long i; const zz_p *ap, *bp; zz_p* xp; for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts(); i; i--, ap++, bp++, xp++) add(*xp, (*ap), (*bp)); if (da > minab && &x != &a) for (i = da-minab; i; i--, xp++, ap++) *xp = *ap; else if (db > minab && &x != &b) for (i = db-minab; i; i--, xp++, bp++) *xp = *bp; else x.normalize();}void add(zz_pX& x, const zz_pX& a, zz_p b){ if (a.rep.length() == 0) { conv(x, b); } else { if (&x != &a) x = a; add(x.rep[0], x.rep[0], b); x.normalize(); }}void sub(zz_pX& x, const zz_pX& a, const zz_pX& b){ long da = deg(a); long db = deg(b); long minab = min(da, db); long maxab = max(da, db); x.rep.SetLength(maxab+1); long i; const zz_p *ap, *bp; zz_p* xp; for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts(); i; i--, ap++, bp++, xp++) sub(*xp, (*ap), (*bp)); if (da > minab && &x != &a) for (i = da-minab; i; i--, xp++, ap++) *xp = *ap; else if (db > minab) for (i = db-minab; i; i--, xp++, bp++) negate(*xp, *bp); else x.normalize();}void sub(zz_pX& x, const zz_pX& a, zz_p b){ if (a.rep.length() == 0) { x.rep.SetLength(1); negate(x.rep[0], b); } else { if (&x != &a) x = a; sub(x.rep[0], x.rep[0], b); } x.normalize();}void sub(zz_pX& x, zz_p a, const zz_pX& b){ negate(x, b); add(x, x, a);}void negate(zz_pX& x, const zz_pX& a){ long n = a.rep.length(); x.rep.SetLength(n); const zz_p* ap = a.rep.elts(); zz_p* xp = x.rep.elts(); long i; for (i = n; i; i--, ap++, xp++) negate((*xp), (*ap));}void mul(zz_pX& x, const zz_pX& a, const zz_pX& b){ if (&a == &b) { sqr(x, a); return; } if (deg(a) > NTL_zz_pX_MUL_CROSSOVER && deg(b) > NTL_zz_pX_MUL_CROSSOVER) FFTMul(x, a, b); else PlainMul(x, a, b);}void sqr(zz_pX& x, const zz_pX& a){ if (deg(a) > NTL_zz_pX_MUL_CROSSOVER) FFTSqr(x, a); else PlainSqr(x, a);}/* "plain" multiplication and squaring actually incorporates Karatsuba */void PlainMul(zz_p *xp, const zz_p *ap, long sa, const zz_p *bp, long sb){ if (sa == 0 || sb == 0) return; long sx = sa+sb-1; if (sa < sb) { { long t = sa; sa = sb; sb = t; } { const zz_p *t = ap; ap = bp; bp = t; } } long i, j; for (i = 0; i < sx; i++) clear(xp[i]); long p = zz_p::modulus(); double pinv = zz_p::ModulusInverse(); for (i = 0; i < sb; i++) { long t1 = rep(bp[i]); double bpinv = ((double) t1)*pinv; zz_p *xp1 = xp+i; for (j = 0; j < sa; j++) { long t2; t2 = MulMod2(rep(ap[j]), t1, p, bpinv); xp1[j].LoopHole() = AddMod(t2, rep(xp1[j]), p); } }}static vec_double a_buf, b_buf;inline void reduce(zz_p& r, double x, long p, double pinv){ long rr = long(x - double(p)*double(long(x*pinv))); if (rr < 0) rr += p; if (rr >= p) rr -= p; r.LoopHole() = rr;}void PlainMul_FP(zz_p *xp, const zz_p *aap, long sa, const zz_p *bbp, long sb){ if (sa == 0 || sb == 0) return; double *ap = a_buf.elts(); double *bp = b_buf.elts(); long d = sa+sb-2; long i, j, jmin, jmax; for (i = 0; i < sa; i++) ap[i] = double(rep(aap[i])); for (i = 0; i < sb; i++) bp[i] = double(rep(bbp[i])); double accum; long p = zz_p::modulus(); double pinv = zz_p::ModulusInverse(); for (i = 0; i <= d; i++) { jmin = max(0, i-(sb-1)); jmax = min((sa-1), i); accum = 0; for (j = jmin; j <= jmax; j++) { accum += ap[j]*bp[i-j]; } reduce(xp[i], accum, p, pinv); }}#define KARX (16)void KarFold(zz_p *T, const zz_p *b, long sb, long hsa){ long m = sb - hsa; long i; for (i = 0; i < m; i++) add(T[i], b[i], b[hsa+i]); for (i = m; i < hsa; i++) T[i] = b[i];}void KarSub(zz_p *T, const zz_p *b, long sb){ long i; for (i = 0; i < sb; i++) sub(T[i], T[i], b[i]);}void KarAdd(zz_p *T, const zz_p *b, long sb){ long i; for (i = 0; i < sb; i++) add(T[i], T[i], b[i]);}void KarFix(zz_p *c, const zz_p *b, long sb, long hsa){ long i; for (i = 0; i < hsa; i++) c[i] = b[i]; for (i = hsa; i < sb; i++) add(c[i], c[i], b[i]);}void KarMul(zz_p *c, const zz_p *a, long sa, const zz_p *b, long sb, zz_p *stk){ if (sa < sb) { { long t = sa; sa = sb; sb = t; } { const zz_p *t = a; a = b; b = t; } } if (sb < KARX) { PlainMul(c, a, sa, b, sb); return; } long hsa = (sa + 1) >> 1; if (hsa < sb) { /* normal case */ long hsa2 = hsa << 1; zz_p *T1, *T2, *T3; T1 = stk; stk += hsa; T2 = stk; stk += hsa; T3 = stk; stk += hsa2 - 1; /* compute T1 = a_lo + a_hi */ KarFold(T1, a, sa, hsa); /* compute T2 = b_lo + b_hi */ KarFold(T2, b, sb, hsa); /* recursively compute T3 = T1 * T2 */ KarMul(T3, T1, hsa, T2, hsa, stk); /* recursively compute a_hi * b_hi into high part of c */ /* and subtract from T3 */ KarMul(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk); KarSub(T3, c + hsa2, sa + sb - hsa2 - 1); /* recursively compute a_lo*b_lo into low part of c */ /* and subtract from T3 */ KarMul(c, a, hsa, b, hsa, stk); KarSub(T3, c, hsa2 - 1); clear(c[hsa2 - 1]); /* finally, add T3 * X^{hsa} to c */ KarAdd(c+hsa, T3, hsa2-1); } else { /* degenerate case */ zz_p *T; T = stk; stk += hsa + sb - 1; /* recursively compute b*a_hi into high part of c */ KarMul(c + hsa, a + hsa, sa - hsa, b, sb, stk); /* recursively compute b*a_lo into T */ KarMul(T, a, hsa, b, sb, stk); KarFix(c, T, hsa + sb - 1, hsa); }}void KarMul_FP(zz_p *c, const zz_p *a, long sa, const zz_p *b, long sb, zz_p *stk){ if (sa < sb) { { long t = sa; sa = sb; sb = t; } { const zz_p *t = a; a = b; b = t; } } if (sb < KARX) { PlainMul_FP(c, a, sa, b, sb); return; } long hsa = (sa + 1) >> 1; if (hsa < sb) { /* normal case */ long hsa2 = hsa << 1; zz_p *T1, *T2, *T3; T1 = stk; stk += hsa; T2 = stk; stk += hsa; T3 = stk; stk += hsa2 - 1; /* compute T1 = a_lo + a_hi */ KarFold(T1, a, sa, hsa); /* compute T2 = b_lo + b_hi */ KarFold(T2, b, sb, hsa); /* recursively compute T3 = T1 * T2 */ KarMul_FP(T3, T1, hsa, T2, hsa, stk); /* recursively compute a_hi * b_hi into high part of c */ /* and subtract from T3 */ KarMul_FP(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk); KarSub(T3, c + hsa2, sa + sb - hsa2 - 1); /* recursively compute a_lo*b_lo into low part of c */ /* and subtract from T3 */ KarMul_FP(c, a, hsa, b, hsa, stk); KarSub(T3, c, hsa2 - 1); clear(c[hsa2 - 1]); /* finally, add T3 * X^{hsa} to c */ KarAdd(c+hsa, T3, hsa2-1); } else { /* degenerate case */ zz_p *T; T = stk; stk += hsa + sb - 1; /* recursively compute b*a_hi into high part of c */ KarMul_FP(c + hsa, a + hsa, sa - hsa, b, sb, stk); /* recursively compute b*a_lo into T */ KarMul_FP(T, a, hsa, b, sb, stk); KarFix(c, T, hsa + sb - 1, hsa); }}void PlainMul(zz_pX& c, const zz_pX& a, const zz_pX& b){ long sa = a.rep.length(); long sb = b.rep.length(); if (sa == 0 || sb == 0) { clear(c); return; } if (sa == 1) { mul(c, b, a.rep[0]); return; } if (sb == 1) { mul(c, a, b.rep[0]); return; } if (&a == &b) { PlainSqr(c, a); return; } vec_zz_p mem; const zz_p *ap, *bp; zz_p *cp; if (&a == &c) { mem = a.rep; ap = mem.elts(); } else ap = a.rep.elts(); if (&b == &c) { mem = b.rep; bp = mem.elts(); } else bp = b.rep.elts(); c.rep.SetLength(sa+sb-1); cp = c.rep.elts(); long p = zz_p::modulus(); long use_FP = ((p < NTL_SP_BOUND/KARX) && (double(p)*double(p) < NTL_FDOUBLE_PRECISION/KARX)); if (sa < KARX || sb < KARX) { if (use_FP) { a_buf.SetLength(max(sa, sb)); b_buf.SetLength(max(sa, sb)); PlainMul_FP(cp, ap, sa, bp, sb); } else PlainMul(cp, ap, sa, bp, sb); } else { /* karatsuba */ long n, hn, sp; n = max(sa, sb); sp = 0; do { hn = (n+1) >> 1; sp += (hn << 2) - 1; n = hn; } while (n >= KARX); vec_zz_p stk; stk.SetLength(sp); if (use_FP) { a_buf.SetLength(max(sa, sb)); b_buf.SetLength(max(sa, sb)); KarMul_FP(cp, ap, sa, bp, sb, stk.elts()); } else KarMul(cp, ap, sa, bp, sb, stk.elts()); } c.normalize();}void PlainSqr_FP(zz_p *xp, const zz_p *aap, long sa){ if (sa == 0) return; long da = sa-1; long d = 2*da; long i, j, jmin, jmax, m, m2; double *ap = a_buf.elts(); for (i = 0; i < sa; i++) ap[i] = double(rep(aap[i])); double accum; long p = zz_p::modulus(); double pinv = zz_p::ModulusInverse(); for (i = 0; i <= d; i++) { jmin = max(0, i-da); jmax = min(da, i); m = jmax - jmin + 1; m2 = m >> 1; jmax = jmin + m2 - 1; accum = 0; for (j = jmin; j <= jmax; j++) { accum += ap[j]*ap[i-j]; } accum += accum; if (m & 1) { accum += ap[jmax + 1]*ap[jmax + 1]; } reduce(xp[i], accum, p, pinv); }}void PlainSqr(zz_p *xp, const zz_p *ap, long sa){ if (sa == 0) return; long i, j, k, cnt; cnt = 2*sa-1; for (i = 0; i < cnt; i++) clear(xp[i]); long p = zz_p::modulus(); double pinv = zz_p::ModulusInverse(); long t1, t2; i = -1; for (j = 0; j <= sa-2; j++) { i += 2; t1 = MulMod(rep(ap[j]), rep(ap[j]), p, pinv); t2 = rep(xp[i-1]); t2 = AddMod(t2, t2, p); t2 = AddMod(t2, t1, p); xp[i-1].LoopHole() = t2; cnt = sa - 1 - j; const zz_p *ap1 = ap+(j+1); zz_p *xp1 = xp+i; t1 = rep(ap[j]); double tpinv = ((double) t1)*pinv; for (k = 0; k < cnt; k++) { t2 = MulMod2(rep(ap1[k]), t1, p, tpinv); t2 = AddMod(t2, rep(xp1[k]), p); xp1[k].LoopHole() = t2; } t2 = rep(*xp1); t2 = AddMod(t2, t2, p); (*xp1).LoopHole() = t2; } t1 = rep(ap[sa-1]); t1 = MulMod(t1, t1, p, pinv); xp[2*sa-2].LoopHole() = t1;}#define KARSX (30)void KarSqr(zz_p *c, const zz_p *a, long sa, zz_p *stk){ if (sa < KARSX) { PlainSqr(c, a, sa); return; } long hsa = (sa + 1) >> 1; long hsa2 = hsa << 1; zz_p *T1, *T2; T1 = stk; stk += hsa; T2 = stk; stk += hsa2-1; KarFold(T1, a, sa, hsa); KarSqr(T2, T1, hsa, stk); KarSqr(c + hsa2, a+hsa, sa-hsa, stk); KarSub(T2, c + hsa2, sa + sa - hsa2 - 1); KarSqr(c, a, hsa, stk); KarSub(T2, c, hsa2 - 1); clear(c[hsa2 - 1]); KarAdd(c+hsa, T2, hsa2-1);}void KarSqr_FP(zz_p *c, const zz_p *a, long sa, zz_p *stk){ if (sa < KARSX) { PlainSqr_FP(c, a, sa); return; } long hsa = (sa + 1) >> 1; long hsa2 = hsa << 1; zz_p *T1, *T2; T1 = stk; stk += hsa; T2 = stk; stk += hsa2-1; KarFold(T1, a, sa, hsa); KarSqr_FP(T2, T1, hsa, stk); KarSqr_FP(c + hsa2, a+hsa, sa-hsa, stk); KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -