📄 mul_fft.c
字号:
/* An implementation in GMP of Scho"nhage's fast multiplication algorithm modulo 2^N+1, by Paul Zimmermann, INRIA Lorraine, February 1998. THE CONTENTS OF THIS FILE ARE FOR INTERNAL USE AND THE FUNCTIONS HAVE MUTABLE INTERFACES. IT IS ONLY SAFE TO REACH THEM THROUGH DOCUMENTED INTERFACES. IT IS ALMOST GUARANTEED THAT THEY'LL CHANGE OR DISAPPEAR IN A FUTURE GNU MP RELEASE.Copyright 1998, 1999, 2000, 2001, 2002, 2004 Free Software Foundation, Inc.This file is part of the GNU MP Library.The GNU MP Library is free software; you can redistribute it and/or modifyit under the terms of the GNU Lesser General Public License as published bythe Free Software Foundation; either version 2.1 of the License, or (at youroption) any later version.The GNU MP Library is distributed in the hope that it will be useful, butWITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITYor FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General PublicLicense for more details.You should have received a copy of the GNU Lesser General Public Licensealong with the GNU MP Library; see the file COPYING.LIB. If not, write tothe Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,MA 02111-1307, USA. *//* References: Schnelle Multiplikation grosser Zahlen, by Arnold Scho"nhage and Volker Strassen, Computing 7, p. 281-292, 1971. Asymptotically fast algorithms for the numerical multiplication and division of polynomials with complex coefficients, by Arnold Scho"nhage, Computer Algebra, EUROCAM'82, LNCS 144, p. 3-15, 1982. Tapes versus Pointers, a study in implementing fast algorithms, by Arnold Scho"nhage, Bulletin of the EATCS, 30, p. 23-32, 1986. See also http://www.loria.fr/~zimmerma/bignum Future: It might be possible to avoid a small number of MPN_COPYs by using a rotating temporary or two. Multiplications of unequal sized operands can be done with this code, but it needs a tighter test for identifying squaring (same sizes as well as same pointers). */#include <stdio.h>#include "gmp.h"#include "gmp-impl.h"/* Change this to "#define TRACE(x) x" for some traces. */#define TRACE(x)FFT_TABLE_ATTRS mp_size_t mpn_fft_table[2][MPN_FFT_TABLE_SIZE] = { MUL_FFT_TABLE, SQR_FFT_TABLE};static void mpn_mul_fft_internal_PROTO ((mp_ptr, mp_srcptr, mp_srcptr, mp_size_t, int, int, mp_ptr *, mp_ptr *, mp_ptr, mp_ptr, mp_size_t, mp_size_t, mp_size_t, int **, mp_ptr,int));/* Find the best k to use for a mod 2^(m*BITS_PER_MP_LIMB)+1 FFT with m >= n. sqr==0 if for a multiply, sqr==1 for a square.*/intmpn_fft_best_k (mp_size_t n, int sqr){ int i; for (i = 0; mpn_fft_table[sqr][i] != 0; i++) if (n < mpn_fft_table[sqr][i]) return i + FFT_FIRST_K; /* treat 4*last as one further entry */ if (i == 0 || n < 4 * mpn_fft_table[sqr][i - 1]) return i + FFT_FIRST_K; else return i + FFT_FIRST_K + 1;}/* Returns smallest possible number of limbs >= pl for a fft of size 2^k, i.e. smallest multiple of 2^k >= pl. */mp_size_tmpn_fft_next_size (mp_size_t pl, int k){ unsigned long K; K = 1 << k; pl = 1 + (pl - 1) / K; /* ceil(pl/K) */ return pl * K;}static voidmpn_fft_initl (int **l, int k){ int i, j, K; l[0][0] = 0; for (i = 1,K = 2; i <= k; i++,K *= 2) { for (j = 0; j < K / 2; j++) { l[i][j] = 2 * l[i - 1][j]; l[i][K / 2 + j] = 1 + l[i][j]; } }}/* a <- a*2^e mod 2^(n*BITS_PER_MP_LIMB)+1 */static voidmpn_fft_mul_2exp_modF (mp_ptr ap, int e, mp_size_t n, mp_ptr tp){ int d, sh, i; mp_limb_t cc; d = e % (n * BITS_PER_MP_LIMB); /* 2^e = (+/-) 2^d */ sh = d % BITS_PER_MP_LIMB; if (sh != 0) mpn_lshift (tp, ap, n + 1, sh); /* no carry here */ else MPN_COPY (tp, ap, n + 1); d /= BITS_PER_MP_LIMB; /* now shift of d limbs to the left */ if (d) { /* ap[d..n-1] = tp[0..n-d-1], ap[0..d-1] = -tp[n-d..n-1] */ /* mpn_xor would be more efficient here */ for (i = d - 1; i >= 0; i--) ap[i] = ~tp[n - d + i]; cc = 1 - mpn_add_1 (ap, ap, d, CNST_LIMB(1)); if (cc != 0) cc = mpn_sub_1 (ap + d, tp, n - d, CNST_LIMB(1)); else MPN_COPY (ap + d, tp, n - d); cc += mpn_sub_1 (ap + d, ap + d, n - d, tp[n]); if (cc != 0) ap[n] = mpn_add_1 (ap, ap, n, cc); else ap[n] = 0; } else if ((ap[n] = mpn_sub_1 (ap, tp, n, tp[n]))) { ap[n] = mpn_add_1 (ap, ap, n, CNST_LIMB(1)); } if ((e / (n * BITS_PER_MP_LIMB)) % 2) { mp_limb_t c; mpn_com_n (ap, ap, n); c = ap[n] + 2; ap[n] = 0; mpn_incr_u (ap, c); }}/* a <- a+b mod 2^(n*BITS_PER_MP_LIMB)+1 */static voidmpn_fft_add_modF (mp_ptr ap, mp_ptr bp, int n){ mp_limb_t c; c = ap[n] + bp[n] + mpn_add_n (ap, ap, bp, n); if (c > 1) /* subtract c-1 to both ap[0] and ap[n] */ { ap[n] = 1; mpn_decr_u (ap, c - 1); } else ap[n] = c;}/* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where N=n*BITS_PER_MP_LIMB 2^omega is a primitive root mod 2^N+1 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */static voidmpn_fft_fft_sqr (mp_ptr *Ap, mp_size_t K, int **ll, mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp){ if (K == 2) { mp_limb_t cy;#if HAVE_NATIVE_mpn_addsub_n cy = mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1;#else MPN_COPY (tp, Ap[0], n + 1); mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1); cy = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1);#endif if (Ap[0][n] > CNST_LIMB(1)) /* can be 2 or 3 */ Ap[0][n] = CNST_LIMB(1) - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - CNST_LIMB(1)); if (cy) /* Ap[inc][n] can be -1 or -2 */ Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + CNST_LIMB(1)); } else { int j, inc2 = 2 * inc; int *lk = *ll; mp_ptr tmp; TMP_DECL(marker); TMP_MARK(marker); tmp = TMP_ALLOC_LIMBS (n + 1); mpn_fft_fft_sqr (Ap, K/2,ll-1,2 * omega,n,inc2, tp); mpn_fft_fft_sqr (Ap+inc, K/2,ll-1,2 * omega,n,inc2, tp); /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc] A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */ for (j = 0; j < K / 2; j++,lk += 2,Ap += 2 * inc) { MPN_COPY (tp, Ap[inc], n + 1); mpn_fft_mul_2exp_modF (Ap[inc], lk[1] * omega, n, tmp); mpn_fft_add_modF (Ap[inc], Ap[0], n); mpn_fft_mul_2exp_modF (tp, lk[0] * omega, n, tmp); mpn_fft_add_modF (Ap[0], tp, n); } TMP_FREE(marker); }}/* input: A[0] ... A[inc*(K-1)] are residues mod 2^N+1 where N=n*BITS_PER_MP_LIMB 2^omega is a primitive root mod 2^N+1 output: A[inc*l[k][i]] <- \sum (2^omega)^(ij) A[inc*j] mod 2^N+1 */static voidmpn_fft_fft (mp_ptr *Ap, mp_ptr *Bp, mp_size_t K, int **ll, mp_size_t omega, mp_size_t n, mp_size_t inc, mp_ptr tp){ if (K == 2) { mp_limb_t ca, cb;#if HAVE_NATIVE_mpn_addsub_n ca = mpn_addsub_n (Ap[0], Ap[inc], Ap[0], Ap[inc], n + 1) & 1; cb = mpn_addsub_n (Bp[0], Bp[inc], Bp[0], Bp[inc], n + 1) & 1;#else MPN_COPY (tp, Ap[0], n + 1); mpn_add_n (Ap[0], Ap[0], Ap[inc], n + 1); ca = mpn_sub_n (Ap[inc], tp, Ap[inc], n + 1); MPN_COPY (tp, Bp[0], n + 1); mpn_add_n (Bp[0], Bp[0], Bp[inc], n + 1); cb = mpn_sub_n (Bp[inc], tp, Bp[inc], n + 1);#endif if (Ap[0][n] > CNST_LIMB(1)) /* can be 2 or 3 */ Ap[0][n] = CNST_LIMB(1) - mpn_sub_1 (Ap[0], Ap[0], n, Ap[0][n] - CNST_LIMB(1)); if (ca) /* Ap[inc][n] can be -1 or -2 */ Ap[inc][n] = mpn_add_1 (Ap[inc], Ap[inc], n, ~Ap[inc][n] + CNST_LIMB(1)); if (Bp[0][n] > CNST_LIMB(1)) /* can be 2 or 3 */ Bp[0][n] = CNST_LIMB(1) - mpn_sub_1 (Bp[0], Bp[0], n, Bp[0][n] - CNST_LIMB(1)); if (cb) /* Bp[inc][n] can be -1 or -2 */ Bp[inc][n] = mpn_add_1 (Bp[inc], Bp[inc], n, ~Bp[inc][n] + CNST_LIMB(1)); } else { int j, inc2=2 * inc; int *lk = *ll; mp_ptr tmp; TMP_DECL(marker); TMP_MARK(marker); tmp = TMP_ALLOC_LIMBS (n + 1); mpn_fft_fft (Ap, Bp, K/2,ll-1,2 * omega,n,inc2, tp); mpn_fft_fft (Ap+inc, Bp+inc, K/2,ll-1,2 * omega,n,inc2, tp); /* A[2*j*inc] <- A[2*j*inc] + omega^l[k][2*j*inc] A[(2j+1)inc] A[(2j+1)inc] <- A[2*j*inc] + omega^l[k][(2j+1)inc] A[(2j+1)inc] */ for (j = 0; j < K / 2; j++,lk += 2,Ap += 2 * inc,Bp += 2 * inc) { MPN_COPY (tp, Ap[inc], n + 1); mpn_fft_mul_2exp_modF (Ap[inc], lk[1] * omega, n, tmp); mpn_fft_add_modF (Ap[inc], Ap[0], n); mpn_fft_mul_2exp_modF (tp, lk[0] * omega, n, tmp); mpn_fft_add_modF (Ap[0], tp, n); MPN_COPY (tp, Bp[inc], n + 1); mpn_fft_mul_2exp_modF (Bp[inc], lk[1] * omega, n, tmp); mpn_fft_add_modF (Bp[inc], Bp[0], n); mpn_fft_mul_2exp_modF (tp, lk[0] * omega, n, tmp); mpn_fft_add_modF (Bp[0], tp, n); } TMP_FREE(marker); }}/* Given ap[0..n] with ap[n]<=1, reduce it modulo 2^(n*BITS_PER_MP_LIMB)+1, by subtracting that modulus if necessary. If ap[0..n] is exactly 2^(n*BITS_PER_MP_LIMB) then the sub_1 produces a borrow and the limbs must be zeroed out again. This will occur very infrequently. */static voidmpn_fft_norm (mp_ptr ap, mp_size_t n){ ASSERT (ap[n] <= 1); if (ap[n]) { if ((ap[n] = mpn_sub_1 (ap, ap, n, CNST_LIMB(1)))) MPN_ZERO (ap, n); }}/* a[i] <- a[i]*b[i] mod 2^(n*BITS_PER_MP_LIMB)+1 for 0 <= i < K */static voidmpn_fft_mul_modF_K (mp_ptr *ap, mp_ptr *bp, mp_size_t n, int K){ int i; int sqr = (ap == bp); TMP_DECL(marker); TMP_MARK(marker); if (n >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) { int k, K2,nprime2,Nprime2,M2,maxLK,l,Mp2; int **_fft_l; mp_ptr *Ap,*Bp,A,B,T; k = mpn_fft_best_k (n, sqr); K2 = 1 << k; ASSERT_ALWAYS(n % K2 == 0); maxLK = (K2>BITS_PER_MP_LIMB) ? K2 : BITS_PER_MP_LIMB; M2 = n*BITS_PER_MP_LIMB/K2; l = n / K2; Nprime2 = ((2 * M2+k+2+maxLK)/maxLK)*maxLK; /* ceil((2*M2+k+3)/maxLK)*maxLK*/ nprime2 = Nprime2 / BITS_PER_MP_LIMB; /* we should ensure that nprime2 is a multiple of the next K */ if (nprime2 >= (sqr ? SQR_FFT_MODF_THRESHOLD : MUL_FFT_MODF_THRESHOLD)) { unsigned long K3; while (nprime2 % (K3 = 1 << mpn_fft_best_k (nprime2, sqr))) { nprime2 = ((nprime2 + K3 - 1) / K3) * K3; Nprime2 = nprime2 * BITS_PER_MP_LIMB; /* warning: since nprime2 changed, K3 may change too! */ } ASSERT(nprime2 % K3 == 0); } ASSERT_ALWAYS(nprime2 < n); /* otherwise we'll loop */ Mp2 = Nprime2 / K2; Ap = TMP_ALLOC_MP_PTRS (K2); Bp = TMP_ALLOC_MP_PTRS (K2); A = TMP_ALLOC_LIMBS (2 * K2 * (nprime2 + 1)); T = TMP_ALLOC_LIMBS (nprime2 + 1); B = A + K2 * (nprime2 + 1); _fft_l = TMP_ALLOC_TYPE (k + 1, int*); for (i = 0; i <= k; i++) _fft_l[i] = TMP_ALLOC_TYPE (1<<i, int); mpn_fft_initl (_fft_l, k);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -