⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classifier.cpp

📁 MultiBoost 是c++实现的多类adaboost酸法。与传统的adaboost算法主要解决二类分类问题不同
💻 CPP
📖 第 1 页 / 共 2 页
字号:
/** This file is part of MultiBoost, a multi-class * AdaBoost learner/classifier** Copyright (C) 2005-2006 Norman Casagrande* For informations write to nova77@gmail.com** This library is free software; you can redistribute it and/or* modify it under the terms of the GNU Lesser General Public* License as published by the Free Software Foundation; either* version 2.1 of the License, or (at your option) any later version.** This library is distributed in the hope that it will be useful,* but WITHOUT ANY WARRANTY; without even the implied warranty of* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU* Lesser General Public License for more details.** You should have received a copy of the GNU Lesser General Public* License along with this library; if not, write to the Free Software* Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA**/#include "IO/Serialization.h"#include "IO/OutputInfo.h"#include "Classifier.h"#include "WeakLearners/SingleStumpLearner.h" // for saveSingleStumpFeatureData#include <iomanip> // for setwnamespace MultiBoost {// -------------------------------------------------------------------------/*** Holds the results per example.* This class holds all the results obtained with computeResults(), * which is equivalent to \f${\bf g}(x)\f$. It also offers two methods* that allow the quick evaluation of the results.* @date 16/11/2005*/class ExampleResults {public:   /**   * The constructor. Initialize the index and the votes vector.   * @param idx The index of the example.   * @param numClasses The number of classes.   * @date 16/11/2005   */   ExampleResults(int idx, int numClasses)       :  idx(idx), votesVector(numClasses, 0) {}   int idx; //!< The index of the example    /**   * The vector with the results. Equivalent to what returned by \f${\bf g}(x)\f$.   * @remark It is public because the methods of Classifier will access it   * directly.   */   vector<double> votesVector;    /**   * Get the winner.    * @param rank The rank. 0 = winner. 1 = second, etc..   * @return A pair <\f$\ell\f$, \f$g_\ell(x)\f$>, where \f$\ell\f$ is the class index.   * @date 16/11/2005   */   pair<int, double> getWinner(int rank = 0);   /**   * Checks if the given class is the winner class.   * Example: if the ranking is 5 2 6 3 1 4 (in class indexes):   * \code   * isWinner(5,0); // -> true   * isWinner(2,0); // -> false   * isWinner(2,1); // -> true   * isWinner(3,3); // -> true   * \endcode   * @param idxRealClass The index of the actual class.   * @param atLeastRank The maximum rank in which the class must be   * to be considered a "winner".   * @date 16/11/2005   */   bool isWinner(int idxRealClass, int atLeastRank = 0) const;private:   /**   * Create a sorted ranking list. It uses votesVector to build a vector   * of pairs that contains the index of the class and the value of the votes   * (that is a vector of <\f$\ell\f$, \f$g_\ell(x)\f$>), which is sorted   * by the second element, resulting in a ranking of the votes per class.   * @param rankedList the vector that will be filled with the rankings.   * @date 16/11/2005   */   void getRankedList( vector< pair<int, double> >& rankedList ) const;}; // ExampleResults// -------------------------------------------------------------------------// -------------------------------------------------------------------------pair<int, double> ExampleResults::getWinner(int rank){   assert(rank >= 0);   vector< pair<int, double> > rankedList;   getRankedList(rankedList); // get the sorted rankings   return rankedList[rank];}// -------------------------------------------------------------------------bool ExampleResults::isWinner(int idxRealClass, int atLeastRank) const{   assert(atLeastRank >= 0);   vector< pair<int, double> > rankedList;   getRankedList(rankedList); // get the sorted rankings   for (int i = 0; i <= atLeastRank; ++i)   {      if ( rankedList[i].first == idxRealClass )         return true;   }   return false;}// -------------------------------------------------------------------------void ExampleResults::getRankedList( vector< pair<int, double> >& rankedList ) const{   rankedList.resize(votesVector.size());   vector<double>::const_iterator vIt;   const vector<double>::const_iterator votesVectorEnd = votesVector.end();   int i;   for (vIt = votesVector.begin(), i = 0; vIt != votesVectorEnd; ++vIt, ++i )      rankedList[i] = make_pair(i, *vIt);   sort( rankedList.begin(), rankedList.end(),       nor_utils::comparePairOnSecond< int, double, greater<double> > );}// -------------------------------------------------------------------------// -------------------------------------------------------------------------// -------------------------------------------------------------------------// -------------------------------------------------------------------------Classifier::Classifier(nor_utils::Args &args, int verbose): _args(args), _verbose(verbose){   // The file with the step-by-step information   if ( args.hasArgument("outputinfo") )      args.getValue("outputinfo", 0, _outputInfoFile);}// -------------------------------------------------------------------------void Classifier::run(const string& dataFileName, const string& shypFileName,                      const string& outResFileName, int numRanksEnclosed){   InputData* pData = loadInputData(dataFileName, shypFileName);   if (_verbose > 0)      cout << "Loading strong hypothesis..." << flush;   // The class that loads the weak hypotheses   UnSerialization us;   // Where to put the weak hypotheses   vector<BaseLearner*> weakHypotheses;   // loads them   us.loadHypotheses(shypFileName, weakHypotheses);   // where the results go   vector< ExampleResults* > results;   if (_verbose > 0)      cout << "Classifying..." << flush;   // get the results   computeResults( pData, weakHypotheses, results );   const int numClasses = ClassMappings::getNumClasses();   if (_verbose > 0)   {      // well.. if verbose = 0 no results are displayed! :)      cout << "Done!" << endl;      vector< vector<double> > rankedError(numRanksEnclosed);      // Get the per-class error for the numRanksEnclosed-th ranks      for (int i = 0; i < numRanksEnclosed; ++i)         getClassError( pData, results, rankedError[i], i );      // output it      cout << endl;      cout << "Error Summary" << endl;      cout << "=============" << endl;      for ( int l = 0; l < numClasses; ++l )      {         // first rank (winner): rankedError[0]         cout << "Class '" << ClassMappings::getClassNameFromIdx(l) << "': "              << setprecision(4) << rankedError[0][l] * 100 << "%";         // output the others on its side         if (numRanksEnclosed > 1 && _verbose > 1)         {            cout << " (";            for (int i = 1; i < numRanksEnclosed; ++i)               cout << " " << i+1 << ":[" << setprecision(4) << rankedError[i][l] * 100 << "%]";            cout << " )";         }         cout << endl;      }      // the overall error      cout << "\n--> Overall Error: "            << setprecision(4) << getOverallError(pData, results, 0) * 100 << "%";      // output the others on its side      if (numRanksEnclosed > 1 && _verbose > 1)      {         cout << " (";         for (int i = 1; i < numRanksEnclosed; ++i)            cout << " " << i+1 << ":[" << setprecision(4) << getOverallError(pData, results, i) * 100 << "%]";         cout << " )";      }      cout << endl;   } // verbose   // If asked output the results   if ( !outResFileName.empty() )   {
      const int numExamples = pData->getNumExamples();
      ofstream outRes(outResFileName.c_str());

      string exampleLabel;

      for (int i = 0; i < numExamples; ++i)
      {
         // output the label if it exists, otherwise the number
         // of the example
         exampleLabel = pData->getLabel(i);
         if ( exampleLabel.empty() )
            outRes << i << '\t';
         else
            outRes << exampleLabel << '\t';

         // output the predicted class
         outRes << ClassMappings::getClassNameFromIdx( results[i]->getWinner().first ) << endl;
      }

      if (_verbose > 0)
         cout << "\nPredictions written on file <" << outResFileName << ">!" << endl;

   }   // delete the input data file   if (pData)       delete pData;   vector<ExampleResults*>::iterator it;   for (it = results.begin(); it != results.end(); ++it)      delete (*it);}// -------------------------------------------------------------------------void Classifier::printConfusionMatrix(const string& dataFileName, const string& shypFileName){   InputData* pData = loadInputData(dataFileName, shypFileName);   if (_verbose > 0)      cout << "Loading strong hypothesis..." << flush;   // The class that loads the weak hypotheses   UnSerialization us;   // Where to put the weak hypotheses   vector<BaseLearner*> weakHypotheses;   // loads them   us.loadHypotheses(shypFileName, weakHypotheses);   // where the results go   vector< ExampleResults* > results;   if (_verbose > 0)      cout << "Classifying..." << flush;   // get the results   computeResults( pData, weakHypotheses, results );   const int numClasses = ClassMappings::getNumClasses();   const int numExamples = pData->getNumExamples();   if (_verbose > 0)      cout << "Done!" << endl;   const int colSize = 7;   if (_verbose > 0)   {      cout << "Raw Confusion Matrix:\n";      cout << setw(colSize) << "Truth       ";      for (int l = 0; l < numClasses; ++l)         cout << setw(colSize) << nor_utils::getAlphanumeric(l);      cout << "\nClassification\n";      for (int l = 0; l < numClasses; ++l)      {         vector<int> winnerCount(numClasses, 0);         for (int i = 0; i < numExamples; ++i)         {            if ( pData->getClass(i) == l )               ++winnerCount[ results[i]->getWinner().first ];         }         // class         cout << setw(colSize) << "           " << nor_utils::getAlphanumeric(l);         for (int j = 0; j < numClasses; ++j)            cout << setw(colSize) << winnerCount[j];         cout << endl;      }   }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -