cholesky.hpp
来自「矩阵运算源码最新版本」· HPP 代码 · 共 580 行 · 第 1/2 页
HPP
580 行
void operator() (MatrixSE & SE, const MatrixSW & SW) { tri_schur_base(SE, SW); } }; struct schur_update_base_t { template < typename MatrixNE, typename MatrixNW, typename MatrixSW > void operator() (MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW) { schur_update_base(NE, NW, SW); } };} // namespace with_iterator// ==================================// Functor types for Cholesky visitor// ==================================template <typename BaseTest, typename CholeskyBase, typename TriSolveBase, typename TriSchur, typename SchurUpdate>struct recursive_cholesky_visitor_t{ typedef BaseTest base_test; template < typename Recursator > bool is_base(const Recursator& recursator) const { return base_test()(recursator); } template < typename Matrix > void cholesky_base(Matrix & matrix) const { CholeskyBase()(matrix); } template < typename MatrixSW, typename MatrixNW > void tri_solve_base(MatrixSW & SW, const MatrixNW & NW) const { TriSolveBase()(SW, NW); } template < typename MatrixSE, typename MatrixSW > void tri_schur_base(MatrixSE & SE, const MatrixSW & SW) const { TriSchur()(SE, SW); } template < typename MatrixNE, typename MatrixNW, typename MatrixSW > void schur_update_base(MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW) const { SchurUpdate()(NE, NW, SW); }};namespace detail { // Compute schur update with external multiplication; must have Assign == minus_mult_assign_t !!! template <typename MatrixMult> struct mult_schur_update_t { template < typename MatrixNE, typename MatrixNW, typename MatrixSW > void operator()(MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW) { transposed_view<MatrixSW> trans_sw(const_cast<MatrixSW&>(SW)); MatrixMult()(NW, trans_sw, NE); } };} // detailnamespace with_bracket { typedef recursive_cholesky_visitor_t<recursion::bound_test_static<64>, cholesky_base_t, tri_solve_base_t, tri_schur_base_t, schur_update_base_t > recursive_cholesky_base_visitor_t;}namespace with_iterator { typedef recursive_cholesky_visitor_t<recursion::bound_test_static<64>, cholesky_base_t, tri_solve_base_t, tri_schur_base_t, schur_update_base_t> recursive_cholesky_base_visitor_t;}typedef with_bracket::recursive_cholesky_base_visitor_t recursive_cholesky_default_visitor_t;namespace with_recursator { template <typename Recursator, typename Visitor> void schur_update(Recursator E, Recursator W, Recursator N, Visitor vis) { using namespace recursion; if (E.is_empty() || W.is_empty() || N.is_empty()) return; if (vis.is_base(E)) { typedef typename Visitor::base_test base_test; typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type; matrix_type base_E(base_case_cast<base_test>(E.get_value())), base_W(base_case_cast<base_test>(W.get_value())), base_N(base_case_cast<base_test>(N.get_value())); vis.schur_update_base(base_E, base_W, base_N); } else{ schur_update( E.north_east(),W.north_west() ,N.south_west() , vis); schur_update( E.north_east(), W.north_east(), N.south_east(), vis); schur_update(E.north_west() , W.north_east(), N.north_east(), vis); schur_update(E.north_west() ,W.north_west() ,N.north_west() , vis); schur_update(E.south_west() ,W.south_west() ,N.north_west() , vis); schur_update(E.south_west() , W.south_east(), N.north_east(), vis); schur_update( E.south_east(), W.south_east(), N.south_east(), vis); schur_update( E.south_east(),W.south_west() ,N.south_west() , vis); } } template <typename Recursator, typename Visitor> void tri_solve(Recursator S, Recursator N, Visitor vis) { using namespace recursion; if (S.is_empty()) return; if (vis.is_base(S)) { typedef typename Visitor::base_test base_test; typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type; matrix_type base_S(base_case_cast<base_test>(S.get_value())), base_N(base_case_cast<base_test>(N.get_value())); vis.tri_solve_base(base_S, base_N); } else{ tri_solve(S.north_west() ,N.north_west(), vis); schur_update( S.north_east(),S.north_west() ,N.south_west(), vis); tri_solve( S.north_east(), N.south_east(), vis); tri_solve(S.south_west() ,N.north_west() , vis); schur_update( S.south_east(),S.south_west() ,N.south_west(), vis); tri_solve( S.south_east(), N.south_east(), vis); } } template <typename Recursator, typename Visitor> void tri_schur(Recursator E, Recursator W, Visitor vis) { using namespace recursion; if (E.is_empty() || W.is_empty()) return; if (vis.is_base(W)) { typedef typename Visitor::base_test base_test; typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type; matrix_type base_E(base_case_cast<base_test>(E.get_value())), base_W(base_case_cast<base_test>(W.get_value())); vis.tri_schur_base(base_E, base_W); } else{ schur_update(E.south_west(), W.south_west(), W.north_west(), vis); schur_update(E.south_west(), W.south_east(), W.north_east(), vis); tri_schur( E.south_east() , W.south_east(), vis); tri_schur( E.south_east() ,W.south_west() , vis); tri_schur( E.north_west(), W.north_east(), vis); tri_schur( E.north_west(),W.north_west() , vis); } } template <typename Recursator, typename Visitor> void cholesky(Recursator recursator, Visitor vis) { using namespace recursion; if (recursator.is_empty()) return; if (vis.is_base (recursator)){ typedef typename Visitor::base_test base_test; typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type; matrix_type base_matrix(base_case_cast<base_test>(recursator.get_value())); vis.cholesky_base (base_matrix); } else { cholesky(recursator.north_west(), vis); tri_solve( recursator.south_west(), recursator.north_west(), vis); tri_schur( recursator.south_east(), recursator.south_west(), vis); cholesky( recursator.south_east(), vis); } } } // namespace with_recursatortemplate <typename Backup= with_bracket::cholesky_base_t>struct recursive_cholesky_t{ template <typename Matrix> void operator()(Matrix& matrix) { (*this)(matrix, recursive_cholesky_default_visitor_t()); } template <typename Matrix, typename Visitor> void operator()(Matrix& matrix, Visitor vis) { apply(matrix, vis, typename traits::category<Matrix>::type()); } private: // If the matrix is not sub-dividable then take backup function template <typename Matrix, typename Visitor> void apply(Matrix& matrix, Visitor, tag::universe) { Backup()(matrix); } // Only if matrix is sub-dividable, otherwise backup template <typename Matrix, typename Visitor> void apply(Matrix& matrix, Visitor vis, tag::qsub_dividable) { matrix_recursator<Matrix> recursator(matrix); with_recursator::cholesky(recursator, vis); }};template <typename Matrix, typename Visitor>inline void recursive_cholesky(Matrix& matrix, Visitor vis){ recursive_cholesky_t<>()(matrix, vis);}template <typename Matrix>inline void recursive_cholesky(Matrix& matrix){ recursive_cholesky(matrix, recursive_cholesky_default_visitor_t());}template <typename Matrix>void fill_matrix_for_cholesky(Matrix& matrix){ typename Matrix::value_type x= 1.0; for (int i=0; i<matrix.num_rows(); i++) for (int j=0; j<=i; j++) if (i != j) { matrix[i][j]= x; matrix[j][i]= x; x=x+1.0; } typename Matrix::value_type rowsum; for (int i=0; i < matrix.num_rows(); i++) { rowsum= 0.0; for (int j=0; j<matrix.num_cols(); j++) if (i!=j) rowsum += matrix[i][j]; matrix[i][i]=rowsum*2; } }} // namespace mtl#endif // MTL_CHOLESKY_INCLUDE
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?