⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lzz_px.cpp

📁 数值算法库for Windows
💻 CPP
📖 第 1 页 / 共 5 页
字号:
   }
   else {
      /* karatsuba */

      long n, hn, sp;

      n = max(sa, sb);
      sp = 0;
      do {
         hn = (n+1) >> 1;
         sp += (hn << 2) - 1;
         n = hn;
      } while (n >= KARX);

      vec_zz_p stk;
      stk.SetLength(sp);

      if (use_FP) {
         a_buf.SetLength(max(sa, sb));
         b_buf.SetLength(max(sa, sb));
         KarMul_FP(cp, ap, sa, bp, sb, stk.elts());
      }
      else
         KarMul(cp, ap, sa, bp, sb, stk.elts());
   }

   c.normalize();
}

void PlainSqr_FP(zz_p *xp, const zz_p *aap, long sa)
{
   if (sa == 0) return;

   long da = sa-1;
   long d = 2*da;

   long i, j, jmin, jmax, m, m2;

   double *ap = a_buf.elts();

   for (i = 0; i < sa; i++) ap[i] = double(rep(aap[i]));

   double accum;
   long p = zz_p::modulus();
   double pinv = zz_p::ModulusInverse();

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-da);
      jmax = min(da, i);
      m = jmax - jmin + 1;
      m2 = m >> 1;
      jmax = jmin + m2 - 1;
      accum = 0;
      for (j = jmin; j <= jmax; j++) {
         accum += ap[j]*ap[i-j];
      }
      accum += accum;
      if (m & 1) {
         accum += ap[jmax + 1]*ap[jmax + 1];
      }

      reduce(xp[i], accum, p, pinv);
   }
}


void PlainSqr(zz_p *xp, const zz_p *ap, long sa)
{
   if (sa == 0) return;

   long i, j, k, cnt;

   cnt = 2*sa-1;
   for (i = 0; i < cnt; i++)
      clear(xp[i]);

   long p = zz_p::modulus();
   double pinv = zz_p::ModulusInverse();
   long t1, t2;

   i = -1;
   for (j = 0; j <= sa-2; j++) {
      i += 2;

      t1 = MulMod(rep(ap[j]), rep(ap[j]), p, pinv);
      t2 = rep(xp[i-1]);
      t2 = AddMod(t2, t2, p);
      t2 = AddMod(t2, t1, p);
      xp[i-1].LoopHole() = t2;

      cnt = sa - 1 - j;
      const zz_p *ap1 = ap+(j+1);
      zz_p *xp1 = xp+i;
      t1 = rep(ap[j]);
      double tpinv = ((double) t1)*pinv;

      for (k = 0; k < cnt; k++) {
         t2 = MulMod2(rep(ap1[k]), t1, p, tpinv);
         t2 = AddMod(t2, rep(xp1[k]), p);
         xp1[k].LoopHole() = t2;
      }
      t2 = rep(*xp1);
      t2 = AddMod(t2, t2, p);
      (*xp1).LoopHole() = t2;
   }


   t1 = rep(ap[sa-1]);
   t1 = MulMod(t1, t1, p, pinv);
   xp[2*sa-2].LoopHole() = t1;
}

#define KARSX (30)

void KarSqr(zz_p *c, const zz_p *a, long sa, zz_p *stk)
{
   if (sa < KARSX) {
      PlainSqr(c, a, sa);
      return;
   }

   long hsa = (sa + 1) >> 1;
   long hsa2 = hsa << 1;

   zz_p *T1, *T2;

   T1 = stk; stk += hsa;
   T2 = stk; stk += hsa2-1;

   KarFold(T1, a, sa, hsa);
   KarSqr(T2, T1, hsa, stk);


   KarSqr(c + hsa2, a+hsa, sa-hsa, stk);
   KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);


   KarSqr(c, a, hsa, stk);
   KarSub(T2, c, hsa2 - 1);

   clear(c[hsa2 - 1]);

   KarAdd(c+hsa, T2, hsa2-1);
}

void KarSqr_FP(zz_p *c, const zz_p *a, long sa, zz_p *stk)
{
   if (sa < KARSX) {
      PlainSqr_FP(c, a, sa);
      return;
   }

   long hsa = (sa + 1) >> 1;
   long hsa2 = hsa << 1;

   zz_p *T1, *T2;

   T1 = stk; stk += hsa;
   T2 = stk; stk += hsa2-1;

   KarFold(T1, a, sa, hsa);
   KarSqr_FP(T2, T1, hsa, stk);


   KarSqr_FP(c + hsa2, a+hsa, sa-hsa, stk);
   KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);


   KarSqr_FP(c, a, hsa, stk);
   KarSub(T2, c, hsa2 - 1);

   clear(c[hsa2 - 1]);

   KarAdd(c+hsa, T2, hsa2-1);
}

void PlainSqr(zz_pX& c, const zz_pX& a)
{
   if (IsZero(a)) {
      clear(c);
      return;
   }

   vec_zz_p mem;

   const zz_p *ap;
   zz_p *cp;

   long sa = a.rep.length();

   if (&a == &c) {
      mem = a.rep;
      ap = mem.elts();
   }
   else
      ap = a.rep.elts();

   c.rep.SetLength(2*sa-1);
   cp = c.rep.elts();

   long p = zz_p::modulus();
   long use_FP = ((p < NTL_SP_BOUND/KARSX) && 
                 (double(p)*double(p) < NTL_FDOUBLE_PRECISION/KARSX));

   if (sa < KARSX) {
      if (use_FP) {
         a_buf.SetLength(sa);
         PlainSqr_FP(cp, ap, sa);
      }
      else
         PlainSqr(cp, ap, sa);
   }
   else {
      /* karatsuba */

      long n, hn, sp;

      n = sa;
      sp = 0;
      do {
         hn = (n+1) >> 1;
         sp += hn+hn+hn - 1;
         n = hn;
      } while (n >= KARSX);

      vec_zz_p stk;
      stk.SetLength(sp);

      if (use_FP) {
         a_buf.SetLength(sa);
         KarSqr_FP(cp, ap, sa, stk.elts());
      }
      else
         KarSqr(cp, ap, sa, stk.elts());
   }

   c.normalize();
}


void PlainDivRem(zz_pX& q, zz_pX& r, const zz_pX& a, const zz_pX& b)
{
   long da, db, dq, i, j, LCIsOne;
   const zz_p *bp;
   zz_p *qp;
   zz_p *xp;


   zz_p LCInv, t;
   zz_p s;

   da = deg(a);
   db = deg(b);

   if (db < 0) Error("zz_pX: division by zero");

   if (da < db) {
      r = a;
      clear(q);
      return;
   }

   zz_pX lb;

   if (&q == &b) {
      lb = b;
      bp = lb.rep.elts();
   }
   else
      bp = b.rep.elts();

   if (IsOne(bp[db]))
      LCIsOne = 1;
   else {
      LCIsOne = 0;
      inv(LCInv, bp[db]);
   }

   vec_zz_p x;
   if (&r == &a)
      xp = r.rep.elts();
   else {
      x = a.rep;
      xp = x.elts();
   }

   dq = da - db;
   q.rep.SetLength(dq+1);
   qp = q.rep.elts();

   long p = zz_p::modulus();
   double pinv = zz_p::ModulusInverse();

   for (i = dq; i >= 0; i--) {
      t = xp[i+db];
      if (!LCIsOne)
	 mul(t, t, LCInv);
      qp[i] = t;
      negate(t, t);

      long T = rep(t);
      double Tpinv = ((double) T)*pinv;

      for (j = db-1; j >= 0; j--) {
         long S = MulMod2(rep(bp[j]), T, p, Tpinv);
         S = AddMod(S, rep(xp[i+j]), p);
         xp[i+j].LoopHole() = S;
      }
   }

   r.rep.SetLength(db);
   if (&r != &a) {
      for (i = 0; i < db; i++)
         r.rep[i] = xp[i];
   }
   r.normalize();
}

void PlainDiv(zz_pX& q, const zz_pX& a, const zz_pX& b)
{
   long da, db, dq, i, j, LCIsOne;
   const zz_p *bp;
   zz_p *qp;
   zz_p *xp;


   zz_p LCInv, t;
   zz_p s;

   da = deg(a);
   db = deg(b);

   if (db < 0) Error("zz_pX: division by zero");

   if (da < db) {
      clear(q);
      return;
   }

   zz_pX lb;

   if (&q == &b) {
      lb = b;
      bp = lb.rep.elts();
   }
   else
      bp = b.rep.elts();

   if (IsOne(bp[db]))
      LCIsOne = 1;
   else {
      LCIsOne = 0;
      inv(LCInv, bp[db]);
   }

   vec_zz_p x;
   x.SetLength(da+1-db);
   for (i = db; i <= da; i++)
      x[i-db] = a.rep[i];

   xp = x.elts();



   dq = da - db;
   q.rep.SetLength(dq+1);
   qp = q.rep.elts();

   long p = zz_p::modulus();
   double pinv = zz_p::ModulusInverse();

   for (i = dq; i >= 0; i--) {
      t = xp[i];
      if (!LCIsOne)
	 mul(t, t, LCInv);
      qp[i] = t;
      negate(t, t);

      long T = rep(t);
      double Tpinv = ((double) T)*pinv;

      long lastj = max(0, db-i);

      for (j = db-1; j >= lastj; j--) {
         long S = MulMod2(rep(bp[j]), T, p, Tpinv);
         S = AddMod(S, rep(xp[i+j-db]), p);
         xp[i+j-db].LoopHole() = S;
      }
   }
}


void PlainRem(zz_pX& r, const zz_pX& a, const zz_pX& b)
{
   long da, db, dq, i, j, LCIsOne;
   const zz_p *bp;
   zz_p *xp;


   zz_p LCInv, t;
   zz_p s;

   da = deg(a);
   db = deg(b);

   if (db < 0) Error("zz_pX: division by zero");

   if (da < db) {
      r = a;
      return;
   }

   bp = b.rep.elts();

   if (IsOne(bp[db]))
      LCIsOne = 1;
   else {
      LCIsOne = 0;
      inv(LCInv, bp[db]);
   }

   vec_zz_p x;

   if (&r == &a)
      xp = r.rep.elts();
   else {
      x = a.rep;
      xp = x.elts();
   }

   dq = da - db;

   long p = zz_p::modulus();
   double pinv = zz_p::ModulusInverse();

   for (i = dq; i >= 0; i--) {
      t = xp[i+db];
      if (!LCIsOne)
	 mul(t, t, LCInv);
      negate(t, t);

      long T = rep(t);
      double Tpinv = ((double) T)*pinv;

      for (j = db-1; j >= 0; j--) {
         long S = MulMod2(rep(bp[j]), T, p, Tpinv);
         S = AddMod(S, rep(xp[i+j]), p);
         xp[i+j].LoopHole() = S;
      }
   }

   r.rep.SetLength(db);
   if (&r != &a) {
      for (i = 0; i < db; i++)
         r.rep[i] = xp[i];
   }
   r.normalize();
}


void mul(zz_pX& x, const zz_pX& a, zz_p b)
{
   if (IsZero(b)) {
      clear(x);
      return;
   }

   if (IsOne(b)) {
      x = a;
      return;
   }

   long i, da;

   const zz_p *ap;
   zz_p* xp;

   long t;
   t = rep(b);
   long p = zz_p::modulus();
   double pinv = zz_p::ModulusInverse();
   double bpinv = t*pinv;

   da = deg(a);
   x.rep.SetLength(da+1);
   ap = a.rep.elts();
   xp = x.rep.elts();

   for (i = 0; i <= da; i++) 
      xp[i].LoopHole() = MulMod2(rep(ap[i]), t, p, bpinv);

   x.normalize();
}



void PlainGCD(zz_pX& x, const zz_pX& a, const zz_pX& b)
{
   zz_p t;

   if (IsZero(b))
      x = a;
   else if (IsZero(a))
      x = b;
   else {
      long n = max(deg(a),deg(b)) + 1;
      zz_pX u(INIT_SIZE, n), v(INIT_SIZE, n);

      u = a;
      v = b;
      do {
         PlainRem(u, u, v);
         swap(u, v);
      } while (!IsZero(v));

      x = u;
   }

   if (IsZero(x)) return;
   if (IsOne(LeadCoeff(x))) return;

   /* make gcd monic */


   inv(t, LeadCoeff(x)); 
   mul(x, x, t); 
}



         

void PlainXGCD(zz_pX& d, zz_pX& s, zz_pX& t, const zz_pX& a, const zz_pX& b)
{
   zz_p z;


   if (IsZero(b)) {
      set(s);
      clear(t);
      d = a;
   }
   else if (IsZero(a)) {
      clear(s);
      set(t);
      d = b;
   }
   else {
      long e = max(deg(a), deg(b)) + 1;

      zz_pX temp(INIT_SIZE, e), u(INIT_SIZE, e), v(INIT_SIZE, e), u0(INIT_SIZE, e), v0(INIT_SIZE, e), 
            u1(INIT_SIZE, e), v1(INIT_SIZE, e), u2(INIT_SIZE, e), v2(INIT_SIZE, e), q(INIT_SIZE, e);


      set(u1); clear(v1);
      clear(u2); set(v2);
      u = a; v = b;

      do {
         DivRem(q, u, u, v);
         swap(u, v);
         u0 = u2;
         v0 = v2;
         mul(temp, q, u2);
         sub(u2, u1, temp);
         mul(temp, q, v2);
         sub(v2, v1, temp);
         u1 = u0;
         v1 = v0;
      } while (!IsZero(v));

      d = u;
      s = u1;
      t = v1;
   }

   if (IsZero(d)) return;
   if (IsOne(LeadCoeff(d))) return;

   /* make gcd monic */

   inv(z, LeadCoeff(d));
   mul(d, d, z);
   mul(s, s, z);
   mul(t, t, z);
}


void MulMod(zz_pX& x, const zz_pX& a, const zz_pX& b, const zz_pX& f)
{
   if (deg(a) >= deg(f) || deg(b) >= deg(f) || deg(f) == 0) 
      Error("MulMod: bad args");

   zz_pX t;

   mul(t, a, b);
   rem(x, t, f);
}

void SqrMod(zz_pX& x, const zz_pX& a, const zz_pX& f)
{
   if (deg(a) >= deg(f) || deg(f) == 0) Error("SqrMod: bad args");

   zz_pX t;

   sqr(t, a);
   rem(x, t, f);
}


void InvMod(zz_pX& x, const zz_pX& a, const zz_pX& f)
{
   if (deg(a) >= deg(f) || deg(f) == 0) Error("InvMod: bad args");

   zz_pX d, t;

   XGCD(d, x, t, a, f);
   if (!IsOne(d))
      Error("zz_pX InvMod: can't compute multiplicative inverse");
}

long InvModStatus(zz_pX& x, const zz_pX& a, const zz_pX& f)
{
   if (deg(a) >= deg(f) || deg(f) == 0) Error("InvModStatus: bad args");

   zz_pX d, t;

   XGCD(d, x, t, a, f);
   if (!IsOne(d)) {
      x = d;
      return 1;
   }
   else
      return 0;
}




static
void MulByXModAux(zz_pX& h, const zz_pX& a, const zz_pX& f)
{
   long i, n, m;
   zz_p* hh;
   const zz_p *aa, *ff;

   zz_p t, z;

   n = deg(f);
   m = deg(a);

   if (m >= n || n == 0) Error("MulByXMod: bad args");

   if (m < 0) {
      clear(h);
      return;
   }

   if (m < n-1) {
      h.rep.SetLength(m+2);
      hh = h.rep.elts();
      aa = a.rep.elts();

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -