⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lzz_px.cpp

📁 数值算法库for Windows
💻 CPP
📖 第 1 页 / 共 5 页
字号:

#include <NTL/lzz_pX.h>
#include <NTL/vec_double.h>

#include <NTL/new.h>

NTL_START_IMPL



long zz_pX_mod_crossover[5] = {45, 45, 90, 180, 180};
long zz_pX_mul_crossover[5] = {90, 400, 600, 1500, 1500};
long zz_pX_newton_crossover[5] = {150, 150, 300, 700, 700};
long zz_pX_div_crossover[5] = {180, 180, 350, 750, 750};
long zz_pX_halfgcd_crossover[5] = {90, 90, 180, 350, 350};
long zz_pX_gcd_crossover[5] = {400, 400, 800, 1400, 1400};
long zz_pX_bermass_crossover[5] = {400, 480, 900, 1600, 1600};
long zz_pX_trace_crossover[5] = {200, 350, 450, 800, 800};

#define QUICK_CRT (NTL_DOUBLE_PRECISION - NTL_SP_NBITS > 10)



const zz_pX& zz_pX::zero()
{
   static zz_pX z;
   return z;
}



istream& operator>>(istream& s, zz_pX& x)
{
   s >> x.rep;
   x.normalize();
   return s;
}

ostream& operator<<(ostream& s, const zz_pX& a)
{
   return s << a.rep;
}


void zz_pX::normalize()
{
   long n;
   const zz_p* p;

   n = rep.length();
   if (n == 0) return;
   p = rep.elts() + (n-1);
   while (n > 0 && IsZero(*p)) {
      p--; 
      n--;
   }
   rep.SetLength(n);
}


long IsZero(const zz_pX& a)
{
   return a.rep.length() == 0;
}


long IsOne(const zz_pX& a)
{
    return a.rep.length() == 1 && IsOne(a.rep[0]);
}

void GetCoeff(zz_p& x, const zz_pX& a, long i)
{
   if (i < 0 || i > deg(a))
      clear(x);
   else
      x = a.rep[i];
}

void SetCoeff(zz_pX& x, long i, zz_p a)
{
   long j, m;

   if (i < 0) 
      Error("SetCoeff: negative index");

   if (i >= (1L << (NTL_BITS_PER_LONG-4)))
      Error("overflow in SetCoeff");

   m = deg(x);

   if (i > m) {
      x.rep.SetLength(i+1);
      for (j = m+1; j < i; j++)
         clear(x.rep[j]);
   }
   x.rep[i] = a;
   x.normalize();
}

void SetCoeff(zz_pX& x, long i, long a)
{
   if (a == 1)
      SetCoeff(x, i);
   else
      SetCoeff(x, i, to_zz_p(a));
}

void SetCoeff(zz_pX& x, long i)
{
   long j, m;

   if (i < 0) 
      Error("coefficient index out of range");

   if (i >= (1L << (NTL_BITS_PER_LONG-4)))
      Error("overflow in SetCoeff");

   m = deg(x);

   if (i > m) {
      x.rep.SetLength(i+1);
      for (j = m+1; j < i; j++)
         clear(x.rep[j]);
   }
   set(x.rep[i]);
   x.normalize();
}


void SetX(zz_pX& x)
{
   clear(x);
   SetCoeff(x, 1);
}


long IsX(const zz_pX& a)
{
   return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));
}
      
      

zz_p coeff(const zz_pX& a, long i)
{
   if (i < 0 || i > deg(a))
      return zz_p::zero();
   else
      return a.rep[i];
}


zz_p LeadCoeff(const zz_pX& a)
{
   if (IsZero(a))
      return zz_p::zero();
   else
      return a.rep[deg(a)];
}

zz_p ConstTerm(const zz_pX& a)
{
   if (IsZero(a))
      return zz_p::zero();
   else
      return a.rep[0];
}



void conv(zz_pX& x, zz_p a)
{
   if (IsZero(a))
      x.rep.SetLength(0);
   else {
      x.rep.SetLength(1);
      x.rep[0] = a;
   }
}

void conv(zz_pX& x, long a)
{
   if (a == 0) {
      x.rep.SetLength(0);
      return;
   }
   
   zz_p t;

   conv(t, a);
   conv(x, t);
}

void conv(zz_pX& x, const ZZ& a)
{
   if (a == 0) {
      x.rep.SetLength(0);
      return;
   }
   
   zz_p t;

   conv(t, a);
   conv(x, t);
}


void conv(zz_pX& x, const vec_zz_p& a)
{
   x.rep = a;
   x.normalize();
}


void add(zz_pX& x, const zz_pX& a, const zz_pX& b)
{
   long da = deg(a);
   long db = deg(b);
   long minab = min(da, db);
   long maxab = max(da, db);
   x.rep.SetLength(maxab+1);

   long i;
   const zz_p *ap, *bp; 
   zz_p* xp;

   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
        i; i--, ap++, bp++, xp++)
      add(*xp, (*ap), (*bp));

   if (da > minab && &x != &a)
      for (i = da-minab; i; i--, xp++, ap++)
         *xp = *ap;
   else if (db > minab && &x != &b)
      for (i = db-minab; i; i--, xp++, bp++)
         *xp = *bp;
   else
      x.normalize();
}

void add(zz_pX& x, const zz_pX& a, zz_p b)
{
   if (a.rep.length() == 0) {
      conv(x, b);
   }
   else {
      if (&x != &a) x = a;
      add(x.rep[0], x.rep[0], b);
      x.normalize();
   }
}


void sub(zz_pX& x, const zz_pX& a, const zz_pX& b)
{
   long da = deg(a);
   long db = deg(b);
   long minab = min(da, db);
   long maxab = max(da, db);
   x.rep.SetLength(maxab+1);

   long i;
   const zz_p *ap, *bp; 
   zz_p* xp;

   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
        i; i--, ap++, bp++, xp++)
      sub(*xp, (*ap), (*bp));

   if (da > minab && &x != &a)
      for (i = da-minab; i; i--, xp++, ap++)
         *xp = *ap;
   else if (db > minab)
      for (i = db-minab; i; i--, xp++, bp++)
         negate(*xp, *bp);
   else
      x.normalize();

}

void sub(zz_pX& x, const zz_pX& a, zz_p b)
{
   if (a.rep.length() == 0) {
      x.rep.SetLength(1);
      negate(x.rep[0], b);
   }
   else {
      if (&x != &a) x = a;
      sub(x.rep[0], x.rep[0], b);
   }
   x.normalize();
}

void sub(zz_pX& x, zz_p a, const zz_pX& b)
{
   negate(x, b);
   add(x, x, a);
}

void negate(zz_pX& x, const zz_pX& a)
{
   long n = a.rep.length();
   x.rep.SetLength(n);

   const zz_p* ap = a.rep.elts();
   zz_p* xp = x.rep.elts();
   long i;

   for (i = n; i; i--, ap++, xp++)
      negate((*xp), (*ap));
}

void mul(zz_pX& x, const zz_pX& a, const zz_pX& b)
{
   if (&a == &b) {
      sqr(x, a);
      return;
   }

   if (deg(a) > NTL_zz_pX_MUL_CROSSOVER && deg(b) > NTL_zz_pX_MUL_CROSSOVER)
      FFTMul(x, a, b);
   else
      PlainMul(x, a, b);
}

void sqr(zz_pX& x, const zz_pX& a)
{
   if (deg(a) > NTL_zz_pX_MUL_CROSSOVER)
      FFTSqr(x, a);
   else
      PlainSqr(x, a);
}

/* "plain" multiplication and squaring actually incorporates Karatsuba */

void PlainMul(zz_p *xp, const zz_p *ap, long sa, const zz_p *bp, long sb)
{
   if (sa == 0 || sb == 0) return;

   long sx = sa+sb-1;


   if (sa < sb) {
      { long t = sa; sa = sb; sb = t; }
      { const zz_p *t = ap; ap = bp; bp = t; }
   }

   long i, j;

   for (i = 0; i < sx; i++)
      clear(xp[i]);

   long p = zz_p::modulus();
   double pinv = zz_p::ModulusInverse();

   for (i = 0; i < sb; i++) {
      long t1 = rep(bp[i]);
      double bpinv = ((double) t1)*pinv;
      zz_p *xp1 = xp+i;
      for (j = 0; j < sa; j++) {
         long t2;
         t2 = MulMod2(rep(ap[j]), t1, p, bpinv);
         xp1[j].LoopHole() = AddMod(t2, rep(xp1[j]), p);
      }
   }
}

static vec_double a_buf, b_buf;

inline void reduce(zz_p& r, double x, long p, double pinv)
{
   long rr = long(x - double(p)*double(long(x*pinv)));
   if (rr < 0) rr += p;
   if (rr >= p) rr -= p;

   r.LoopHole() = rr;
}

void PlainMul_FP(zz_p *xp, const zz_p *aap, long sa, const zz_p *bbp, long sb)
{
   if (sa == 0 || sb == 0) return;

   double *ap = a_buf.elts();
   double *bp = b_buf.elts();

   long d = sa+sb-2;

   long i, j, jmin, jmax;

   for (i = 0; i < sa; i++) ap[i] = double(rep(aap[i]));
   for (i = 0; i < sb; i++) bp[i] = double(rep(bbp[i]));

   double accum;

   long p = zz_p::modulus();
   double pinv = zz_p::ModulusInverse();

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-(sb-1));
      jmax = min((sa-1), i);
      accum = 0;
      for (j = jmin; j <= jmax; j++) {
         accum += ap[j]*bp[i-j];
      }
      reduce(xp[i], accum, p, pinv);
   }
}

#define KARX (16)

void KarFold(zz_p *T, const zz_p *b, long sb, long hsa)
{
   long m = sb - hsa;
   long i;

   for (i = 0; i < m; i++)
      add(T[i], b[i], b[hsa+i]);

   for (i = m; i < hsa; i++)
      T[i] = b[i];
}

void KarSub(zz_p *T, const zz_p *b, long sb)
{
   long i;

   for (i = 0; i < sb; i++)
      sub(T[i], T[i], b[i]);
}

void KarAdd(zz_p *T, const zz_p *b, long sb)
{
   long i;

   for (i = 0; i < sb; i++)
      add(T[i], T[i], b[i]);
}

void KarFix(zz_p *c, const zz_p *b, long sb, long hsa)
{
   long i;

   for (i = 0; i < hsa; i++)
      c[i] = b[i];

   for (i = hsa; i < sb; i++)
      add(c[i], c[i], b[i]);
}


void KarMul(zz_p *c, const zz_p *a, long sa, const zz_p *b, long sb, zz_p *stk)
{
   if (sa < sb) {
      { long t = sa; sa = sb; sb = t; }
      { const zz_p *t = a; a = b; b = t; }
   }

   if (sb < KARX) {
      PlainMul(c, a, sa, b, sb);
      return;
   }

   long hsa = (sa + 1) >> 1;

   if (hsa < sb) {
      /* normal case */

      long hsa2 = hsa << 1;

      zz_p *T1, *T2, *T3;

      T1 = stk; stk += hsa;
      T2 = stk; stk += hsa;
      T3 = stk; stk += hsa2 - 1;

      /* compute T1 = a_lo + a_hi */

      KarFold(T1, a, sa, hsa);

      /* compute T2 = b_lo + b_hi */

      KarFold(T2, b, sb, hsa);

      /* recursively compute T3 = T1 * T2 */

      KarMul(T3, T1, hsa, T2, hsa, stk);

      /* recursively compute a_hi * b_hi into high part of c */
      /* and subtract from T3 */

      KarMul(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk);
      KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);


      /* recursively compute a_lo*b_lo into low part of c */
      /* and subtract from T3 */

      KarMul(c, a, hsa, b, hsa, stk);
      KarSub(T3, c, hsa2 - 1);

      clear(c[hsa2 - 1]);

      /* finally, add T3 * X^{hsa} to c */

      KarAdd(c+hsa, T3, hsa2-1);
   }
   else {
      /* degenerate case */

      zz_p *T;

      T = stk; stk += hsa + sb - 1;

      /* recursively compute b*a_hi into high part of c */

      KarMul(c + hsa, a + hsa, sa - hsa, b, sb, stk);

      /* recursively compute b*a_lo into T */

      KarMul(T, a, hsa, b, sb, stk);

      KarFix(c, T, hsa + sb - 1, hsa);
   }
}

void KarMul_FP(zz_p *c, const zz_p *a, long sa, const zz_p *b, long sb, zz_p *stk)
{
   if (sa < sb) {
      { long t = sa; sa = sb; sb = t; }
      { const zz_p *t = a; a = b; b = t; }
   }

   if (sb < KARX) {
      PlainMul_FP(c, a, sa, b, sb);
      return;
   }

   long hsa = (sa + 1) >> 1;

   if (hsa < sb) {
      /* normal case */

      long hsa2 = hsa << 1;

      zz_p *T1, *T2, *T3;

      T1 = stk; stk += hsa;
      T2 = stk; stk += hsa;
      T3 = stk; stk += hsa2 - 1;

      /* compute T1 = a_lo + a_hi */

      KarFold(T1, a, sa, hsa);

      /* compute T2 = b_lo + b_hi */

      KarFold(T2, b, sb, hsa);

      /* recursively compute T3 = T1 * T2 */

      KarMul_FP(T3, T1, hsa, T2, hsa, stk);

      /* recursively compute a_hi * b_hi into high part of c */
      /* and subtract from T3 */

      KarMul_FP(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk);
      KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);


      /* recursively compute a_lo*b_lo into low part of c */
      /* and subtract from T3 */

      KarMul_FP(c, a, hsa, b, hsa, stk);
      KarSub(T3, c, hsa2 - 1);

      clear(c[hsa2 - 1]);

      /* finally, add T3 * X^{hsa} to c */

      KarAdd(c+hsa, T3, hsa2-1);
   }
   else {
      /* degenerate case */

      zz_p *T;

      T = stk; stk += hsa + sb - 1;

      /* recursively compute b*a_hi into high part of c */

      KarMul_FP(c + hsa, a + hsa, sa - hsa, b, sb, stk);

      /* recursively compute b*a_lo into T */

      KarMul_FP(T, a, hsa, b, sb, stk);

      KarFix(c, T, hsa + sb - 1, hsa);
   }
}


void PlainMul(zz_pX& c, const zz_pX& a, const zz_pX& b)
{
   long sa = a.rep.length();
   long sb = b.rep.length();

   if (sa == 0 || sb == 0) {
      clear(c);
      return;
   }

   if (sa == 1) {
      mul(c, b, a.rep[0]);
      return;
   }

   if (sb == 1) {
      mul(c, a, b.rep[0]);
      return;
   }

   if (&a == &b) {
      PlainSqr(c, a);
      return;
   }

   vec_zz_p mem;

   const zz_p *ap, *bp;
   zz_p *cp;

   if (&a == &c) {
      mem = a.rep;
      ap = mem.elts();
   }
   else
      ap = a.rep.elts();

   if (&b == &c) {
      mem = b.rep;
      bp = mem.elts();
   }
   else
      bp = b.rep.elts();

   c.rep.SetLength(sa+sb-1);
   cp = c.rep.elts();

   long p = zz_p::modulus();
   long use_FP = ((p < NTL_SP_BOUND/KARX) && 
                 (double(p)*double(p) < NTL_FDOUBLE_PRECISION/KARX));

   if (sa < KARX || sb < KARX) {
      if (use_FP) {
         a_buf.SetLength(max(sa, sb));
         b_buf.SetLength(max(sa, sb));

         PlainMul_FP(cp, ap, sa, bp, sb);
      }
      else
         PlainMul(cp, ap, sa, bp, sb);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -