📄 algo_test.h
字号:
} T tmp = x[i] * y[j] + x[j] * y[i]; AA(i, j) += tmp; } // compare A's pass = mtl::matrix_equal(A, AA); if (pass) cout << test.c_str() << " passed rank two " << endl; else { cout << "*** rank two " << test.c_str() << " failed" << endl;#if !defined(_MSVCPP_) cout << "result" << endl; mtl::print_all_matrix(A); cout << "correct" << endl; mtl::print_all_matrix(AA);#endif }}template <class Matrix>void test_ranktwo(std::string test, Matrix&, banded_tag){ cout << test.c_str() << " skipping rank two update" << endl;}template <class Matrix>void test_ranktwo(std::string test, Matrix& A, symmetric_tag){ if (A.super() == int(A.nrows()) - 1) test_ranktwo(test, A, rectangle_tag()); else cout << test.c_str() << " skipping rank two update (symm banded)" << endl;}template <class Matrix>void test_ranktwo(std::string test, Matrix& A){ typedef typename mtl::matrix_traits<Matrix>::shape Shape; test_ranktwo(test, A, Shape());}template <class Matrix, class Vector>void simple_tri_solve(const Matrix& A, Vector& X, bool is_unit, bool is_upper){ // this algorithm cooresponds to netlib dtrsv typedef typename mtl::matrix_traits<Matrix>::value_type T; typedef typename mtl::matrix_traits<Matrix>::size_type Int; char DIAG, UPLO; T TEMP; int I, J; Int N = A.ncols(); if (is_unit) DIAG = 'U'; else DIAG = 'N'; if (is_upper) UPLO = 'U'; else UPLO = 'L'; bool NOUNIT = DIAG == 'N'; if (UPLO == 'U') { for (J = N - 1; J >= 0; --J) { if (X[J] != T(0)) { if (NOUNIT) X[J] = X[J]/A(J,J); TEMP = X[J]; for (I = J - 1; I >= 0; --I) X[I] = X[I] - TEMP * A(I,J); } } } else { for (J = 0; J < int(N); ++J) { if (X[J] != T(0) ) { if (NOUNIT) X[J] = X[J]/A(J, J); TEMP = X[J]; for (I = J + 1; I < int(N); ++I) X[I] = X[I] - TEMP*A(I, J); } } }}template <class MatrixA>void test_tri_solve(std::string test, MatrixA& A, mtl::triangle_tag){ if (A.nrows() != A.ncols()) { cout << test.c_str() << " skipping tri solve (A must be N x N)" << endl; return; } bool success = true; typedef typename mtl::matrix_traits<MatrixA>::value_type T; typedef typename mtl::matrix_traits<MatrixA>::size_type Int; dense1D<T> x(A.ncols()), y(A.ncols()); mtl::matrix<T>::type C(A.nrows(), A.ncols()); mtl::set_value(A, T()); mtl::set_value(C, T()); Int i,j; if (A.is_upper()) { for (i = 0; i < A.nrows(); ++i) { for (j = i; j < A.ncols(); ++j) { if ((!A.is_unit() && i == j) || j - i == 1) { A(i,j) = T(1.5); } else if (j > i) { T t = T(i + j + 1); A(i,j) = t; } } } } else { // lower for (i = 0; i < A.nrows(); ++i) { for (j = 0; j <= i; ++j) { if ((!A.is_unit() && i == j) || i - j == 1) { A(i,j) = T(1.5); } else if (j < i) { T t = T(i + j + 1); A(i,j) = t; } } } } mtl::copy(A, C); mtl::set_value(x, T(1)); mtl::set_value(y, T(1)); // MTL mtl::tri_solve(A, x); simple_tri_solve(C, y, A.is_unit(), A.is_upper()); // COMPARE bool isequal = true; for (i = 0; i < A.ncols(); ++i) if (MTL_ABS(x[i] - y[i]) > MTL_MIN(MTL_ABS(x[i]),MTL_ABS(y[i])) / 100) {#if !defined(_MSVCPP_) cout << x[i] << " != " << y[i] << endl;#endif isequal = false; break; } if (! isequal) { cout << "*** " << test.c_str() << " failed tri_solve(A,x,y)" << endl;#if !defined(_MSVCPP_) cout << "result vector "; mtl::print_vector(x); cout << "correct vector "; mtl::print_vector(y);#endif success = false; } if (success) cout << test.c_str() << " passed matvec tri_solve" << endl; }template <class MatrixA>void test_tri_solve(std::string test, MatrixA&, mtl::rectangle_tag){ cout << test.c_str() << " skipping tri solve" << endl;}template <class MatrixA>void test_tri_solve(std::string test, MatrixA&, mtl::banded_tag){ cout << test.c_str() << " skipping tri solve" << endl;}template <class MatrixA>void test_tri_solve(std::string test, MatrixA&, mtl::symmetric_tag){ cout << test.c_str() << " skipping tri solve" << endl;}template <class MatrixA>void test_tri_solve(std::string test, MatrixA& A){ typedef typename mtl::matrix_traits<MatrixA>::shape Shape; test_tri_solve(test, A, Shape());}template <class MatA, class MatC, class MatC2>bool test_add(std::string test, MatA& A, MatC& C, MatC2& C2){ typedef typename mtl::matrix_traits<MatA>::size_type Int; typedef typename mtl::matrix_traits<MatA>::value_type T; Int i, j; bool passed = true; mtl::add(A, C); for (i = 0; i < A.nrows(); ++i) { Int first = MTL_MAX(0, int(i) - int(A.sub())); Int last = MTL_MIN(int(A.ncols()), int(i) + int(A.super()) + 1); for (j = 0; j < A.ncols(); ++j) if (j >= first && j < last) C2(i,j) = C2(i,j) + A(i,j); } if (A.is_unit()) for (Int i = 0; i < MTL_MIN(A.nrows(), A.ncols()); ++i) C2(i,i) = C2(i,i) + T(1); if (! mtl::matrix_equal(C, C2)) { passed = false; cout << test.c_str() << " failed add" << endl;#ifndef _MSVCPP_ cout << "result" << endl; mtl::print_all_matrix(C); cout << "correct" << endl; mtl::print_all_matrix(C2);#endif } cout << test.c_str() << " passed mat-mat add" << endl; return passed;}/* * matmat algorithms */template <class MatA, class MatB, class MatC, class MatC2>void test_mult(std::string test, MatA& A, MatB& B, MatC& C, MatC2& C2){ typedef typename mtl::matrix_traits<MatA>::size_type Int; mtl::mult(A, B, C); Int M = C.nrows(); Int N = C.ncols(); Int K = A.ncols(); for (Int i = 0; i != M; ++i) for (Int j = 0; j != N; ++j) for (Int k = 0; k != K; ++k) { Int first = MTL_MAX(0, int(i) - int(A.sub())); Int last = MTL_MIN(int(A.ncols()), int(i) + int(A.super()) + 1); if (k >= first && k < last) C2(i,j) += A(i,k) * B(k,j); } if (A.is_unit()) { Int M = MTL_MIN(A.nrows(), A.ncols()); Int N = B.ncols(); for (Int i = 0; i < M; ++i) for (Int j = 0; j < N; ++j) C2(i,j) += B(i,j); } if (! mtl::matrix_equal(C, C2)) { cout << test.c_str() << " failed mult" << endl;#if !defined(_MSVCPP_) cout << "result" << endl; mtl::print_all_matrix(C); cout << "correct" << endl; mtl::print_all_matrix(C2);#endif } else cout << test.c_str() << " passed mat-mat multiply test" << endl;}template <class MatA, class MatC, class MatC2>void test_copy(std::string test, MatA& A, MatC& C, MatC2& C2){ typedef typename mtl::matrix_traits<MatA>::size_type Int; mtl::copy(A, C); for (Int i = 0; i < A.nrows(); ++i) for (Int j = 0; j < A.ncols(); ++j) { Int first = MTL_MAX(0, int(i) - int(A.sub())); Int last = MTL_MIN(int(A.ncols()), int(i) + int(A.super()) + 1); if (j >= first && j < last) C2(i,j) = A(i,j); } if (A.is_unit()) mtl::set_diagonal(C2,1); if (! mtl::matrix_equal(C, C2)) { cout << test.c_str() << " failed copy" << endl;#if !defined(_MSVCPP_) cout << "result" << endl; mtl::print_all_matrix(C); cout << "correct" << endl; mtl::print_all_matrix(C2);#endif } else { cout << test.c_str() << " passed mat-mat copy test" << endl; }}#if 0template <class MatA, class MatB, class MatC, class MatC2>bool test_swap(std::string test, MatA& A, MatB& B, MatC& C, MatC2& C2){ bool passed = true; fill_mats(A, B, C); mtl::swap(A, C); for (int i = 0; i < A.nrows(); ++i) for (int j = 0; j < A.ncols(); ++j) { Int first = MTL_MAX(0, int(i) - int(A.sub())); Int last = MTL_MIN(int(A.ncols()), int(i) + int(A.super()) + 1); if (j >= first && j < last) C2(i,j) = A(i,j); } if (! mtl::matrix_equal(C, C2)) { passed = false; cout << test.c_str() << " failed swap" << endl; cout << "result" << endl; mtl::print_row(C); cout << "correct" << endl; mtl::print_all_matrix(C2); } return passed;}template <class MatA, class MatB, class MatC, class MatC2>bool test_ele_mult(std::string test, MatA& A, MatB& B, MatC& C, MatC2& C2){ bool passed = true; int i, j; fill_mats(A, B, C); mtl::ele_mult(A, C); mtl::set_value(C2, 0); for (i = 0; i < C2.nrows(); ++++i) for (j = 0; j < C2.ncols(); ++++j) C2(i,j) = 1; for (i = 0; i < A.nrows(); ++i) for (j = 0; j < A.ncols(); ++j) C2(i,j) *= A(i,j); if (! mtl::matrix_equal(C, C2)) { passed = false; cout << test.c_str() << " failed" << endl; cout << "result" << endl; mtl::print_all_matrix(C); cout << "correct" << endl; mtl::print_all_matrix(C2); } return passed;}template <class MatA, class MatB, class MatC, class MatC2>bool test_transpose(std::string test, MatA& A, MatB& B, MatC& C, MatC2& C2){ bool passed = true; // only transpose square matrices if (A.nrows() == A.ncols() && C.nrows() == C.ncols() && A.nrows() == A.ncols()) { fill_mats(A, B, C); mtl::transpose(A, C); for (int i = 0; i < A.nrows(); ++i) for (int j = 0; j < A.ncols(); ++j) C2(j,i) = A(i,j); if (! mtl::matrix_equal(C, C2)) { passed = false; cout << test.c_str() << " failed" << endl; cout << "result" << endl; mtl::print_all_matrix(C); cout << "correct" << endl; mtl::print_all_matrix(C2); } } return passed;}#endif /* matmat algos */#endif /* MTL_ALGO_TEST_H */
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -