📄 zzx1.c
字号:
#include <NTL/ZZX.h>#include <NTL/new.h>NTL_START_IMPLvoid conv(zz_pX& x, const ZZX& a){ conv(x.rep, a.rep); x.normalize();}void conv(ZZX& x, const zz_pX& a){ conv(x.rep, a.rep); x.normalize();}long CRT(ZZX& gg, ZZ& a, const zz_pX& G){ long n = gg.rep.length(); long p = zz_p::modulus(); ZZ new_a; mul(new_a, a, p); long a_inv; a_inv = rem(a, p); a_inv = InvMod(a_inv, p); long p1; p1 = p >> 1; ZZ a1; RightShift(a1, a, 1); long p_odd = (p & 1); long modified = 0; long h; ZZ ah; long m = G.rep.length(); long max_mn = max(m, n); gg.rep.SetLength(max_mn); ZZ g; long i; for (i = 0; i < n; i++) { if (!CRTInRange(gg.rep[i], a)) { modified = 1; rem(g, gg.rep[i], a); if (g > a1) sub(g, g, a); } else g = gg.rep[i]; h = rem(g, p); if (i < m) h = SubMod(rep(G.rep[i]), h, p); else h = NegateMod(h, p); h = MulMod(h, a_inv, p); if (h > p1) h = h - p; if (h != 0) { modified = 1; mul(ah, a, h); if (!p_odd && g > 0 && (h == p1)) sub(g, g, ah); else add(g, g, ah); } gg.rep[i] = g; } for (; i < m; i++) { h = rep(G.rep[i]); h = MulMod(h, a_inv, p); if (h > p1) h = h - p; modified = 1; mul(g, a, h); gg.rep[i] = g; } gg.normalize(); a = new_a; return modified;}long CRT(ZZX& gg, ZZ& a, const ZZ_pX& G){ long n = gg.rep.length(); const ZZ& p = ZZ_p::modulus(); ZZ new_a; mul(new_a, a, p); ZZ a_inv; rem(a_inv, a, p); InvMod(a_inv, a_inv, p); ZZ p1; RightShift(p1, p, 1); ZZ a1; RightShift(a1, a, 1); long p_odd = IsOdd(p); long modified = 0; ZZ h; ZZ ah; long m = G.rep.length(); long max_mn = max(m, n); gg.rep.SetLength(max_mn); ZZ g; long i; for (i = 0; i < n; i++) { if (!CRTInRange(gg.rep[i], a)) { modified = 1; rem(g, gg.rep[i], a); if (g > a1) sub(g, g, a); } else g = gg.rep[i]; rem(h, g, p); if (i < m) SubMod(h, rep(G.rep[i]), h, p); else NegateMod(h, h, p); MulMod(h, h, a_inv, p); if (h > p1) sub(h, h, p); if (h != 0) { modified = 1; mul(ah, a, h); if (!p_odd && g > 0 && (h == p1)) sub(g, g, ah); else add(g, g, ah); } gg.rep[i] = g; } for (; i < m; i++) { h = rep(G.rep[i]); MulMod(h, h, a_inv, p); if (h > p1) sub(h, h, p); modified = 1; mul(g, a, h); gg.rep[i] = g; } gg.normalize(); a = new_a; return modified;}/* Compute a = b * 2^l mod p, where p = 2^n+1. 0<=l<=n and 0<b<p are assumed. */static void LeftRotate(ZZ& a, const ZZ& b, long l, const ZZ& p, long n){ if (l == 0) { if (&a != &b) { a = b; } return; } /* tmp := upper l bits of b */ static ZZ tmp; RightShift(tmp, b, n - l); /* a := 2^l * lower n - l bits of b */ trunc(a, b, n - l); LeftShift(a, a, l); /* a -= tmp */ sub(a, a, tmp); if (sign(a) < 0) { add(a, a, p); }}/* Compute a = b * 2^l mod p, where p = 2^n+1. 0<=p<b is assumed. */static void Rotate(ZZ& a, const ZZ& b, long l, const ZZ& p, long n){ if (IsZero(b)) { clear(a); return; } /* l %= 2n */ if (l >= 0) { l %= (n << 1); } else { l = (n << 1) - 1 - (-(l + 1) % (n << 1)); } /* a = b * 2^l mod p */ if (l < n) { LeftRotate(a, b, l, p, n); } else { LeftRotate(a, b, l - n, p, n); SubPos(a, p, a); }}/* Fast Fourier Transform. a is a vector of length 2^l, 2^l divides 2n, p = 2^n+1, w = 2^r mod p is a primitive (2^l)th root of unity. Returns a(1),a(w),...,a(w^{2^l-1}) mod p in bit-reverse order. */static void fft(vec_ZZ& a, long r, long l, const ZZ& p, long n){ long round; long off, i, j, e; long halfsize; ZZ tmp, tmp1; for (round = 0; round < l; round++, r <<= 1) { halfsize = 1L << (l - 1 - round); for (i = (1L << round) - 1, off = 0; i >= 0; i--, off += halfsize) { for (j = 0, e = 0; j < halfsize; j++, off++, e+=r) { /* One butterfly : ( a[off], a[off+halfsize] ) *= ( 1 w^{j2^round} ) ( 1 -w^{j2^round} ) */ /* tmp = a[off] - a[off + halfsize] mod p */ sub(tmp, a[off], a[off + halfsize]); if (sign(tmp) < 0) { add(tmp, tmp, p); } /* a[off] += a[off + halfsize] mod p */ add(a[off], a[off], a[off + halfsize]); sub(tmp1, a[off], p); if (sign(tmp1) >= 0) { a[off] = tmp1; } /* a[off + halfsize] = tmp * w^{j2^round} mod p */ Rotate(a[off + halfsize], tmp, e, p, n); } } }}/* Inverse FFT. r must be the same as in the call to FFT. Result is by 2^l too large. */static void ifft(vec_ZZ& a, long r, long l, const ZZ& p, long n){ long round; long off, i, j, e; long halfsize; ZZ tmp, tmp1; for (round = l - 1, r <<= l - 1; round >= 0; round--, r >>= 1) { halfsize = 1L << (l - 1 - round); for (i = (1L << round) - 1, off = 0; i >= 0; i--, off += halfsize) { for (j = 0, e = 0; j < halfsize; j++, off++, e+=r) { /* One inverse butterfly : ( a[off], a[off+halfsize] ) *= ( 1 1 ) ( w^{-j2^round} -w^{-j2^round} ) */ /* a[off + halfsize] *= w^{-j2^round} mod p */ Rotate(a[off + halfsize], a[off + halfsize], -e, p, n); /* tmp = a[off] - a[off + halfsize] */ sub(tmp, a[off], a[off + halfsize]); /* a[off] += a[off + halfsize] mod p */ add(a[off], a[off], a[off + halfsize]); sub(tmp1, a[off], p); if (sign(tmp1) >= 0) { a[off] = tmp1; } /* a[off+halfsize] = tmp mod p */ if (sign(tmp) < 0) { add(a[off+halfsize], tmp, p); } else { a[off+halfsize] = tmp; } } } }}/* Multiplication a la Schoenhage & Strassen, modulo a "Fermat" number p = 2^{mr}+1, where m is a power of two and r is odd. Then w = 2^r is a primitive 2mth root of unity, i.e., polynomials whose product has degree less than 2m can be multiplied, provided that the coefficients of the product polynomial are at most 2^{mr-1} in absolute value. The algorithm is not called recursively; coefficient arithmetic is done directly.*/void SSMul(ZZX& c, const ZZX& a, const ZZX& b){ long na = deg(a); long nb = deg(b); if (na <= 0 || nb <= 0) { PlainMul(c, a, b); return; } long n = na + nb; /* degree of the product */ /* Choose m and r suitably */ long l = NextPowerOfTwo(n + 1) - 1; /* 2^l <= n < 2^{l+1} */ long m2 = 1L << (l + 1); /* m2 = 2m = 2^{l+1} */ /* Bitlength of the product: if the coefficients of a are absolutely less than 2^ka and the coefficients of b are absolutely less than 2^kb, then the coefficients of ab are absolutely less than (min(na,nb)+1)2^{ka+kb} <= 2^bound. */ long bound = 2 + NumBits(min(na, nb)) + MaxBits(a) + MaxBits(b); /* Let r be minimal so that mr > bound */ long r = (bound >> l) + 1; long mr = r << l; /* p := 2^{mr}+1 */ ZZ p; set(p); LeftShift(p, p, mr); add(p, p, 1); /* Make coefficients of a and b positive */ vec_ZZ aa, bb; aa.SetLength(m2); bb.SetLength(m2); long i; for (i = 0; i <= deg(a); i++) { if (sign(a.rep[i]) >= 0) { aa[i] = a.rep[i]; } else { add(aa[i], a.rep[i], p); } } for (i = 0; i <= deg(b); i++) { if (sign(b.rep[i]) >= 0) { bb[i] = b.rep[i]; } else { add(bb[i], b.rep[i], p); } } /* 2m-point FFT's mod p */ fft(aa, r, l + 1, p, mr); fft(bb, r, l + 1, p, mr); /* Pointwise multiplication aa := aa * bb mod p */ ZZ tmp, ai; for (i = 0; i < m2; i++) { mul(ai, aa[i], bb[i]); if (NumBits(ai) > mr) { RightShift(tmp, ai, mr); trunc(ai, ai, mr); sub(ai, ai, tmp); if (sign(ai) < 0) { add(ai, ai, p); } } aa[i] = ai; } ifft(aa, r, l + 1, p, mr); /* Retrieve c, dividing by 2m, and subtracting p where necessary */ c.rep.SetLength(n + 1); for (i = 0; i <= n; i++) { ai = aa[i]; ZZ& ci = c.rep[i]; if (!IsZero(ai)) { /* ci = -ai * 2^{mr-l-1} = ai * 2^{-l-1} = ai / 2m mod p */ LeftRotate(ai, ai, mr - l - 1, p, mr); sub(tmp, p, ai); if (NumBits(tmp) >= mr) { /* ci >= (p-1)/2 */ negate(ci, ai); /* ci = -ai = ci - p */ } else ci = tmp; } else clear(ci); }}// SSRatio computes how much bigger the SS moduls must be// to accomodate the necessary roots of unity.// This is useful in determining algorithm crossover points.double SSRatio(long na, long maxa, long nb, long maxb){ if (na <= 0 || nb <= 0) return 0; long n = na + nb; /* degree of the product */ long l = NextPowerOfTwo(n + 1) - 1; /* 2^l <= n < 2^{l+1} */ long bound = 2 + NumBits(min(na, nb)) + maxa + maxb; long r = (bound >> l) + 1; long mr = r << l; return double(mr + 1)/double(bound);}void HomMul(ZZX& x, const ZZX& a, const ZZX& b){ if (&a == &b) { HomSqr(x, a); return; } long da = deg(a); long db = deg(b); if (da < 0 || db < 0) { clear(x); return; } long bound = 2 + NumBits(min(da, db)+1) + MaxBits(a) + MaxBits(b); ZZ prod; set(prod); long i, nprimes; zz_pBak bak; bak.save(); for (nprimes = 0; NumBits(prod) <= bound; nprimes++) { if (nprimes >= NumFFTPrimes) zz_p::FFTInit(nprimes); mul(prod, prod, FFTPrime[nprimes]); } ZZ coeff; ZZ t1; long tt; vec_ZZ c; c.SetLength(da+db+1); long j; for (i = 0; i < nprimes; i++) { zz_p::FFTInit(i); long p = zz_p::modulus(); div(t1, prod, p); tt = rem(t1, p); tt = InvMod(tt, p); mul(coeff, t1, tt); zz_pX A, B, C; conv(A, a); conv(B, b); mul(C, A, B); long m = deg(C); for (j = 0; j <= m; j++) { /* c[j] += coeff*rep(C.rep[j]) */ mul(t1, coeff, rep(C.rep[j])); add(c[j], c[j], t1); } } x.rep.SetLength(da+db+1); ZZ prod2; RightShift(prod2, prod, 1); for (j = 0; j <= da+db; j++) { rem(t1, c[j], prod); if (t1 > prod2) sub(x.rep[j], t1, prod); else x.rep[j] = t1; } x.normalize(); bak.restore();}staticlong MaxSize(const ZZX& a){ long res = 0; long n = a.rep.length(); long i; for (i = 0; i < n; i++) { long t = a.rep[i].size(); if (t > res) res = t; } return res;}void mul(ZZX& c, const ZZX& a, const ZZX& b){ if (IsZero(a) || IsZero(b)) { clear(c); return; } if (&a == &b) { sqr(c, a); return; } long maxa = MaxSize(a); long maxb = MaxSize(b); long k = min(maxa, maxb); long s = min(deg(a), deg(b)) + 1; if (s == 1 || (k == 1 && s < 40) || (k == 2 && s < 20) || (k == 3 && s < 10)) { PlainMul(c, a, b); return; } if (s < 80 || (k < 30 && s < 150)) { KarMul(c, a, b); return; } if (maxa + maxb >= 40 && SSRatio(deg(a), MaxBits(a), deg(b), MaxBits(b)) < 1.75) SSMul(c, a, b); else HomMul(c, a, b);}void SSSqr(ZZX& c, const ZZX& a){ long na = deg(a); if (na <= 0) { PlainSqr(c, a); return; } long n = na + na; /* degree of the product */ long l = NextPowerOfTwo(n + 1) - 1; /* 2^l <= n < 2^{l+1} */ long m2 = 1L << (l + 1); /* m2 = 2m = 2^{l+1} */ long bound = 2 + NumBits(na) + 2*MaxBits(a); long r = (bound >> l) + 1; long mr = r << l; /* p := 2^{mr}+1 */ ZZ p; set(p); LeftShift(p, p, mr); add(p, p, 1); vec_ZZ aa; aa.SetLength(m2); long i; for (i = 0; i <= deg(a); i++) { if (sign(a.rep[i]) >= 0) { aa[i] = a.rep[i]; } else { add(aa[i], a.rep[i], p); } } /* 2m-point FFT's mod p */ fft(aa, r, l + 1, p, mr); /* Pointwise multiplication aa := aa * aa mod p */ ZZ tmp, ai; for (i = 0; i < m2; i++) { sqr(ai, aa[i]); if (NumBits(ai) > mr) { RightShift(tmp, ai, mr); trunc(ai, ai, mr); sub(ai, ai, tmp); if (sign(ai) < 0) { add(ai, ai, p); } } aa[i] = ai; } ifft(aa, r, l + 1, p, mr); ZZ ci; /* Retrieve c, dividing by 2m, and subtracting p where necessary */ c.rep.SetLength(n + 1); for (i = 0; i <= n; i++) { ai = aa[i]; ZZ& ci = c.rep[i]; if (!IsZero(ai)) { /* ci = -ai * 2^{mr-l-1} = ai * 2^{-l-1} = ai / 2m mod p */ LeftRotate(ai, ai, mr - l - 1, p, mr); sub(tmp, p, ai); if (NumBits(tmp) >= mr) { /* ci >= (p-1)/2 */ negate(ci, ai); /* ci = -ai = ci - p */ } else ci = tmp; } else clear(ci); }}void HomSqr(ZZX& x, const ZZX& a){ long da = deg(a); if (da < 0) { clear(x); return; } long bound = 2 + NumBits(da+1) + 2*MaxBits(a); ZZ prod; set(prod); long i, nprimes; zz_pBak bak; bak.save(); for (nprimes = 0; NumBits(prod) <= bound; nprimes++) { if (nprimes >= NumFFTPrimes) zz_p::FFTInit(nprimes); mul(prod, prod, FFTPrime[nprimes]); } ZZ coeff; ZZ t1; long tt; vec_ZZ c; c.SetLength(da+da+1); long j; for (i = 0; i < nprimes; i++) { zz_p::FFTInit(i); long p = zz_p::modulus(); div(t1, prod, p); tt = rem(t1, p); tt = InvMod(tt, p); mul(coeff, t1, tt); zz_pX A, C; conv(A, a); sqr(C, A); long m = deg(C); for (j = 0; j <= m; j++) { /* c[j] += coeff*rep(C.rep[j]) */ mul(t1, coeff, rep(C.rep[j])); add(c[j], c[j], t1); } } x.rep.SetLength(da+da+1); ZZ prod2; RightShift(prod2, prod, 1); for (j = 0; j <= da+da; j++) { rem(t1, c[j], prod); if (t1 > prod2) sub(x.rep[j], t1, prod); else x.rep[j] = t1; } x.normalize(); bak.restore();}void sqr(ZZX& c, const ZZX& a){ if (IsZero(a)) { clear(c); return; } long maxa = MaxSize(a); long k = maxa; long s = deg(a) + 1; if (s == 1 || (k == 1 && s < 50) || (k == 2 && s < 25) || (k == 3 && s < 25) || (k == 4 && s < 10)) { PlainSqr(c, a); return; } if (s < 80 || (k < 30 && s < 150)) { KarSqr(c, a); return; } long mba = MaxBits(a); if (2*maxa >= 40 && SSRatio(deg(a), mba, deg(a), mba) < 1.75) SSSqr(c, a); else HomSqr(c, a);}void mul(ZZX& x, const ZZX& a, const ZZ& b){ ZZ t; long i, da; const ZZ *ap; ZZ* xp; if (IsZero(b)) { clear(x); return; } t = b; da = deg(a); x.rep.SetLength(da+1); ap = a.rep.elts(); xp = x.rep.elts(); for (i = 0; i <= da; i++) mul(xp[i], ap[i], t);}void mul(ZZX& x, const ZZX& a, long b){ long i, da; const ZZ *ap; ZZ* xp; if (b == 0) { clear(x); return; } da = deg(a); x.rep.SetLength(da+1);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -