⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 algo_test.h

📁 一个通用的数学库
💻 H
📖 第 1 页 / 共 2 页
字号:
    }    T tmp = x[i] * y[j] + x[j] * y[i];    AA(i, j) += tmp;  }  // compare A's  pass = mtl::matrix_equal(A, AA);        if (pass)    cout << test.c_str() << " passed rank two " << endl;  else {    cout << "*** rank two " << test.c_str() << " failed" << endl;#if !defined(_MSVCPP_)    cout << "result" << endl;    mtl::print_all_matrix(A);    cout << "correct" << endl;    mtl::print_all_matrix(AA);#endif  }}template <class Matrix>void test_ranktwo(std::string test, Matrix&, banded_tag){  cout << test.c_str() << " skipping rank two update" << endl;}template <class Matrix>void test_ranktwo(std::string test, Matrix& A, symmetric_tag){  if (A.super() == int(A.nrows()) - 1)    test_ranktwo(test, A, rectangle_tag());  else    cout << test.c_str() << " skipping rank two update (symm banded)" << endl;}template <class Matrix>void test_ranktwo(std::string test, Matrix& A){  typedef typename mtl::matrix_traits<Matrix>::shape Shape;  test_ranktwo(test, A, Shape());}template <class Matrix, class Vector>void simple_tri_solve(const Matrix& A, Vector& X, bool is_unit, bool is_upper){  // this algorithm cooresponds to netlib dtrsv  typedef typename mtl::matrix_traits<Matrix>::value_type T;  typedef typename mtl::matrix_traits<Matrix>::size_type Int;  char DIAG, UPLO;  T TEMP;  int I, J;  Int N = A.ncols();    if (is_unit)    DIAG = 'U';  else    DIAG = 'N';  if (is_upper)    UPLO = 'U';  else    UPLO = 'L';    bool NOUNIT = DIAG == 'N';  if (UPLO == 'U') {    for (J = N - 1; J >= 0; --J) {      if (X[J] != T(0)) {        if (NOUNIT)          X[J] = X[J]/A(J,J);        TEMP = X[J];        for (I = J - 1; I >= 0; --I)          X[I] = X[I] - TEMP * A(I,J);      }    }  } else {    for (J = 0; J < int(N); ++J) {      if (X[J] != T(0) ) {        if (NOUNIT)          X[J] = X[J]/A(J, J);        TEMP = X[J];        for (I = J + 1; I < int(N); ++I)          X[I] = X[I] - TEMP*A(I, J);      }    }  }}template <class MatrixA>void test_tri_solve(std::string test, MatrixA& A, mtl::triangle_tag){  if (A.nrows() != A.ncols()) {    cout << test.c_str() << " skipping tri solve (A must be N x N)" << endl;    return;  }  bool success = true;  typedef typename mtl::matrix_traits<MatrixA>::value_type T;  typedef typename mtl::matrix_traits<MatrixA>::size_type Int;  dense1D<T> x(A.ncols()), y(A.ncols());  mtl::matrix<T>::type C(A.nrows(), A.ncols());  mtl::set_value(A, T());  mtl::set_value(C, T());  Int i,j;  if (A.is_upper()) {    for (i = 0; i < A.nrows(); ++i) {      for (j = i; j < A.ncols(); ++j) {        if ((!A.is_unit() && i == j) || j - i == 1) {          A(i,j) = T(1.5);        } else if (j > i) {          T t = T(i + j + 1);          A(i,j) = t;        }      }    }  } else { // lower    for (i = 0; i < A.nrows(); ++i) {      for (j = 0; j <= i; ++j) {        if ((!A.is_unit() && i == j) || i - j == 1) {          A(i,j) = T(1.5);        } else if (j < i) {          T t = T(i + j + 1);          A(i,j) = t;        }      }    }  }  mtl::copy(A, C);  mtl::set_value(x, T(1));  mtl::set_value(y, T(1));  // MTL  mtl::tri_solve(A, x);  simple_tri_solve(C, y, A.is_unit(), A.is_upper());  // COMPARE  bool isequal = true;  for (i = 0; i < A.ncols(); ++i)    if (MTL_ABS(x[i] - y[i]) > MTL_MIN(MTL_ABS(x[i]),MTL_ABS(y[i])) / 100) {#if !defined(_MSVCPP_)      cout << x[i] << " != " << y[i] << endl;#endif      isequal = false;      break;    }  if (! isequal) {    cout << "*** " << test.c_str() << " failed tri_solve(A,x,y)" << endl;#if !defined(_MSVCPP_)    cout << "result vector  ";    mtl::print_vector(x);    cout << "correct vector ";    mtl::print_vector(y);#endif    success = false;  }  if (success)    cout << test.c_str() << " passed matvec tri_solve" << endl; }template <class MatrixA>void test_tri_solve(std::string test, MatrixA&, mtl::rectangle_tag){  cout << test.c_str() << " skipping tri solve" << endl;}template <class MatrixA>void test_tri_solve(std::string test, MatrixA&, mtl::banded_tag){  cout << test.c_str() << " skipping tri solve" << endl;}template <class MatrixA>void test_tri_solve(std::string test, MatrixA&, mtl::symmetric_tag){  cout << test.c_str() << " skipping tri solve" << endl;}template <class MatrixA>void test_tri_solve(std::string test, MatrixA& A){  typedef typename mtl::matrix_traits<MatrixA>::shape Shape;  test_tri_solve(test, A, Shape());}template <class MatA, class MatC, class MatC2>bool test_add(std::string test, MatA& A, MatC& C, MatC2& C2){  typedef typename mtl::matrix_traits<MatA>::size_type Int;  typedef typename mtl::matrix_traits<MatA>::value_type T;  Int i, j;  bool passed = true;  mtl::add(A, C);  for (i = 0; i < A.nrows(); ++i) {    Int first = MTL_MAX(0, int(i) - int(A.sub()));    Int last = MTL_MIN(int(A.ncols()), int(i) + int(A.super()) + 1);    for (j = 0; j < A.ncols(); ++j)      if (j >= first && j < last)        C2(i,j) = C2(i,j) + A(i,j);  }  if (A.is_unit())    for (Int i = 0; i < MTL_MIN(A.nrows(), A.ncols()); ++i)      C2(i,i) = C2(i,i) + T(1);  if (! mtl::matrix_equal(C, C2)) {    passed = false;    cout << test.c_str() << " failed add" << endl;#ifndef _MSVCPP_    cout << "result" << endl;    mtl::print_all_matrix(C);    cout << "correct" << endl;    mtl::print_all_matrix(C2);#endif  }  cout << test.c_str() << " passed mat-mat add" << endl;  return passed;}/* *  matmat algorithms */template <class MatA, class MatB, class MatC, class MatC2>void test_mult(std::string test, MatA& A, MatB& B, MatC& C, MatC2& C2){  typedef typename mtl::matrix_traits<MatA>::size_type Int;  mtl::mult(A, B, C);    Int M = C.nrows();  Int N = C.ncols();  Int K = A.ncols();  for (Int i = 0; i != M; ++i)    for (Int j = 0; j != N; ++j)      for (Int k = 0; k != K; ++k) {        Int first = MTL_MAX(0, int(i) - int(A.sub()));        Int last = MTL_MIN(int(A.ncols()), int(i) + int(A.super()) + 1);        if (k >= first && k < last)          C2(i,j) += A(i,k) * B(k,j);      }  if (A.is_unit()) {    Int M = MTL_MIN(A.nrows(), A.ncols());    Int N = B.ncols();    for (Int i = 0; i < M; ++i)      for (Int j = 0; j < N; ++j)        C2(i,j) += B(i,j);  }  if (! mtl::matrix_equal(C, C2)) {    cout << test.c_str() << " failed mult" << endl;#if !defined(_MSVCPP_)    cout << "result" << endl;    mtl::print_all_matrix(C);    cout << "correct" << endl;    mtl::print_all_matrix(C2);#endif  } else    cout << test.c_str() << " passed mat-mat multiply test" << endl;}template <class MatA, class MatC, class MatC2>void test_copy(std::string test, MatA& A, MatC& C, MatC2& C2){  typedef typename mtl::matrix_traits<MatA>::size_type Int;  mtl::copy(A, C);  for (Int i = 0; i < A.nrows(); ++i)    for (Int j = 0; j < A.ncols(); ++j) {      Int first = MTL_MAX(0, int(i) - int(A.sub()));      Int last = MTL_MIN(int(A.ncols()), int(i) + int(A.super()) + 1);      if (j >= first && j < last)        C2(i,j) = A(i,j);    }  if (A.is_unit())    mtl::set_diagonal(C2,1);  if (! mtl::matrix_equal(C, C2)) {    cout << test.c_str() << " failed copy" << endl;#if !defined(_MSVCPP_)    cout << "result" << endl;    mtl::print_all_matrix(C);    cout << "correct" << endl;    mtl::print_all_matrix(C2);#endif  } else {    cout << test.c_str() << " passed mat-mat copy test" << endl;  }}#if 0template <class MatA, class MatB, class MatC, class MatC2>bool test_swap(std::string test, MatA& A, MatB& B, MatC& C, MatC2& C2){  bool passed = true;  fill_mats(A, B, C);  mtl::swap(A, C);  for (int i = 0; i < A.nrows(); ++i)    for (int j = 0; j < A.ncols(); ++j) {      Int first = MTL_MAX(0, int(i) - int(A.sub()));      Int last = MTL_MIN(int(A.ncols()), int(i) + int(A.super()) + 1);      if (j >= first && j < last)        C2(i,j) = A(i,j);    }  if (! mtl::matrix_equal(C, C2)) {    passed = false;    cout << test.c_str() << " failed swap" << endl;    cout << "result" << endl;    mtl::print_row(C);    cout << "correct" << endl;    mtl::print_all_matrix(C2);  }  return passed;}template <class MatA, class MatB, class MatC, class MatC2>bool test_ele_mult(std::string test, MatA& A, MatB& B, MatC& C, MatC2& C2){  bool passed = true;  int i, j;  fill_mats(A, B, C);  mtl::ele_mult(A, C);  mtl::set_value(C2, 0);  for (i = 0; i < C2.nrows(); ++++i)    for (j = 0; j < C2.ncols(); ++++j)      C2(i,j) = 1;  for (i = 0; i < A.nrows(); ++i)    for (j = 0; j < A.ncols(); ++j)      C2(i,j) *= A(i,j);  if (! mtl::matrix_equal(C, C2)) {    passed = false;    cout << test.c_str() << " failed" << endl;    cout << "result" << endl;    mtl::print_all_matrix(C);    cout << "correct" << endl;    mtl::print_all_matrix(C2);  }  return passed;}template <class MatA, class MatB, class MatC, class MatC2>bool test_transpose(std::string test, MatA& A, MatB& B, MatC& C, MatC2& C2){  bool passed = true;  // only transpose square matrices  if (A.nrows() == A.ncols() && C.nrows() == C.ncols()       && A.nrows() == A.ncols()) {    fill_mats(A, B, C);    mtl::transpose(A, C);    for (int i = 0; i < A.nrows(); ++i)      for (int j = 0; j < A.ncols(); ++j)        C2(j,i) = A(i,j);        if (! mtl::matrix_equal(C, C2)) {      passed = false;      cout << test.c_str() << " failed" << endl;      cout << "result" << endl;      mtl::print_all_matrix(C);      cout << "correct" << endl;      mtl::print_all_matrix(C2);    }  }  return passed;}#endif /* matmat algos */#endif /* MTL_ALGO_TEST_H */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -