⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 itksupervisedimageclassifiertest.cxx

📁 DTMK软件开发包,此为开源软件,是一款很好的医学图像开发资源.
💻 CXX
字号:
/*=========================================================================

  Program:   Insight Segmentation & Registration Toolkit
  Module:    $RCSfile: itkSupervisedImageClassifierTest.cxx,v $
  Language:  C++
  Date:      $Date: 2007-08-20 12:47:12 $
  Version:   $Revision: 1.12 $

  Copyright (c) Insight Software Consortium. All rights reserved.
  See ITKCopyright.txt or http://www.itk.org/HTML/Copyright.htm for details.

     This software is distributed WITHOUT ANY WARRANTY; without even 
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR 
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#if defined(_MSC_VER)
#pragma warning ( disable : 4786 )
#pragma warning ( disable : 4288 )
#endif
// Insight classes
#include "itkImage.h"
#include "itkVector.h"
#include "vnl/vnl_matrix_fixed.h"
#include "itkImageRegionIterator.h"
#include "itkLightProcessObject.h"

#include "itkImageGaussianModelEstimator.h"
#include "itkMahalanobisDistanceMembershipFunction.h"
#include "itkMinimumDecisionRule.h"
#include "itkImageClassifierBase.h"


//Data definitons 
#define   IMGWIDTH            2
#define   IMGHEIGHT           2
#define   NFRAMES             4
#define   NUMBANDS            2  
#define   NDIMENSION          3
#define   NUM_CLASSES         3
#define   MAX_NUM_ITER       50


// class to support progress feeback
class ShowProgressObject
{
public:
  ShowProgressObject(itk::LightProcessObject * o)
    {m_Process = o;}
  void ShowProgress()
    {std::cout << "Progress " << m_Process->GetProgress() << std::endl;}
  itk::LightProcessObject::Pointer m_Process;
};



int itkSupervisedImageClassifierTest(int, char* [] )
{

  //------------------------------------------------------
  //Create a simple test image with width, height, and 
  //depth 4 vectors each with each vector having data for 
  //2 bands.
  //------------------------------------------------------
  typedef itk::Image<itk::Vector<double,NUMBANDS>,NDIMENSION> VecImageType; 

  VecImageType::Pointer vecImage = VecImageType::New();

  VecImageType::SizeType vecImgSize = {{ IMGWIDTH , IMGHEIGHT, NFRAMES }};

  VecImageType::IndexType index;
  index.Fill(0);
  VecImageType::RegionType region;

  region.SetSize( vecImgSize );
  region.SetIndex( index );

  vecImage->SetLargestPossibleRegion( region );
  vecImage->SetBufferedRegion( region );
  vecImage->Allocate();

  // setup the iterators
  typedef VecImageType::PixelType VecImagePixelType;

  enum { VecImageDimension = VecImageType::ImageDimension };
  typedef
    itk::ImageRegionIterator< VecImageType > VecIterator;

  VecIterator outIt( vecImage, vecImage->GetBufferedRegion() );

  //--------------------------------------------------------------------------
  //Manually create and store each vector
  //--------------------------------------------------------------------------

  //Slice 1
  //Vector no. 1
  VecImagePixelType vec;
  vec.Fill(21); outIt.Set( vec ); ++outIt;
  //Vector no. 2
  vec.Fill(20); outIt.Set( vec ); ++outIt;
  //Vector no. 3
  vec.Fill(8); outIt.Set( vec ); ++outIt;
  //Vector no. 4
  vec.Fill(10); outIt.Set( vec ); ++outIt;
  //Slice 2
  //Vector no. 1
  vec.Fill(22); outIt.Set( vec ); ++outIt;
  //Vector no. 2
  vec.Fill(21); outIt.Set( vec ); ++outIt;
  //Vector no. 3
  vec.Fill(11); outIt.Set( vec ); ++outIt;
  //Vector no. 4
  vec.Fill(9); outIt.Set( vec ); ++outIt;
  
  //Slice 3
  //Vector no. 1 
  vec.Fill(19); outIt.Set( vec ); ++outIt;
  //Vector no. 2
  vec.Fill(19); outIt.Set( vec ); ++outIt;
  //Vector no. 3
  vec.Fill(11); outIt.Set( vec ); ++outIt;
  //Vector no. 4
  vec.Fill(11); outIt.Set( vec ); ++outIt;
  
  //Slice 4
  //Vector no. 1
  vec.Fill(18); outIt.Set( vec ); ++outIt;
  //Vector no. 2
  vec.Fill(18); outIt.Set( vec ); ++outIt;
  //Vector no. 3
  vec.Fill(12); outIt.Set( vec ); ++outIt;
  //Vector no. 4
  vec.Fill(14); outIt.Set( vec ); ++outIt;

  //---------------------------------------------------------------
  //Generate the training data
  //---------------------------------------------------------------
  typedef itk::Image<unsigned short,NDIMENSION> ClassImageType; 
  ClassImageType::Pointer classImage  = ClassImageType::New();

  ClassImageType::SizeType classImgSize = {{ IMGWIDTH , IMGHEIGHT, NFRAMES }};

  ClassImageType::IndexType classindex;
  classindex.Fill(0);

  ClassImageType::RegionType classregion;

  classregion.SetSize( classImgSize );
  classregion.SetIndex( classindex );

  classImage->SetLargestPossibleRegion( classregion );
  classImage->SetBufferedRegion( classregion );
  classImage->Allocate();

  // setup the iterators
  typedef ClassImageType::PixelType ClassImagePixelType;

  typedef
    itk::ImageRegionIterator<ClassImageType> ClassImageIterator;

  ClassImageIterator 
    classoutIt( classImage, classImage->GetBufferedRegion() );



  ClassImagePixelType outputPixel;
  //--------------------------------------------------------------------------
  //Manually create and store each vector
  //--------------------------------------------------------------------------
  //Slice 1
  //Pixel no. 1
  outputPixel = 2;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Pixel no. 2 
  outputPixel = 2;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Pixel no. 3
  outputPixel = 1;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Pixel no. 4
  outputPixel = 1;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Slice 2
  //Pixel no. 1
  outputPixel = 0;
  classoutIt.Set( outputPixel );
  ++classoutIt;
  
  //Pixel no. 2
  outputPixel = 0;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Pixel no. 3
  outputPixel = 0;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Pixel no. 4
  outputPixel = 0;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Slice 3
  //Pixel no. 1 
  outputPixel = 2;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Pixel no. 2
  outputPixel = 2;
  classoutIt.Set( outputPixel );
  ++classoutIt;
  
  //Pixel no. 3
  outputPixel = 1;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Pixel no. 4
  outputPixel = 1;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Slice 4
  //Pixel no. 1
  outputPixel = 0;
  classoutIt.Set( outputPixel );
  ++classoutIt;
  
  //Pixel no. 2
  outputPixel = 0;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Pixel no. 3
  outputPixel = 0;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //Pixel no. 4
  outputPixel = 0;
  classoutIt.Set( outputPixel );
  ++classoutIt;

  //----------------------------------------------------------------------
  //Set membership function (Using the statistics objects)
  //----------------------------------------------------------------------
  namespace stat = itk::Statistics;

  typedef stat::MahalanobisDistanceMembershipFunction< VecImagePixelType > 
    MembershipFunctionType ;
  typedef MembershipFunctionType::Pointer MembershipFunctionPointer ;

  typedef std::vector< MembershipFunctionPointer > 
    MembershipFunctionPointerVector;

  //----------------------------------------------------------------------
  //Set the image model estimator
  //----------------------------------------------------------------------
  typedef itk::ImageGaussianModelEstimator<VecImageType,
    MembershipFunctionType, ClassImageType> 
    ImageGaussianModelEstimatorType;
  
  ImageGaussianModelEstimatorType::Pointer 
    applyEstimateModel = ImageGaussianModelEstimatorType::New();  

  applyEstimateModel->SetNumberOfModels(NUM_CLASSES);
  applyEstimateModel->SetInputImage(vecImage);
  applyEstimateModel->SetTrainingImage(classImage);  

  //Run the gaussian classifier algorithm
  applyEstimateModel->Update();
  applyEstimateModel->Print(std::cout); 

  MembershipFunctionPointerVector membershipFunctions = 
    applyEstimateModel->GetMembershipFunctions();  

  for(unsigned int idx=0; idx < membershipFunctions.size(); idx++ )
    {
    std::cout << "Number of samples for class " << idx << " is " <<
      membershipFunctions[ idx ]->GetNumberOfSamples() << std::endl;
    }

  //----------------------------------------------------------------------
  //Set the decision rule 
  //----------------------------------------------------------------------  
  typedef itk::DecisionRuleBase::Pointer DecisionRuleBasePointer;

  typedef itk::MinimumDecisionRule DecisionRuleType;
  DecisionRuleType::Pointer  
    myDecisionRule = DecisionRuleType::New();

  //----------------------------------------------------------------------
  // Test code for the supervised classifier algorithm
  //----------------------------------------------------------------------

  //---------------------------------------------------------------------
  // Multiband data is now available in the right format
  //---------------------------------------------------------------------
  typedef VecImagePixelType MeasurementVectorType;

  typedef itk::ImageClassifierBase< VecImageType,
    ClassImageType > SupervisedClassifierType;

  SupervisedClassifierType::Pointer 
    applyClassifier = SupervisedClassifierType::New();
 
  typedef ShowProgressObject 
    ProgressType;

  ProgressType progressWatch(applyClassifier);
  itk::SimpleMemberCommand<ProgressType>::Pointer command;
  command = itk::SimpleMemberCommand<ProgressType>::New();
  command->SetCallbackFunction(&progressWatch,
                               &ProgressType::ShowProgress);
  applyClassifier->AddObserver(itk::ProgressEvent(), command);

  // Set the Classifier parameters
  applyClassifier->SetNumberOfClasses(NUM_CLASSES);
  applyClassifier->SetInputImage(vecImage);

  // Set the decison rule 
  applyClassifier->
    SetDecisionRule((DecisionRuleBasePointer) myDecisionRule );

  //Add the membership functions
  for( unsigned int i=0; i<NUM_CLASSES; i++ )
    {
    applyClassifier->AddMembershipFunction( membershipFunctions[i] );
    }

  //Run the gaussian classifier algorithm
  applyClassifier->Update();

  //Get the classified image
  ClassImageType::Pointer 
    outClassImage = applyClassifier->GetClassifiedImage();

  applyClassifier->Print(std::cout); 

  //Print the gaussian classified image
  ClassImageIterator labeloutIt( outClassImage, 
                                 outClassImage->GetBufferedRegion() );

  int i=0;
  while(!labeloutIt.IsAtEnd())
    {
    //Print the classified index
    int classIndex = (int) labeloutIt.Get();
    std::cout << " Pixel No " << i << " Value " << classIndex << std::endl;
    ++i;
    ++labeloutIt;
    }//end while

  //Verify if the results were as per expectation
  labeloutIt.GoToBegin();
  bool passTest = true;

  //Loop through the data set
  while(!labeloutIt.IsAtEnd())
    {
    int classIndex = (int) labeloutIt.Get();
    if (classIndex != 2)
      {
      passTest = false;
      break;
      }
    ++labeloutIt;

    classIndex = (int) labeloutIt.Get();
    if (classIndex != 2)
      {
      passTest = false;
      break;
      }
    ++labeloutIt;

    classIndex = (int) labeloutIt.Get();
    if (classIndex != 1)
      {
      passTest = false;
      break;
      }
    ++labeloutIt;

    classIndex = (int) labeloutIt.Get();
    if (classIndex != 1)
      {
      passTest = false;
      break;
      }
    ++labeloutIt;

    }//end while

  if( passTest == true )
    {
    std::cout<< "Supervised Classifier Test Passed" << std::endl;
    }
  else
    {
    std::cout<< "Supervised Classifier Test failed" << std::endl;
    return EXIT_FAILURE;
    }

  return EXIT_SUCCESS;
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -