📄 mat_zz.c
字号:
determinant(dd, AA); if (CRT(d, d_prod, rep(dd), P)) d_instable = 1; else break; } } zz_p::FFTInit(i); long p = zz_p::modulus(); mat_zz_p AA; conv(AA, A); if (!check) { mat_zz_p xx; zz_p dd; inv(dd, xx, AA); d_instable = CRT(d, d_prod, rep(dd), p); if (!IsZero(dd)) { mul(xx, xx, dd); x_instable = CRT(x, x_prod, xx); } else x_instable = 1; if (!d_instable && !x_instable) { mul(y, x, A); if (IsDiag(y, n, d)) { d1 = d; check = 1; } } } else { zz_p dd; determinant(dd, AA); d_instable = CRT(d, d_prod, rep(dd), p); } } if (check && d1 != d) { mul(x, x, d); ExactDiv(x, d1); } d_out = d; if (check) x_out = x; zbak.restore(); Zbak.restore();}void negate(mat_ZZ& X, const mat_ZZ& A){ long n = A.NumRows(); long m = A.NumCols(); X.SetDims(n, m); long i, j; for (i = 1; i <= n; i++) for (j = 1; j <= m; j++) negate(X(i,j), A(i,j));}long IsZero(const mat_ZZ& a){ long n = a.NumRows(); long i; for (i = 0; i < n; i++) if (!IsZero(a[i])) return 0; return 1;}void clear(mat_ZZ& x){ long n = x.NumRows(); long i; for (i = 0; i < n; i++) clear(x[i]);}mat_ZZ operator+(const mat_ZZ& a, const mat_ZZ& b){ mat_ZZ res; add(res, a, b); NTL_OPT_RETURN(mat_ZZ, res);}mat_ZZ operator*(const mat_ZZ& a, const mat_ZZ& b){ mat_ZZ res; mul_aux(res, a, b); NTL_OPT_RETURN(mat_ZZ, res);}mat_ZZ operator-(const mat_ZZ& a, const mat_ZZ& b){ mat_ZZ res; sub(res, a, b); NTL_OPT_RETURN(mat_ZZ, res);}mat_ZZ operator-(const mat_ZZ& a){ mat_ZZ res; negate(res, a); NTL_OPT_RETURN(mat_ZZ, res);}vec_ZZ operator*(const mat_ZZ& a, const vec_ZZ& b){ vec_ZZ res; mul_aux(res, a, b); NTL_OPT_RETURN(vec_ZZ, res);}vec_ZZ operator*(const vec_ZZ& a, const mat_ZZ& b){ vec_ZZ res; mul_aux(res, a, b); NTL_OPT_RETURN(vec_ZZ, res);}void inv(mat_ZZ& X, const mat_ZZ& A){ ZZ d; inv(d, X, A); if (d == -1) negate(X, X); else if (d != 1) Error("inv: non-invertible matrix");}void power(mat_ZZ& X, const mat_ZZ& A, const ZZ& e){ if (A.NumRows() != A.NumCols()) Error("power: non-square matrix"); if (e == 0) { ident(X, A.NumRows()); return; } mat_ZZ T1, T2; long i, k; k = NumBits(e); T1 = A; for (i = k-2; i >= 0; i--) { sqr(T2, T1); if (bit(e, i)) mul(T1, T2, A); else T1 = T2; } if (e < 0) inv(X, T1); else X = T1;}/*********************************************************** routines for solving a linear system vi Hensel lifting************************************************************/staticlong MaxBits(const mat_ZZ& A){ long m = 0; long i, j; for (i = 0; i < A.NumRows(); i++) for (j = 0; j < A.NumCols(); j++) m = max(m, NumBits(A[i][j])); return m;}// Computes an upper bound on the numerators and denominators// to the solution x*A = b using Hadamard's bound and Cramer's rule. // If A contains a zero row, then sets both bounds to zero.staticvoid hadamard(ZZ& num_bound, ZZ& den_bound, const mat_ZZ& A, const vec_ZZ& b){ long n = A.NumRows(); if (n == 0) Error("internal error: hadamard with n = 0"); ZZ b_len, min_A_len, prod, t1; InnerProduct(min_A_len, A[0], A[0]); prod = min_A_len; long i; for (i = 1; i < n; i++) { InnerProduct(t1, A[i], A[i]); if (t1 < min_A_len) min_A_len = t1; mul(prod, prod, t1); } if (min_A_len == 0) { num_bound = 0; den_bound = 0; return; } InnerProduct(b_len, b, b); div(t1, prod, min_A_len); mul(t1, t1, b_len); SqrRoot(num_bound, t1); SqrRoot(den_bound, prod);}staticvoid MixedMul(vec_ZZ& x, const vec_zz_p& a, const mat_ZZ& B){ long n = B.NumRows(); long l = B.NumCols(); if (n != a.length()) Error("matrix mul: dimension mismatch"); x.SetLength(l); long i, k; ZZ acc, tmp; for (i = 1; i <= l; i++) { clear(acc); for (k = 1; k <= n; k++) { mul(tmp, B(k, i), rep(a(k))); add(acc, acc, tmp); } x(i) = acc; }} staticvoid SubDiv(vec_ZZ& e, const vec_ZZ& t, long p){ long n = e.length(); if (t.length() != n) Error("SubDiv: dimension mismatch"); ZZ s; long i; for (i = 0; i < n; i++) { sub(s, e[i], t[i]); div(e[i], s, p); }}staticvoid MulAdd(vec_ZZ& x, const ZZ& prod, const vec_zz_p& h){ long n = x.length(); if (h.length() != n) Error("MulAdd: dimension mismatch"); ZZ t; long i; for (i = 0; i < n; i++) { mul(t, prod, rep(h[i])); add(x[i], x[i], t); }}staticvoid double_MixedMul1(vec_ZZ& x, double *a, double **B, long n){ long i, k; double acc; for (i = 0; i < n; i++) { double *bp = B[i]; acc = 0; for (k = 0; k < n; k++) { acc += bp[k] * a[k]; } conv(x[i], acc); }} staticvoid double_MixedMul2(vec_ZZ& x, double *a, double **B, long n, long limit){ long i, k; double acc; ZZ acc1, t; long j; for (i = 0; i < n; i++) { double *bp = B[i]; clear(acc1); acc = 0; j = 0; for (k = 0; k < n; k++) { acc += bp[k] * a[k]; j++; if (j == limit) { conv(t, acc); add(acc1, acc1, t); acc = 0; j = 0; } } if (j > 0) { conv(t, acc); add(acc1, acc1, t); } x[i] = acc1; }} staticvoid long_MixedMul1(vec_ZZ& x, long *a, long **B, long n){ long i, k; long acc; for (i = 0; i < n; i++) { long *bp = B[i]; acc = 0; for (k = 0; k < n; k++) { acc += bp[k] * a[k]; } conv(x[i], acc); }} staticvoid long_MixedMul2(vec_ZZ& x, long *a, long **B, long n, long limit){ long i, k; long acc; ZZ acc1, t; long j; for (i = 0; i < n; i++) { long *bp = B[i]; clear(acc1); acc = 0; j = 0; for (k = 0; k < n; k++) { acc += bp[k] * a[k]; j++; if (j == limit) { conv(t, acc); add(acc1, acc1, t); acc = 0; j = 0; } } if (j > 0) { conv(t, acc); add(acc1, acc1, t); } x[i] = acc1; }} void solve1(ZZ& d_out, vec_ZZ& x_out, const mat_ZZ& A, const vec_ZZ& b){ long n = A.NumRows(); if (A.NumCols() != n) Error("solve1: nonsquare matrix"); if (b.length() != n) Error("solve1: dimension mismatch"); if (n == 0) { set(d_out); x_out.SetLength(0); return; } ZZ num_bound, den_bound; hadamard(num_bound, den_bound, A, b); if (den_bound == 0) { clear(d_out); return; } zz_pBak zbak; zbak.save(); long i; long j; ZZ prod; prod = 1; mat_zz_p B; for (i = 0; ; i++) { zz_p::FFTInit(i); mat_zz_p AA, BB; zz_p dd; conv(AA, A); inv(dd, BB, AA); if (dd != 0) { transpose(B, BB); break; } mul(prod, prod, zz_p::modulus()); if (prod > den_bound) { d_out = 0; return; } } long max_A_len = MaxBits(A); long use_double_mul1 = 0; long use_double_mul2 = 0; long double_limit = 0; if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_DOUBLE_PRECISION-1) use_double_mul1 = 1; if (!use_double_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_DOUBLE_PRECISION-1) { use_double_mul2 = 1; double_limit = (1L << (NTL_DOUBLE_PRECISION-1-max_A_len-NTL_SP_NBITS)); } long use_long_mul1 = 0; long use_long_mul2 = 0; long long_limit = 0; if (max_A_len + NTL_SP_NBITS + NumBits(n) <= NTL_BITS_PER_LONG-1) use_long_mul1 = 1; if (!use_long_mul1 && max_A_len+NTL_SP_NBITS+2 <= NTL_BITS_PER_LONG-1) { use_long_mul2 = 1; long_limit = (1L << (NTL_BITS_PER_LONG-1-max_A_len-NTL_SP_NBITS)); } if (use_double_mul1 && use_long_mul1) use_long_mul1 = 0; else if (use_double_mul1 && use_long_mul2) use_long_mul2 = 0; else if (use_double_mul2 && use_long_mul1) use_double_mul2 = 0; else if (use_double_mul2 && use_long_mul2) { if (long_limit > double_limit) use_double_mul2 = 0; else use_long_mul2 = 0; } double **double_A; double *double_h; typedef double *double_ptr; if (use_double_mul1 || use_double_mul2) { double_h = NTL_NEW_OP double[n]; double_A = NTL_NEW_OP double_ptr[n]; if (!double_h || !double_A) Error("solve1: out of mem"); for (i = 0; i < n; i++) { double_A[i] = NTL_NEW_OP double[n]; if (!double_A[i]) Error("solve1: out of mem"); } for (i = 0; i < n; i++) for (j = 0; j < n; j++) double_A[j][i] = to_double(A[i][j]); } long **long_A; long *long_h; typedef long *long_ptr; if (use_long_mul1 || use_long_mul2) { long_h = NTL_NEW_OP long[n]; long_A = NTL_NEW_OP long_ptr[n]; if (!long_h || !long_A) Error("solve1: out of mem"); for (i = 0; i < n; i++) { long_A[i] = NTL_NEW_OP long[n]; if (!long_A[i]) Error("solve1: out of mem"); } for (i = 0; i < n; i++) for (j = 0; j < n; j++) long_A[j][i] = to_long(A[i][j]); } vec_ZZ x; x.SetLength(n); vec_zz_p h; h.SetLength(n); vec_ZZ e; e = b; vec_zz_p ee; vec_ZZ t; t.SetLength(n); prod = 1; ZZ bound1; mul(bound1, num_bound, den_bound); mul(bound1, bound1, 2); while (prod <= bound1) { conv(ee, e); mul(h, B, ee); if (use_double_mul1) { for (i = 0; i < n; i++) double_h[i] = to_double(rep(h[i])); double_MixedMul1(t, double_h, double_A, n); } else if (use_double_mul2) { for (i = 0; i < n; i++) double_h[i] = to_double(rep(h[i])); double_MixedMul2(t, double_h, double_A, n, double_limit); } else if (use_long_mul1) { for (i = 0; i < n; i++) long_h[i] = to_long(rep(h[i])); long_MixedMul1(t, long_h, long_A, n); } else if (use_long_mul2) { for (i = 0; i < n; i++) long_h[i] = to_long(rep(h[i])); long_MixedMul2(t, long_h, long_A, n, long_limit); } else MixedMul(t, h, A); // t = h*A SubDiv(e, t, zz_p::modulus()); // e = (e-t)/p MulAdd(x, prod, h); // x = x + prod*h mul(prod, prod, zz_p::modulus()); } vec_ZZ num, denom; ZZ d, d_mod_prod, tmp1; num.SetLength(n); denom.SetLength(n); d = 1; d_mod_prod = 1; for (i = 0; i < n; i++) { rem(x[i], x[i], prod); MulMod(x[i], x[i], d_mod_prod, prod); if (!ReconstructRational(num[i], denom[i], x[i], prod, num_bound, den_bound)) Error("solve1 internal error: rat recon failed!"); mul(d, d, denom[i]); if (i != n-1) { if (denom[i] != 1) { div(den_bound, den_bound, denom[i]); mul(bound1, num_bound, den_bound); mul(bound1, bound1, 2); div(tmp1, prod, zz_p::modulus()); while (tmp1 > bound1) { prod = tmp1; div(tmp1, prod, zz_p::modulus()); } rem(tmp1, denom[i], prod); rem(d_mod_prod, d_mod_prod, prod); MulMod(d_mod_prod, d_mod_prod, tmp1, prod); } } } tmp1 = 1; for (i = n-1; i >= 0; i--) { mul(num[i], num[i], tmp1); mul(tmp1, tmp1, denom[i]); } x_out.SetLength(n); for (i = 0; i < n; i++) { x_out[i] = num[i]; } d_out = d; if (use_double_mul1 || use_double_mul2) { delete [] double_h; for (i = 0; i < n; i++) { delete [] double_A[i]; } delete [] double_A; } if (use_long_mul1 || use_long_mul2) { delete [] long_h; for (i = 0; i < n; i++) { delete [] long_A[i]; } delete [] long_A; }}NTL_END_IMPL
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -