📄 classifier.cpp
字号:
/** This file is part of MultiBoost, a multi-class * AdaBoost learner/classifier** Copyright (C) 2005-2006 Norman Casagrande* For informations write to nova77@gmail.com** This library is free software; you can redistribute it and/or* modify it under the terms of the GNU Lesser General Public* License as published by the Free Software Foundation; either* version 2.1 of the License, or (at your option) any later version.** This library is distributed in the hope that it will be useful,* but WITHOUT ANY WARRANTY; without even the implied warranty of* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU* Lesser General Public License for more details.** You should have received a copy of the GNU Lesser General Public* License along with this library; if not, write to the Free Software* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA**/#include "IO/Serialization.h"#include "IO/OutputInfo.h"#include "Classifier.h"#include "WeakLearners/SingleStumpLearner.h" // for saveSingleStumpFeatureData#include <iomanip> // for setwnamespace MultiBoost {// -------------------------------------------------------------------------/*** Holds the results per example.* This class holds all the results obtained with computeResults(), * which is equivalent to \f${\bf g}(x)\f$. It also offers two methods* that allow the quick evaluation of the results.* @date 16/11/2005*/class ExampleResults {public: /** * The constructor. Initialize the index and the votes vector. * @param idx The index of the example. * @param numClasses The number of classes. * @date 16/11/2005 */ ExampleResults(int idx, int numClasses) : idx(idx), votesVector(numClasses, 0) {} int idx; //!< The index of the example /** * The vector with the results. Equivalent to what returned by \f${\bf g}(x)\f$. * @remark It is public because the methods of Classifier will access it * directly. */ vector<double> votesVector; /** * Get the winner. * @param rank The rank. 0 = winner. 1 = second, etc.. * @return A pair <\f$\ell\f$, \f$g_\ell(x)\f$>, where \f$\ell\f$ is the class index. * @date 16/11/2005 */ pair<int, double> getWinner(int rank = 0); /** * Checks if the given class is the winner class. * Example: if the ranking is 5 2 6 3 1 4 (in class indexes): * \code * isWinner(5,0); // -> true * isWinner(2,0); // -> false * isWinner(2,1); // -> true * isWinner(3,3); // -> true * \endcode * @param idxRealClass The index of the actual class. * @param atLeastRank The maximum rank in which the class must be * to be considered a "winner". * @date 16/11/2005 */ bool isWinner(int idxRealClass, int atLeastRank = 0) const;private: /** * Create a sorted ranking list. It uses votesVector to build a vector * of pairs that contains the index of the class and the value of the votes * (that is a vector of <\f$\ell\f$, \f$g_\ell(x)\f$>), which is sorted * by the second element, resulting in a ranking of the votes per class. * @param rankedList the vector that will be filled with the rankings. * @date 16/11/2005 */ void getRankedList( vector< pair<int, double> >& rankedList ) const;}; // ExampleResults// -------------------------------------------------------------------------// -------------------------------------------------------------------------pair<int, double> ExampleResults::getWinner(int rank){ assert(rank >= 0); vector< pair<int, double> > rankedList; getRankedList(rankedList); // get the sorted rankings return rankedList[rank];}// -------------------------------------------------------------------------bool ExampleResults::isWinner(int idxRealClass, int atLeastRank) const{ assert(atLeastRank >= 0); vector< pair<int, double> > rankedList; getRankedList(rankedList); // get the sorted rankings for (int i = 0; i <= atLeastRank; ++i) { if ( rankedList[i].first == idxRealClass ) return true; } return false;}// -------------------------------------------------------------------------void ExampleResults::getRankedList( vector< pair<int, double> >& rankedList ) const{ rankedList.resize(votesVector.size()); vector<double>::const_iterator vIt; const vector<double>::const_iterator votesVectorEnd = votesVector.end(); int i; for (vIt = votesVector.begin(), i = 0; vIt != votesVectorEnd; ++vIt, ++i ) rankedList[i] = make_pair(i, *vIt); sort( rankedList.begin(), rankedList.end(), nor_utils::comparePairOnSecond< int, double, greater<double> > );}// -------------------------------------------------------------------------// -------------------------------------------------------------------------// -------------------------------------------------------------------------// -------------------------------------------------------------------------Classifier::Classifier(nor_utils::Args &args, int verbose): _args(args), _verbose(verbose){ // The file with the step-by-step information if ( args.hasArgument("outputinfo") ) args.getValue("outputinfo", 0, _outputInfoFile);}// -------------------------------------------------------------------------void Classifier::run(const string& dataFileName, const string& shypFileName, const string& outResFileName, int numRanksEnclosed){ InputData* pData = loadInputData(dataFileName, shypFileName); if (_verbose > 0) cout << "Loading strong hypothesis..." << flush; // The class that loads the weak hypotheses UnSerialization us; // Where to put the weak hypotheses vector<BaseLearner*> weakHypotheses; // loads them us.loadHypotheses(shypFileName, weakHypotheses); // where the results go vector< ExampleResults* > results; if (_verbose > 0) cout << "Classifying..." << flush; // get the results computeResults( pData, weakHypotheses, results ); const int numClasses = ClassMappings::getNumClasses(); if (_verbose > 0) { // well.. if verbose = 0 no results are displayed! :) cout << "Done!" << endl; vector< vector<double> > rankedError(numRanksEnclosed); // Get the per-class error for the numRanksEnclosed-th ranks for (int i = 0; i < numRanksEnclosed; ++i) getClassError( pData, results, rankedError[i], i ); // output it cout << endl; cout << "Error Summary" << endl; cout << "=============" << endl; for ( int l = 0; l < numClasses; ++l ) { // first rank (winner): rankedError[0] cout << "Class '" << ClassMappings::getClassNameFromIdx(l) << "': " << setprecision(4) << rankedError[0][l] * 100 << "%"; // output the others on its side if (numRanksEnclosed > 1 && _verbose > 1) { cout << " ("; for (int i = 1; i < numRanksEnclosed; ++i) cout << " " << i+1 << ":[" << setprecision(4) << rankedError[i][l] * 100 << "%]"; cout << " )"; } cout << endl; } // the overall error cout << "\n--> Overall Error: " << setprecision(4) << getOverallError(pData, results, 0) * 100 << "%"; // output the others on its side if (numRanksEnclosed > 1 && _verbose > 1) { cout << " ("; for (int i = 1; i < numRanksEnclosed; ++i) cout << " " << i+1 << ":[" << setprecision(4) << getOverallError(pData, results, i) * 100 << "%]"; cout << " )"; } cout << endl; } // verbose // If asked output the results if ( !outResFileName.empty() ) {
const int numExamples = pData->getNumExamples();
ofstream outRes(outResFileName.c_str());
string exampleLabel;
for (int i = 0; i < numExamples; ++i)
{
// output the label if it exists, otherwise the number
// of the example
exampleLabel = pData->getLabel(i);
if ( exampleLabel.empty() )
outRes << i << '\t';
else
outRes << exampleLabel << '\t';
// output the predicted class
outRes << ClassMappings::getClassNameFromIdx( results[i]->getWinner().first ) << endl;
}
if (_verbose > 0)
cout << "\nPredictions written on file <" << outResFileName << ">!" << endl;
} // delete the input data file if (pData) delete pData; vector<ExampleResults*>::iterator it; for (it = results.begin(); it != results.end(); ++it) delete (*it);}// -------------------------------------------------------------------------void Classifier::printConfusionMatrix(const string& dataFileName, const string& shypFileName){ InputData* pData = loadInputData(dataFileName, shypFileName); if (_verbose > 0) cout << "Loading strong hypothesis..." << flush; // The class that loads the weak hypotheses UnSerialization us; // Where to put the weak hypotheses vector<BaseLearner*> weakHypotheses; // loads them us.loadHypotheses(shypFileName, weakHypotheses); // where the results go vector< ExampleResults* > results; if (_verbose > 0) cout << "Classifying..." << flush; // get the results computeResults( pData, weakHypotheses, results ); const int numClasses = ClassMappings::getNumClasses(); const int numExamples = pData->getNumExamples(); if (_verbose > 0) cout << "Done!" << endl; const int colSize = 7; if (_verbose > 0) { cout << "Raw Confusion Matrix:\n"; cout << setw(colSize) << "Truth "; for (int l = 0; l < numClasses; ++l) cout << setw(colSize) << nor_utils::getAlphanumeric(l); cout << "\nClassification\n"; for (int l = 0; l < numClasses; ++l) { vector<int> winnerCount(numClasses, 0); for (int i = 0; i < numExamples; ++i) { if ( pData->getClass(i) == l ) ++winnerCount[ results[i]->getWinner().first ]; } // class cout << setw(colSize) << " " << nor_utils::getAlphanumeric(l); for (int j = 0; j < numClasses; ++j) cout << setw(colSize) << winnerCount[j]; cout << endl; } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -