⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mat_zz.cpp

📁 数值算法库for Windows
💻 CPP
📖 第 1 页 / 共 2 页
字号:
            determinant(dd, AA);
   
            if (CRT(d, d_prod, rep(dd), P))
               d_instable = 1;
            else 
               break;
         }
      }


      zz_p::FFTInit(i);
      long p = zz_p::modulus();

      mat_zz_p AA;
      conv(AA, A);

      if (!check) {
         mat_zz_p xx;

         zz_p dd; 

         inv(dd, xx, AA);

         d_instable = CRT(d, d_prod, rep(dd), p);
         if (!IsZero(dd)) {
            mul(xx, xx, dd);
            x_instable = CRT(x, x_prod, xx);
         }
         else
            x_instable = 1;

         if (!d_instable && !x_instable) {
            mul(y, x, A);
            if (IsDiag(y, n, d)) {
               d1 = d;
               check = 1;
            }
         }
      }
      else {
         zz_p dd;
         determinant(dd, AA);
         d_instable = CRT(d, d_prod, rep(dd), p);
      }
   }

   if (check && d1 != d) {
      mul(x, x, d);
      ExactDiv(x, d1);
   }

   d_out = d;
   if (check) x_out = x;

   zbak.restore();
   Zbak.restore();
}

void negate(mat_ZZ& X, const mat_ZZ& A)
{
   long n = A.NumRows();
   long m = A.NumCols();


   X.SetDims(n, m);

   long i, j;
   for (i = 1; i <= n; i++)
      for (j = 1; j <= m; j++)
         negate(X(i,j), A(i,j));
}



long IsZero(const mat_ZZ& a)
{
   long n = a.NumRows();
   long i;

   for (i = 0; i < n; i++)
      if (!IsZero(a[i]))
         return 0;

   return 1;
}

void clear(mat_ZZ& x)
{
   long n = x.NumRows();
   long i;
   for (i = 0; i < n; i++)
      clear(x[i]);
}


mat_ZZ operator+(const mat_ZZ& a, const mat_ZZ& b)
{
   mat_ZZ res;
   add(res, a, b);
   NTL_OPT_RETURN(mat_ZZ, res);
}

mat_ZZ operator*(const mat_ZZ& a, const mat_ZZ& b)
{
   mat_ZZ res;
   mul_aux(res, a, b);
   NTL_OPT_RETURN(mat_ZZ, res);
}

mat_ZZ operator-(const mat_ZZ& a, const mat_ZZ& b)
{
   mat_ZZ res;
   sub(res, a, b);
   NTL_OPT_RETURN(mat_ZZ, res);
}


mat_ZZ operator-(const mat_ZZ& a)
{
   mat_ZZ res;
   negate(res, a);
   NTL_OPT_RETURN(mat_ZZ, res);
}

vec_ZZ operator*(const mat_ZZ& a, const vec_ZZ& b)
{
   vec_ZZ res;
   mul_aux(res, a, b);
   NTL_OPT_RETURN(vec_ZZ, res);
}

vec_ZZ operator*(const vec_ZZ& a, const mat_ZZ& b)
{
   vec_ZZ res;
   mul_aux(res, a, b);
   NTL_OPT_RETURN(vec_ZZ, res);
}




void inv(mat_ZZ& X, const mat_ZZ& A)
{
   ZZ d;
   inv(d, X, A);
   if (d == -1)
      negate(X, X);
   else if (d != 1)
      Error("inv: non-invertible matrix");
}

void power(mat_ZZ& X, const mat_ZZ& A, const ZZ& e)
{
   if (A.NumRows() != A.NumCols()) Error("power: non-square matrix");

   if (e == 0) {
      ident(X, A.NumRows());
      return;
   }

   mat_ZZ T1, T2;
   long i, k;

   k = NumBits(e);
   T1 = A;

   for (i = k-2; i >= 0; i--) {
      sqr(T2, T1);
      if (bit(e, i))
         mul(T1, T2, A);
      else
         T1 = T2;
   }

   if (e < 0)
      inv(X, T1);
   else
      X = T1;
}



/***********************************************************

   routines for solving a linear system vi Hensel lifting

************************************************************/


static
long MaxBits(const mat_ZZ& A)
{
   long m = 0;
   long i, j;
   for (i = 0; i < A.NumRows(); i++)
      for (j = 0; j < A.NumCols(); j++)
         m = max(m, NumBits(A[i][j]));

   return m;
}




// Computes an upper bound on the numerators and denominators
// to the solution x*A = b using Hadamard's bound and Cramer's rule. 
// If A contains a zero row, then sets both bounds to zero.

static
void hadamard(ZZ& num_bound, ZZ& den_bound, 
              const mat_ZZ& A, const vec_ZZ& b)
{
   long n = A.NumRows();

   if (n == 0) Error("internal error: hadamard with n = 0");

   ZZ b_len, min_A_len, prod, t1;

   InnerProduct(min_A_len, A[0], A[0]);

   prod = min_A_len;

   long i;
   for (i = 1; i < n; i++) {
      InnerProduct(t1, A[i], A[i]);
      if (t1 < min_A_len)
         min_A_len = t1;
      mul(prod, prod, t1);
   }

   if (min_A_len == 0) {
      num_bound = 0;
      den_bound = 0;
      return;
   }

   InnerProduct(b_len, b, b);

   div(t1, prod, min_A_len);
   mul(t1, t1, b_len);

   SqrRoot(num_bound, t1);
   SqrRoot(den_bound, prod);
}


static
void MixedMul(vec_ZZ& x, const vec_zz_p& a, const mat_ZZ& B)
{
   long n = B.NumRows();
   long l = B.NumCols();

   if (n != a.length())
      Error("matrix mul: dimension mismatch");

   x.SetLength(l);

   long i, k;
   ZZ acc, tmp;

   for (i = 1; i <= l; i++) {
      clear(acc);
      for (k = 1; k <= n; k++) {
         mul(tmp, B(k, i), rep(a(k)));
         add(acc, acc, tmp);
      }
      x(i) = acc;
    }
} 

static
void SubDiv(vec_ZZ& e, const vec_ZZ& t, long p)
{
   long n = e.length();
   if (t.length() != n) Error("SubDiv: dimension mismatch");

   ZZ s;
   long i;

   for (i = 0; i < n; i++) {
      sub(s, e[i], t[i]);
      div(e[i], s, p);
   }
}

static
void MulAdd(vec_ZZ& x, const ZZ& prod, const vec_zz_p& h)
{
   long n = x.length();
   if (h.length() != n) Error("MulAdd: dimension mismatch");

   ZZ t;
   long i;

   for (i = 0; i < n; i++) {
      mul(t, prod, rep(h[i]));
      add(x[i], x[i], t);
   }
}


static
void double_MixedMul1(vec_ZZ& x, double *a, double **B, long n)
{
   long i, k;
   double acc;

   for (i = 0; i < n; i++) {
      double *bp = B[i];
      acc = 0;
      for (k = 0; k < n; k++) {
         acc += bp[k] * a[k];
      }
      conv(x[i], acc);
    }
} 


static
void double_MixedMul2(vec_ZZ& x, double *a, double **B, long n, long limit)
{
   long i, k;
   double acc;
   ZZ acc1, t;
   long j;

   for (i = 0; i < n; i++) {
      double *bp = B[i];

      clear(acc1);
      acc = 0;
      j = 0;

      for (k = 0; k < n; k++) {
         acc += bp[k] * a[k];
         j++;
         if (j == limit) {
            conv(t, acc);
            add(acc1, acc1, t);
            acc = 0;
            j = 0;
         }
      }

      if (j > 0) {
         conv(t, acc);
         add(acc1, acc1, t);
      }

      x[i] = acc1;
    }
} 


static
void long_MixedMul1(vec_ZZ& x, long *a, long **B, long n)
{
   long i, k;
   long acc;

   for (i = 0; i < n; i++) {
      long *bp = B[i];
      acc = 0;
      for (k = 0; k < n; k++) {
         acc += bp[k] * a[k];
      }
      conv(x[i], acc);
    }
} 


static
void long_MixedMul2(vec_ZZ& x, long *a, long **B, long n, long limit)
{
   long i, k;
   long acc;
   ZZ acc1, t;
   long j;

   for (i = 0; i < n; i++) {
      long *bp = B[i];

      clear(acc1);
      acc = 0;
      j = 0;

      for (k = 0; k < n; k++) {
         acc += bp[k] * a[k];
         j++;
         if (j == limit) {
            conv(t, acc);
            add(acc1, acc1, t);
            acc = 0;
            j = 0;
         }
      }

      if (j > 0) {
         conv(t, acc);
         add(acc1, acc1, t);
      }

      x[i] = acc1;
    }
} 


void solve1(ZZ& d_out, vec_ZZ& x_out, const mat_ZZ& A, const vec_ZZ& b)
{
   long n = A.NumRows();

   if (A.NumCols() != n)
      Error("solve1: nonsquare matrix");

   if (b.length() != n)
      Error("solve1: dimension mismatch");

   if (n == 0) {
      set(d_out);
      x_out.SetLength(0);
      return;
   }

   ZZ num_bound, den_bound;

   hadamard(num_bound, den_bound, A, b);

   if (den_bound == 0) {
      clear(d_out);
      return;
   }

   zz_pBak zbak;
   zbak.save();

   long i;
   long j;

   ZZ prod;
   prod = 1;

   mat_zz_p B;


   for (i = 0; ; i++) {
      zz_p::FFTInit(i);

      mat_zz_p AA, BB;
      zz_p dd;

      conv(AA, A);
      inv(dd, BB, AA);

      if (dd != 0) {
         transpose(B, BB);
         break;
      }

      mul(prod, prod, zz_p::modulus());
      
      if (prod > den_bound) {
         d_out = 0;
         return;
      }
   }

   long max_A_len = MaxBits(A);

   long use_double_mul1 = 0;
   long use_double_mul2 = 0;
   long double_limit = 0;

   if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_DOUBLE_PRECISION-1)
      use_double_mul1 = 1;

   if (!use_double_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_DOUBLE_PRECISION-1) {
      use_double_mul2 = 1;
      double_limit = (1L << (NTL_DOUBLE_PRECISION-1-max_A_len-NTL_SP_NBITS));
   }

   long use_long_mul1 = 0;
   long use_long_mul2 = 0;
   long long_limit = 0;

   if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_BITS_PER_LONG-1)
      use_long_mul1 = 1;

   if (!use_long_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_BITS_PER_LONG-1) {
      use_long_mul2 = 1;
      long_limit = (1L << (NTL_BITS_PER_LONG-1-max_A_len-NTL_SP_NBITS));
   }



   if (use_double_mul1 && use_long_mul1)
      use_long_mul1 = 0;
   else if (use_double_mul1 && use_long_mul2)
      use_long_mul2 = 0;
   else if (use_double_mul2 && use_long_mul1)
      use_double_mul2 = 0;
   else if (use_double_mul2 && use_long_mul2) {
      if (long_limit > double_limit)
         use_double_mul2 = 0;
      else
         use_long_mul2 = 0;
   }


   double **double_A;
   double *double_h;

   typedef double *double_ptr;

   if (use_double_mul1 || use_double_mul2) {
      double_h = NTL_NEW_OP double[n];
      double_A = NTL_NEW_OP double_ptr[n];
      if (!double_h || !double_A) Error("solve1: out of mem");

      for (i = 0; i < n; i++) {
         double_A[i] = NTL_NEW_OP double[n];
         if (!double_A[i]) Error("solve1: out of mem");
      }

      for (i = 0; i < n; i++)
         for (j = 0; j < n; j++)
            double_A[j][i] = to_double(A[i][j]);
   }

   long **long_A;
   long *long_h;

   typedef long *long_ptr;

   if (use_long_mul1 || use_long_mul2) {
      long_h = NTL_NEW_OP long[n];
      long_A = NTL_NEW_OP long_ptr[n];
      if (!long_h || !long_A) Error("solve1: out of mem");

      for (i = 0; i < n; i++) {
         long_A[i] = NTL_NEW_OP long[n];
         if (!long_A[i]) Error("solve1: out of mem");
      }

      for (i = 0; i < n; i++)
         for (j = 0; j < n; j++)
            long_A[j][i] = to_long(A[i][j]);
   }


   vec_ZZ x;
   x.SetLength(n);

   vec_zz_p h;
   h.SetLength(n);

   vec_ZZ e;
   e = b;

   vec_zz_p ee;

   vec_ZZ t;
   t.SetLength(n);

   prod = 1;

   ZZ bound1;
   mul(bound1, num_bound, den_bound);
   mul(bound1, bound1, 2);

   while (prod <= bound1) {
      conv(ee, e);

      mul(h, B, ee);

      if (use_double_mul1) {
         for (i = 0; i < n; i++)
            double_h[i] = to_double(rep(h[i]));

         double_MixedMul1(t, double_h, double_A, n);
      }
      else if (use_double_mul2) {
         for (i = 0; i < n; i++)
            double_h[i] = to_double(rep(h[i]));

         double_MixedMul2(t, double_h, double_A, n, double_limit);
      }
      else if (use_long_mul1) {
         for (i = 0; i < n; i++)
            long_h[i] = to_long(rep(h[i]));

         long_MixedMul1(t, long_h, long_A, n);
      }
      else if (use_long_mul2) {
         for (i = 0; i < n; i++)
            long_h[i] = to_long(rep(h[i]));

         long_MixedMul2(t, long_h, long_A, n, long_limit);
      }
      else
         MixedMul(t, h, A); // t = h*A

      SubDiv(e, t, zz_p::modulus()); // e = (e-t)/p
      MulAdd(x, prod, h);  // x = x + prod*h

      mul(prod, prod, zz_p::modulus());
   }

   vec_ZZ num, denom;
   ZZ d, d_mod_prod, tmp1;

   num.SetLength(n);
   denom.SetLength(n);
 
   d = 1;
   d_mod_prod = 1;

   for (i = 0; i < n; i++) {
      rem(x[i], x[i], prod);
      MulMod(x[i], x[i], d_mod_prod, prod);

      if (!ReconstructRational(num[i], denom[i], x[i], prod, 
           num_bound, den_bound))
          Error("solve1 internal error: rat recon failed!");

      mul(d, d, denom[i]);

      if (i != n-1) {
         if (denom[i] != 1) {
            div(den_bound, den_bound, denom[i]); 
            mul(bound1, num_bound, den_bound);
            mul(bound1, bound1, 2);

            div(tmp1, prod, zz_p::modulus());
            while (tmp1 > bound1) {
               prod = tmp1;
               div(tmp1, prod, zz_p::modulus());
            }

            rem(tmp1, denom[i], prod);
            rem(d_mod_prod, d_mod_prod, prod);
            MulMod(d_mod_prod, d_mod_prod, tmp1, prod);
         }
      }
   }

   tmp1 = 1;
   for (i = n-1; i >= 0; i--) {
      mul(num[i], num[i], tmp1);
      mul(tmp1, tmp1, denom[i]);
   }
   
   x_out.SetLength(n);

   for (i = 0; i < n; i++) {
      x_out[i] = num[i];
   }

   d_out = d;

   if (use_double_mul1 || use_double_mul2) {
      delete [] double_h;

      for (i = 0; i < n; i++) {
         delete [] double_A[i];
      }

      delete [] double_A;
   }

   if (use_long_mul1 || use_long_mul2) {
      delete [] long_h;

      for (i = 0; i < n; i++) {
         delete [] long_A[i];
      }

      delete [] long_A;
   }

}

NTL_END_IMPL

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -