⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 fft.cpp

📁 数值算法库for Windows
💻 CPP
字号:


#include <NTL/FFT.h>

#include <NTL/new.h>

NTL_START_IMPL

long NumFFTPrimes = 0;

long *FFTPrime = 0;
long **RootTable = 0;
long **RootInvTable = 0;
long **TwoInvTable = 0;
double *FFTPrimeInv = 0;


static
long IsFFTPrime(long n, long& w)
{
   long  m, x, y, z;
   long j, k;

   if (n % 3 == 0) return 0;

   if (n % 5 == 0) return 0;

   if (n % 7 == 0) return 0;
   
   m = n - 1;
   k = 0;
   while ((m & 1) == 0) {
      m = m >> 1;
      k++;
   }

   for (;;) {
      x = RandomBnd(n);

      if (x == 0) continue;
      z = PowerMod(x, m, n);
      if (z == 1) continue;

      x = z;
      j = 0;
      do {
         y = z;
         z = MulMod(y, y, n);
         j++;
      } while (j != k && z != 1);

      if (z != 1 || y !=  n-1) return 0;

      if (j == k) 
         break;
   }

   /* x^{2^k} = 1 mod n, x^{2^{k-1}} = -1 mod n */

   long TrialBound;

   TrialBound = m >> k;
   if (TrialBound > 0) {
      if (!ProbPrime(n, 5)) return 0;
   
      /* we have to do trial division by special numbers */
   
      TrialBound = SqrRoot(TrialBound);
   
      long a, b;
   
      for (a = 1; a <= TrialBound; a++) {
         b = (a << k) + 1;
         if (n % b == 0) return 0; 
      }
   }

   /* n is an FFT prime */

   for (j = NTL_FFTMaxRoot; j < k; j++)
      x = MulMod(x, x, n);

   w = x;
   return 1;
}


static
void NextFFTPrime(long& q, long& w)
{
   static long m = NTL_FFTMaxRootBnd + 1;
   static long k = 0;

   long t, cand;

   for (;;) {
      if (k == 0) {
         m--;
         if (m < 5) Error("ran out of FFT primes");
         k = 1L << (NTL_SP_NBITS-m-2);
      }

      k--;

      cand = (1L << (NTL_SP_NBITS-1)) + (k << (m+1)) + (1L << m) + 1;

      if (!IsFFTPrime(cand, t)) continue;
      q = cand;
      w = t;
      return;
   }
}


long CalcMaxRoot(long p)
{
   p = p-1;
   long k = 0;
   while ((p & 1) == 0) {
      p = p >> 1;
      k++;
   }

   if (k > NTL_FFTMaxRoot)
      return NTL_FFTMaxRoot;
   else
      return k; 
}


void UseFFTPrime(long index)
{
   if (index < 0 || index > NumFFTPrimes)
      Error("invalid FFT prime index");

   if (index < NumFFTPrimes) return;

   long q, w;

   NextFFTPrime(q, w);

   long mr = CalcMaxRoot(q);

   // tables are allocated in increments of 100

   if (index == 0) { 
      FFTPrime = (long *) malloc(sizeof(long)*100);
      RootTable = (long **) malloc(sizeof(long *)*100);
      RootInvTable = (long **) malloc(sizeof(long *)*100);
      TwoInvTable = (long **) malloc(sizeof(long *)*100);
      FFTPrimeInv = (double *) malloc(sizeof(double)*100);
   }
   else if ((index % 100) == 0) {
      FFTPrime = (long *) realloc(FFTPrime, sizeof(long)*(index+100));
      RootTable = (long **) realloc(RootTable, sizeof(long *)*(index+100));
      RootInvTable = (long **) 
                     realloc(RootInvTable, sizeof(long *)*(index+100));
      TwoInvTable = (long **) realloc(TwoInvTable, sizeof(long *)*(index+100));
      FFTPrimeInv = (double *) realloc(FFTPrimeInv, sizeof(double)*(index+100));
   }

   if (!FFTPrime || !RootTable || !RootInvTable || !TwoInvTable ||
       !FFTPrimeInv) 
      Error("out of space");

   FFTPrime[index] = q;

   long *rt, *rit, *tit;

   if (!(rt = RootTable[index] = (long *)malloc(sizeof(long)*(mr+1))))
      Error("out of space");
   if (!(rit = RootInvTable[index] = (long *)malloc(sizeof(long)*(mr+1))))
      Error("out of space");
   if (!(tit = TwoInvTable[index] = (long *)malloc(sizeof(long)*(mr+1))))
      Error("out of space");

   long j;
   long t;

   rt[mr] = w;
   for (j = mr-1; j >= 0; j--)
      rt[j] = MulMod(rt[j+1], rt[j+1], q);

   rit[mr] = InvMod(w, q);
   for (j = mr-1; j >= 0; j--)
      rit[j] = MulMod(rit[j+1], rit[j+1], q);

   t = InvMod(2, q);
   tit[0] = 1;
   for (j = 1; j <= mr; j++)
      tit[j] = MulMod(tit[j-1], t, q);

   FFTPrimeInv[index] = 1/double(q);

   NumFFTPrimes++;
}
   

long RevInc(long a, long k)
{
   long j, m;
   j = k; m = 1L << (k-1);
   while (j && (m & a)) {
      a ^= m;
      m >>= 1;
      j--;
   }
   if (j) a ^= m;
   return a;
}

void BitReverseCopy(long *A, const long *a, long k)

{
   static long* mem[NTL_FFTMaxRoot+1];

   long n = 1L << k;
   long* rev;
   long i, j;

   rev = mem[k];
   if (!rev) {
      rev = mem[k] = NTL_NEW_OP long[n];
      if (!rev) Error("out of memory in BitReverseCopy");
      for (i = 0, j = 0; i < n; i++, j = RevInc(j, k))
         rev[i] = j;
   }

   for (i = 0; i < n; i++)
      A[rev[i]] = a[i];
}


#ifdef NTL_FFT_PIPELINE

/*****************************************************

   This version of the FFT is written with an explicit
   "software pipeline", which sometimes speeds things up.

*******************************************************/

void FFT(long* A, const long* a, long k, long q, const long* root)

// performs a 2^k-point convolution modulo q

{
   if (k == 0) {
      A[0] = a[0];
      return;
   }

   if (k == 1) {
      A[0] = AddMod(a[0], a[1], q);
      A[1] = SubMod(a[0], a[1], q);
      return;
   }

   // assume k > 1

   long n = 1L << k;
   long s, m, m2, j;
   long t, u, v, w, z, tt;
   long *p1, *p2, *ub, *ub1;
   double qinv = ((double) 1)/((double) q);
   double wqinv, zqinv;

   BitReverseCopy(A, a, k);

   ub = A+n;

   p2 = A;
   while (p2 < ub) {
      u = *p2;
      v = *(p2+1);
      *p2 = AddMod(u, v, q);
      *(p2+1) = SubMod(u, v, q);
      p2 += 2;
   }

   for (s = 2; s < k; s++) {
      m = 1L << s;
      m2 = m >> 1;

      p2 = A;
      p1 = p2 + m2;
      while (p2 < ub) {
         u = *p2;
         v = *p1;
         *p2 = AddMod(u, v, q);
         *p1 = SubMod(u, v, q);
         p1 += m;
         p2 += m;
      }

      z = root[s];
      w = z;
      for (j = 1; j < m2; j++) {
         wqinv = ((double) w)*qinv;
         p2 = A + j;
         p1 = p2 + m2;

         ub1 = ub-m;

         u = *p2;
         t = MulMod2(*p1, w, q, wqinv);

         while (p2 < ub1) {
            tt = MulMod2(*(p1+m), w, q, wqinv);
            *p2 = AddMod(u, t, q);
            *p1 = SubMod(u, t, q);
            p1 += m;
            p2 += m;
            u = *p2;
            t = tt;
         }
         *p2 = AddMod(u, t, q);
         *p1 = SubMod(u, t, q);
         
         w = MulMod2(z, w, q, wqinv);
      }
   }

   m2 = n >> 1;
   z = root[k];
   zqinv = ((double) z)*qinv;
   w = 1;
   p2 = A;
   p1 = A + m2;
   m2--;
   u = *p2;
   t = *p1;
   while (m2) {
      w = MulMod2(w, z, q, zqinv);
      tt = MulMod(*(p1+1), w, q, qinv);
      *p2 = AddMod(u, t, q);
      *p1 = SubMod(u, t, q);
      p2++;
      p1++;
      u = *p2;
      t = tt;
      m2--;
   }
   *p2 = AddMod(u, t, q);
   *p1 = SubMod(u, t, q);
}



#else


/*****************************************************

   This version of the FFT has no "software pipeline".

******************************************************/



void FFT(long* A, const long* a, long k, long q, const long* root)

// performs a 2^k-point convolution modulo q

{
   if (k == 0) {
      A[0] = a[0];
      return;
   }

   if (k == 1) {
      A[0] = AddMod(a[0], a[1], q);
      A[1] = SubMod(a[0], a[1], q);
      return;
   }

   // assume k > 1

   long n = 1L << k;
   long s, m, m2, j;
   long t, u, v, w, z;
   long *p1, *p2, *ub;
   double qinv = ((double) 1)/((double) q);
   double wqinv, zqinv;

   BitReverseCopy(A, a, k);

   ub = A+n;

   p2 = A;
   while (p2 < ub) {
      u = *p2;
      v = *(p2+1);
      *p2 = AddMod(u, v, q);
      *(p2+1) = SubMod(u, v, q);
      p2 += 2;
   }

   for (s = 2; s < k; s++) {
      m = 1L << s;
      m2 = m >> 1;

      p2 = A;
      p1 = p2 + m2;
      while (p2 < ub) {
         u = *p2;
         v = *p1;
         *p2 = AddMod(u, v, q);
         *p1 = SubMod(u, v, q);
         p1 += m;
         p2 += m;
      }

      z = root[s];
      w = z;
      for (j = 1; j < m2; j++) {
         wqinv = ((double) w)*qinv;
         p2 = A + j;
         p1 = p2 + m2;
         while (p2 < ub) {
            u = *p2;
            v = *p1;
            t = MulMod2(v, w, q, wqinv);
            *p2 = AddMod(u, t, q);
            *p1 = SubMod(u, t, q);
            p1 += m;
            p2 += m;
         }
         w = MulMod2(z, w, q, wqinv);
      }
   }

   m2 = n >> 1;
   z = root[k];
   zqinv = ((double) z)*qinv;
   w = 1;
   p2 = A;
   p1 = A + m2;
   for (j = 0; j < m2; j++) {
      u = *p2;
      v = *p1;
      t = MulMod(v, w, q, qinv);
      *p2 = AddMod(u, t, q);
      *p1 = SubMod(u, t, q);
      w = MulMod2(w, z, q, zqinv);
      p2++;
      p1++;
   }
}

#endif

NTL_END_IMPL

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -