⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mat_zz.c

📁 密码大家Shoup写的数论算法c语言实现
💻 C
📖 第 1 页 / 共 2 页
字号:
            determinant(dd, AA);               if (CRT(d, d_prod, rep(dd), P))               d_instable = 1;            else                break;         }      }      zz_p::FFTInit(i);      long p = zz_p::modulus();      mat_zz_p AA;      conv(AA, A);      if (!check) {         mat_zz_p xx;         zz_p dd;          inv(dd, xx, AA);         d_instable = CRT(d, d_prod, rep(dd), p);         if (!IsZero(dd)) {            mul(xx, xx, dd);            x_instable = CRT(x, x_prod, xx);         }         else            x_instable = 1;         if (!d_instable && !x_instable) {            mul(y, x, A);            if (IsDiag(y, n, d)) {               d1 = d;               check = 1;            }         }      }      else {         zz_p dd;         determinant(dd, AA);         d_instable = CRT(d, d_prod, rep(dd), p);      }   }   if (check && d1 != d) {      mul(x, x, d);      ExactDiv(x, d1);   }   d_out = d;   if (check) x_out = x;   zbak.restore();   Zbak.restore();}void negate(mat_ZZ& X, const mat_ZZ& A){   long n = A.NumRows();   long m = A.NumCols();   X.SetDims(n, m);   long i, j;   for (i = 1; i <= n; i++)      for (j = 1; j <= m; j++)         negate(X(i,j), A(i,j));}long IsZero(const mat_ZZ& a){   long n = a.NumRows();   long i;   for (i = 0; i < n; i++)      if (!IsZero(a[i]))         return 0;   return 1;}void clear(mat_ZZ& x){   long n = x.NumRows();   long i;   for (i = 0; i < n; i++)      clear(x[i]);}mat_ZZ operator+(const mat_ZZ& a, const mat_ZZ& b){   mat_ZZ res;   add(res, a, b);   NTL_OPT_RETURN(mat_ZZ, res);}mat_ZZ operator*(const mat_ZZ& a, const mat_ZZ& b){   mat_ZZ res;   mul_aux(res, a, b);   NTL_OPT_RETURN(mat_ZZ, res);}mat_ZZ operator-(const mat_ZZ& a, const mat_ZZ& b){   mat_ZZ res;   sub(res, a, b);   NTL_OPT_RETURN(mat_ZZ, res);}mat_ZZ operator-(const mat_ZZ& a){   mat_ZZ res;   negate(res, a);   NTL_OPT_RETURN(mat_ZZ, res);}vec_ZZ operator*(const mat_ZZ& a, const vec_ZZ& b){   vec_ZZ res;   mul_aux(res, a, b);   NTL_OPT_RETURN(vec_ZZ, res);}vec_ZZ operator*(const vec_ZZ& a, const mat_ZZ& b){   vec_ZZ res;   mul_aux(res, a, b);   NTL_OPT_RETURN(vec_ZZ, res);}void inv(mat_ZZ& X, const mat_ZZ& A){   ZZ d;   inv(d, X, A);   if (d == -1)      negate(X, X);   else if (d != 1)      Error("inv: non-invertible matrix");}void power(mat_ZZ& X, const mat_ZZ& A, const ZZ& e){   if (A.NumRows() != A.NumCols()) Error("power: non-square matrix");   if (e == 0) {      ident(X, A.NumRows());      return;   }   mat_ZZ T1, T2;   long i, k;   k = NumBits(e);   T1 = A;   for (i = k-2; i >= 0; i--) {      sqr(T2, T1);      if (bit(e, i))         mul(T1, T2, A);      else         T1 = T2;   }   if (e < 0)      inv(X, T1);   else      X = T1;}/***********************************************************   routines for solving a linear system vi Hensel lifting************************************************************/staticlong MaxBits(const mat_ZZ& A){   long m = 0;   long i, j;   for (i = 0; i < A.NumRows(); i++)      for (j = 0; j < A.NumCols(); j++)         m = max(m, NumBits(A[i][j]));   return m;}// Computes an upper bound on the numerators and denominators// to the solution x*A = b using Hadamard's bound and Cramer's rule. // If A contains a zero row, then sets both bounds to zero.staticvoid hadamard(ZZ& num_bound, ZZ& den_bound,               const mat_ZZ& A, const vec_ZZ& b){   long n = A.NumRows();   if (n == 0) Error("internal error: hadamard with n = 0");   ZZ b_len, min_A_len, prod, t1;   InnerProduct(min_A_len, A[0], A[0]);   prod = min_A_len;   long i;   for (i = 1; i < n; i++) {      InnerProduct(t1, A[i], A[i]);      if (t1 < min_A_len)         min_A_len = t1;      mul(prod, prod, t1);   }   if (min_A_len == 0) {      num_bound = 0;      den_bound = 0;      return;   }   InnerProduct(b_len, b, b);   div(t1, prod, min_A_len);   mul(t1, t1, b_len);   SqrRoot(num_bound, t1);   SqrRoot(den_bound, prod);}staticvoid MixedMul(vec_ZZ& x, const vec_zz_p& a, const mat_ZZ& B){   long n = B.NumRows();   long l = B.NumCols();   if (n != a.length())      Error("matrix mul: dimension mismatch");   x.SetLength(l);   long i, k;   ZZ acc, tmp;   for (i = 1; i <= l; i++) {      clear(acc);      for (k = 1; k <= n; k++) {         mul(tmp, B(k, i), rep(a(k)));         add(acc, acc, tmp);      }      x(i) = acc;    }} staticvoid SubDiv(vec_ZZ& e, const vec_ZZ& t, long p){   long n = e.length();   if (t.length() != n) Error("SubDiv: dimension mismatch");   ZZ s;   long i;   for (i = 0; i < n; i++) {      sub(s, e[i], t[i]);      div(e[i], s, p);   }}staticvoid MulAdd(vec_ZZ& x, const ZZ& prod, const vec_zz_p& h){   long n = x.length();   if (h.length() != n) Error("MulAdd: dimension mismatch");   ZZ t;   long i;   for (i = 0; i < n; i++) {      mul(t, prod, rep(h[i]));      add(x[i], x[i], t);   }}staticvoid double_MixedMul1(vec_ZZ& x, double *a, double **B, long n){   long i, k;   double acc;   for (i = 0; i < n; i++) {      double *bp = B[i];      acc = 0;      for (k = 0; k < n; k++) {         acc += bp[k] * a[k];      }      conv(x[i], acc);    }} staticvoid double_MixedMul2(vec_ZZ& x, double *a, double **B, long n, long limit){   long i, k;   double acc;   ZZ acc1, t;   long j;   for (i = 0; i < n; i++) {      double *bp = B[i];      clear(acc1);      acc = 0;      j = 0;      for (k = 0; k < n; k++) {         acc += bp[k] * a[k];         j++;         if (j == limit) {            conv(t, acc);            add(acc1, acc1, t);            acc = 0;            j = 0;         }      }      if (j > 0) {         conv(t, acc);         add(acc1, acc1, t);      }      x[i] = acc1;    }} staticvoid long_MixedMul1(vec_ZZ& x, long *a, long **B, long n){   long i, k;   long acc;   for (i = 0; i < n; i++) {      long *bp = B[i];      acc = 0;      for (k = 0; k < n; k++) {         acc += bp[k] * a[k];      }      conv(x[i], acc);    }} staticvoid long_MixedMul2(vec_ZZ& x, long *a, long **B, long n, long limit){   long i, k;   long acc;   ZZ acc1, t;   long j;   for (i = 0; i < n; i++) {      long *bp = B[i];      clear(acc1);      acc = 0;      j = 0;      for (k = 0; k < n; k++) {         acc += bp[k] * a[k];         j++;         if (j == limit) {            conv(t, acc);            add(acc1, acc1, t);            acc = 0;            j = 0;         }      }      if (j > 0) {         conv(t, acc);         add(acc1, acc1, t);      }      x[i] = acc1;    }} void solve1(ZZ& d_out, vec_ZZ& x_out, const mat_ZZ& A, const vec_ZZ& b){   long n = A.NumRows();   if (A.NumCols() != n)      Error("solve1: nonsquare matrix");   if (b.length() != n)      Error("solve1: dimension mismatch");   if (n == 0) {      set(d_out);      x_out.SetLength(0);      return;   }   ZZ num_bound, den_bound;   hadamard(num_bound, den_bound, A, b);   if (den_bound == 0) {      clear(d_out);      return;   }   zz_pBak zbak;   zbak.save();   long i;   long j;   ZZ prod;   prod = 1;   mat_zz_p B;   for (i = 0; ; i++) {      zz_p::FFTInit(i);      mat_zz_p AA, BB;      zz_p dd;      conv(AA, A);      inv(dd, BB, AA);      if (dd != 0) {         transpose(B, BB);         break;      }      mul(prod, prod, zz_p::modulus());            if (prod > den_bound) {         d_out = 0;         return;      }   }   long max_A_len = MaxBits(A);   long use_double_mul1 = 0;   long use_double_mul2 = 0;   long double_limit = 0;   if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_DOUBLE_PRECISION-1)      use_double_mul1 = 1;   if (!use_double_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_DOUBLE_PRECISION-1) {      use_double_mul2 = 1;      double_limit = (1L << (NTL_DOUBLE_PRECISION-1-max_A_len-NTL_SP_NBITS));   }   long use_long_mul1 = 0;   long use_long_mul2 = 0;   long long_limit = 0;   if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_BITS_PER_LONG-1)      use_long_mul1 = 1;   if (!use_long_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_BITS_PER_LONG-1) {      use_long_mul2 = 1;      long_limit = (1L << (NTL_BITS_PER_LONG-1-max_A_len-NTL_SP_NBITS));   }   if (use_double_mul1 && use_long_mul1)      use_long_mul1 = 0;   else if (use_double_mul1 && use_long_mul2)      use_long_mul2 = 0;   else if (use_double_mul2 && use_long_mul1)      use_double_mul2 = 0;   else if (use_double_mul2 && use_long_mul2) {      if (long_limit > double_limit)         use_double_mul2 = 0;      else         use_long_mul2 = 0;   }   double **double_A;   double *double_h;   typedef double *double_ptr;   if (use_double_mul1 || use_double_mul2) {      double_h = NTL_NEW_OP double[n];      double_A = NTL_NEW_OP double_ptr[n];      if (!double_h || !double_A) Error("solve1: out of mem");      for (i = 0; i < n; i++) {         double_A[i] = NTL_NEW_OP double[n];         if (!double_A[i]) Error("solve1: out of mem");      }      for (i = 0; i < n; i++)         for (j = 0; j < n; j++)            double_A[j][i] = to_double(A[i][j]);   }   long **long_A;   long *long_h;   typedef long *long_ptr;   if (use_long_mul1 || use_long_mul2) {      long_h = NTL_NEW_OP long[n];      long_A = NTL_NEW_OP long_ptr[n];      if (!long_h || !long_A) Error("solve1: out of mem");      for (i = 0; i < n; i++) {         long_A[i] = NTL_NEW_OP long[n];         if (!long_A[i]) Error("solve1: out of mem");      }      for (i = 0; i < n; i++)         for (j = 0; j < n; j++)            long_A[j][i] = to_long(A[i][j]);   }   vec_ZZ x;   x.SetLength(n);   vec_zz_p h;   h.SetLength(n);   vec_ZZ e;   e = b;   vec_zz_p ee;   vec_ZZ t;   t.SetLength(n);   prod = 1;   ZZ bound1;   mul(bound1, num_bound, den_bound);   mul(bound1, bound1, 2);   while (prod <= bound1) {      conv(ee, e);      mul(h, B, ee);      if (use_double_mul1) {         for (i = 0; i < n; i++)            double_h[i] = to_double(rep(h[i]));         double_MixedMul1(t, double_h, double_A, n);      }      else if (use_double_mul2) {         for (i = 0; i < n; i++)            double_h[i] = to_double(rep(h[i]));         double_MixedMul2(t, double_h, double_A, n, double_limit);      }      else if (use_long_mul1) {         for (i = 0; i < n; i++)            long_h[i] = to_long(rep(h[i]));         long_MixedMul1(t, long_h, long_A, n);      }      else if (use_long_mul2) {         for (i = 0; i < n; i++)            long_h[i] = to_long(rep(h[i]));         long_MixedMul2(t, long_h, long_A, n, long_limit);      }      else         MixedMul(t, h, A); // t = h*A      SubDiv(e, t, zz_p::modulus()); // e = (e-t)/p      MulAdd(x, prod, h);  // x = x + prod*h      mul(prod, prod, zz_p::modulus());   }   vec_ZZ num, denom;   ZZ d, d_mod_prod, tmp1;   num.SetLength(n);   denom.SetLength(n);    d = 1;   d_mod_prod = 1;   for (i = 0; i < n; i++) {      rem(x[i], x[i], prod);      MulMod(x[i], x[i], d_mod_prod, prod);      if (!ReconstructRational(num[i], denom[i], x[i], prod,            num_bound, den_bound))          Error("solve1 internal error: rat recon failed!");      mul(d, d, denom[i]);      if (i != n-1) {         if (denom[i] != 1) {            div(den_bound, den_bound, denom[i]);             mul(bound1, num_bound, den_bound);            mul(bound1, bound1, 2);            div(tmp1, prod, zz_p::modulus());            while (tmp1 > bound1) {               prod = tmp1;               div(tmp1, prod, zz_p::modulus());            }            rem(tmp1, denom[i], prod);            rem(d_mod_prod, d_mod_prod, prod);            MulMod(d_mod_prod, d_mod_prod, tmp1, prod);         }      }   }   tmp1 = 1;   for (i = n-1; i >= 0; i--) {      mul(num[i], num[i], tmp1);      mul(tmp1, tmp1, denom[i]);   }      x_out.SetLength(n);   for (i = 0; i < n; i++) {      x_out[i] = num[i];   }   d_out = d;   if (use_double_mul1 || use_double_mul2) {      delete [] double_h;      for (i = 0; i < n; i++) {         delete [] double_A[i];      }      delete [] double_A;   }   if (use_long_mul1 || use_long_mul2) {      delete [] long_h;      for (i = 0; i < n; i++) {         delete [] long_A[i];      }      delete [] long_A;   }}NTL_END_IMPL

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -