📄 zz_px.c
字号:
#include <NTL/ZZ_pX.h>// The mul & sqr routines use routines from ZZX, // which is faster for small degree polynomials.// Define this macro to revert to old strategy.#ifndef NTL_OLD_ZZ_pX_MUL#include <NTL/ZZX.h>#endif#include <NTL/new.h>NTL_START_IMPLconst ZZ_pX& ZZ_pX::zero(){ static ZZ_pX z; return z;}ZZ_pX& ZZ_pX::operator=(long a){ conv(*this, a); return *this;}ZZ_pX& ZZ_pX::operator=(const ZZ_p& a){ conv(*this, a); return *this;}istream& operator>>(istream& s, ZZ_pX& x){ s >> x.rep; x.normalize(); return s;}ostream& operator<<(ostream& s, const ZZ_pX& a){ return s << a.rep;}void ZZ_pX::normalize(){ long n; const ZZ_p* p; n = rep.length(); if (n == 0) return; p = rep.elts() + (n-1); while (n > 0 && IsZero(*p)) { p--; n--; } rep.SetLength(n);}long IsZero(const ZZ_pX& a){ return a.rep.length() == 0;}long IsOne(const ZZ_pX& a){ return a.rep.length() == 1 && IsOne(a.rep[0]);}void GetCoeff(ZZ_p& x, const ZZ_pX& a, long i){ if (i < 0 || i > deg(a)) clear(x); else x = a.rep[i];}void SetCoeff(ZZ_pX& x, long i, const ZZ_p& a){ long j, m; if (i < 0) Error("SetCoeff: negative index"); if (i >= (1L << (NTL_BITS_PER_LONG-4))) Error("overflow in SetCoeff"); m = deg(x); if (i > m) { long pos = x.rep.position(a); x.rep.SetLength(i+1); if (pos != -1) x.rep[i] = x.rep.RawGet(pos); else x.rep[i] = a; for (j = m+1; j < i; j++) clear(x.rep[j]); } else x.rep[i] = a; x.normalize();}void SetCoeff(ZZ_pX& x, long i, long a){ if (a == 1) SetCoeff(x, i); else { ZZ_pTemp TT; ZZ_p& T = TT.val(); conv(T, a); SetCoeff(x, i, T); }}void SetCoeff(ZZ_pX& x, long i){ long j, m; if (i < 0) Error("coefficient index out of range"); if (i >= (1L << (NTL_BITS_PER_LONG-4))) Error("overflow in SetCoeff"); m = deg(x); if (i > m) { x.rep.SetLength(i+1); for (j = m+1; j < i; j++) clear(x.rep[j]); } set(x.rep[i]); x.normalize();}void SetX(ZZ_pX& x){ clear(x); SetCoeff(x, 1);}long IsX(const ZZ_pX& a){ return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));} const ZZ_p& coeff(const ZZ_pX& a, long i){ if (i < 0 || i > deg(a)) return ZZ_p::zero(); else return a.rep[i];}const ZZ_p& LeadCoeff(const ZZ_pX& a){ if (IsZero(a)) return ZZ_p::zero(); else return a.rep[deg(a)];}const ZZ_p& ConstTerm(const ZZ_pX& a){ if (IsZero(a)) return ZZ_p::zero(); else return a.rep[0];}void conv(ZZ_pX& x, const ZZ_p& a){ if (IsZero(a)) x.rep.SetLength(0); else { x.rep.SetLength(1); x.rep[0] = a; // note: if a aliases x.rep[i], i > 0, this code // will still work, since is is assumed that // SetLength(1) will not relocate or destroy x.rep[i] }}void conv(ZZ_pX& x, long a){ if (a == 0) clear(x); else if (a == 1) set(x); else { ZZ_pTemp TT; ZZ_p& T = TT.val(); conv(T, a); conv(x, T); }}void conv(ZZ_pX& x, const ZZ& a){ if (IsZero(a)) clear(x); else { ZZ_pTemp TT; ZZ_p& T = TT.val(); conv(T, a); conv(x, T); }}void conv(ZZ_pX& x, const vec_ZZ_p& a){ x.rep = a; x.normalize();}void add(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b){ long da = deg(a); long db = deg(b); long minab = min(da, db); long maxab = max(da, db); x.rep.SetLength(maxab+1); long i; const ZZ_p *ap, *bp; ZZ_p* xp; for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts(); i; i--, ap++, bp++, xp++) add(*xp, (*ap), (*bp)); if (da > minab && &x != &a) for (i = da-minab; i; i--, xp++, ap++) *xp = *ap; else if (db > minab && &x != &b) for (i = db-minab; i; i--, xp++, bp++) *xp = *bp; else x.normalize();}void add(ZZ_pX& x, const ZZ_pX& a, const ZZ_p& b){ long n = a.rep.length(); if (n == 0) { conv(x, b); } else if (&x == &a) { add(x.rep[0], a.rep[0], b); x.normalize(); } else if (x.rep.MaxLength() == 0) { x = a; add(x.rep[0], a.rep[0], b); x.normalize(); } else { // ugly...b could alias a coeff of x ZZ_p *xp = x.rep.elts(); add(xp[0], a.rep[0], b); x.rep.SetLength(n); xp = x.rep.elts(); const ZZ_p *ap = a.rep.elts(); long i; for (i = 1; i < n; i++) xp[i] = ap[i]; x.normalize(); }}void add(ZZ_pX& x, const ZZ_pX& a, long b){ if (a.rep.length() == 0) { conv(x, b); } else { if (&x != &a) x = a; add(x.rep[0], x.rep[0], b); x.normalize(); }}void sub(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b){ long da = deg(a); long db = deg(b); long minab = min(da, db); long maxab = max(da, db); x.rep.SetLength(maxab+1); long i; const ZZ_p *ap, *bp; ZZ_p* xp; for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts(); i; i--, ap++, bp++, xp++) sub(*xp, (*ap), (*bp)); if (da > minab && &x != &a) for (i = da-minab; i; i--, xp++, ap++) *xp = *ap; else if (db > minab) for (i = db-minab; i; i--, xp++, bp++) negate(*xp, *bp); else x.normalize();}void sub(ZZ_pX& x, const ZZ_pX& a, const ZZ_p& b){ long n = a.rep.length(); if (n == 0) { conv(x, b); negate(x, x); } else if (&x == &a) { sub(x.rep[0], a.rep[0], b); x.normalize(); } else if (x.rep.MaxLength() == 0) { x = a; sub(x.rep[0], a.rep[0], b); x.normalize(); } else { // ugly...b could alias a coeff of x ZZ_p *xp = x.rep.elts(); sub(xp[0], a.rep[0], b); x.rep.SetLength(n); xp = x.rep.elts(); const ZZ_p *ap = a.rep.elts(); long i; for (i = 1; i < n; i++) xp[i] = ap[i]; x.normalize(); }}void sub(ZZ_pX& x, const ZZ_pX& a, long b){ if (b == 0) { x = a; return; } if (a.rep.length() == 0) { x.rep.SetLength(1); x.rep[0] = b; negate(x.rep[0], x.rep[0]); } else { if (&x != &a) x = a; sub(x.rep[0], x.rep[0], b); } x.normalize();}void sub(ZZ_pX& x, const ZZ_p& a, const ZZ_pX& b){ ZZ_pTemp TT; ZZ_p& T = TT.val(); T = a; negate(x, b); add(x, x, T);}void sub(ZZ_pX& x, long a, const ZZ_pX& b){ ZZ_pTemp TT; ZZ_p& T = TT.val(); T = a; negate(x, b); add(x, x, T);}void negate(ZZ_pX& x, const ZZ_pX& a){ long n = a.rep.length(); x.rep.SetLength(n); const ZZ_p* ap = a.rep.elts(); ZZ_p* xp = x.rep.elts(); long i; for (i = n; i; i--, ap++, xp++) negate((*xp), (*ap));}#ifndef NTL_OLD_ZZ_pX_MUL// These crossovers are tuned for a Pentium, but hopefully// they should be OK on other machines as well.const long SS_kbound = 40;const double SS_rbound = 1.25;void mul(ZZ_pX& c, const ZZ_pX& a, const ZZ_pX& b){ if (IsZero(a) || IsZero(b)) { clear(c); return; } if (&a == &b) { sqr(c, a); return; } long k = ZZ_p::ModulusSize(); long s = min(deg(a), deg(b)) + 1; if (s == 1 || (k == 1 && s < 40) || (k == 2 && s < 20) || (k == 3 && s < 12) || (k <= 5 && s < 8) || (k <= 12 && s < 4) ) { PlainMul(c, a, b); } else if (s < 80) { ZZX A, B, C; conv(A, a); conv(B, b); KarMul(C, A, B); conv(c, C); } else { long mbits; mbits = NumBits(ZZ_p::modulus()); if (k >= SS_kbound && SSRatio(deg(a), mbits, deg(b), mbits) < SS_rbound) { ZZX A, B, C; conv(A, a); conv(B, b); SSMul(C, A, B); conv(c, C); } else { FFTMul(c, a, b); } }}void sqr(ZZ_pX& c, const ZZ_pX& a){ if (IsZero(a)) { clear(c); return; } long k = ZZ_p::ModulusSize(); long s = deg(a) + 1; if (s == 1 || (k == 1 && s < 50) || (k == 2 && s < 25) || (k == 3 && s < 25) || (k <= 6 && s < 12) || (k <= 8 && s < 8) || (k == 9 && s < 6) || (k <= 30 && s < 4) ) { PlainSqr(c, a); } else if (s < 80) { ZZX C, A; conv(A, a); KarSqr(C, A); conv(c, C); } else { long mbits; mbits = NumBits(ZZ_p::modulus()); if (k >= SS_kbound && SSRatio(deg(a), mbits, deg(a), mbits) < SS_rbound) { ZZX A, C; conv(A, a); SSSqr(C, A); conv(c, C); } else { FFTSqr(c, a); } }}#elsevoid mul(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b){ if (&a == &b) { sqr(x, a); return; } if (deg(a) > NTL_ZZ_pX_FFT_CROSSOVER && deg(b) > NTL_ZZ_pX_FFT_CROSSOVER) FFTMul(x, a, b); else PlainMul(x, a, b);}void sqr(ZZ_pX& x, const ZZ_pX& a){ if (deg(a) > NTL_ZZ_pX_FFT_CROSSOVER) FFTSqr(x, a); else PlainSqr(x, a);}#endifvoid PlainMul(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b){ long da = deg(a); long db = deg(b); if (da < 0 || db < 0) { clear(x); return; } if (da == 0) { mul(x, b, a.rep[0]); return; } if (db == 0) { mul(x, a, b.rep[0]); return; } long d = da+db; const ZZ_p *ap, *bp; ZZ_p *xp; ZZ_pX la, lb; if (&x == &a) { la = a; ap = la.rep.elts(); } else ap = a.rep.elts(); if (&x == &b) { lb = b; bp = lb.rep.elts(); } else bp = b.rep.elts(); x.rep.SetLength(d+1); xp = x.rep.elts(); long i, j, jmin, jmax; static ZZ t, accum; for (i = 0; i <= d; i++) { jmin = max(0, i-db); jmax = min(da, i); clear(accum); for (j = jmin; j <= jmax; j++) { mul(t, rep(ap[j]), rep(bp[i-j])); add(accum, accum, t); } conv(xp[i], accum); } x.normalize();}void PlainSqr(ZZ_pX& x, const ZZ_pX& a){ long da = deg(a); if (da < 0) { clear(x); return; } long d = 2*da; const ZZ_p *ap; ZZ_p *xp; ZZ_pX la; if (&x == &a) { la = a; ap = la.rep.elts(); } else ap = a.rep.elts(); x.rep.SetLength(d+1); xp = x.rep.elts(); long i, j, jmin, jmax; long m, m2; static ZZ t, accum; for (i = 0; i <= d; i++) { jmin = max(0, i-da); jmax = min(da, i); m = jmax - jmin + 1; m2 = m >> 1; jmax = jmin + m2 - 1; clear(accum); for (j = jmin; j <= jmax; j++) { mul(t, rep(ap[j]), rep(ap[i-j])); add(accum, accum, t); } add(accum, accum, accum); if (m & 1) { sqr(t, rep(ap[jmax + 1])); add(accum, accum, t); } conv(xp[i], accum); } x.normalize();}void PlainDivRem(ZZ_pX& q, ZZ_pX& r, const ZZ_pX& a, const ZZ_pX& b){ long da, db, dq, i, j, LCIsOne; const ZZ_p *bp; ZZ_p *qp; ZZ *xp; ZZ_p LCInv, t; static ZZ s; da = deg(a); db = deg(b); if (db < 0) Error("ZZ_pX: division by zero"); if (da < db) { r = a; clear(q); return; } ZZ_pX lb; if (&q == &b) { lb = b; bp = lb.rep.elts(); } else bp = b.rep.elts(); if (IsOne(bp[db])) LCIsOne = 1; else { LCIsOne = 0; inv(LCInv, bp[db]); } ZZVec x(da + 1, ZZ_pInfo->ExtendedModulusSize); for (i = 0; i <= da; i++) x[i] = rep(a.rep[i]); xp = x.elts(); dq = da - db; q.rep.SetLength(dq+1); qp = q.rep.elts(); for (i = dq; i >= 0; i--) { conv(t, xp[i+db]); if (!LCIsOne) mul(t, t, LCInv); qp[i] = t; negate(t, t); for (j = db-1; j >= 0; j--) { mul(s, rep(t), rep(bp[j])); add(xp[i+j], xp[i+j], s); } } r.rep.SetLength(db); for (i = 0; i < db; i++) conv(r.rep[i], xp[i]); r.normalize();}void PlainRem(ZZ_pX& r, const ZZ_pX& a, const ZZ_pX& b, ZZVec& x){ long da, db, dq, i, j, LCIsOne; const ZZ_p *bp; ZZ *xp; ZZ_p LCInv, t; static ZZ s; da = deg(a); db = deg(b); if (db < 0) Error("ZZ_pX: division by zero"); if (da < db) { r = a; return; } bp = b.rep.elts(); if (IsOne(bp[db])) LCIsOne = 1; else { LCIsOne = 0; inv(LCInv, bp[db]); } for (i = 0; i <= da; i++) x[i] = rep(a.rep[i]); xp = x.elts(); dq = da - db; for (i = dq; i >= 0; i--) { conv(t, xp[i+db]); if (!LCIsOne) mul(t, t, LCInv); negate(t, t); for (j = db-1; j >= 0; j--) { mul(s, rep(t), rep(bp[j])); add(xp[i+j], xp[i+j], s); } } r.rep.SetLength(db); for (i = 0; i < db; i++) conv(r.rep[i], xp[i]); r.normalize();}void PlainDivRem(ZZ_pX& q, ZZ_pX& r, const ZZ_pX& a, const ZZ_pX& b, ZZVec& x){ long da, db, dq, i, j, LCIsOne; const ZZ_p *bp; ZZ_p *qp; ZZ *xp; ZZ_p LCInv, t; static ZZ s; da = deg(a); db = deg(b); if (db < 0) Error("ZZ_pX: division by zero"); if (da < db) { r = a; clear(q);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -