📄 fft.c
字号:
#include <NTL/FFT.h>#include <NTL/new.h>NTL_START_IMPLlong NumFFTPrimes = 0;long *FFTPrime = 0;long **RootTable = 0;long **RootInvTable = 0;long **TwoInvTable = 0;double *FFTPrimeInv = 0;staticlong IsFFTPrime(long n, long& w){ long m, x, y, z; long j, k; if (n % 3 == 0) return 0; if (n % 5 == 0) return 0; if (n % 7 == 0) return 0; m = n - 1; k = 0; while ((m & 1) == 0) { m = m >> 1; k++; } for (;;) { x = RandomBnd(n); if (x == 0) continue; z = PowerMod(x, m, n); if (z == 1) continue; x = z; j = 0; do { y = z; z = MulMod(y, y, n); j++; } while (j != k && z != 1); if (z != 1 || y != n-1) return 0; if (j == k) break; } /* x^{2^k} = 1 mod n, x^{2^{k-1}} = -1 mod n */ long TrialBound; TrialBound = m >> k; if (TrialBound > 0) { if (!ProbPrime(n, 5)) return 0; /* we have to do trial division by special numbers */ TrialBound = SqrRoot(TrialBound); long a, b; for (a = 1; a <= TrialBound; a++) { b = (a << k) + 1; if (n % b == 0) return 0; } } /* n is an FFT prime */ for (j = NTL_FFTMaxRoot; j < k; j++) x = MulMod(x, x, n); w = x; return 1;}staticvoid NextFFTPrime(long& q, long& w){ static long m = NTL_FFTMaxRootBnd + 1; static long k = 0; long t, cand; for (;;) { if (k == 0) { m--; if (m < 5) Error("ran out of FFT primes"); k = 1L << (NTL_SP_NBITS-m-2); } k--; cand = (1L << (NTL_SP_NBITS-1)) + (k << (m+1)) + (1L << m) + 1; if (!IsFFTPrime(cand, t)) continue; q = cand; w = t; return; }}long CalcMaxRoot(long p){ p = p-1; long k = 0; while ((p & 1) == 0) { p = p >> 1; k++; } if (k > NTL_FFTMaxRoot) return NTL_FFTMaxRoot; else return k; }void UseFFTPrime(long index){ if (index < 0 || index > NumFFTPrimes) Error("invalid FFT prime index"); if (index < NumFFTPrimes) return; long q, w; NextFFTPrime(q, w); long mr = CalcMaxRoot(q); // tables are allocated in increments of 100 if (index == 0) { FFTPrime = (long *) malloc(sizeof(long)*100); RootTable = (long **) malloc(sizeof(long *)*100); RootInvTable = (long **) malloc(sizeof(long *)*100); TwoInvTable = (long **) malloc(sizeof(long *)*100); FFTPrimeInv = (double *) malloc(sizeof(double)*100); } else if ((index % 100) == 0) { FFTPrime = (long *) realloc(FFTPrime, sizeof(long)*(index+100)); RootTable = (long **) realloc(RootTable, sizeof(long *)*(index+100)); RootInvTable = (long **) realloc(RootInvTable, sizeof(long *)*(index+100)); TwoInvTable = (long **) realloc(TwoInvTable, sizeof(long *)*(index+100)); FFTPrimeInv = (double *) realloc(FFTPrimeInv, sizeof(double)*(index+100)); } if (!FFTPrime || !RootTable || !RootInvTable || !TwoInvTable || !FFTPrimeInv) Error("out of space"); FFTPrime[index] = q; long *rt, *rit, *tit; if (!(rt = RootTable[index] = (long *)malloc(sizeof(long)*(mr+1)))) Error("out of space"); if (!(rit = RootInvTable[index] = (long *)malloc(sizeof(long)*(mr+1)))) Error("out of space"); if (!(tit = TwoInvTable[index] = (long *)malloc(sizeof(long)*(mr+1)))) Error("out of space"); long j; long t; rt[mr] = w; for (j = mr-1; j >= 0; j--) rt[j] = MulMod(rt[j+1], rt[j+1], q); rit[mr] = InvMod(w, q); for (j = mr-1; j >= 0; j--) rit[j] = MulMod(rit[j+1], rit[j+1], q); t = InvMod(2, q); tit[0] = 1; for (j = 1; j <= mr; j++) tit[j] = MulMod(tit[j-1], t, q); FFTPrimeInv[index] = 1/double(q); NumFFTPrimes++;} long RevInc(long a, long k){ long j, m; j = k; m = 1L << (k-1); while (j && (m & a)) { a ^= m; m >>= 1; j--; } if (j) a ^= m; return a;}void BitReverseCopy(long *A, const long *a, long k){ static long* mem[NTL_FFTMaxRoot+1]; long n = 1L << k; long* rev; long i, j; rev = mem[k]; if (!rev) { rev = mem[k] = NTL_NEW_OP long[n]; if (!rev) Error("out of memory in BitReverseCopy"); for (i = 0, j = 0; i < n; i++, j = RevInc(j, k)) rev[i] = j; } for (i = 0; i < n; i++) A[rev[i]] = a[i];}#ifdef NTL_FFT_PIPELINE/***************************************************** This version of the FFT is written with an explicit "software pipeline", which sometimes speeds things up.*******************************************************/void FFT(long* A, const long* a, long k, long q, const long* root)// performs a 2^k-point convolution modulo q{ if (k == 0) { A[0] = a[0]; return; } if (k == 1) { A[0] = AddMod(a[0], a[1], q); A[1] = SubMod(a[0], a[1], q); return; } // assume k > 1 long n = 1L << k; long s, m, m2, j; long t, u, v, w, z, tt; long *p1, *p2, *ub, *ub1; double qinv = ((double) 1)/((double) q); double wqinv, zqinv; BitReverseCopy(A, a, k); ub = A+n; p2 = A; while (p2 < ub) { u = *p2; v = *(p2+1); *p2 = AddMod(u, v, q); *(p2+1) = SubMod(u, v, q); p2 += 2; } for (s = 2; s < k; s++) { m = 1L << s; m2 = m >> 1; p2 = A; p1 = p2 + m2; while (p2 < ub) { u = *p2; v = *p1; *p2 = AddMod(u, v, q); *p1 = SubMod(u, v, q); p1 += m; p2 += m; } z = root[s]; w = z; for (j = 1; j < m2; j++) { wqinv = ((double) w)*qinv; p2 = A + j; p1 = p2 + m2; ub1 = ub-m; u = *p2; t = MulMod2(*p1, w, q, wqinv); while (p2 < ub1) { tt = MulMod2(*(p1+m), w, q, wqinv); *p2 = AddMod(u, t, q); *p1 = SubMod(u, t, q); p1 += m; p2 += m; u = *p2; t = tt; } *p2 = AddMod(u, t, q); *p1 = SubMod(u, t, q); w = MulMod2(z, w, q, wqinv); } } m2 = n >> 1; z = root[k]; zqinv = ((double) z)*qinv; w = 1; p2 = A; p1 = A + m2; m2--; u = *p2; t = *p1; while (m2) { w = MulMod2(w, z, q, zqinv); tt = MulMod(*(p1+1), w, q, qinv); *p2 = AddMod(u, t, q); *p1 = SubMod(u, t, q); p2++; p1++; u = *p2; t = tt; m2--; } *p2 = AddMod(u, t, q); *p1 = SubMod(u, t, q);}#else/***************************************************** This version of the FFT has no "software pipeline".******************************************************/void FFT(long* A, const long* a, long k, long q, const long* root)// performs a 2^k-point convolution modulo q{ if (k == 0) { A[0] = a[0]; return; } if (k == 1) { A[0] = AddMod(a[0], a[1], q); A[1] = SubMod(a[0], a[1], q); return; } // assume k > 1 long n = 1L << k; long s, m, m2, j; long t, u, v, w, z; long *p1, *p2, *ub; double qinv = ((double) 1)/((double) q); double wqinv, zqinv; BitReverseCopy(A, a, k); ub = A+n; p2 = A; while (p2 < ub) { u = *p2; v = *(p2+1); *p2 = AddMod(u, v, q); *(p2+1) = SubMod(u, v, q); p2 += 2; } for (s = 2; s < k; s++) { m = 1L << s; m2 = m >> 1; p2 = A; p1 = p2 + m2; while (p2 < ub) { u = *p2; v = *p1; *p2 = AddMod(u, v, q); *p1 = SubMod(u, v, q); p1 += m; p2 += m; } z = root[s]; w = z; for (j = 1; j < m2; j++) { wqinv = ((double) w)*qinv; p2 = A + j; p1 = p2 + m2; while (p2 < ub) { u = *p2; v = *p1; t = MulMod2(v, w, q, wqinv); *p2 = AddMod(u, t, q); *p1 = SubMod(u, t, q); p1 += m; p2 += m; } w = MulMod2(z, w, q, wqinv); } } m2 = n >> 1; z = root[k]; zqinv = ((double) z)*qinv; w = 1; p2 = A; p1 = A + m2; for (j = 0; j < m2; j++) { u = *p2; v = *p1; t = MulMod(v, w, q, qinv); *p2 = AddMod(u, t, q); *p1 = SubMod(u, t, q); w = MulMod2(w, z, q, zqinv); p2++; p1++; }}#endifNTL_END_IMPL
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -