📄 adaboostlearner.cpp
字号:
//LAB_LicenseBegin==============================================================
// Copyright (c) 2005-2006, Hicham GHORAYEB < ghorayeb@gmail.com >
// All rights reserved.
//
// This software is a Library for Adaptive Boosting. It provides a generic
// framework for the study of the Boosting algorithms. The framework provides
// the different tasks for boosting: Learning, Validation, Test, Profiling and
// Performance Analysis Tasks.
//
// This Library was developped during my PhD studies at:
// Ecole des Mines de Paris - Centre de Robotique( CAOR )
// http://caor.ensmp.fr
// under the supervision of Pr. Claude Laurgeau and Bruno Steux
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the distribution.
// * Neither the name of the Ecole des Mines de Paris nor the names of
// its contributors may be used to endorse or promote products
// derived from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
//================================================================LAB_LicenseEnd
#include "AdaBoostLearner.h"
#include "IWeakLearner.h"
#include "ILearner.h"
#include "samples/Sample.h"
#include "samples/learningsets/ILearningSet.h"
#include "classifiers/features/WeakClassifier.h"
#include "classifiers/WeightedSumClassifier.h"
#include "LibAdaBoost/modules/options/AdaBoostLearnerOptions.h"
using namespace modules::options;
#include <iostream>
#include <sstream>
#include <fstream>
#include <cmath>
#include <cstdlib>
using namespace std;
using namespace samples;
using namespace classifiers;
using namespace classifiers::features;
using namespace learners;
AdaBoostLearner::AdaBoostLearner()
:m_WeakLearner( NULL )
,m_wsClassifier( NULL )
,m_LearningSet( NULL )
,m_err( 0.0 )
,m_NbrCycles( 0 )
{
m_wsClassifier = new WeightedSumClassifier();
}
AdaBoostLearner::AdaBoostLearner(ILearningSet *learningSet)
:m_WeakLearner( NULL )
,m_wsClassifier( NULL )
,m_LearningSet( learningSet )
,m_err( 0.0 )
,m_NbrCycles( 0 )
{
m_wsClassifier = new WeightedSumClassifier();
}
AdaBoostLearner::AdaBoostLearner(ILearningSet *learningSet, IWeakLearner *weakLearner)
:m_WeakLearner( weakLearner )
,m_wsClassifier( NULL )
,m_LearningSet( learningSet )
,m_err( 0.0 )
,m_NbrCycles( 0 )
{
m_wsClassifier = new WeightedSumClassifier();
}
AdaBoostLearner::~AdaBoostLearner()
{
if(m_wsClassifier != NULL) delete m_wsClassifier;
m_WeakLearner = NULL;
m_wsClassifier = NULL;
m_LearningSet = NULL;
}
void
AdaBoostLearner::Learn(void)
{
// Create Output log files for the learner and the Weak Learner
std::ofstream oLogLearner;
std::string learner_log_filename = "learner_log.txt";
oLogLearner.open( learner_log_filename.c_str() );
// Iterates M times: M is equal to m_NbrBoostingCycles
for(int iCycle = 0; iCycle < m_NbrCycles; iCycle++)
{
std::ofstream oLogWeakLearner;
ostringstream oStr;
oStr << "wlearner_log_step_"<< iCycle<< ".txt" <<std::ends;
std::string wlearner_log_filename = oStr.str();
oLogWeakLearner.open( wlearner_log_filename.c_str() );
std::cerr << "AdaBoost Step[ "<< iCycle <<" ]" << std::endl;
WeakClassifier *wc = NULL;
double error = 0.0;
double w = 0.0;
double z = 0.0;
double prevd = 0.0;
double newd = 0.0;
oStr.clear();
wc = m_WeakLearner->GetNewWeakClassifier(error);
m_err = error;
oLogLearner << error << std::endl;
if(error > 0.5){
std::cerr << "Error: WeakLearner Must return a classifier error with less than 0.5" << std::endl;
return;
}
// what happens if error == 0 ??
w = 0.5 * std::log((1.0 - error)/error);
m_wsClassifier->PushBack( wc, w );
z = 2.0 * std::sqrt( error * (1.0 - error));
for(unsigned i = 0; i < m_LearningSet->NbrOfPositiveSamples(); i++)
{
Sample * sample = m_LearningSet->GetPositiveSampleAt( i );
int vote = wc->Classify( sample );
prevd = sample->GetWeight();
if( vote == POSITIVE_CLASS_ID ){
newd = prevd * std::exp( -w ) / z;
}else{
newd = prevd * std::exp( w ) / z;
}
sample->SetWeight( newd );
}
for(unsigned i = 0; i < m_LearningSet->NbrOfNegativeSamples(); i++)
{
Sample * sample = m_LearningSet->GetNegativeSampleAt( i );
int vote = wc->Classify( sample );
prevd = sample->GetWeight();
if( vote == NEGATIVE_CLASS_ID ){
newd = prevd * std::exp( -w ) / z;
}else{
newd = prevd * std::exp( w ) / z;
}
sample->SetWeight( newd );
}
oLogWeakLearner << m_WeakLearner->GetWeakLearnerLog();
oLogWeakLearner.close();
setStateChanged();
notifySubscribers();
}
oLogLearner.close();
}
IWeakLearner *
AdaBoostLearner::GetWeakLearner(void)
{
return m_WeakLearner;
}
void
AdaBoostLearner::SetWeakLearner(IWeakLearner *weakLearner)
{
m_WeakLearner = weakLearner;
}
void
AdaBoostLearner::Reset(void)
{
assert( m_LearningSet != NULL );
assert( m_wsClassifier != NULL );
assert( m_WeakLearner != NULL);
m_LearningSet->UniformWeight();
m_err = 0.0;
m_wsClassifier->Clear();
}
IClassifier *
AdaBoostLearner::GetResult(void)
{
return (IClassifier *)m_wsClassifier;
}
ILearningSet *
AdaBoostLearner::GetLearningSet(void)
{
return m_LearningSet;
}
void
AdaBoostLearner::SetLearningSet(ILearningSet *learningSet)
{
m_LearningSet = learningSet;
}
void
AdaBoostLearner::SetOptions(const modules::options::ModuleOptions &options)
{
const modules::options::AdaBoostLearnerOptions &opts = static_cast<const AdaBoostLearnerOptions &>(options);
int nbrCycles;
nbrCycles = opts.GetNbrBoostingCycles();
// set the filename
m_NbrCycles = nbrCycles;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -