⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mat_zz.cpp

📁 数值算法库for Windows
💻 CPP
📖 第 1 页 / 共 2 页
字号:

#include <NTL/mat_ZZ.h>

#include <NTL/new.h>

NTL_START_IMPL

NTL_matrix_impl(ZZ,vec_ZZ,vec_vec_ZZ,mat_ZZ)
NTL_io_matrix_impl(ZZ,vec_ZZ,vec_vec_ZZ,mat_ZZ)
NTL_eq_matrix_impl(ZZ,vec_ZZ,vec_vec_ZZ,mat_ZZ)



void add(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)  
{  
   long n = A.NumRows();  
   long m = A.NumCols();  
  
   if (B.NumRows() != n || B.NumCols() != m)   
      Error("matrix add: dimension mismatch");  
  
   X.SetDims(n, m);  
  
   long i, j;  
   for (i = 1; i <= n; i++)   
      for (j = 1; j <= m; j++)  
         add(X(i,j), A(i,j), B(i,j));  
}  
  
void sub(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)  
{  
   long n = A.NumRows();  
   long m = A.NumCols();  
  
   if (B.NumRows() != n || B.NumCols() != m)  
      Error("matrix sub: dimension mismatch");  
  
   X.SetDims(n, m);  
  
   long i, j;  
   for (i = 1; i <= n; i++)  
      for (j = 1; j <= m; j++)  
         sub(X(i,j), A(i,j), B(i,j));  
}  
  
void mul_aux(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)  
{  
   long n = A.NumRows();  
   long l = A.NumCols();  
   long m = B.NumCols();  
  
   if (l != B.NumRows())  
      Error("matrix mul: dimension mismatch");  
  
   X.SetDims(n, m);  
  
   long i, j, k;  
   ZZ acc, tmp;  
  
   for (i = 1; i <= n; i++) {  
      for (j = 1; j <= m; j++) {  
         clear(acc);  
         for(k = 1; k <= l; k++) {  
            mul(tmp, A(i,k), B(k,j));  
            add(acc, acc, tmp);  
         }  
         X(i,j) = acc;  
      }  
   }  
}  
  
  
void mul(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B)  
{  
   if (&X == &A || &X == &B) {  
      mat_ZZ tmp;  
      mul_aux(tmp, A, B);  
      X = tmp;  
   }  
   else  
      mul_aux(X, A, B);  
}  
  
  
static
void mul_aux(vec_ZZ& x, const mat_ZZ& A, const vec_ZZ& b)  
{  
   long n = A.NumRows();  
   long l = A.NumCols();  
  
   if (l != b.length())  
      Error("matrix mul: dimension mismatch");  
  
   x.SetLength(n);  
  
   long i, k;  
   ZZ acc, tmp;  
  
   for (i = 1; i <= n; i++) {  
      clear(acc);  
      for (k = 1; k <= l; k++) {  
         mul(tmp, A(i,k), b(k));  
         add(acc, acc, tmp);  
      }  
      x(i) = acc;  
   }  
}  
  
  
void mul(vec_ZZ& x, const mat_ZZ& A, const vec_ZZ& b)  
{  
   if (&b == &x || A.position(b) != -1) {
      vec_ZZ tmp;
      mul_aux(tmp, A, b);
      x = tmp;
   }
   else
      mul_aux(x, A, b);
}  

static
void mul_aux(vec_ZZ& x, const vec_ZZ& a, const mat_ZZ& B)  
{  
   long n = B.NumRows();  
   long l = B.NumCols();  
  
   if (n != a.length())  
      Error("matrix mul: dimension mismatch");  
  
   x.SetLength(l);  
  
   long i, k;  
   ZZ acc, tmp;  
  
   for (i = 1; i <= l; i++) {  
      clear(acc);  
      for (k = 1; k <= n; k++) {  
         mul(tmp, a(k), B(k,i));
         add(acc, acc, tmp);  
      }  
      x(i) = acc;  
   }  
}  

void mul(vec_ZZ& x, const vec_ZZ& a, const mat_ZZ& B)
{
   if (&a == &x || B.position(a) != -1) {
      vec_ZZ tmp;
      mul_aux(tmp, a, B);
      x = tmp;
   }
   else
      mul_aux(x, a, B);
}

     
  
void ident(mat_ZZ& X, long n)  
{  
   X.SetDims(n, n);  
   long i, j;  
  
   for (i = 1; i <= n; i++)  
      for (j = 1; j <= n; j++)  
         if (i == j)  
            set(X(i, j));  
         else  
            clear(X(i, j));  
} 

static
long DetBound(const mat_ZZ& a)
{
   long n = a.NumRows();
   long i;
   ZZ res, t1;

   set(res);

   for (i = 0; i < n; i++) {
      InnerProduct(t1, a[i], a[i]);
      if (t1 > 1) {
         SqrRoot(t1, t1);
         add(t1, t1, 1);
      }
      mul(res, res, t1);
   }

   return NumBits(res);
}



   

void determinant(ZZ& rres, const mat_ZZ& a, long deterministic)
{
   long n = a.NumRows();
   if (a.NumCols() != n)
      Error("determinant: nonsquare matrix");

   if (n == 0) {
      set(rres);
      return;
   }

   zz_pBak zbak;
   zbak.save();

   ZZ_pBak Zbak;
   Zbak.save();

   long instable = 1;

   long gp_cnt = 0;

   long bound = 2+DetBound(a);

   ZZ res, prod;

   clear(res);
   set(prod);


   long i;
   for (i = 0; ; i++) {
      if (NumBits(prod) > bound)
         break;

      if (!deterministic &&
          !instable && bound > 1000 && NumBits(prod) < 0.25*bound) {
         ZZ P;


         long plen = 90 + NumBits(max(bound, NumBits(res)));
         GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));

         ZZ_p::init(P);

         mat_ZZ_p A;
         conv(A, a);

         ZZ_p t;
         determinant(t, A);

         if (CRT(res, prod, rep(t), P))
            instable = 1;
         else
            break;
      }


      zz_p::FFTInit(i);
      long p = zz_p::modulus();

      mat_zz_p A;
      conv(A, a);

      zz_p t;
      determinant(t, A);

      instable = CRT(res, prod, rep(t), p);
   }

   rres = res;

   zbak.restore();
   Zbak.restore();
}




void conv(mat_zz_p& x, const mat_ZZ& a)
{
   long n = a.NumRows();
   long m = a.NumCols();
   long i;

   x.SetDims(n, m);
   for (i = 0; i < n; i++)
      conv(x[i], a[i]);
}

void conv(mat_ZZ_p& x, const mat_ZZ& a)
{
   long n = a.NumRows();
   long m = a.NumCols();
   long i;

   x.SetDims(n, m);
   for (i = 0; i < n; i++)
      conv(x[i], a[i]);
}

long IsIdent(const mat_ZZ& A, long n)
{
   if (A.NumRows() != n || A.NumCols() != n)
      return 0;

   long i, j;

   for (i = 1; i <= n; i++)
      for (j = 1; j <= n; j++)
         if (i != j) {
            if (!IsZero(A(i, j))) return 0;
         }
         else {
            if (!IsOne(A(i, j))) return 0;
         }

   return 1;
}


void transpose(mat_ZZ& X, const mat_ZZ& A)
{
   long n = A.NumRows();
   long m = A.NumCols();

   long i, j;

   if (&X == & A) {
      if (n == m)
         for (i = 1; i <= n; i++)
            for (j = i+1; j <= n; j++)
               swap(X(i, j), X(j, i));
      else {
         mat_ZZ tmp;
         tmp.SetDims(m, n);
         for (i = 1; i <= n; i++)
            for (j = 1; j <= m; j++)
               tmp(j, i) = A(i, j);
         X.kill();
         X = tmp;
      }
   }
   else {
      X.SetDims(m, n);
      for (i = 1; i <= n; i++)
         for (j = 1; j <= m; j++)
            X(j, i) = A(i, j);
   }
}

long CRT(mat_ZZ& gg, ZZ& a, const mat_zz_p& G)
{
   long n = gg.NumRows();
   long m = gg.NumCols();

   if (G.NumRows() != n || G.NumCols() != m)
      Error("CRT: dimension mismatch");

   long p = zz_p::modulus();

   ZZ new_a;
   mul(new_a, a, p);

   long a_inv;
   a_inv = rem(a, p);
   a_inv = InvMod(a_inv, p);

   long p1;
   p1 = p >> 1;

   ZZ a1;
   RightShift(a1, a, 1);

   long p_odd = (p & 1);

   long modified = 0;

   long h;
   ZZ ah;

   ZZ g;
   long i, j;

   for (i = 0; i < n; i++) {
      for (j = 0; j < m; j++) {
         if (!CRTInRange(gg[i][j], a)) {
            modified = 1;
            rem(g, gg[i][j], a);
            if (g > a1) sub(g, g, a);
         }
         else
            g = gg[i][j];
      
         h = rem(g, p);
         h = SubMod(rep(G[i][j]), h, p);
         h = MulMod(h, a_inv, p);
         if (h > p1)
            h = h - p;
      
         if (h != 0) {
            modified = 1;
            mul(ah, a, h);
      
            if (!p_odd && g > 0 && (h == p1))
               sub(g, g, ah);
            else
               add(g, g, ah);
         }
   
         gg[i][j] = g;
      }
   }

   a = new_a;

   return modified;

}


void mul(mat_ZZ& X, const mat_ZZ& A, const ZZ& b_in)
{
   ZZ b = b_in;
   long n = A.NumRows();
   long m = A.NumCols();

   X.SetDims(n, m);

   long i, j;
   for (i = 0; i < n; i++)
      for (j = 0; j < m; j++)
         mul(X[i][j], A[i][j], b);
}

void mul(mat_ZZ& X, const mat_ZZ& A, long b)
{
   long n = A.NumRows();
   long m = A.NumCols();

   X.SetDims(n, m);

   long i, j;
   for (i = 0; i < n; i++)
      for (j = 0; j < m; j++)
         mul(X[i][j], A[i][j], b);
}


static
void ExactDiv(vec_ZZ& x, const ZZ& d)
{
   long n = x.length();
   long i;

   for (i = 0; i < n; i++)
      if (!divide(x[i], x[i], d))
         Error("inexact division");
}

static
void ExactDiv(mat_ZZ& x, const ZZ& d)
{
   long n = x.NumRows();
   long m = x.NumCols();
   
   long i, j;

   for (i = 0; i < n; i++)
      for (j = 0; j < m; j++)
         if (!divide(x[i][j], x[i][j], d))
            Error("inexact division");
}

void diag(mat_ZZ& X, long n, const ZZ& d_in)  
{  
   ZZ d = d_in;
   X.SetDims(n, n);  
   long i, j;  
  
   for (i = 1; i <= n; i++)  
      for (j = 1; j <= n; j++)  
         if (i == j)  
            X(i, j) = d;  
         else  
            clear(X(i, j));  
} 

long IsDiag(const mat_ZZ& A, long n, const ZZ& d)
{
   if (A.NumRows() != n || A.NumCols() != n)
      return 0;

   long i, j;

   for (i = 1; i <= n; i++)
      for (j = 1; j <= n; j++)
         if (i != j) {
            if (!IsZero(A(i, j))) return 0;
         }
         else {
            if (A(i, j) != d) return 0;
         }

   return 1;
}




void solve(ZZ& d_out, vec_ZZ& x_out,
           const mat_ZZ& A, const vec_ZZ& b,
           long deterministic)
{
   long n = A.NumRows();
   
   if (A.NumCols() != n)
      Error("solve: nonsquare matrix");

   if (b.length() != n)
      Error("solve: dimension mismatch");

   if (n == 0) {
      set(d_out);
      x_out.SetLength(0);
      return;
   }

   zz_pBak zbak;
   zbak.save();

   ZZ_pBak Zbak;
   Zbak.save();

   vec_ZZ x(INIT_SIZE, n);
   ZZ d, d1;

   ZZ d_prod, x_prod;
   set(d_prod);
   set(x_prod);

   long d_instable = 1;
   long x_instable = 1;

   long check = 0;

   long gp_cnt = 0;

   vec_ZZ y, b1;

   long i;
   long bound = 2+DetBound(A);

   for (i = 0; ; i++) {
      if ((check || IsZero(d)) && !d_instable) {
         if (NumBits(d_prod) > bound) {
            break;
         }
         else if (!deterministic &&
                  bound > 1000 && NumBits(d_prod) < 0.25*bound) {

            ZZ P;
   
            long plen = 90 + NumBits(max(bound, NumBits(d)));
            GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
   
            ZZ_p::init(P);
   
            mat_ZZ_p AA;
            conv(AA, A);
   
            ZZ_p dd;
            determinant(dd, AA);
   
            if (CRT(d, d_prod, rep(dd), P))
               d_instable = 1;
            else 
               break;
         }
      }


      zz_p::FFTInit(i);
      long p = zz_p::modulus();

      mat_zz_p AA;
      conv(AA, A);

      if (!check) {
         vec_zz_p bb, xx;
         conv(bb, b);

         zz_p dd; 

         solve(dd, xx, AA, bb);

         d_instable = CRT(d, d_prod, rep(dd), p);
         if (!IsZero(dd)) {
            mul(xx, xx, dd);
            x_instable = CRT(x, x_prod, xx);
         }
         else
            x_instable = 1;

         if (!d_instable && !x_instable) {
            mul(y, x, A);
            mul(b1, b, d);
            if (y == b1) {
               d1 = d;
               check = 1;
            }
         }
      }
      else {
         zz_p dd;
         determinant(dd, AA);
         d_instable = CRT(d, d_prod, rep(dd), p);
      }
   }

   if (check && d1 != d) {
      mul(x, x, d);
      ExactDiv(x, d1);
   }

   d_out = d;
   if (check) x_out = x;

   zbak.restore();
   Zbak.restore();
}

void inv(ZZ& d_out, mat_ZZ& x_out, const mat_ZZ& A, long deterministic)
{
   long n = A.NumRows();
   
   if (A.NumCols() != n)
      Error("solve: nonsquare matrix");

   if (n == 0) {
      set(d_out);
      x_out.SetDims(0, 0);
      return;
   }

   zz_pBak zbak;
   zbak.save();

   ZZ_pBak Zbak;
   Zbak.save();

   mat_ZZ x(INIT_SIZE, n, n);
   ZZ d, d1;

   ZZ d_prod, x_prod;
   set(d_prod);
   set(x_prod);

   long d_instable = 1;
   long x_instable = 1;

   long gp_cnt = 0;

   long check = 0;


   mat_ZZ y;

   long i;
   long bound = 2+DetBound(A);

   for (i = 0; ; i++) {
      if ((check || IsZero(d)) && !d_instable) {
         if (NumBits(d_prod) > bound) {
            break;
         }
         else if (!deterministic &&
                  bound > 1000 && NumBits(d_prod) < 0.25*bound) {

            ZZ P;
   
            long plen = 90 + NumBits(max(bound, NumBits(d)));
            GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++));
   
            ZZ_p::init(P);
   
            mat_ZZ_p AA;
            conv(AA, A);
   
            ZZ_p dd;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -