⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 test_levenberg_marquardt.cxx

📁 DTMK软件开发包,此为开源软件,是一款很好的医学图像开发资源.
💻 CXX
字号:
// @author fsm
#include <vnl/vnl_double_2.h>
#include <vcl_cmath.h>
#include <vcl_iostream.h>

#include <testlib/testlib_test.h>
#include <vnl/vnl_least_squares_function.h>
#include <vnl/algo/vnl_levenberg_marquardt.h>

struct vnl_rosenbrock : public vnl_least_squares_function
{
  vnl_rosenbrock(bool with_grad): vnl_least_squares_function(2, 2, with_grad ? use_gradient : no_gradient) {}

  void f(vnl_vector<double> const& x, vnl_vector<double>& y) {
    //testlib_test_assert("size of x", x.size() == 2);
    //testlib_test_assert("size of y", y.size() == 2);
    y[0] = 10*(x[1] - x[0]*x[0]);
    y[1] = 1 - x[0];
  }

  void gradf(vnl_vector<double> const& x, vnl_matrix<double> &J) {
    //testlib_test_assert("size of x", x.size() == 2);
    //testlib_test_assert("size of J", J.rows() == 2 && J.cols() == 2);
    J[0][0] = -20 * x[0]; J[0][1] = 10;
    J[1][0] = -1;         J[1][1] = 0;
  }
};

struct linear_est : public vnl_least_squares_function
{
  linear_est(vnl_matrix<double> const& A, vnl_vector<double> const& b, bool with_grad)
  : vnl_least_squares_function(A.cols(), A.rows(), with_grad ? use_gradient : no_gradient),
    A_(A), b_(b)
  { assert(A.rows() == b.size()); }

  void f(vnl_vector<double> const& x, vnl_vector<double>& y) {
    y = A_*x -b_;
  }

  void gradf(vnl_vector<double> const& /*x*/, vnl_matrix<double> &J) {
    J=A_;
  }

  vnl_matrix<double> A_;
  vnl_vector<double> b_;
};

static
void do_rosenbrock_test(bool with_grad)
{
  vnl_rosenbrock f(with_grad);

  vnl_double_2 x0(2.7,-1.3);
  vcl_cout << "x0 = " << x0 << vcl_endl;

  vnl_levenberg_marquardt lm(f);

  vnl_vector<double> x1 = x0.as_vector();
  if (f.has_gradient())
    lm.minimize_using_gradient(x1);
  else
    lm.minimize_without_gradient(x1);
  lm.diagnose_outcome(vcl_cout);
  vcl_cout << "x1 = " << x1 << vcl_endl;

  double err = vcl_abs(x1[0] - 1) + vcl_abs(x1[1] - 1);
  vcl_cout << "err = " << err << vcl_endl;
  testlib_test_assert("converged to (1, 1)", err <= 1e-10);
}

static
void do_linear_test( bool with_grad )
{
  vnl_matrix<double> A(6,2,1.0);
  vnl_vector<double> b(6);

  A(0,1) = 10;
  A(1,1) = 15;
  A(2,1) = 5.1;
  A(3,1) = 20.2;
  A(4,1) = -0.3;
  A(5,1) = 25;

  b(0) = 10;
  b(1) = 15.5;
  b(2) = 4.5;
  b(3) = 21;
  b(4) = 1;
  b(5) = 24.3;

  linear_est f(A, b, with_grad);
  vnl_levenberg_marquardt lm(f);
  vnl_vector<double> x(2,-1000.0);   // init can be far off
  // since obj function is linear
  // high precision can be achieved
  lm.set_x_tolerance(1e-12);
  lm.set_f_tolerance(1e-12);
  lm.set_g_tolerance(1e-12);

  if (f.has_gradient())
    lm.minimize_using_gradient(x);
  else
    lm.minimize_without_gradient(x);
  lm.diagnose_outcome(vcl_cout);
  vcl_cout << "x = " << x << vcl_endl;

  vnl_vector<double> true_x(2);
  true_x[1]=0.969684757298943;
  true_x[0]=0.595607200429874;

  TEST_NEAR( "converged to true estimate", (true_x-x).two_norm(), 0, 1e-6 );

  // now check (inverse of) covariance approximation
  vnl_matrix<double> true_cov(2,2);
  true_cov(0,0)=6;
  true_cov(1,0)=75;
  true_cov(0,1)=75;
  true_cov(1,1)=1384.14;

  vnl_matrix<double> covar = lm.get_JtJ();
  vcl_cout << "Cov(x) =\n" << covar << vcl_endl;
  TEST_NEAR( "covariance approximation", (true_cov-covar).array_two_norm(), 0, 1e-5 );
}

static
void test_levenberg_marquardt()
{
  do_rosenbrock_test(true);
  do_rosenbrock_test(false);

  do_linear_test(true);
  do_linear_test(false);
}

TESTMAIN(test_levenberg_marquardt);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -