⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 adaboostlearner.cpp

📁 WeakLearner,弱分类器
💻 CPP
字号:
//LAB_LicenseBegin==============================================================
//  Copyright (c) 2005-2006, Hicham GHORAYEB < ghorayeb@gmail.com >
//  All rights reserved.
//
//	This software is a Library for Adaptive Boosting. It provides a generic 
//	framework for the study of the Boosting algorithms. The framework provides 
//	the different tasks for boosting: Learning, Validation, Test, Profiling and 
//	Performance Analysis Tasks.
//
//	This Library was developped during my PhD studies at:
//	Ecole des Mines de Paris - Centre de Robotique( CAOR )
//	http://caor.ensmp.fr
//	under the supervision of Pr. Claude Laurgeau and Bruno Steux
//
//  Redistribution and use in source and binary forms, with or without
//  modification, are permitted provided that the following conditions are met:
//
//      * Redistributions of source code must retain the above copyright
//        notice, this list of conditions and the following disclaimer.
//      * Redistributions in binary form must reproduce the above copyright
//        notice, this list of conditions and the following disclaimer
//        in the documentation and/or other materials provided with the distribution.
//      * Neither the name of the Ecole des Mines de Paris nor the names of
//        its contributors may be used to endorse or promote products
//        derived from this software without specific prior written permission.
//
//  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
//  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, 
//  THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 
//  PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 
//  CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 
//  EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 
//  PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 
//  PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 
//  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 
//  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 
//  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
//================================================================LAB_LicenseEnd
#include "AdaBoostLearner.h"
#include "IWeakLearner.h"
#include "ILearner.h"

#include "samples/Sample.h"
#include "samples/learningsets/ILearningSet.h"
#include "classifiers/features/WeakClassifier.h"
#include "classifiers/WeightedSumClassifier.h"

#include "LibAdaBoost/modules/options/AdaBoostLearnerOptions.h"
using namespace modules::options;

#include <iostream>
#include <sstream>
#include <fstream>
#include <cmath>
#include <cstdlib>
using namespace std;

using namespace samples;
using namespace classifiers;
using namespace classifiers::features;
using namespace learners;

AdaBoostLearner::AdaBoostLearner()
:m_WeakLearner( NULL )
,m_wsClassifier( NULL )
,m_LearningSet( NULL )
,m_err( 0.0 )
,m_NbrCycles( 0 )
{
	m_wsClassifier = new WeightedSumClassifier();
}

AdaBoostLearner::AdaBoostLearner(ILearningSet *learningSet)
:m_WeakLearner( NULL )
,m_wsClassifier( NULL )
,m_LearningSet( learningSet )
,m_err( 0.0 )
,m_NbrCycles( 0 )
{
	m_wsClassifier = new WeightedSumClassifier();
}

AdaBoostLearner::AdaBoostLearner(ILearningSet *learningSet, IWeakLearner *weakLearner)
:m_WeakLearner( weakLearner )
,m_wsClassifier( NULL )
,m_LearningSet( learningSet )
,m_err( 0.0 )
,m_NbrCycles( 0 )
{
	m_wsClassifier = new WeightedSumClassifier();
}

AdaBoostLearner::~AdaBoostLearner()
{
	if(m_wsClassifier != NULL) delete m_wsClassifier;
	m_WeakLearner = NULL;
	m_wsClassifier = NULL;
	m_LearningSet = NULL;
}

void
AdaBoostLearner::Learn(void)
{
	// Create Output log files for the learner and the Weak Learner
	std::ofstream oLogLearner;
	std::string learner_log_filename  = "learner_log.txt";
	oLogLearner.open( learner_log_filename.c_str() );

	// Iterates M times: M is equal to m_NbrBoostingCycles
	for(int iCycle = 0; iCycle < m_NbrCycles; iCycle++)
	{
		std::ofstream oLogWeakLearner;
		ostringstream oStr;
		oStr << "wlearner_log_step_"<< iCycle<< ".txt" <<std::ends;
		std::string wlearner_log_filename = oStr.str();
		oLogWeakLearner.open( wlearner_log_filename.c_str() );

		std::cerr << "AdaBoost Step[ "<< iCycle <<" ]" << std::endl;

		WeakClassifier *wc = NULL;
		double error = 0.0;
		double     w = 0.0;
		double     z = 0.0;
		double prevd = 0.0;
		double  newd = 0.0;

		oStr.clear();

		wc = m_WeakLearner->GetNewWeakClassifier(error);
		m_err = error;

		oLogLearner     << error << std::endl;

		if(error > 0.5){
			std::cerr << "Error: WeakLearner Must return a classifier error with less than 0.5" << std::endl;
			return;
		}

		// what happens if error == 0 ??

		w = 0.5 * std::log((1.0 - error)/error);
		
		m_wsClassifier->PushBack( wc, w );

		z = 2.0 * std::sqrt( error * (1.0 - error));

		for(unsigned i = 0; i < m_LearningSet->NbrOfPositiveSamples(); i++)
		{
			Sample * sample = m_LearningSet->GetPositiveSampleAt( i );
			int vote = wc->Classify( sample );

			prevd = sample->GetWeight();
			if( vote == POSITIVE_CLASS_ID ){
				newd = prevd * std::exp( -w ) / z;
			}else{
				newd = prevd * std::exp(  w ) / z;
			}
			sample->SetWeight( newd );
		}

		for(unsigned i = 0; i < m_LearningSet->NbrOfNegativeSamples(); i++)
		{
			Sample * sample = m_LearningSet->GetNegativeSampleAt( i );
			int vote = wc->Classify( sample );

			prevd = sample->GetWeight();
			if( vote == NEGATIVE_CLASS_ID ){
				newd = prevd * std::exp( -w ) / z;
			}else{
				newd = prevd * std::exp(  w ) / z;
			}
			sample->SetWeight( newd );
		}

		oLogWeakLearner << m_WeakLearner->GetWeakLearnerLog();

		oLogWeakLearner.close();

		setStateChanged();
        notifySubscribers();
	}
	
	oLogLearner.close();
}

IWeakLearner *
AdaBoostLearner::GetWeakLearner(void)
{
	return m_WeakLearner;
}

void
AdaBoostLearner::SetWeakLearner(IWeakLearner *weakLearner)
{
	m_WeakLearner = weakLearner;
}

void
AdaBoostLearner::Reset(void)
{
	assert( m_LearningSet != NULL );
	assert( m_wsClassifier != NULL );
	assert( m_WeakLearner != NULL);

	m_LearningSet->UniformWeight();
	m_err = 0.0;
	m_wsClassifier->Clear();
}

IClassifier *
AdaBoostLearner::GetResult(void)
{
	return (IClassifier *)m_wsClassifier;
}

ILearningSet *
AdaBoostLearner::GetLearningSet(void)
{
	return m_LearningSet;
}

void
AdaBoostLearner::SetLearningSet(ILearningSet *learningSet)
{
	m_LearningSet = learningSet;
}

void 
AdaBoostLearner::SetOptions(const modules::options::ModuleOptions &options)
{
	const modules::options::AdaBoostLearnerOptions &opts = static_cast<const AdaBoostLearnerOptions &>(options);
	int nbrCycles;

	nbrCycles = opts.GetNbrBoostingCycles(); 

	// set the filename
	m_NbrCycles = nbrCycles;	
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -