⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lzz_px.c

📁 数值算法库for Unix
💻 C
📖 第 1 页 / 共 4 页
字号:
#include <NTL/lzz_pX.h>#include <NTL/vec_double.h>#include <NTL/new.h>NTL_START_IMPLlong zz_pX_mod_crossover[5] = {45, 45, 90, 180, 180};long zz_pX_mul_crossover[5] = {90, 400, 600, 1500, 1500};long zz_pX_newton_crossover[5] = {150, 150, 300, 700, 700};long zz_pX_div_crossover[5] = {180, 180, 350, 750, 750};long zz_pX_halfgcd_crossover[5] = {90, 90, 180, 350, 350};long zz_pX_gcd_crossover[5] = {400, 400, 800, 1400, 1400};long zz_pX_bermass_crossover[5] = {400, 480, 900, 1600, 1600};long zz_pX_trace_crossover[5] = {200, 350, 450, 800, 800};#define QUICK_CRT (NTL_DOUBLE_PRECISION - NTL_SP_NBITS > 10)const zz_pX& zz_pX::zero(){   static zz_pX z;   return z;}istream& operator>>(istream& s, zz_pX& x){   s >> x.rep;   x.normalize();   return s;}ostream& operator<<(ostream& s, const zz_pX& a){   return s << a.rep;}void zz_pX::normalize(){   long n;   const zz_p* p;   n = rep.length();   if (n == 0) return;   p = rep.elts() + (n-1);   while (n > 0 && IsZero(*p)) {      p--;       n--;   }   rep.SetLength(n);}long IsZero(const zz_pX& a){   return a.rep.length() == 0;}long IsOne(const zz_pX& a){    return a.rep.length() == 1 && IsOne(a.rep[0]);}void GetCoeff(zz_p& x, const zz_pX& a, long i){   if (i < 0 || i > deg(a))      clear(x);   else      x = a.rep[i];}void SetCoeff(zz_pX& x, long i, zz_p a){   long j, m;   if (i < 0)       Error("SetCoeff: negative index");   if (i >= (1L << (NTL_BITS_PER_LONG-4)))      Error("overflow in SetCoeff");   m = deg(x);   if (i > m) {      x.rep.SetLength(i+1);      for (j = m+1; j < i; j++)         clear(x.rep[j]);   }   x.rep[i] = a;   x.normalize();}void SetCoeff(zz_pX& x, long i, long a){   if (a == 1)      SetCoeff(x, i);   else      SetCoeff(x, i, to_zz_p(a));}void SetCoeff(zz_pX& x, long i){   long j, m;   if (i < 0)       Error("coefficient index out of range");   if (i >= (1L << (NTL_BITS_PER_LONG-4)))      Error("overflow in SetCoeff");   m = deg(x);   if (i > m) {      x.rep.SetLength(i+1);      for (j = m+1; j < i; j++)         clear(x.rep[j]);   }   set(x.rep[i]);   x.normalize();}void SetX(zz_pX& x){   clear(x);   SetCoeff(x, 1);}long IsX(const zz_pX& a){   return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));}            zz_p coeff(const zz_pX& a, long i){   if (i < 0 || i > deg(a))      return zz_p::zero();   else      return a.rep[i];}zz_p LeadCoeff(const zz_pX& a){   if (IsZero(a))      return zz_p::zero();   else      return a.rep[deg(a)];}zz_p ConstTerm(const zz_pX& a){   if (IsZero(a))      return zz_p::zero();   else      return a.rep[0];}void conv(zz_pX& x, zz_p a){   if (IsZero(a))      x.rep.SetLength(0);   else {      x.rep.SetLength(1);      x.rep[0] = a;   }}void conv(zz_pX& x, long a){   if (a == 0) {      x.rep.SetLength(0);      return;   }      zz_p t;   conv(t, a);   conv(x, t);}void conv(zz_pX& x, const ZZ& a){   if (a == 0) {      x.rep.SetLength(0);      return;   }      zz_p t;   conv(t, a);   conv(x, t);}void conv(zz_pX& x, const vec_zz_p& a){   x.rep = a;   x.normalize();}void add(zz_pX& x, const zz_pX& a, const zz_pX& b){   long da = deg(a);   long db = deg(b);   long minab = min(da, db);   long maxab = max(da, db);   x.rep.SetLength(maxab+1);   long i;   const zz_p *ap, *bp;    zz_p* xp;   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();        i; i--, ap++, bp++, xp++)      add(*xp, (*ap), (*bp));   if (da > minab && &x != &a)      for (i = da-minab; i; i--, xp++, ap++)         *xp = *ap;   else if (db > minab && &x != &b)      for (i = db-minab; i; i--, xp++, bp++)         *xp = *bp;   else      x.normalize();}void add(zz_pX& x, const zz_pX& a, zz_p b){   if (a.rep.length() == 0) {      conv(x, b);   }   else {      if (&x != &a) x = a;      add(x.rep[0], x.rep[0], b);      x.normalize();   }}void sub(zz_pX& x, const zz_pX& a, const zz_pX& b){   long da = deg(a);   long db = deg(b);   long minab = min(da, db);   long maxab = max(da, db);   x.rep.SetLength(maxab+1);   long i;   const zz_p *ap, *bp;    zz_p* xp;   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();        i; i--, ap++, bp++, xp++)      sub(*xp, (*ap), (*bp));   if (da > minab && &x != &a)      for (i = da-minab; i; i--, xp++, ap++)         *xp = *ap;   else if (db > minab)      for (i = db-minab; i; i--, xp++, bp++)         negate(*xp, *bp);   else      x.normalize();}void sub(zz_pX& x, const zz_pX& a, zz_p b){   if (a.rep.length() == 0) {      x.rep.SetLength(1);      negate(x.rep[0], b);   }   else {      if (&x != &a) x = a;      sub(x.rep[0], x.rep[0], b);   }   x.normalize();}void sub(zz_pX& x, zz_p a, const zz_pX& b){   negate(x, b);   add(x, x, a);}void negate(zz_pX& x, const zz_pX& a){   long n = a.rep.length();   x.rep.SetLength(n);   const zz_p* ap = a.rep.elts();   zz_p* xp = x.rep.elts();   long i;   for (i = n; i; i--, ap++, xp++)      negate((*xp), (*ap));}void mul(zz_pX& x, const zz_pX& a, const zz_pX& b){   if (&a == &b) {      sqr(x, a);      return;   }   if (deg(a) > NTL_zz_pX_MUL_CROSSOVER && deg(b) > NTL_zz_pX_MUL_CROSSOVER)      FFTMul(x, a, b);   else      PlainMul(x, a, b);}void sqr(zz_pX& x, const zz_pX& a){   if (deg(a) > NTL_zz_pX_MUL_CROSSOVER)      FFTSqr(x, a);   else      PlainSqr(x, a);}/* "plain" multiplication and squaring actually incorporates Karatsuba */void PlainMul(zz_p *xp, const zz_p *ap, long sa, const zz_p *bp, long sb){   if (sa == 0 || sb == 0) return;   long sx = sa+sb-1;   if (sa < sb) {      { long t = sa; sa = sb; sb = t; }      { const zz_p *t = ap; ap = bp; bp = t; }   }   long i, j;   for (i = 0; i < sx; i++)      clear(xp[i]);   long p = zz_p::modulus();   double pinv = zz_p::ModulusInverse();   for (i = 0; i < sb; i++) {      long t1 = rep(bp[i]);      double bpinv = ((double) t1)*pinv;      zz_p *xp1 = xp+i;      for (j = 0; j < sa; j++) {         long t2;         t2 = MulMod2(rep(ap[j]), t1, p, bpinv);         xp1[j].LoopHole() = AddMod(t2, rep(xp1[j]), p);      }   }}static vec_double a_buf, b_buf;inline void reduce(zz_p& r, double x, long p, double pinv){   long rr = long(x - double(p)*double(long(x*pinv)));   if (rr < 0) rr += p;   if (rr >= p) rr -= p;   r.LoopHole() = rr;}void PlainMul_FP(zz_p *xp, const zz_p *aap, long sa, const zz_p *bbp, long sb){   if (sa == 0 || sb == 0) return;   double *ap = a_buf.elts();   double *bp = b_buf.elts();   long d = sa+sb-2;   long i, j, jmin, jmax;   for (i = 0; i < sa; i++) ap[i] = double(rep(aap[i]));   for (i = 0; i < sb; i++) bp[i] = double(rep(bbp[i]));   double accum;   long p = zz_p::modulus();   double pinv = zz_p::ModulusInverse();   for (i = 0; i <= d; i++) {      jmin = max(0, i-(sb-1));      jmax = min((sa-1), i);      accum = 0;      for (j = jmin; j <= jmax; j++) {         accum += ap[j]*bp[i-j];      }      reduce(xp[i], accum, p, pinv);   }}#define KARX (16)void KarFold(zz_p *T, const zz_p *b, long sb, long hsa){   long m = sb - hsa;   long i;   for (i = 0; i < m; i++)      add(T[i], b[i], b[hsa+i]);   for (i = m; i < hsa; i++)      T[i] = b[i];}void KarSub(zz_p *T, const zz_p *b, long sb){   long i;   for (i = 0; i < sb; i++)      sub(T[i], T[i], b[i]);}void KarAdd(zz_p *T, const zz_p *b, long sb){   long i;   for (i = 0; i < sb; i++)      add(T[i], T[i], b[i]);}void KarFix(zz_p *c, const zz_p *b, long sb, long hsa){   long i;   for (i = 0; i < hsa; i++)      c[i] = b[i];   for (i = hsa; i < sb; i++)      add(c[i], c[i], b[i]);}void KarMul(zz_p *c, const zz_p *a, long sa, const zz_p *b, long sb, zz_p *stk){   if (sa < sb) {      { long t = sa; sa = sb; sb = t; }      { const zz_p *t = a; a = b; b = t; }   }   if (sb < KARX) {      PlainMul(c, a, sa, b, sb);      return;   }   long hsa = (sa + 1) >> 1;   if (hsa < sb) {      /* normal case */      long hsa2 = hsa << 1;      zz_p *T1, *T2, *T3;      T1 = stk; stk += hsa;      T2 = stk; stk += hsa;      T3 = stk; stk += hsa2 - 1;      /* compute T1 = a_lo + a_hi */      KarFold(T1, a, sa, hsa);      /* compute T2 = b_lo + b_hi */      KarFold(T2, b, sb, hsa);      /* recursively compute T3 = T1 * T2 */      KarMul(T3, T1, hsa, T2, hsa, stk);      /* recursively compute a_hi * b_hi into high part of c */      /* and subtract from T3 */      KarMul(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk);      KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);      /* recursively compute a_lo*b_lo into low part of c */      /* and subtract from T3 */      KarMul(c, a, hsa, b, hsa, stk);      KarSub(T3, c, hsa2 - 1);      clear(c[hsa2 - 1]);      /* finally, add T3 * X^{hsa} to c */      KarAdd(c+hsa, T3, hsa2-1);   }   else {      /* degenerate case */      zz_p *T;      T = stk; stk += hsa + sb - 1;      /* recursively compute b*a_hi into high part of c */      KarMul(c + hsa, a + hsa, sa - hsa, b, sb, stk);      /* recursively compute b*a_lo into T */      KarMul(T, a, hsa, b, sb, stk);      KarFix(c, T, hsa + sb - 1, hsa);   }}void KarMul_FP(zz_p *c, const zz_p *a, long sa, const zz_p *b, long sb, zz_p *stk){   if (sa < sb) {      { long t = sa; sa = sb; sb = t; }      { const zz_p *t = a; a = b; b = t; }   }   if (sb < KARX) {      PlainMul_FP(c, a, sa, b, sb);      return;   }   long hsa = (sa + 1) >> 1;   if (hsa < sb) {      /* normal case */      long hsa2 = hsa << 1;      zz_p *T1, *T2, *T3;      T1 = stk; stk += hsa;      T2 = stk; stk += hsa;      T3 = stk; stk += hsa2 - 1;      /* compute T1 = a_lo + a_hi */      KarFold(T1, a, sa, hsa);      /* compute T2 = b_lo + b_hi */      KarFold(T2, b, sb, hsa);      /* recursively compute T3 = T1 * T2 */      KarMul_FP(T3, T1, hsa, T2, hsa, stk);      /* recursively compute a_hi * b_hi into high part of c */      /* and subtract from T3 */      KarMul_FP(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk);      KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);      /* recursively compute a_lo*b_lo into low part of c */      /* and subtract from T3 */      KarMul_FP(c, a, hsa, b, hsa, stk);      KarSub(T3, c, hsa2 - 1);      clear(c[hsa2 - 1]);      /* finally, add T3 * X^{hsa} to c */      KarAdd(c+hsa, T3, hsa2-1);   }   else {      /* degenerate case */      zz_p *T;      T = stk; stk += hsa + sb - 1;      /* recursively compute b*a_hi into high part of c */      KarMul_FP(c + hsa, a + hsa, sa - hsa, b, sb, stk);      /* recursively compute b*a_lo into T */      KarMul_FP(T, a, hsa, b, sb, stk);      KarFix(c, T, hsa + sb - 1, hsa);   }}void PlainMul(zz_pX& c, const zz_pX& a, const zz_pX& b){   long sa = a.rep.length();   long sb = b.rep.length();   if (sa == 0 || sb == 0) {      clear(c);      return;   }   if (sa == 1) {      mul(c, b, a.rep[0]);      return;   }   if (sb == 1) {      mul(c, a, b.rep[0]);      return;   }   if (&a == &b) {      PlainSqr(c, a);      return;   }   vec_zz_p mem;   const zz_p *ap, *bp;   zz_p *cp;   if (&a == &c) {      mem = a.rep;      ap = mem.elts();   }   else      ap = a.rep.elts();   if (&b == &c) {      mem = b.rep;      bp = mem.elts();   }   else      bp = b.rep.elts();   c.rep.SetLength(sa+sb-1);   cp = c.rep.elts();   long p = zz_p::modulus();   long use_FP = ((p < NTL_SP_BOUND/KARX) &&                  (double(p)*double(p) < NTL_FDOUBLE_PRECISION/KARX));   if (sa < KARX || sb < KARX) {      if (use_FP) {         a_buf.SetLength(max(sa, sb));         b_buf.SetLength(max(sa, sb));         PlainMul_FP(cp, ap, sa, bp, sb);      }      else         PlainMul(cp, ap, sa, bp, sb);   }   else {      /* karatsuba */      long n, hn, sp;      n = max(sa, sb);      sp = 0;      do {         hn = (n+1) >> 1;         sp += (hn << 2) - 1;         n = hn;      } while (n >= KARX);      vec_zz_p stk;      stk.SetLength(sp);      if (use_FP) {         a_buf.SetLength(max(sa, sb));         b_buf.SetLength(max(sa, sb));         KarMul_FP(cp, ap, sa, bp, sb, stk.elts());      }      else         KarMul(cp, ap, sa, bp, sb, stk.elts());   }   c.normalize();}void PlainSqr_FP(zz_p *xp, const zz_p *aap, long sa){   if (sa == 0) return;   long da = sa-1;   long d = 2*da;   long i, j, jmin, jmax, m, m2;   double *ap = a_buf.elts();   for (i = 0; i < sa; i++) ap[i] = double(rep(aap[i]));   double accum;   long p = zz_p::modulus();   double pinv = zz_p::ModulusInverse();   for (i = 0; i <= d; i++) {      jmin = max(0, i-da);      jmax = min(da, i);      m = jmax - jmin + 1;      m2 = m >> 1;      jmax = jmin + m2 - 1;      accum = 0;      for (j = jmin; j <= jmax; j++) {         accum += ap[j]*ap[i-j];      }      accum += accum;      if (m & 1) {         accum += ap[jmax + 1]*ap[jmax + 1];      }      reduce(xp[i], accum, p, pinv);   }}void PlainSqr(zz_p *xp, const zz_p *ap, long sa){   if (sa == 0) return;   long i, j, k, cnt;   cnt = 2*sa-1;   for (i = 0; i < cnt; i++)      clear(xp[i]);   long p = zz_p::modulus();   double pinv = zz_p::ModulusInverse();   long t1, t2;   i = -1;   for (j = 0; j <= sa-2; j++) {      i += 2;      t1 = MulMod(rep(ap[j]), rep(ap[j]), p, pinv);      t2 = rep(xp[i-1]);      t2 = AddMod(t2, t2, p);      t2 = AddMod(t2, t1, p);      xp[i-1].LoopHole() = t2;      cnt = sa - 1 - j;      const zz_p *ap1 = ap+(j+1);      zz_p *xp1 = xp+i;      t1 = rep(ap[j]);      double tpinv = ((double) t1)*pinv;      for (k = 0; k < cnt; k++) {         t2 = MulMod2(rep(ap1[k]), t1, p, tpinv);         t2 = AddMod(t2, rep(xp1[k]), p);         xp1[k].LoopHole() = t2;      }      t2 = rep(*xp1);      t2 = AddMod(t2, t2, p);      (*xp1).LoopHole() = t2;   }   t1 = rep(ap[sa-1]);   t1 = MulMod(t1, t1, p, pinv);   xp[2*sa-2].LoopHole() = t1;}#define KARSX (30)void KarSqr(zz_p *c, const zz_p *a, long sa, zz_p *stk){   if (sa < KARSX) {      PlainSqr(c, a, sa);      return;   }   long hsa = (sa + 1) >> 1;   long hsa2 = hsa << 1;   zz_p *T1, *T2;   T1 = stk; stk += hsa;   T2 = stk; stk += hsa2-1;   KarFold(T1, a, sa, hsa);   KarSqr(T2, T1, hsa, stk);   KarSqr(c + hsa2, a+hsa, sa-hsa, stk);   KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);   KarSqr(c, a, hsa, stk);   KarSub(T2, c, hsa2 - 1);   clear(c[hsa2 - 1]);   KarAdd(c+hsa, T2, hsa2-1);}void KarSqr_FP(zz_p *c, const zz_p *a, long sa, zz_p *stk){   if (sa < KARSX) {      PlainSqr_FP(c, a, sa);      return;   }   long hsa = (sa + 1) >> 1;   long hsa2 = hsa << 1;   zz_p *T1, *T2;   T1 = stk; stk += hsa;   T2 = stk; stk += hsa2-1;   KarFold(T1, a, sa, hsa);   KarSqr_FP(T2, T1, hsa, stk);   KarSqr_FP(c + hsa2, a+hsa, sa-hsa, stk);   KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -