⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 zzx1.cpp

📁 数值算法库for Windows
💻 CPP
📖 第 1 页 / 共 3 页
字号:


#include <NTL/ZZX.h>

#include <NTL/new.h>

NTL_START_IMPL





void conv(zz_pX& x, const ZZX& a)
{
   conv(x.rep, a.rep);
   x.normalize();
}


void conv(ZZX& x, const zz_pX& a)
{
   conv(x.rep, a.rep);
   x.normalize();
}


long CRT(ZZX& gg, ZZ& a, const zz_pX& G)
{
   long n = gg.rep.length();

   long p = zz_p::modulus();

   ZZ new_a;
   mul(new_a, a, p);

   long a_inv;
   a_inv = rem(a, p);
   a_inv = InvMod(a_inv, p);

   long p1;
   p1 = p >> 1;

   ZZ a1;
   RightShift(a1, a, 1);

   long p_odd = (p & 1);

   long modified = 0;

   long h;
   ZZ ah;

   long m = G.rep.length();

   long max_mn = max(m, n);

   gg.rep.SetLength(max_mn);

   ZZ g;
   long i;

   for (i = 0; i < n; i++) {
      if (!CRTInRange(gg.rep[i], a)) {
         modified = 1;
         rem(g, gg.rep[i], a);
         if (g > a1) sub(g, g, a);
      }
      else
         g = gg.rep[i];
   
      h = rem(g, p);

      if (i < m)
         h = SubMod(rep(G.rep[i]), h, p);
      else
         h = NegateMod(h, p);

      h = MulMod(h, a_inv, p);
      if (h > p1)
         h = h - p;
   
      if (h != 0) {
         modified = 1;
         mul(ah, a, h);
   
         if (!p_odd && g > 0 && (h == p1))
            sub(g, g, ah);
         else
            add(g, g, ah);
      }

      gg.rep[i] = g;
   }


   for (; i < m; i++) {
      h = rep(G.rep[i]);
      h = MulMod(h, a_inv, p);
      if (h > p1)
         h = h - p;
   
      modified = 1;
      mul(g, a, h);
      gg.rep[i] = g;
   }

   gg.normalize();
   a = new_a;

   return modified;
}

long CRT(ZZX& gg, ZZ& a, const ZZ_pX& G)
{
   long n = gg.rep.length();

   const ZZ& p = ZZ_p::modulus();

   ZZ new_a;
   mul(new_a, a, p);

   ZZ a_inv;
   rem(a_inv, a, p);
   InvMod(a_inv, a_inv, p);

   ZZ p1;
   RightShift(p1, p, 1);

   ZZ a1;
   RightShift(a1, a, 1);

   long p_odd = IsOdd(p);

   long modified = 0;

   ZZ h;
   ZZ ah;

   long m = G.rep.length();

   long max_mn = max(m, n);

   gg.rep.SetLength(max_mn);

   ZZ g;
   long i;

   for (i = 0; i < n; i++) {
      if (!CRTInRange(gg.rep[i], a)) {
         modified = 1;
         rem(g, gg.rep[i], a);
         if (g > a1) sub(g, g, a);
      }
      else
         g = gg.rep[i];
   
      rem(h, g, p);

      if (i < m)
         SubMod(h, rep(G.rep[i]), h, p);
      else
         NegateMod(h, h, p);

      MulMod(h, h, a_inv, p);
      if (h > p1)
         sub(h, h, p);
   
      if (h != 0) {
         modified = 1;
         mul(ah, a, h);
   
         if (!p_odd && g > 0 && (h == p1))
            sub(g, g, ah);
         else
            add(g, g, ah);
      }

      gg.rep[i] = g;
   }


   for (; i < m; i++) {
      h = rep(G.rep[i]);
      MulMod(h, h, a_inv, p);
      if (h > p1)
         sub(h, h, p);
   
      modified = 1;
      mul(g, a, h);
      gg.rep[i] = g;
   }

   gg.normalize();
   a = new_a;

   return modified;
}




/* Compute a = b * 2^l mod p, where p = 2^n+1. 0<=l<=n and 0<b<p are
   assumed. */
static void LeftRotate(ZZ& a, const ZZ& b, long l, const ZZ& p, long n)
{
  if (l == 0) {
    if (&a != &b) {
      a = b;
    }
    return;
  }

  /* tmp := upper l bits of b */
  static ZZ tmp;
  RightShift(tmp, b, n - l);
  /* a := 2^l * lower n - l bits of b */
  trunc(a, b, n - l);
  LeftShift(a, a, l);
  /* a -= tmp */
  sub(a, a, tmp);
  if (sign(a) < 0) {
    add(a, a, p);
  }
}


/* Compute a = b * 2^l mod p, where p = 2^n+1. 0<=p<b is assumed. */
static void Rotate(ZZ& a, const ZZ& b, long l, const ZZ& p, long n)
{
  if (IsZero(b)) {
    clear(a);
    return;
  }

  /* l %= 2n */
  if (l >= 0) {
    l %= (n << 1);
  } else {
    l = (n << 1) - 1 - (-(l + 1) % (n << 1));
  }

  /* a = b * 2^l mod p */
  if (l < n) {
    LeftRotate(a, b, l, p, n);
  } else {
    LeftRotate(a, b, l - n, p, n);
    SubPos(a, p, a);
  }
}



/* Fast Fourier Transform. a is a vector of length 2^l, 2^l divides 2n,
   p = 2^n+1, w = 2^r mod p is a primitive (2^l)th root of
   unity. Returns a(1),a(w),...,a(w^{2^l-1}) mod p in bit-reverse
   order. */
static void fft(vec_ZZ& a, long r, long l, const ZZ& p, long n)
{
  long round;
  long off, i, j, e;
  long halfsize;
  ZZ tmp, tmp1;

  for (round = 0; round < l; round++, r <<= 1) {
    halfsize =  1L << (l - 1 - round);
    for (i = (1L << round) - 1, off = 0; i >= 0; i--, off += halfsize) {
      for (j = 0, e = 0; j < halfsize; j++, off++, e+=r) {
	/* One butterfly : 
	 ( a[off], a[off+halfsize] ) *= ( 1  w^{j2^round} )
	                                ( 1 -w^{j2^round} ) */
	/* tmp = a[off] - a[off + halfsize] mod p */
	sub(tmp, a[off], a[off + halfsize]);
	if (sign(tmp) < 0) {
	  add(tmp, tmp, p);
	}
	/* a[off] += a[off + halfsize] mod p */
	add(a[off], a[off], a[off + halfsize]);
	sub(tmp1, a[off], p);
	if (sign(tmp1) >= 0) {
	  a[off] = tmp1;
	}
	/* a[off + halfsize] = tmp * w^{j2^round} mod p */
	Rotate(a[off + halfsize], tmp, e, p, n);
      }
    }
  }
}

/* Inverse FFT. r must be the same as in the call to FFT. Result is
   by 2^l too large. */
static void ifft(vec_ZZ& a, long r, long l, const ZZ& p, long n)
{
  long round;
  long off, i, j, e;
  long halfsize;
  ZZ tmp, tmp1;

  for (round = l - 1, r <<= l - 1; round >= 0; round--, r >>= 1) {
    halfsize = 1L << (l - 1 - round);
    for (i = (1L << round) - 1, off = 0; i >= 0; i--, off += halfsize) {
      for (j = 0, e = 0; j < halfsize; j++, off++, e+=r) {
	/* One inverse butterfly : 
	 ( a[off], a[off+halfsize] ) *= ( 1               1             )
	                                ( w^{-j2^round}  -w^{-j2^round} ) */
	/* a[off + halfsize] *= w^{-j2^round} mod p */
	Rotate(a[off + halfsize], a[off + halfsize], -e, p, n);
	/* tmp = a[off] - a[off + halfsize] */
	sub(tmp, a[off], a[off + halfsize]);

	/* a[off] += a[off + halfsize] mod p */
	add(a[off], a[off], a[off + halfsize]);
	sub(tmp1, a[off], p);
	if (sign(tmp1) >= 0) {
	  a[off] = tmp1;
	}
	/* a[off+halfsize] = tmp mod p */
	if (sign(tmp) < 0) {
	  add(a[off+halfsize], tmp, p);
	} else {
	  a[off+halfsize] = tmp;
	}
      }
    }
  }
}



/* Multiplication a la Schoenhage & Strassen, modulo a "Fermat" number
   p = 2^{mr}+1, where m is a power of two and r is odd. Then w = 2^r
   is a primitive 2mth root of unity, i.e., polynomials whose product
   has degree less than 2m can be multiplied, provided that the
   coefficients of the product polynomial are at most 2^{mr-1} in
   absolute value. The algorithm is not called recursively;
   coefficient arithmetic is done directly.*/

void SSMul(ZZX& c, const ZZX& a, const ZZX& b)
{
  long na = deg(a);
  long nb = deg(b);

  if (na <= 0 || nb <= 0) {
    PlainMul(c, a, b);
    return;
  }

  long n = na + nb; /* degree of the product */


  /* Choose m and r suitably */
  long l = NextPowerOfTwo(n + 1) - 1; /* 2^l <= n < 2^{l+1} */
  long m2 = 1L << (l + 1); /* m2 = 2m = 2^{l+1} */
  /* Bitlength of the product: if the coefficients of a are absolutely less
     than 2^ka and the coefficients of b are absolutely less than 2^kb, then
     the coefficients of ab are absolutely less than
     (min(na,nb)+1)2^{ka+kb} <= 2^bound. */
  long bound = 2 + NumBits(min(na, nb)) + MaxBits(a) + MaxBits(b);
  /* Let r be minimal so that mr > bound */
  long r = (bound >> l) + 1;
  long mr = r << l;

  /* p := 2^{mr}+1 */
  ZZ p;
  set(p);
  LeftShift(p, p, mr);
  add(p, p, 1);

  /* Make coefficients of a and b positive */
  vec_ZZ aa, bb;
  aa.SetLength(m2);
  bb.SetLength(m2);

  long i;
  for (i = 0; i <= deg(a); i++) {
    if (sign(a.rep[i]) >= 0) {
      aa[i] = a.rep[i];
    } else {
      add(aa[i], a.rep[i], p);
    }
  }

  for (i = 0; i <= deg(b); i++) {
    if (sign(b.rep[i]) >= 0) {
      bb[i] = b.rep[i];
    } else {
      add(bb[i], b.rep[i], p);
    }
  }

  /* 2m-point FFT's mod p */
  fft(aa, r, l + 1, p, mr);
  fft(bb, r, l + 1, p, mr);

  /* Pointwise multiplication aa := aa * bb mod p */
  ZZ tmp, ai;
  for (i = 0; i < m2; i++) {
    mul(ai, aa[i], bb[i]);
    if (NumBits(ai) > mr) {
      RightShift(tmp, ai, mr);
      trunc(ai, ai, mr);
      sub(ai, ai, tmp);
      if (sign(ai) < 0) {
	add(ai, ai, p);
      }
    }
    aa[i] = ai;
  }
  
  ifft(aa, r, l + 1, p, mr);

  /* Retrieve c, dividing by 2m, and subtracting p where necessary */
  c.rep.SetLength(n + 1);
  for (i = 0; i <= n; i++) {
    ai = aa[i];
    ZZ& ci = c.rep[i];
    if (!IsZero(ai)) {
      /* ci = -ai * 2^{mr-l-1} = ai * 2^{-l-1} = ai / 2m mod p */
      LeftRotate(ai, ai, mr - l - 1, p, mr);
      sub(tmp, p, ai);
      if (NumBits(tmp) >= mr) { /* ci >= (p-1)/2 */
	negate(ci, ai); /* ci = -ai = ci - p */
      }
      else
        ci = tmp;
    } 
    else
       clear(ci);
  }
}


// SSRatio computes how much bigger the SS moduls must be
// to accomodate the necessary roots of unity.
// This is useful in determining algorithm crossover points.

double SSRatio(long na, long maxa, long nb, long maxb)
{
  if (na <= 0 || nb <= 0) return 0;

  long n = na + nb; /* degree of the product */


  long l = NextPowerOfTwo(n + 1) - 1; /* 2^l <= n < 2^{l+1} */
  long bound = 2 + NumBits(min(na, nb)) + maxa + maxb;
  long r = (bound >> l) + 1;
  long mr = r << l;

  return double(mr + 1)/double(bound);
}

void HomMul(ZZX& x, const ZZX& a, const ZZX& b)
{
   if (&a == &b) {
      HomSqr(x, a);
      return;
   }

   long da = deg(a);
   long db = deg(b);

   if (da < 0 || db < 0) {
      clear(x);
      return;
   }

   long bound = 2 + NumBits(min(da, db)+1) + MaxBits(a) + MaxBits(b);


   ZZ prod;
   set(prod);

   long i, nprimes;

   zz_pBak bak;
   bak.save();

   for (nprimes = 0; NumBits(prod) <= bound; nprimes++) {
      if (nprimes >= NumFFTPrimes)
         zz_p::FFTInit(nprimes);
      mul(prod, prod, FFTPrime[nprimes]);
   }


   ZZ coeff;
   ZZ t1;
   long tt;

   vec_ZZ c;

   c.SetLength(da+db+1);

   long j;

   for (i = 0; i < nprimes; i++) {
      zz_p::FFTInit(i);
      long p = zz_p::modulus();

      div(t1, prod, p);
      tt = rem(t1, p);
      tt = InvMod(tt, p);
      mul(coeff, t1, tt);

      zz_pX A, B, C;

      conv(A, a);
      conv(B, b);
      mul(C, A, B);

      long m = deg(C);

      for (j = 0; j <= m; j++) {
         /* c[j] += coeff*rep(C.rep[j]) */
         mul(t1, coeff, rep(C.rep[j]));
         add(c[j], c[j], t1); 
      }
   }

   x.rep.SetLength(da+db+1);

   ZZ prod2;
   RightShift(prod2, prod, 1);

   for (j = 0; j <= da+db; j++) {
      rem(t1, c[j], prod);

      if (t1 > prod2)
         sub(x.rep[j], t1, prod);
      else
         x.rep[j] = t1;
   }

   x.normalize();

   bak.restore();
}

static
long MaxSize(const ZZX& a)
{
   long res = 0;
   long n = a.rep.length();

   long i;
   for (i = 0; i < n; i++) {
      long t = a.rep[i].size();
      if (t > res)
         res = t;
   }

   return res;
}


void mul(ZZX& c, const ZZX& a, const ZZX& b)
{
   if (IsZero(a) || IsZero(b)) {
      clear(c);
      return;
   }

   if (&a == &b) {
      sqr(c, a);
      return;
   }

   long maxa = MaxSize(a);
   long maxb = MaxSize(b);

   long k = min(maxa, maxb);
   long s = min(deg(a), deg(b)) + 1;

   if (s == 1 || (k == 1 && s < 40) || (k == 2 && s < 20) || 
                 (k == 3 && s < 10)) {

      PlainMul(c, a, b);
      return;
   }

   if (s < 80 || (k < 30 && s < 150))  {
      KarMul(c, a, b);
      return;
   }


   if (maxa + maxb >= 40 && 
       SSRatio(deg(a), MaxBits(a), deg(b), MaxBits(b)) < 1.75) 
      SSMul(c, a, b);
   else
      HomMul(c, a, b);
}


void SSSqr(ZZX& c, const ZZX& a)
{
  long na = deg(a);
  if (na <= 0) {
    PlainSqr(c, a);
    return;
  }

  long n = na + na; /* degree of the product */


  long l = NextPowerOfTwo(n + 1) - 1; /* 2^l <= n < 2^{l+1} */
  long m2 = 1L << (l + 1); /* m2 = 2m = 2^{l+1} */
  long bound = 2 + NumBits(na) + 2*MaxBits(a);
  long r = (bound >> l) + 1;
  long mr = r << l;

  /* p := 2^{mr}+1 */
  ZZ p;
  set(p);
  LeftShift(p, p, mr);
  add(p, p, 1);

  vec_ZZ aa;
  aa.SetLength(m2);

  long i;
  for (i = 0; i <= deg(a); i++) {
    if (sign(a.rep[i]) >= 0) {
      aa[i] = a.rep[i];
    } else {
      add(aa[i], a.rep[i], p);
    }
  }


  /* 2m-point FFT's mod p */
  fft(aa, r, l + 1, p, mr);

  /* Pointwise multiplication aa := aa * aa mod p */
  ZZ tmp, ai;
  for (i = 0; i < m2; i++) {
    sqr(ai, aa[i]);
    if (NumBits(ai) > mr) {
      RightShift(tmp, ai, mr);
      trunc(ai, ai, mr);
      sub(ai, ai, tmp);
      if (sign(ai) < 0) {
	add(ai, ai, p);
      }
    }
    aa[i] = ai;
  }
  
  ifft(aa, r, l + 1, p, mr);

  ZZ ci;

  /* Retrieve c, dividing by 2m, and subtracting p where necessary */
  c.rep.SetLength(n + 1);

  for (i = 0; i <= n; i++) {
    ai = aa[i];
    ZZ& ci = c.rep[i];
    if (!IsZero(ai)) {
      /* ci = -ai * 2^{mr-l-1} = ai * 2^{-l-1} = ai / 2m mod p */
      LeftRotate(ai, ai, mr - l - 1, p, mr);
      sub(tmp, p, ai);
      if (NumBits(tmp) >= mr) { /* ci >= (p-1)/2 */
	negate(ci, ai); /* ci = -ai = ci - p */
      }
      else
        ci = tmp;
    } 
    else
       clear(ci);
  }
}

void HomSqr(ZZX& x, const ZZX& a)
{

   long da = deg(a);

   if (da < 0) {
      clear(x);
      return;
   }

   long bound = 2 + NumBits(da+1) + 2*MaxBits(a);


   ZZ prod;
   set(prod);

   long i, nprimes;

   zz_pBak bak;
   bak.save();

   for (nprimes = 0; NumBits(prod) <= bound; nprimes++) {
      if (nprimes >= NumFFTPrimes)
         zz_p::FFTInit(nprimes);
      mul(prod, prod, FFTPrime[nprimes]);
   }


   ZZ coeff;
   ZZ t1;
   long tt;

   vec_ZZ c;

   c.SetLength(da+da+1);

   long j;

   for (i = 0; i < nprimes; i++) {
      zz_p::FFTInit(i);
      long p = zz_p::modulus();

      div(t1, prod, p);
      tt = rem(t1, p);
      tt = InvMod(tt, p);
      mul(coeff, t1, tt);

      zz_pX A, C;

      conv(A, a);
      sqr(C, A);

      long m = deg(C);

      for (j = 0; j <= m; j++) {
         /* c[j] += coeff*rep(C.rep[j]) */
         mul(t1, coeff, rep(C.rep[j]));
         add(c[j], c[j], t1); 
      }
   }

   x.rep.SetLength(da+da+1);

   ZZ prod2;
   RightShift(prod2, prod, 1);

   for (j = 0; j <= da+da; j++) {
      rem(t1, c[j], prod);

      if (t1 > prod2)
         sub(x.rep[j], t1, prod);
      else
         x.rep[j] = t1;
   }

   x.normalize();

   bak.restore();
}


void sqr(ZZX& c, const ZZX& a)
{
   if (IsZero(a)) {
      clear(c);
      return;
   }

   long maxa = MaxSize(a);

   long k = maxa;
   long s = deg(a) + 1;

   if (s == 1 || (k == 1 && s < 50) || (k == 2 && s < 25) || 
                 (k == 3 && s < 25) || (k == 4 && s < 10)) {

      PlainSqr(c, a);
      return;
   }

   if (s < 80 || (k < 30 && s < 150))  {
      KarSqr(c, a);
      return;
   }

   long mba = MaxBits(a);
   
   if (2*maxa >= 40 && 
       SSRatio(deg(a), mba, deg(a), mba) < 1.75) 
      SSSqr(c, a);
   else
      HomSqr(c, a);
}


void mul(ZZX& x, const ZZX& a, const ZZ& b)
{
   ZZ t;
   long i, da;

   const ZZ *ap;
   ZZ* xp;

   if (IsZero(b)) {
      clear(x);
      return;
   }

   t = b;
   da = deg(a);
   x.rep.SetLength(da+1);
   ap = a.rep.elts();
   xp = x.rep.elts();

   for (i = 0; i <= da; i++) 
      mul(xp[i], ap[i], t);
}

void mul(ZZX& x, const ZZX& a, long b)
{
   long i, da;

   const ZZ *ap;
   ZZ* xp;

   if (b == 0) {
      clear(x);
      return;
   }

   da = deg(a);
   x.rep.SetLength(da+1);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -