📄 ide.h
字号:
} /*!\brief Calculates the inverse of a non-singular square matrix, * given an LU decomposition. * * This function returns the inverse of an arbitrary, non-singular, * square matrix \a A when passed a permutation of an LU * decomposition, such as that returned by lu_decomp(). A * one-parameter version of this function exists that does not * require the user to pre-decompose the system. * * \param A The Matrix to be inverted. * \param L A Lower triangular matrix resulting from decomposition. * \param U An Upper triangular matrix resulting from decomposition. * \param perm_vec The permutation vector recording the row-wise permutation of \a A actually decomposed by the algorithm. * * \see inv (const Matrix<T, PO, PS>&) * \see invpd(const Matrix<T, PO, PS>&) * \see invpd(const Matrix<T, PO1, PS1>&, const Matrix<T, PO2, PS2>&) * \see lu_decomp(Matrix<T,PO1,PS1>, Matrix<T,PO2,Concrete>&, Matrix<T,PO3,Concrete>&, Matrix<unsigned int, PO4, Concrete>&) * * \throw scythe_null_error(Level 1) * \throw scythe_dimension_error (Level 1) * \throw scythe_conformation_error (Level 1) */ template<matrix_order RO, matrix_style RS, typename T, matrix_order PO1, matrix_style PS1, matrix_order PO2, matrix_style PS2, matrix_order PO3, matrix_style PS3, matrix_order PO4, matrix_style PS4> Matrix<T,RO,RS> inv (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& L, const Matrix<T,PO3,PS3>& U, const Matrix<unsigned int,PO4,PS4>& perm_vec) { SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error, "A is not square"); SCYTHE_CHECK_10(A.rows() != L.rows() || A.rows() != U.rows() || A.cols() != L.cols() || A.cols() != U.cols(), scythe_conformation_error, "A, L, and U do not conform"); SCYTHE_CHECK_10(perm_vec.rows() + 1 != A.rows() && !(A.isScalar() && perm_vec.isScalar()), scythe_conformation_error, "perm_vec does not have exactly one less row than A"); // For the final result Matrix<T,RO,Concrete> Ainv(A.rows(), A.rows(), false); // for the solve block T *y = new T[A.rows()]; T *x = new T[A.rows()]; Matrix<T, RO, Concrete> b(A.rows(), 1); // full of zeros Matrix<T,RO,Concrete> bb; for (uint k = 0; k < A.rows(); ++k) { b[k] = (T) 1; bb = row_interchange(b, perm_vec); solve(L, U, bb, x, y); b[k] = (T) 0; for (uint l = 0; l < A.rows(); ++l) Ainv(l,k) = x[l]; } delete[] y; delete[] x; SCYTHE_VIEW_RETURN(T, RO, RS, Ainv) } template<typename T, matrix_order PO1, matrix_style PS1, matrix_order PO2, matrix_style PS2, matrix_order PO3, matrix_style PS3, matrix_order PO4, matrix_style PS4> Matrix<T,PO1,Concrete> inv (const Matrix<T,PO1,PS1>& A, const Matrix<T,PO2,PS2>& L, const Matrix<T,PO3,PS3>& U, const Matrix<unsigned int,PO4,PS4>& perm_vec) { return inv<PO1,Concrete>(A, L, U, perm_vec); } /*!\brief Invert an arbitrary, non-singular, square matrix. * * This function returns the inverse of a non-singular square matrix, * using lu_decomp() to do the necessary decomposition. This method * is significantly slower than the inverse function for symmetric * positive definite matrices, invpd(). * * \param A The Matrix to be inverted. * * \see inv (const Matrix<T,PO1,PS1>&, const Matrix<T,PO2,PS2>&, const Matrix<T,PO3,PS3>&, const Matrix<unsigned int,PO4,PS4>&) * \see invpd(const Matrix<T, PO, PS>&) * \see invpd(const Matrix<T, PO1, PS1>&, const Matrix<T, PO2, PS2>&) * * \throw scythe_null_error(Level 1) * \throw scythe_dimension_error (Level 1) * \throw scythe_conformation_error (Level 1) * \throw scythe_type_error (Level 2) */ template <matrix_order RO, matrix_style RS, typename T, matrix_order PO, matrix_style PS> Matrix<T, RO, RS> inv (const Matrix<T, PO, PS>& A) { // Make a copy of A for the decomposition (do it with an explicit // copy to a concrete case A is a view) Matrix<T,RO,Concrete> AA = A; // step 1 compute the LU factorization Matrix<T, RO, Concrete> L, U; Matrix<uint, RO, Concrete> perm_vec; lu_decomp_alg(AA, L, U, perm_vec); return inv<RO,RS>(A, L, U, perm_vec); } template <typename T, matrix_order O, matrix_style S> Matrix<T, O, Concrete> inv (const Matrix<T, O, S>& A) { return inv<O,Concrete>(A); } /* Interchanges the rows of A with those in vector p */ /*!\brief Interchange the rows of a Matrix according to a * permutation vector. * * This function permutes the rows of Matrix \a A according to \a * perm_vec. Each element i of perm_vec contains a row-number, r. * For each row, i, in \a A, A[i] is interchanged with A[r]. * * \param A The matrix to permute. * \param p The column vector describing the permutations to perform * on \a A. * * \see lu_decomp(Matrix<T,PO1,PS1>, Matrix<T,PO2,Concrete>&, Matrix<T,PO3,Concrete>&, Matrix<unsigned int, PO4, Concrete>&) * * \throw scythe_dimension_error (Level 1) * \throw scythe_conformation_error (Level 1) */ template <matrix_order RO, matrix_style RS, typename T, matrix_order PO1, matrix_style PS1, matrix_order PO2, matrix_style PS2> Matrix<T,RO,RS> row_interchange (Matrix<T,PO1,PS1> A, const Matrix<unsigned int,PO2,PS2>& p) { SCYTHE_CHECK_10(! p.isColVector(), scythe_dimension_error, "p not a column vector"); SCYTHE_CHECK_10(p.rows() + 1 != A.rows() && ! p.isScalar(), scythe_conformation_error, "p must have one less row than A"); for (uint i = 0; i < A.rows() - 1; ++i) { Matrix<T,PO1,View> vec1 = A(i, _); Matrix<T,PO1,View> vec2 = A(p[i], _); std::swap_ranges(vec1.begin_f(), vec1.end_f(), vec2.begin_f()); } return A; } template <typename T, matrix_order PO1, matrix_style PS1, matrix_order PO2, matrix_style PS2> Matrix<T,PO1,Concrete> row_interchange (const Matrix<T,PO1,PS1>& A, const Matrix<unsigned int,PO2,PS2>& p) { return row_interchange<PO1,Concrete>(A, p); } /*! \brief Calculate the determinant of a square Matrix. * * This routine calculates the determinant of a square Matrix, using * LU decomposition. * * \param A The Matrix to calculate the determinant of. * * \see lu_decomp(Matrix<T,PO1,PS1>, Matrix<T,PO2,Concrete>&, Matrix<T,PO3,Concrete>&, Matrix<unsigned int, PO4, Concrete>&) * * \throws scythe_dimension_error (Level 1) * \throws scythe_null_error (Level 1) */ template <typename T, matrix_order PO, matrix_style PS> T det (const Matrix<T, PO, PS>& A) { SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error, "Matrix is not square") SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "Matrix is NULL") // Make a copy of A for the decomposition (do it here instead of // at parameter pass in case A is a view) Matrix<T,PO,Concrete> AA = A; // step 1 compute the LU factorization Matrix<T, PO, Concrete> L, U; Matrix<uint, PO, Concrete> perm_vec; T sign = lu_decomp_alg(AA, L, U, perm_vec); // step 2 calculate the product of diag(U) and sign T det = (T) 1; for (uint i = 0; i < AA.rows(); ++i) det *= AA(i, i); return sign * det; }#ifdef SCYTHE_LAPACK template<> inline Matrix<> cholesky (const Matrix<>& A) { SCYTHE_DEBUG_MSG("Using lapack/blas for cholesky"); SCYTHE_CHECK_10(! A.isSquare(), scythe_dimension_error, "Matrix not square"); SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "Matrix is NULL"); // We have to do an explicit copy within the func to match the // template declaration of the more general template. Matrix<> AA = A; // Get a pointer to the internal array and set up some vars double* Aarray = AA.getArray(); // internal array pointer int rows = (int) AA.rows(); // the dim of the matrix int err = 0; // The output error condition // Cholesky decomposition step lapack::dpotrf_("L", &rows, Aarray, &rows, &err); SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is not positive definite") SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, "The " << err << "th value of the matrix had an illegal value") // Zero out upper triangle for (uint j = 1; j < AA.cols(); ++j) for (uint i = 0; i < j; ++i) AA(i, j) = 0; return AA; } template<> inline Matrix<> chol_solve (const Matrix<>& A, const Matrix<>& b, const Matrix<>& M) { SCYTHE_DEBUG_MSG("Using lapack/blas for chol_solve"); SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error, "b must be a column vector"); SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, "A and b do not conform"); SCYTHE_CHECK_10(A.rows() != M.rows(), scythe_conformation_error, "A and M do not conform"); SCYTHE_CHECK_10(! M.isSquare(), scythe_dimension_error, "M must be square"); // The algorithm modifies b in place. We make a copy. Matrix<> bb = b; // Get array pointers and set up some vars const double* Marray = M.getArray(); double* barray = bb.getArray(); int rows = (int) bb.rows(); int cols = (int) bb.cols(); // currently always one, but generalizable int err = 0; // Solve the system lapack::dpotrs_("L", &rows, &cols, Marray, &rows, barray, &rows, &err); SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is not positive definite") SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, "The " << err << "th value of the matrix had an illegal value") return bb; } template<> inline Matrix<> chol_solve (const Matrix<>& A, const Matrix<>& b) { SCYTHE_DEBUG_MSG("Using lapack/blas for chol_solve"); SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") SCYTHE_CHECK_10(! b.isColVector(), scythe_dimension_error, "b must be a column vector"); SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, "A and b do not conform"); // The algorithm modifies both A and b in place, so we make copies Matrix<> AA =A; Matrix<> bb = b; // Get array pointers and set up some vars double* Aarray = AA.getArray(); double* barray = bb.getArray(); int rows = (int) bb.rows(); int cols = (int) bb.cols(); // currently always one, but generalizable int err = 0; // Solve the system lapack::dposv_("L", &rows, &cols, Aarray, &rows, barray, &rows, &err); SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is not positive definite") SCYTHE_CHECK_10(err < 0, scythe_invalid_arg, "The " << err << "th value of the matrix had an illegal value") return bb; } template <matrix_order PO2, matrix_order PO3, matrix_order PO4> inline double lu_decomp_alg(Matrix<>& A, Matrix<double,PO2,Concrete>& L, Matrix<double,PO3,Concrete>& U, Matrix<unsigned int, PO4, Concrete>& perm_vec) { SCYTHE_DEBUG_MSG("Using lapack/blas for lu_decomp_alg"); SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error, "A is not square"); if (A.isRowVector()) { L = Matrix<double,PO2,Concrete> (1, 1, true, 1); // all 1s U = A; perm_vec = Matrix<uint, PO4, Concrete>(1, 1); // all 0s return 0.; } L = U = Matrix<double, PO2, Concrete>(A.rows(), A.cols(), false); perm_vec = Matrix<uint, PO3, Concrete> (A.rows(), 1, false); // Get a pointer to the internal array and set up some vars double* Aarray = A.getArray(); // internal array pointer int rows = (int) A.rows(); // the dim of the matrix int* ipiv = (int*) perm_vec.getArray(); // Holds the lu decomp pivot array int err = 0; // The output error condition // Do the decomposition lapack::dgetrf_(&rows, &rows, Aarray, &rows, ipiv, &err); SCYTHE_CHECK_10(err > 0, scythe_type_error, "Matrix is singular"); SCYTHE_CHECK_10(err < 0, scythe_lapack_internal_error, "The " << err << "th value of the matrix had an illegal value");
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -