📄 ide.h
字号:
// Now fill in the L and U matrices. L = A; for (uint i = 0; i < A.rows(); ++i) { for (uint j = i; j < A.rows(); ++j) { U(i,j) = A(i,j); L(i,j) = 0.; L(i,i) = 1.; } } // Change to scythe's rows-1 perm_vec format and c++ indexing // XXX Cutting off the last pivot term may be buggy if it isn't // always just pointing at itself if (perm_vec(perm_vec.size() - 1) != perm_vec.size()) SCYTHE_THROW(scythe_unexpected_default_error, "This is an unexpected error. Please notify the developers.") perm_vec = perm_vec(0, 0, perm_vec.rows() - 2, 0) - 1; // Finally, figure out the sign of perm_vec if (sum(perm_vec > 0) % 2 == 0) return 1; return -1; } /*! \brief The result of a QR decomposition. * * Objects of this type contain three matrices, \a QR, \a tau, and * \a pivot, representing the results of a QR decomposition of a * \f$m \times n\f$ matrix. After decomposition, the upper triangle * of \a QR contains the min(\f$m\f$, \f$n\f$) by \f$n\f$ upper * trapezoidal matrix \f$R\f$, while \a tau and the elements of \a * QR below the diagonal represent the orthogonal matrix \f$Q\f$ as * a product of min(\f$m\f$, \f$n\f$) elementary reflectors. The * vector \a pivot is a permutation vector containing information * about the pivoting strategy used in the factorization. * * \a QR is \f$m \times n\f$, tau is a vector of dimension * min(\f$m\f$, \f$n\f$), and pivot is a vector of dimension * \f$n\f$. * * \see qr_decomp (const Matrix<>& A) */ struct QRdecomp { Matrix<> QR; Matrix<> tau; Matrix<> pivot; }; /*! \brief QR decomposition of a matrix. * * This function performs QR decomposition. That is, given a * \f$m \times n \f$ matrix \a A, qr_decomp computes the QR factorization * of \a A with column pivoting, such that \f$A \cdot P = Q \cdot * R\f$. The resulting QRdecomp object contains three matrices, \a * QR, \a tau, and \a pivot. The upper triangle of \a QR contains the * min(\f$m\f$, \f$n\f$) by \f$n\f$ upper trapezoidal matrix * \f$R\f$, while \a tau and the elements of \a QR below the * diagonal represent the orthogonal matrix \f$Q\f$ as a product of * min(\f$m\f$, \f$n\f$) elementary reflectors. The vector \a pivot * is a permutation vector containing information about the pivoting * strategy used in the factorization. * * \note This function requires BLAS/LAPACK functionality and is * only available on machines that provide these libraries. Make * sure you enable the SCYTHE_LAPACK preprocessor flag if you wish * to use this function. Furthermore, note that this function takes * and returns only column-major concrete matrices. Future versions * of Scythe will provide a native C++ implementation of this * function with support for general matrix templates. * * \param A A matrix to decompose. * * \see QRdecomp * \see lu_decomp(Matrix<T,PO1,PS1>, Matrix<T,PO2,Concrete>&, Matrix<T,PO3,Concrete>&, Matrix<unsigned int, PO4, Concrete>&) * \see cholesky (const Matrix<T, PO, PS>&) * \see qr_solve (const Matrix<>& A, const Matrix<>& b, const QRdecomp& QR) * \see qr_solve (const Matrix<>& A, const Matrix<>& b); * * \throw scythe_null_error (Level 1) * \throw scythe_lapack_internal_error (Level 1) */ inline QRdecomp qr_decomp (const Matrix<>& A) { SCYTHE_DEBUG_MSG("Using lapack/blas for qr_decomp"); SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL"); // Set up working variables Matrix<> QR = A; double* QRarray = QR.getArray(); // input/output array pointer int rows = (int) QR.rows(); int cols = (int) QR.cols(); Matrix<unsigned int> pivot(cols, 1); // pivot vector int* parray = (int*) pivot.getArray(); // pivot vector array pointer Matrix<> tau = Matrix<>(rows < cols ? rows : cols, 1); double* tarray = tau.getArray(); // tau output array pointer double tmp, *work; // workspace vars int lwork, info; // workspace size var and error info var // Get workspace size lwork = -1; lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, &tmp, &lwork, &info); SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dgeqp3"); lwork = (int) tmp; work = new double[lwork]; // run the routine for real lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, work, &lwork, &info); SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dgeqp3"); delete[] work; pivot -= 1; QRdecomp result; result.QR = QR; result.tau = tau; result.pivot = pivot; return result; } /*! \brief Solve \f$Ax=b\f$ given a QR decomposition. * * This function solves the system of equations \f$Ax = b\f$ using * the results of a QR decomposition. This function requires the * actual QR decomposition to be performed ahead of time; by * qr_decomp() for example. * * This function is intended for repeatedly solving systems of * equations based on \a A. That is \a A stays constant while \a b * varies. * * \note This function requires BLAS/LAPACK functionality and is * only available on machines that provide these libraries. Make * sure you enable the SCYTHE_LAPACK preprocessor flag if you wish * to use this function. Furthermore, note that this function takes * and returns only column-major concrete matrices. Future versions * of Scythe will provide a native C++ implementation of this * function with support for general matrix templates. * * \param A A Matrix to decompose. * \param b A Matrix with as many rows as \a A. * \param QR A QRdecomp object containing the result of the QR decomposition of \a A. * * \see QRdecomp * \see qr_solve (const Matrix<>& A, const Matrix<>& b) * \see qr_decomp (const Matrix<>& A) * \see lu_solve (const Matrix<T,PO1,PS1>&, const Matrix<T,PO2,PS2>&, const Matrix<T,PO3,PS3>&, const Matrix<T,PO4,PS4>&, const Matrix<unsigned int, PO5, PS5>&) * \see lu_solve (Matrix<T,PO1,PS1>, const Matrix<T,PO2,PS2>&) * \see chol_solve(const Matrix<T,PO1,PS1> &, const Matrix<T,PO2,PS2> &) * \see chol_solve(const Matrix<T,PO1,PS1> &, const Matrix<T,PO2,PS2> &, const Matrix<T,PO3,PS3> &) * * \throw scythe_null_error (Level 1) * \throw scythe_conformation_error (Level 1) * \throw scythe_type_error (Level 1) * \throw scythe_lapack_internal_error (Level 1) */ inline Matrix<> qr_solve(const Matrix<>& A, const Matrix<>& b, const QRdecomp& QR) { SCYTHE_DEBUG_MSG("Using lapack/blas for qr_solve"); SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, "A and b do not conform"); SCYTHE_CHECK_10(A.rows() != QR.QR.rows() || A.cols() != QR.QR.cols(), scythe_conformation_error, "A and QR do not conform"); int taudim = (int) (A.rows() < A.cols() ? A.rows() : A.cols()); SCYTHE_CHECK_10(QR.tau.size() != taudim, scythe_conformation_error, "A and tau do not conform"); SCYTHE_CHECK_10(QR.pivot.size() != A.cols(), scythe_conformation_error, "pivot vector is not the right length"); int rows = (int) QR.QR.rows(); int cols = (int) QR.QR.cols(); int nrhs = (int) b.cols(); int lwork, info; double *work, tmp; double* QRarray = QR.QR.getArray(); double* tarray = QR.tau.getArray(); Matrix<> bb = b; double* barray = bb.getArray(); // Get workspace size lwork = -1; lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows, tarray, barray, &rows, &tmp, &lwork, &info); SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dormqr"); // And now for real lwork = (int) tmp; work = new double[lwork]; lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows, tarray, barray, &rows, work, &lwork, &info); SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dormqr"); lapack::dtrtrs_("U", "N", "N", &taudim, &nrhs, QRarray, &rows, barray, &rows, &info); SCYTHE_CHECK_10(info > 0, scythe_type_error, "Matrix is singular"); SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dtrtrs"); delete[] work; Matrix<> result(A.cols(), b.cols(), false); for (uint i = 0; i < QR.pivot.size(); ++i) result(i, _) = bb((uint) QR.pivot(i), _); return result; } /*! \brief Solve \f$Ax=b\f$ using QR decomposition. * * This function solves the system of equations \f$Ax = b\f$ using * QR decomposition. This function is intended for repeatedly * solving systems of equations based on \a A. That is \a A stays * constant while \a b varies. * * \note This function used BLAS/LAPACK support functionality and is * only available on machines that provide these libraries. Make * sure you enable the SCYTHE_LAPACK preprocessor flag if you wish * to use this function. Furthermore, note that the function takes * and returns only column-major concrete matrices. Future versions * of Scythe will provide a native C++ implementation of this * function with support for general matrix templates. * * \param A A Matrix to decompose. * \param b A Matrix with as many rows as \a A. * * \see QRdecomp * \see qr_solve (const Matrix<>& A, const Matrix<>& b, const QRdecomp& QR) * \see qr_decomp (const Matrix<>& A) * \see lu_solve (const Matrix<T,PO1,PS1>&, const Matrix<T,PO2,PS2>&, const Matrix<T,PO3,PS3>&, const Matrix<T,PO4,PS4>&, const Matrix<unsigned int, PO5, PS5>&) * \see lu_solve (Matrix<T,PO1,PS1>, const Matrix<T,PO2,PS2>&) * \see chol_solve(const Matrix<T,PO1,PS1> &, const Matrix<T,PO2,PS2> &) * \see chol_solve(const Matrix<T,PO1,PS1> &, const Matrix<T,PO2,PS2> &, const Matrix<T,PO3,PS3> &) * * \throw scythe_null_error (Level 1) * \throw scythe_conformation_error (Level 1) * \throw scythe_type_error (Level 1) * \throw scythe_lapack_internal_error (Level 1) */ inline Matrix<> qr_solve (const Matrix<>& A, const Matrix<>& b) { SCYTHE_DEBUG_MSG("Using lapack/blas for qr_solve"); SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") SCYTHE_CHECK_10(A.rows() != b.rows(), scythe_conformation_error, "A and b do not conform"); /* Do decomposition */ // Set up working variables Matrix<> QR = A; double* QRarray = QR.getArray(); // input/output array pointer int rows = (int) QR.rows(); int cols = (int) QR.cols(); Matrix<unsigned int> pivot(cols, 1); // pivot vector int* parray = (int*) pivot.getArray(); // pivot vector array pointer Matrix<> tau = Matrix<>(rows < cols ? rows : cols, 1); double* tarray = tau.getArray(); // tau output array pointer double tmp, *work; // workspace vars int lwork, info; // workspace size var and error info var // Get workspace size lwork = -1; lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, &tmp, &lwork, &info); SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dgeqp3"); lwork = (int) tmp; work = new double[lwork]; // run the routine for real lapack::dgeqp3_(&rows, &cols, QRarray, &rows, parray, tarray, work, &lwork, &info); SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dgeqp3"); delete[] work; pivot -= 1; /* Now solve the system */ // working vars int nrhs = (int) b.cols(); Matrix<> bb = b; double* barray = bb.getArray(); int taudim = (int) tau.size(); // Get workspace size lwork = -1; lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows, tarray, barray, &rows, &tmp, &lwork, &info); SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dormqr"); // And now for real lwork = (int) tmp; work = new double[lwork]; lapack::dormqr_("L", "T", &rows, &nrhs, &taudim, QRarray, &rows, tarray, barray, &rows, work, &lwork, &info); SCYTHE_CHECK_10(info != 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dormqr"); lapack::dtrtrs_("U", "N", "N", &taudim, &nrhs, QRarray, &rows, barray, &rows, &info); SCYTHE_CHECK_10(info > 0, scythe_type_error, "Matrix is singular"); SCYTHE_CHECK_10(info < 0, scythe_lapack_internal_error, "Internal error in LAPACK routine dtrtrs"); delete[] work; Matrix<> result(A.cols(), b.cols(), false); for (uint i = 0; i < pivot.size(); ++i) result(i, _) = bb(pivot(i), _); return result; } template<> inline Matrix<> invpd (const Matrix<>& A) { SCYTHE_DEBUG_MSG("Using lapack/blas for invpd"); SCYTHE_CHECK_10(A.isNull(), scythe_null_error, "A is NULL") SCYTHE_CHECK_10 (! A.isSquare(), scythe_dimension_error, "A is not square"); // We have to do an explicit copy within the func to match the
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -