📄 mat_zz.c
字号:
#include <NTL/mat_ZZ.h>#include <NTL/new.h>NTL_START_IMPLNTL_matrix_impl(ZZ,vec_ZZ,vec_vec_ZZ,mat_ZZ)NTL_io_matrix_impl(ZZ,vec_ZZ,vec_vec_ZZ,mat_ZZ)NTL_eq_matrix_impl(ZZ,vec_ZZ,vec_vec_ZZ,mat_ZZ)void add(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B) { long n = A.NumRows(); long m = A.NumCols(); if (B.NumRows() != n || B.NumCols() != m) Error("matrix add: dimension mismatch"); X.SetDims(n, m); long i, j; for (i = 1; i <= n; i++) for (j = 1; j <= m; j++) add(X(i,j), A(i,j), B(i,j)); } void sub(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B) { long n = A.NumRows(); long m = A.NumCols(); if (B.NumRows() != n || B.NumCols() != m) Error("matrix sub: dimension mismatch"); X.SetDims(n, m); long i, j; for (i = 1; i <= n; i++) for (j = 1; j <= m; j++) sub(X(i,j), A(i,j), B(i,j)); } void mul_aux(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B) { long n = A.NumRows(); long l = A.NumCols(); long m = B.NumCols(); if (l != B.NumRows()) Error("matrix mul: dimension mismatch"); X.SetDims(n, m); long i, j, k; ZZ acc, tmp; for (i = 1; i <= n; i++) { for (j = 1; j <= m; j++) { clear(acc); for(k = 1; k <= l; k++) { mul(tmp, A(i,k), B(k,j)); add(acc, acc, tmp); } X(i,j) = acc; } } } void mul(mat_ZZ& X, const mat_ZZ& A, const mat_ZZ& B) { if (&X == &A || &X == &B) { mat_ZZ tmp; mul_aux(tmp, A, B); X = tmp; } else mul_aux(X, A, B); } staticvoid mul_aux(vec_ZZ& x, const mat_ZZ& A, const vec_ZZ& b) { long n = A.NumRows(); long l = A.NumCols(); if (l != b.length()) Error("matrix mul: dimension mismatch"); x.SetLength(n); long i, k; ZZ acc, tmp; for (i = 1; i <= n; i++) { clear(acc); for (k = 1; k <= l; k++) { mul(tmp, A(i,k), b(k)); add(acc, acc, tmp); } x(i) = acc; } } void mul(vec_ZZ& x, const mat_ZZ& A, const vec_ZZ& b) { if (&b == &x || A.position(b) != -1) { vec_ZZ tmp; mul_aux(tmp, A, b); x = tmp; } else mul_aux(x, A, b);} staticvoid mul_aux(vec_ZZ& x, const vec_ZZ& a, const mat_ZZ& B) { long n = B.NumRows(); long l = B.NumCols(); if (n != a.length()) Error("matrix mul: dimension mismatch"); x.SetLength(l); long i, k; ZZ acc, tmp; for (i = 1; i <= l; i++) { clear(acc); for (k = 1; k <= n; k++) { mul(tmp, a(k), B(k,i)); add(acc, acc, tmp); } x(i) = acc; } } void mul(vec_ZZ& x, const vec_ZZ& a, const mat_ZZ& B){ if (&a == &x || B.position(a) != -1) { vec_ZZ tmp; mul_aux(tmp, a, B); x = tmp; } else mul_aux(x, a, B);} void ident(mat_ZZ& X, long n) { X.SetDims(n, n); long i, j; for (i = 1; i <= n; i++) for (j = 1; j <= n; j++) if (i == j) set(X(i, j)); else clear(X(i, j)); } staticlong DetBound(const mat_ZZ& a){ long n = a.NumRows(); long i; ZZ res, t1; set(res); for (i = 0; i < n; i++) { InnerProduct(t1, a[i], a[i]); if (t1 > 1) { SqrRoot(t1, t1); add(t1, t1, 1); } mul(res, res, t1); } return NumBits(res);} void determinant(ZZ& rres, const mat_ZZ& a, long deterministic){ long n = a.NumRows(); if (a.NumCols() != n) Error("determinant: nonsquare matrix"); if (n == 0) { set(rres); return; } zz_pBak zbak; zbak.save(); ZZ_pBak Zbak; Zbak.save(); long instable = 1; long gp_cnt = 0; long bound = 2+DetBound(a); ZZ res, prod; clear(res); set(prod); long i; for (i = 0; ; i++) { if (NumBits(prod) > bound) break; if (!deterministic && !instable && bound > 1000 && NumBits(prod) < 0.25*bound) { ZZ P; long plen = 90 + NumBits(max(bound, NumBits(res))); GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++)); ZZ_p::init(P); mat_ZZ_p A; conv(A, a); ZZ_p t; determinant(t, A); if (CRT(res, prod, rep(t), P)) instable = 1; else break; } zz_p::FFTInit(i); long p = zz_p::modulus(); mat_zz_p A; conv(A, a); zz_p t; determinant(t, A); instable = CRT(res, prod, rep(t), p); } rres = res; zbak.restore(); Zbak.restore();}void conv(mat_zz_p& x, const mat_ZZ& a){ long n = a.NumRows(); long m = a.NumCols(); long i; x.SetDims(n, m); for (i = 0; i < n; i++) conv(x[i], a[i]);}void conv(mat_ZZ_p& x, const mat_ZZ& a){ long n = a.NumRows(); long m = a.NumCols(); long i; x.SetDims(n, m); for (i = 0; i < n; i++) conv(x[i], a[i]);}long IsIdent(const mat_ZZ& A, long n){ if (A.NumRows() != n || A.NumCols() != n) return 0; long i, j; for (i = 1; i <= n; i++) for (j = 1; j <= n; j++) if (i != j) { if (!IsZero(A(i, j))) return 0; } else { if (!IsOne(A(i, j))) return 0; } return 1;}void transpose(mat_ZZ& X, const mat_ZZ& A){ long n = A.NumRows(); long m = A.NumCols(); long i, j; if (&X == & A) { if (n == m) for (i = 1; i <= n; i++) for (j = i+1; j <= n; j++) swap(X(i, j), X(j, i)); else { mat_ZZ tmp; tmp.SetDims(m, n); for (i = 1; i <= n; i++) for (j = 1; j <= m; j++) tmp(j, i) = A(i, j); X.kill(); X = tmp; } } else { X.SetDims(m, n); for (i = 1; i <= n; i++) for (j = 1; j <= m; j++) X(j, i) = A(i, j); }}long CRT(mat_ZZ& gg, ZZ& a, const mat_zz_p& G){ long n = gg.NumRows(); long m = gg.NumCols(); if (G.NumRows() != n || G.NumCols() != m) Error("CRT: dimension mismatch"); long p = zz_p::modulus(); ZZ new_a; mul(new_a, a, p); long a_inv; a_inv = rem(a, p); a_inv = InvMod(a_inv, p); long p1; p1 = p >> 1; ZZ a1; RightShift(a1, a, 1); long p_odd = (p & 1); long modified = 0; long h; ZZ ah; ZZ g; long i, j; for (i = 0; i < n; i++) { for (j = 0; j < m; j++) { if (!CRTInRange(gg[i][j], a)) { modified = 1; rem(g, gg[i][j], a); if (g > a1) sub(g, g, a); } else g = gg[i][j]; h = rem(g, p); h = SubMod(rep(G[i][j]), h, p); h = MulMod(h, a_inv, p); if (h > p1) h = h - p; if (h != 0) { modified = 1; mul(ah, a, h); if (!p_odd && g > 0 && (h == p1)) sub(g, g, ah); else add(g, g, ah); } gg[i][j] = g; } } a = new_a; return modified;}void mul(mat_ZZ& X, const mat_ZZ& A, const ZZ& b_in){ ZZ b = b_in; long n = A.NumRows(); long m = A.NumCols(); X.SetDims(n, m); long i, j; for (i = 0; i < n; i++) for (j = 0; j < m; j++) mul(X[i][j], A[i][j], b);}void mul(mat_ZZ& X, const mat_ZZ& A, long b){ long n = A.NumRows(); long m = A.NumCols(); X.SetDims(n, m); long i, j; for (i = 0; i < n; i++) for (j = 0; j < m; j++) mul(X[i][j], A[i][j], b);}staticvoid ExactDiv(vec_ZZ& x, const ZZ& d){ long n = x.length(); long i; for (i = 0; i < n; i++) if (!divide(x[i], x[i], d)) Error("inexact division");}staticvoid ExactDiv(mat_ZZ& x, const ZZ& d){ long n = x.NumRows(); long m = x.NumCols(); long i, j; for (i = 0; i < n; i++) for (j = 0; j < m; j++) if (!divide(x[i][j], x[i][j], d)) Error("inexact division");}void diag(mat_ZZ& X, long n, const ZZ& d_in) { ZZ d = d_in; X.SetDims(n, n); long i, j; for (i = 1; i <= n; i++) for (j = 1; j <= n; j++) if (i == j) X(i, j) = d; else clear(X(i, j)); } long IsDiag(const mat_ZZ& A, long n, const ZZ& d){ if (A.NumRows() != n || A.NumCols() != n) return 0; long i, j; for (i = 1; i <= n; i++) for (j = 1; j <= n; j++) if (i != j) { if (!IsZero(A(i, j))) return 0; } else { if (A(i, j) != d) return 0; } return 1;}void solve(ZZ& d_out, vec_ZZ& x_out, const mat_ZZ& A, const vec_ZZ& b, long deterministic){ long n = A.NumRows(); if (A.NumCols() != n) Error("solve: nonsquare matrix"); if (b.length() != n) Error("solve: dimension mismatch"); if (n == 0) { set(d_out); x_out.SetLength(0); return; } zz_pBak zbak; zbak.save(); ZZ_pBak Zbak; Zbak.save(); vec_ZZ x(INIT_SIZE, n); ZZ d, d1; ZZ d_prod, x_prod; set(d_prod); set(x_prod); long d_instable = 1; long x_instable = 1; long check = 0; long gp_cnt = 0; vec_ZZ y, b1; long i; long bound = 2+DetBound(A); for (i = 0; ; i++) { if ((check || IsZero(d)) && !d_instable) { if (NumBits(d_prod) > bound) { break; } else if (!deterministic && bound > 1000 && NumBits(d_prod) < 0.25*bound) { ZZ P; long plen = 90 + NumBits(max(bound, NumBits(d))); GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++)); ZZ_p::init(P); mat_ZZ_p AA; conv(AA, A); ZZ_p dd; determinant(dd, AA); if (CRT(d, d_prod, rep(dd), P)) d_instable = 1; else break; } } zz_p::FFTInit(i); long p = zz_p::modulus(); mat_zz_p AA; conv(AA, A); if (!check) { vec_zz_p bb, xx; conv(bb, b); zz_p dd; solve(dd, xx, AA, bb); d_instable = CRT(d, d_prod, rep(dd), p); if (!IsZero(dd)) { mul(xx, xx, dd); x_instable = CRT(x, x_prod, xx); } else x_instable = 1; if (!d_instable && !x_instable) { mul(y, x, A); mul(b1, b, d); if (y == b1) { d1 = d; check = 1; } } } else { zz_p dd; determinant(dd, AA); d_instable = CRT(d, d_prod, rep(dd), p); } } if (check && d1 != d) { mul(x, x, d); ExactDiv(x, d1); } d_out = d; if (check) x_out = x; zbak.restore(); Zbak.restore();}void inv(ZZ& d_out, mat_ZZ& x_out, const mat_ZZ& A, long deterministic){ long n = A.NumRows(); if (A.NumCols() != n) Error("solve: nonsquare matrix"); if (n == 0) { set(d_out); x_out.SetDims(0, 0); return; } zz_pBak zbak; zbak.save(); ZZ_pBak Zbak; Zbak.save(); mat_ZZ x(INIT_SIZE, n, n); ZZ d, d1; ZZ d_prod, x_prod; set(d_prod); set(x_prod); long d_instable = 1; long x_instable = 1; long gp_cnt = 0; long check = 0; mat_ZZ y; long i; long bound = 2+DetBound(A); for (i = 0; ; i++) { if ((check || IsZero(d)) && !d_instable) { if (NumBits(d_prod) > bound) { break; } else if (!deterministic && bound > 1000 && NumBits(d_prod) < 0.25*bound) { ZZ P; long plen = 90 + NumBits(max(bound, NumBits(d))); GenPrime(P, plen, 90 + 2*NumBits(gp_cnt++)); ZZ_p::init(P); mat_ZZ_p AA; conv(AA, A); ZZ_p dd;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -