📄 mul_fft.c
字号:
TRACE (printf ("recurse: %dx%d limbs -> %d times %dx%d (%1.2f)\n", n, n, K2, nprime2, nprime2, 2.0*(double)n/nprime2/K2)); for (i = 0; i < K; i++,ap++,bp++) { mpn_fft_norm (*ap, n); if (!sqr) mpn_fft_norm (*bp, n); mpn_mul_fft_internal (*ap, *ap, *bp, n, k, K2, Ap, Bp, A, B, nprime2, l, Mp2, _fft_l, T, 1); } } else { mp_ptr a, b, tp, tpn; mp_limb_t cc; int n2 = 2 * n; tp = TMP_ALLOC_LIMBS (n2); tpn = tp+n; TRACE (printf (" mpn_mul_n %d of %d limbs\n", K, n)); for (i = 0; i < K; i++) { a = *ap++; b = *bp++; if (sqr) mpn_sqr_n (tp, a, n); else mpn_mul_n (tp, b, a, n); if (a[n] != 0) cc = mpn_add_n (tpn, tpn, b, n); else cc = 0; if (b[n] != 0) cc += mpn_add_n (tpn, tpn, a, n) + a[n]; if (cc != 0) { cc = mpn_add_1 (tp, tp, n2, cc); ASSERT_NOCARRY (mpn_add_1 (tp, tp, n2, cc)); } a[n] = mpn_sub_n (a, tp, tpn, n) && mpn_add_1 (a, a, n, CNST_LIMB(1)); } } TMP_FREE(marker);}/* input: A^[l[k][0]] A^[l[k][1]] ... A^[l[k][K-1]] output: K*A[0] K*A[K-1] ... K*A[1]. Assumes the Ap[] are pseudo-normalized, i.e. 0 <= Ap[][n] <= 1. This condition is also fulfilled at exit.*/static voidmpn_fft_fftinv (mp_ptr *Ap, int K, mp_size_t omega, mp_size_t n, mp_ptr tp){ if (K == 2) { mp_limb_t cy;#if HAVE_NATIVE_mpn_addsub_n cy = mpn_addsub_n (Ap[0], Ap[1], Ap[0], Ap[1], n + 1) & 1;#else MPN_COPY (tp, Ap[0], n + 1); mpn_add_n (Ap[0], Ap[0], Ap[1], n + 1); cy = mpn_sub_n (Ap[1], tp, Ap[1], n + 1);#endif if (Ap[0][n] > CNST_LIMB(1)) /* can be 2 or 3 */ Ap[0][n] = CNST_LIMB(1) - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - CNST_LIMB(1)); if (cy) /* Ap[1][n] can be -1 or -2 */ Ap[1][n] = mpn_add_1 (Ap[1], Ap[1], n, ~Ap[1][n] + CNST_LIMB(1)); } else { int j, K2 = K / 2; mp_ptr *Bp = Ap + K2, tmp; TMP_DECL(marker); TMP_MARK(marker); tmp = TMP_ALLOC_LIMBS (n + 1); mpn_fft_fftinv (Ap, K2, 2 * omega, n, tp); mpn_fft_fftinv (Bp, K2, 2 * omega, n, tp); /* A[j] <- A[j] + omega^j A[j+K/2] A[j+K/2] <- A[j] + omega^(j+K/2) A[j+K/2] */ for (j = 0; j < K2; j++,Ap++,Bp++) { MPN_COPY (tp, Bp[0], n + 1); mpn_fft_mul_2exp_modF (Bp[0], (j + K2) * omega, n, tmp); mpn_fft_add_modF (Bp[0], Ap[0], n); mpn_fft_mul_2exp_modF (tp, j * omega, n, tmp); mpn_fft_add_modF (Ap[0], tp, n); } TMP_FREE(marker); }}/* A <- A/2^k mod 2^(n*BITS_PER_MP_LIMB)+1 */static voidmpn_fft_div_2exp_modF (mp_ptr ap, int k, mp_size_t n, mp_ptr tp){ int i; i = 2 * n * BITS_PER_MP_LIMB; i = (i - k) % i; mpn_fft_mul_2exp_modF (ap, i, n, tp); /* 1/2^k = 2^(2nL-k) mod 2^(n*BITS_PER_MP_LIMB)+1 */ /* normalize so that A < 2^(n*BITS_PER_MP_LIMB)+1 */ mpn_fft_norm (ap, n);}/* R <- A mod 2^(n*BITS_PER_MP_LIMB)+1, n <= an <= 3*n */static voidmpn_fft_norm_modF (mp_ptr rp, mp_ptr ap, mp_size_t n, mp_size_t an){ mp_size_t l; ASSERT (n <= an && an <= 3 * n); if (an > 2 * n) { l = n; rp[n] = mpn_add_1 (rp + an - 2 * n, ap + an - 2 * n, 3 * n - an, mpn_add_n (rp, ap, ap + 2 * n, an - 2 * n)); } else { l = an - n; MPN_COPY (rp, ap, n); rp[n] = 0; } if (mpn_sub_n (rp, rp, ap + n, l)) { if (mpn_sub_1 (rp + l, rp + l, n + 1 - l, CNST_LIMB(1))) rp[n] = mpn_add_1 (rp, rp, n, CNST_LIMB(1)); }}static voidmpn_mul_fft_internal (mp_ptr op, mp_srcptr n, mp_srcptr m, mp_size_t pl, int k, int K, mp_ptr *Ap, mp_ptr *Bp, mp_ptr A, mp_ptr B, mp_size_t nprime, mp_size_t l, mp_size_t Mp, int **_fft_l, mp_ptr T, int rec){ int i, sqr, pla, lo, sh, j; mp_ptr p; sqr = n == m; TRACE (printf ("pl=%d k=%d K=%d np=%d l=%d Mp=%d rec=%d sqr=%d\n", pl,k,K,nprime,l,Mp,rec,sqr)); /* decomposition of inputs into arrays Ap[i] and Bp[i] */ if (rec) for (i = 0; i < K; i++) { Ap[i] = A + i * (nprime + 1); Bp[i] = B + i * (nprime + 1); /* store the next M bits of n into A[i] */ /* supposes that M is a multiple of BITS_PER_MP_LIMB */ MPN_COPY (Ap[i], n, l); n += l; MPN_ZERO (Ap[i]+l, nprime + 1 - l); /* set most significant bits of n and m (important in recursive calls) */ if (i == K - 1) Ap[i][l] = n[0]; mpn_fft_mul_2exp_modF (Ap[i], i * Mp, nprime, T); if (!sqr) { MPN_COPY (Bp[i], m, l); m += l; MPN_ZERO (Bp[i] + l, nprime + 1 - l); if (i == K - 1) Bp[i][l] = m[0]; mpn_fft_mul_2exp_modF (Bp[i], i * Mp, nprime, T); } } /* direct fft's */ if (sqr) mpn_fft_fft_sqr (Ap, K, _fft_l + k, 2 * Mp, nprime, 1, T); else mpn_fft_fft (Ap, Bp, K, _fft_l + k, 2 * Mp, nprime, 1, T); /* term to term multiplications */ mpn_fft_mul_modF_K (Ap, (sqr) ? Ap : Bp, nprime, K); /* inverse fft's */ mpn_fft_fftinv (Ap, K, 2 * Mp, nprime, T); /* division of terms after inverse fft */ for (i = 0; i < K; i++) mpn_fft_div_2exp_modF (Ap[i], k + ((K - i) % K) * Mp, nprime, T); /* addition of terms in result p */ MPN_ZERO (T, nprime + 1); pla = l * (K - 1) + nprime + 1; /* number of required limbs for p */ p = B; /* B has K*(n' + 1) limbs, which is >= pla, i.e. enough */ MPN_ZERO (p, pla); sqr = 0; /* will accumulate the (signed) carry at p[pla] */ for (i = K - 1,lo = l * i + nprime,sh = l * i; i >= 0; i--,lo -= l,sh -= l) { mp_ptr n = p+sh; j = (K-i)%K; if (mpn_add_n (n, n, Ap[j], nprime + 1)) sqr += mpn_add_1 (n + nprime + 1, n + nprime + 1, pla - sh - nprime - 1, CNST_LIMB(1)); T[2 * l]=i + 1; /* T = (i + 1)*2^(2*M) */ if (mpn_cmp (Ap[j],T,nprime + 1)>0) { /* subtract 2^N'+1 */ sqr -= mpn_sub_1 (n, n, pla - sh, CNST_LIMB(1)); sqr -= mpn_sub_1 (p + lo, p + lo, pla - lo, CNST_LIMB(1)); } } if (sqr == -1) { if ((sqr = mpn_add_1 (p + pla - pl,p + pla - pl,pl, CNST_LIMB(1)))) { /* p[pla-pl]...p[pla-1] are all zero */ mpn_sub_1 (p + pla - pl - 1, p + pla - pl - 1, pl + 1, CNST_LIMB(1)); mpn_sub_1 (p + pla - 1, p + pla - 1, 1, CNST_LIMB(1)); } } else if (sqr == 1) { if (pla >= 2 * pl) { while ((sqr = mpn_add_1 (p + pla - 2 * pl, p + pla - 2 * pl, 2 * pl, sqr))) ; } else { sqr = mpn_sub_1 (p + pla - pl, p + pla - pl, pl, sqr); ASSERT (sqr == 0); } } else ASSERT (sqr == 0); /* here p < 2^(2M) [K 2^(M(K-1)) + (K-1) 2^(M(K-2)) + ... ] < K 2^(2M) [2^(M(K-1)) + 2^(M(K-2)) + ... ] < K 2^(2M) 2^(M(K-1))*2 = 2^(M*K+M+k+1) */ mpn_fft_norm_modF (op, p, pl, pla);}/* op <- n*m mod 2^N+1 with fft of size 2^k where N=pl*BITS_PER_MP_LIMB n and m have respectively nl and ml limbs op must have space for pl+1 limbs Assumes pl is multiple of 2^k.*/voidmpn_mul_fft (mp_ptr op, mp_size_t pl, mp_srcptr n, mp_size_t nl, mp_srcptr m, mp_size_t ml, int k){ int K,maxLK,i,j; mp_size_t N, Nprime, nprime, M, Mp, l; mp_ptr *Ap,*Bp, A, T, B; int **_fft_l; int sqr = (n == m && nl == ml); TMP_DECL(marker); TRACE (printf ("\nmpn_mul_fft pl=%ld nl=%ld ml=%ld k=%d\n", pl, nl, ml, k)); TMP_MARK(marker); N = pl * BITS_PER_MP_LIMB; _fft_l = TMP_ALLOC_TYPE (k + 1, int*); for (i = 0; i <= k; i++) _fft_l[i] = TMP_ALLOC_TYPE (1<<i, int); mpn_fft_initl (_fft_l, k); K = 1 << k; ASSERT_ALWAYS (pl % K == 0); M = N/K; /* exact: N = 2^k M */ l = M / BITS_PER_MP_LIMB; /* l = pl / K also */ maxLK = (K>BITS_PER_MP_LIMB) ? K : BITS_PER_MP_LIMB; Nprime = ((2 * M + k + 2 + maxLK) / maxLK) * maxLK; /* ceil((2*M+k+3)/maxLK)*maxLK; */ nprime = Nprime / BITS_PER_MP_LIMB; /* with B := BITS_PER_MP_LIMB, nprime >= 2*M/B = 2*N/(K*B) = 2*pl/K = 2*l */ TRACE (printf ("N=%d K=%d, M=%d, l=%d, maxLK=%d, Np=%d, np=%d\n", N, K, M, l, maxLK, Nprime, nprime)); /* we should ensure that recursively, nprime is a multiple of the next K */ if (nprime >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) { unsigned long K2; while (nprime % (K2 = 1 << mpn_fft_best_k (nprime, sqr))) { nprime = ((nprime + K2 - 1) / K2) * K2; Nprime = nprime * BITS_PER_MP_LIMB; /* warning: since nprime changed, K2 may change too! */ } TRACE (printf ("new maxLK=%d, Np=%d, np=%d\n", maxLK, Nprime, nprime)); ASSERT(nprime % K2 == 0); } ASSERT_ALWAYS (nprime < pl); /* otherwise we'll loop */ T = TMP_ALLOC_LIMBS (nprime + 1); Mp = Nprime/K; TRACE (printf ("%dx%d limbs -> %d times %dx%d limbs (%1.2f)\n", pl,pl,K,nprime,nprime,2.0*(double)N/Nprime/K); printf (" temp space %ld\n", 2 * K * (nprime + 1))); A = __GMP_ALLOCATE_FUNC_LIMBS (2 * K * (nprime + 1)); B = A + K * (nprime + 1); Ap = TMP_ALLOC_MP_PTRS (K); Bp = TMP_ALLOC_MP_PTRS (K); /* special decomposition for main call */ for (i = 0; i < K; i++) { Ap[i] = A + i * (nprime + 1); Bp[i] = B + i * (nprime + 1); /* store the next M bits of n into A[i] */ /* supposes that M is a multiple of BITS_PER_MP_LIMB */ if (nl > 0) { j = (nl>=l) ? l : nl; /* limbs to store in Ap[i] */ MPN_COPY (Ap[i], n, j); n += l; MPN_ZERO (Ap[i] + j, nprime + 1 - j); mpn_fft_mul_2exp_modF (Ap[i], i * Mp, nprime, T); } else MPN_ZERO (Ap[i], nprime + 1); nl -= l; if (n != m) { if (ml > 0) { j = (ml>=l) ? l : ml; /* limbs to store in Bp[i] */ MPN_COPY (Bp[i], m, j); m += l; MPN_ZERO (Bp[i] + j, nprime + 1 - j); mpn_fft_mul_2exp_modF (Bp[i], i * Mp, nprime, T); } else MPN_ZERO (Bp[i], nprime + 1); } ml -= l; } mpn_mul_fft_internal (op, n, m, pl, k, K, Ap, Bp, A, B, nprime, l, Mp, _fft_l, T, 0); TMP_FREE(marker); __GMP_FREE_FUNC_LIMBS (A, 2 * K * (nprime + 1));}/* Multiply {n,nl}*{m,ml} and write the result to {op,nl+ml}. FIXME: Duplicating the result like this is wasteful, do something better perhaps at the norm_modF stage above. */voidmpn_mul_fft_full (mp_ptr op, mp_srcptr n, mp_size_t nl, mp_srcptr m, mp_size_t ml){ mp_ptr pad_op; mp_size_t pl; int k; int sqr = (n == m && nl == ml); k = mpn_fft_best_k (nl + ml, sqr); pl = mpn_fft_next_size (nl + ml, k); TRACE (printf ("mpn_mul_fft_full nl=%ld ml=%ld -> pl=%ld k=%d\n", nl, ml, pl, k)); pad_op = __GMP_ALLOCATE_FUNC_LIMBS (pl + 1); mpn_mul_fft (pad_op, pl, n, nl, m, ml, k); ASSERT_MPN_ZERO_P (pad_op + nl + ml, pl + 1 - (nl + ml)); MPN_COPY (op, pad_op, nl + ml); __GMP_FREE_FUNC_LIMBS (pad_op, pl + 1);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -