⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 zz_px.cpp

📁 数值算法库for Windows
💻 CPP
📖 第 1 页 / 共 4 页
字号:

#include <NTL/ZZ_pX.h>


// The mul & sqr routines use routines from ZZX, 
// which is faster for small degree polynomials.
// Define this macro to revert to old strategy.


#ifndef NTL_OLD_ZZ_pX_MUL

#include <NTL/ZZX.h>

#endif

#include <NTL/new.h>



NTL_START_IMPL




const ZZ_pX& ZZ_pX::zero()
{
   static ZZ_pX z;
   return z;
}


ZZ_pX& ZZ_pX::operator=(long a)
{
   conv(*this, a);
   return *this;
}


ZZ_pX& ZZ_pX::operator=(const ZZ_p& a)
{
   conv(*this, a);
   return *this;
}


istream& operator>>(istream& s, ZZ_pX& x)
{
   s >> x.rep;
   x.normalize();
   return s;
}

ostream& operator<<(ostream& s, const ZZ_pX& a)
{
   return s << a.rep;
}


void ZZ_pX::normalize()
{
   long n;
   const ZZ_p* p;

   n = rep.length();
   if (n == 0) return;
   p = rep.elts() + (n-1);
   while (n > 0 && IsZero(*p)) {
      p--; 
      n--;
   }
   rep.SetLength(n);
}


long IsZero(const ZZ_pX& a)
{
   return a.rep.length() == 0;
}


long IsOne(const ZZ_pX& a)
{
    return a.rep.length() == 1 && IsOne(a.rep[0]);
}

void GetCoeff(ZZ_p& x, const ZZ_pX& a, long i)
{
   if (i < 0 || i > deg(a))
      clear(x);
   else
      x = a.rep[i];
}

void SetCoeff(ZZ_pX& x, long i, const ZZ_p& a)
{
   long j, m;

   if (i < 0) 
      Error("SetCoeff: negative index");

   if (i >= (1L << (NTL_BITS_PER_LONG-4)))
      Error("overflow in SetCoeff");

   m = deg(x);

   if (i > m) {
      long pos = x.rep.position(a);
      x.rep.SetLength(i+1);

      if (pos != -1)
         x.rep[i] = x.rep.RawGet(pos);
      else
         x.rep[i] = a;

      for (j = m+1; j < i; j++)
         clear(x.rep[j]);
   }
   else
      x.rep[i] = a;

   x.normalize();
}

void SetCoeff(ZZ_pX& x, long i, long a)
{
   if (a == 1) 
      SetCoeff(x, i);
   else {
      ZZ_pTemp TT;  ZZ_p& T = TT.val(); 
      conv(T, a);
      SetCoeff(x, i, T);
   }
}

void SetCoeff(ZZ_pX& x, long i)
{
   long j, m;

   if (i < 0) 
      Error("coefficient index out of range");

   if (i >= (1L << (NTL_BITS_PER_LONG-4)))
      Error("overflow in SetCoeff");

   m = deg(x);

   if (i > m) {
      x.rep.SetLength(i+1);
      for (j = m+1; j < i; j++)
         clear(x.rep[j]);
   }
   set(x.rep[i]);
   x.normalize();
}


void SetX(ZZ_pX& x)
{
   clear(x);
   SetCoeff(x, 1);
}


long IsX(const ZZ_pX& a)
{
   return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));
}
      
      

const ZZ_p& coeff(const ZZ_pX& a, long i)
{
   if (i < 0 || i > deg(a))
      return ZZ_p::zero();
   else
      return a.rep[i];
}


const ZZ_p& LeadCoeff(const ZZ_pX& a)
{
   if (IsZero(a))
      return ZZ_p::zero();
   else
      return a.rep[deg(a)];
}

const ZZ_p& ConstTerm(const ZZ_pX& a)
{
   if (IsZero(a))
      return ZZ_p::zero();
   else
      return a.rep[0];
}



void conv(ZZ_pX& x, const ZZ_p& a)
{
   if (IsZero(a))
      x.rep.SetLength(0);
   else {
      x.rep.SetLength(1);
      x.rep[0] = a;

      // note: if a aliases x.rep[i], i > 0, this code
      //       will still work, since is is assumed that
      //       SetLength(1) will not relocate or destroy x.rep[i]
   }
}

void conv(ZZ_pX& x, long a)
{
   if (a == 0)
      clear(x);
   else if (a == 1)
      set(x);
   else {
      ZZ_pTemp TT; ZZ_p& T = TT.val();
      conv(T, a);
      conv(x, T);
   }
}

void conv(ZZ_pX& x, const ZZ& a)
{
   if (IsZero(a))
      clear(x);
   else {
      ZZ_pTemp TT; ZZ_p& T = TT.val();
      conv(T, a);
      conv(x, T);
   }
}

void conv(ZZ_pX& x, const vec_ZZ_p& a)
{
   x.rep = a;
   x.normalize();
}


void add(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b)
{
   long da = deg(a);
   long db = deg(b);
   long minab = min(da, db);
   long maxab = max(da, db);
   x.rep.SetLength(maxab+1);

   long i;
   const ZZ_p *ap, *bp; 
   ZZ_p* xp;

   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
        i; i--, ap++, bp++, xp++)
      add(*xp, (*ap), (*bp));

   if (da > minab && &x != &a)
      for (i = da-minab; i; i--, xp++, ap++)
         *xp = *ap;
   else if (db > minab && &x != &b)
      for (i = db-minab; i; i--, xp++, bp++)
         *xp = *bp;
   else
      x.normalize();
}


void add(ZZ_pX& x, const ZZ_pX& a, const ZZ_p& b)
{
   long n = a.rep.length();
   if (n == 0) {
      conv(x, b);
   }
   else if (&x == &a) {
      add(x.rep[0], a.rep[0], b);
      x.normalize();
   }
   else if (x.rep.MaxLength() == 0) {
      x = a;
      add(x.rep[0], a.rep[0], b);
      x.normalize();
   }
   else {
      // ugly...b could alias a coeff of x

      ZZ_p *xp = x.rep.elts();
      add(xp[0], a.rep[0], b);
      x.rep.SetLength(n);
      xp = x.rep.elts();
      const ZZ_p *ap = a.rep.elts();
      long i;
      for (i = 1; i < n; i++)
         xp[i] = ap[i];
      x.normalize();
   }
}

void add(ZZ_pX& x, const ZZ_pX& a, long b)
{
   if (a.rep.length() == 0) {
      conv(x, b);
   }
   else {
      if (&x != &a) x = a;
      add(x.rep[0], x.rep[0], b);
      x.normalize();
   }
}

void sub(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b)
{
   long da = deg(a);
   long db = deg(b);
   long minab = min(da, db);
   long maxab = max(da, db);
   x.rep.SetLength(maxab+1);

   long i;
   const ZZ_p *ap, *bp; 
   ZZ_p* xp;

   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
        i; i--, ap++, bp++, xp++)
      sub(*xp, (*ap), (*bp));

   if (da > minab && &x != &a)
      for (i = da-minab; i; i--, xp++, ap++)
         *xp = *ap;
   else if (db > minab)
      for (i = db-minab; i; i--, xp++, bp++)
         negate(*xp, *bp);
   else
      x.normalize();

}

void sub(ZZ_pX& x, const ZZ_pX& a, const ZZ_p& b)
{
   long n = a.rep.length();
   if (n == 0) {
      conv(x, b);
      negate(x, x);
   }
   else if (&x == &a) {
      sub(x.rep[0], a.rep[0], b);
      x.normalize();
   }
   else if (x.rep.MaxLength() == 0) {
      x = a;
      sub(x.rep[0], a.rep[0], b);
      x.normalize();
   }
   else {
      // ugly...b could alias a coeff of x

      ZZ_p *xp = x.rep.elts();
      sub(xp[0], a.rep[0], b);
      x.rep.SetLength(n);
      xp = x.rep.elts();
      const ZZ_p *ap = a.rep.elts();
      long i;
      for (i = 1; i < n; i++)
         xp[i] = ap[i];
      x.normalize();
   }
}

void sub(ZZ_pX& x, const ZZ_pX& a, long b)
{
   if (b == 0) {
      x = a;
      return;
   }

   if (a.rep.length() == 0) {
      x.rep.SetLength(1);
      x.rep[0] = b;
      negate(x.rep[0], x.rep[0]);
   }
   else {
      if (&x != &a) x = a;
      sub(x.rep[0], x.rep[0], b);
   }
   x.normalize();
}

void sub(ZZ_pX& x, const ZZ_p& a, const ZZ_pX& b)
{
   ZZ_pTemp TT; ZZ_p& T = TT.val(); 
   T = a;

   negate(x, b);
   add(x, x, T);
}

void sub(ZZ_pX& x, long a, const ZZ_pX& b)
{
   ZZ_pTemp TT; ZZ_p& T = TT.val(); 
   T = a;

   negate(x, b);
   add(x, x, T);
}

void negate(ZZ_pX& x, const ZZ_pX& a)
{
   long n = a.rep.length();
   x.rep.SetLength(n);

   const ZZ_p* ap = a.rep.elts();
   ZZ_p* xp = x.rep.elts();
   long i;

   for (i = n; i; i--, ap++, xp++)
      negate((*xp), (*ap));

}


#ifndef NTL_OLD_ZZ_pX_MUL

// These crossovers are tuned for a Pentium, but hopefully
// they should be OK on other machines as well.


const long SS_kbound = 40;
const double SS_rbound = 1.25;


void mul(ZZ_pX& c, const ZZ_pX& a, const ZZ_pX& b)
{
   if (IsZero(a) || IsZero(b)) {
      clear(c);
      return;
   }

   if (&a == &b) {
      sqr(c, a);
      return;
   }

   long k = ZZ_p::ModulusSize();
   long s = min(deg(a), deg(b)) + 1;

   if (s == 1 || (k == 1 && s < 40) || (k == 2 && s < 20) ||
                 (k == 3 && s < 12) || (k <= 5 && s < 8) ||
                 (k <= 12 && s < 4) )  {
      PlainMul(c, a, b);
   }
   else if (s < 80) {
      ZZX A, B, C;
      conv(A, a);
      conv(B, b);
      KarMul(C, A, B);
      conv(c, C);
   }
   else {
      long mbits;
      mbits = NumBits(ZZ_p::modulus());
      if (k >= SS_kbound && 
          SSRatio(deg(a), mbits, deg(b), mbits) < SS_rbound) {
         ZZX A, B, C;
         conv(A, a);
         conv(B, b);
         SSMul(C, A, B);
         conv(c, C);
      }
      else {
         FFTMul(c, a, b);
      }
   }
}

void sqr(ZZ_pX& c, const ZZ_pX& a)
{
   if (IsZero(a)) {
      clear(c);
      return;
   }

   long k = ZZ_p::ModulusSize();
   long s = deg(a) + 1;

   if (s == 1 || (k == 1 && s < 50) || (k == 2 && s < 25) ||
                 (k == 3 && s < 25) || (k <= 6 && s < 12) ||
                 (k <= 8 && s < 8)  || (k == 9 && s < 6)  ||
                 (k <= 30 && s < 4) ) {

      PlainSqr(c, a);
   }
   else if (s < 80) {
      ZZX C, A;
      conv(A, a);
      KarSqr(C, A);
      conv(c, C);
   }
   else {
      long mbits;
      mbits = NumBits(ZZ_p::modulus());
      if (k >= SS_kbound && 
          SSRatio(deg(a), mbits, deg(a), mbits) < SS_rbound) {
         ZZX A, C;
         conv(A, a);
         SSSqr(C, A);
         conv(c, C);
      }
      else {
         FFTSqr(c, a);
      }
   }
}

#else

void mul(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b)
{
   if (&a == &b) {
      sqr(x, a);
      return;
   }

   if (deg(a) > NTL_ZZ_pX_FFT_CROSSOVER && deg(b) > NTL_ZZ_pX_FFT_CROSSOVER)
      FFTMul(x, a, b);
   else
      PlainMul(x, a, b);
}

void sqr(ZZ_pX& x, const ZZ_pX& a)
{
   if (deg(a) > NTL_ZZ_pX_FFT_CROSSOVER)
      FFTSqr(x, a);
   else
      PlainSqr(x, a);
}


#endif


void PlainMul(ZZ_pX& x, const ZZ_pX& a, const ZZ_pX& b)
{
   long da = deg(a);
   long db = deg(b);

   if (da < 0 || db < 0) {
      clear(x);
      return;
   }

   if (da == 0) {
      mul(x, b, a.rep[0]);
      return;
   }

   if (db == 0) {
      mul(x, a, b.rep[0]);
      return;
   }

   long d = da+db;



   const ZZ_p *ap, *bp;
   ZZ_p *xp;
   
   ZZ_pX la, lb;

   if (&x == &a) {
      la = a;
      ap = la.rep.elts();
   }
   else
      ap = a.rep.elts();

   if (&x == &b) {
      lb = b;
      bp = lb.rep.elts();
   }
   else
      bp = b.rep.elts();

   x.rep.SetLength(d+1);

   xp = x.rep.elts();

   long i, j, jmin, jmax;
   static ZZ t, accum;

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-db);
      jmax = min(da, i);
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
	 mul(t, rep(ap[j]), rep(bp[i-j]));
	 add(accum, accum, t);
      }
      conv(xp[i], accum);
   }
   x.normalize();
}

void PlainSqr(ZZ_pX& x, const ZZ_pX& a)
{
   long da = deg(a);

   if (da < 0) {
      clear(x);
      return;
   }

   long d = 2*da;

   const ZZ_p *ap;
   ZZ_p *xp;

   ZZ_pX la;

   if (&x == &a) {
      la = a;
      ap = la.rep.elts();
   }
   else
      ap = a.rep.elts();


   x.rep.SetLength(d+1);

   xp = x.rep.elts();

   long i, j, jmin, jmax;
   long m, m2;
   static ZZ t, accum;

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-da);
      jmax = min(da, i);
      m = jmax - jmin + 1;
      m2 = m >> 1;
      jmax = jmin + m2 - 1;
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
	 mul(t, rep(ap[j]), rep(ap[i-j]));
	 add(accum, accum, t);
      }
      add(accum, accum, accum);
      if (m & 1) {
	 sqr(t, rep(ap[jmax + 1]));
	 add(accum, accum, t);
      }

      conv(xp[i], accum);
   }

   x.normalize();
}

void PlainDivRem(ZZ_pX& q, ZZ_pX& r, const ZZ_pX& a, const ZZ_pX& b)
{
   long da, db, dq, i, j, LCIsOne;
   const ZZ_p *bp;
   ZZ_p *qp;
   ZZ *xp;


   ZZ_p LCInv, t;
   static ZZ s;

   da = deg(a);
   db = deg(b);

   if (db < 0) Error("ZZ_pX: division by zero");

   if (da < db) {
      r = a;
      clear(q);
      return;
   }

   ZZ_pX lb;

   if (&q == &b) {
      lb = b;
      bp = lb.rep.elts();
   }
   else
      bp = b.rep.elts();

   if (IsOne(bp[db]))
      LCIsOne = 1;
   else {
      LCIsOne = 0;
      inv(LCInv, bp[db]);
   }

   ZZVec x(da + 1, ZZ_pInfo->ExtendedModulusSize);

   for (i = 0; i <= da; i++)
      x[i] = rep(a.rep[i]);

   xp = x.elts();

   dq = da - db;
   q.rep.SetLength(dq+1);
   qp = q.rep.elts();

   for (i = dq; i >= 0; i--) {
      conv(t, xp[i+db]);
      if (!LCIsOne)
	 mul(t, t, LCInv);
      qp[i] = t;
      negate(t, t);

      for (j = db-1; j >= 0; j--) {
	 mul(s, rep(t), rep(bp[j]));
	 add(xp[i+j], xp[i+j], s);
      }
   }

   r.rep.SetLength(db);
   for (i = 0; i < db; i++)
      conv(r.rep[i], xp[i]);
   r.normalize();
}


void PlainRem(ZZ_pX& r, const ZZ_pX& a, const ZZ_pX& b, ZZVec& x)
{
   long da, db, dq, i, j, LCIsOne;
   const ZZ_p *bp;
   ZZ *xp;


   ZZ_p LCInv, t;
   static ZZ s;

   da = deg(a);
   db = deg(b);

   if (db < 0) Error("ZZ_pX: division by zero");

   if (da < db) {
      r = a;
      return;
   }

   bp = b.rep.elts();

   if (IsOne(bp[db]))
      LCIsOne = 1;
   else {
      LCIsOne = 0;
      inv(LCInv, bp[db]);
   }

   for (i = 0; i <= da; i++)
      x[i] = rep(a.rep[i]);

   xp = x.elts();

   dq = da - db;

   for (i = dq; i >= 0; i--) {
      conv(t, xp[i+db]);
      if (!LCIsOne)
	 mul(t, t, LCInv);
      negate(t, t);

      for (j = db-1; j >= 0; j--) {
	 mul(s, rep(t), rep(bp[j]));
	 add(xp[i+j], xp[i+j], s);
      }
   }

   r.rep.SetLength(db);
   for (i = 0; i < db; i++)
      conv(r.rep[i], xp[i]);
   r.normalize();
}


void PlainDivRem(ZZ_pX& q, ZZ_pX& r, const ZZ_pX& a, const ZZ_pX& b, ZZVec& x)
{
   long da, db, dq, i, j, LCIsOne;
   const ZZ_p *bp;
   ZZ_p *qp;
   ZZ *xp;


   ZZ_p LCInv, t;
   static ZZ s;

   da = deg(a);
   db = deg(b);

   if (db < 0) Error("ZZ_pX: division by zero");

   if (da < db) {
      r = a;
      clear(q);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -