cholesky.hpp

来自「矩阵运算源码最新版本」· HPP 代码 · 共 580 行 · 第 1/2 页

HPP
580
字号
	void operator() (MatrixSE & SE, const MatrixSW & SW)	{	    tri_schur_base(SE, SW);	}    };    struct schur_update_base_t    {	template < typename MatrixNE, typename MatrixNW, typename MatrixSW >	void operator() (MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW)	{	    schur_update_base(NE, NW, SW);	}    };} // namespace with_iterator// ==================================// Functor types for Cholesky visitor// ==================================template <typename BaseTest, typename CholeskyBase, typename TriSolveBase, typename TriSchur, typename SchurUpdate>struct recursive_cholesky_visitor_t{    typedef  BaseTest                   base_test;    template < typename Recursator >     bool is_base(const Recursator& recursator) const    {	return base_test()(recursator);    }    template < typename Matrix >     void cholesky_base(Matrix & matrix) const    {	CholeskyBase()(matrix);    }    template < typename MatrixSW, typename MatrixNW >     void tri_solve_base(MatrixSW & SW, const MatrixNW & NW) const    {	TriSolveBase()(SW, NW);    }    template < typename MatrixSE, typename MatrixSW >     void tri_schur_base(MatrixSE & SE, const MatrixSW & SW) const    {	TriSchur()(SE, SW);    }    template < typename MatrixNE, typename MatrixNW, typename MatrixSW >    void schur_update_base(MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW) const    {	SchurUpdate()(NE, NW, SW);    }};namespace detail {    // Compute schur update with external multiplication; must have Assign == minus_mult_assign_t !!!    template <typename MatrixMult>    struct mult_schur_update_t    {	template < typename MatrixNE, typename MatrixNW, typename MatrixSW >	void operator()(MatrixNE & NE, const MatrixNW & NW, const MatrixSW & SW)	{	    transposed_view<MatrixSW> trans_sw(const_cast<MatrixSW&>(SW)); 	    MatrixMult()(NW, trans_sw, NE);	}    };} // detailnamespace with_bracket {    typedef recursive_cholesky_visitor_t<recursion::bound_test_static<64>, cholesky_base_t, tri_solve_base_t, 					 tri_schur_base_t, schur_update_base_t >                recursive_cholesky_base_visitor_t;}namespace with_iterator {    typedef recursive_cholesky_visitor_t<recursion::bound_test_static<64>, 					 cholesky_base_t, tri_solve_base_t, tri_schur_base_t, schur_update_base_t>               recursive_cholesky_base_visitor_t;}typedef with_bracket::recursive_cholesky_base_visitor_t                    recursive_cholesky_default_visitor_t;namespace with_recursator {    template <typename Recursator, typename Visitor>    void schur_update(Recursator E, Recursator W, Recursator N, Visitor vis)    {	using namespace recursion;	if (E.is_empty() || W.is_empty() || N.is_empty())	    return;	if (vis.is_base(E)) {	    typedef typename Visitor::base_test  base_test;	    typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type;	    	    matrix_type  base_E(base_case_cast<base_test>(E.get_value())), 		         base_W(base_case_cast<base_test>(W.get_value())),		         base_N(base_case_cast<base_test>(N.get_value()));	    vis.schur_update_base(base_E, base_W, base_N);	} else{	    schur_update(     E.north_east(),W.north_west()     ,N.south_west()     , vis);	    schur_update(     E.north_east(),     W.north_east(),     N.south_east(), vis);    	    schur_update(E.north_west()     ,     W.north_east(),     N.north_east(), vis);	    schur_update(E.north_west()     ,W.north_west()     ,N.north_west()     , vis);    	    schur_update(E.south_west()     ,W.south_west()     ,N.north_west()     , vis);	    schur_update(E.south_west()     ,     W.south_east(),     N.north_east(), vis);    	    schur_update(     E.south_east(),     W.south_east(),     N.south_east(), vis);	    schur_update(     E.south_east(),W.south_west()     ,N.south_west()     , vis);	}    }    template <typename Recursator, typename Visitor>    void tri_solve(Recursator S, Recursator N, Visitor vis)    {	using namespace recursion;        if (S.is_empty())	    return;        if (vis.is_base(S)) {   	    typedef typename Visitor::base_test  base_test;	    typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type;	    	    matrix_type  base_S(base_case_cast<base_test>(S.get_value())), 		         base_N(base_case_cast<base_test>(N.get_value()));	    vis.tri_solve_base(base_S, base_N);        } else{     	    tri_solve(S.north_west()     ,N.north_west(), vis);	    schur_update(  S.north_east(),S.north_west()     ,N.south_west(), vis);	    tri_solve(     S.north_east(),     N.south_east(), vis);	    tri_solve(S.south_west()     ,N.north_west()     , vis);	    schur_update(  S.south_east(),S.south_west()     ,N.south_west(), vis);	    tri_solve(     S.south_east(),     N.south_east(), vis);	}    }    template <typename Recursator, typename Visitor>    void tri_schur(Recursator E, Recursator W, Visitor vis)    { 	using namespace recursion;        if (E.is_empty() || W.is_empty())	    return;        if (vis.is_base(W)) {	    typedef typename Visitor::base_test  base_test;	    typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type;	    	    matrix_type  base_E(base_case_cast<base_test>(E.get_value())),                		 base_W(base_case_cast<base_test>(W.get_value()));	    vis.tri_schur_base(base_E, base_W);        } else{          	    schur_update(E.south_west(),     W.south_west(),    W.north_west(), vis);	    schur_update(E.south_west(),     W.south_east(),    W.north_east(), vis);	    tri_schur(   E.south_east()     ,     W.south_east(), vis);	    tri_schur(   E.south_east()     ,W.south_west()     , vis);	    tri_schur(        E.north_west(),     W.north_east(), vis);	    tri_schur(        E.north_west(),W.north_west()     , vis);        }    }    template <typename Recursator, typename Visitor>    void cholesky(Recursator recursator, Visitor vis)    {	using namespace recursion;        if (recursator.is_empty())	    return;        if (vis.is_base (recursator)){    	    typedef typename Visitor::base_test  base_test;	    typedef typename base_case_matrix<typename Recursator::matrix_type, base_test>::type matrix_type;	    	    matrix_type  base_matrix(base_case_cast<base_test>(recursator.get_value()));	    vis.cholesky_base (base_matrix);              } else {	    cholesky(recursator.north_west(), vis);	    tri_solve(    recursator.south_west(),       recursator.north_west(), vis);	    tri_schur(    recursator.south_east(), recursator.south_west(), vis);	    cholesky(     recursator.south_east(), vis);        }    }        } // namespace with_recursatortemplate <typename Backup= with_bracket::cholesky_base_t>struct recursive_cholesky_t{    template <typename Matrix>    void operator()(Matrix& matrix)    {	(*this)(matrix, recursive_cholesky_default_visitor_t());    }    template <typename Matrix, typename Visitor>    void operator()(Matrix& matrix, Visitor vis)    {	apply(matrix, vis, typename traits::category<Matrix>::type());    }    private:    // If the matrix is not sub-dividable then take backup function    template <typename Matrix, typename Visitor>    void apply(Matrix& matrix, Visitor, tag::universe)    {	Backup()(matrix);    }    // Only if matrix is sub-dividable, otherwise backup    template <typename Matrix, typename Visitor>    void apply(Matrix& matrix, Visitor vis, tag::qsub_dividable)    {	matrix_recursator<Matrix>  recursator(matrix);	with_recursator::cholesky(recursator, vis);    }};template <typename Matrix, typename Visitor>inline void recursive_cholesky(Matrix& matrix, Visitor vis){    recursive_cholesky_t<>()(matrix, vis);}template <typename Matrix>inline void recursive_cholesky(Matrix& matrix){    recursive_cholesky(matrix, recursive_cholesky_default_visitor_t());}template <typename Matrix>void fill_matrix_for_cholesky(Matrix& matrix){    typename Matrix::value_type   x= 1.0;     for (int i=0; i<matrix.num_rows(); i++)        for (int j=0; j<=i; j++)	   if (i != j) {	       matrix[i][j]= x; matrix[j][i]= x; 	       x=x+1.0; 	   }      typename Matrix::value_type    rowsum;    for (int i=0; i < matrix.num_rows(); i++) {	rowsum= 0.0;	for (int j=0; j<matrix.num_cols(); j++)	    if (i!=j)		rowsum += matrix[i][j]; 	matrix[i][i]=rowsum*2;    }       }} // namespace mtl#endif // MTL_CHOLESKY_INCLUDE

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?