⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svmfutrain.cpp

📁 This is SvmFu, a package for training and testing support vector machines (SVMs). It s written in C
💻 CPP
字号:
// Copyright (C) 2000 Ryan M. Rifkin <rif@mit.edu>//  // This program is free software; you can redistribute it and/or// modify it under the terms of the GNU General Public License as// published by the Free Software Foundation; either version 2 of the// License, or (at your option) any later version.//  // This program is distributed in the hope that it will be useful, but// WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU// General Public License for more details.//  // You should have received a copy of the GNU General Public License// along with this program; if not, write to the Free Software// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA// 02111-1307, USA.//  #include "clientincludes.h"#include <sys/time.h>// #include <unistd.h>// Prototypesvoid out_of_store();void parseArgs(int argc, char **argv);// file scope varsdouble posC=1, negC=1;double tol=10E-4;double eps=10E-12;char inputFile[255] = "inputfile";char kernelFile[255];bool useKernelFileP = false;bool useOutputFileP = false;bool useAsciiP = false;bool usePosNegC = false;char outputFile[255];char svmFile[255];bool computeLOO;char looFile[255];int extraCacheRows = 0;int chunkSize = 500;void out_of_store() {  cerr << "op new failed: out of store.\n";  exit(1);}// long optionsstruct option optionNames[] ={  {"compute-loo", required_argument, 0, 'l'},  {"c",           required_argument, 0, 'C'},  {"neg-c",       required_argument, 0, 'N'},  {"input-file",  required_argument, 0, 'f'},  {"kernel-file", required_argument, 0, 'i'},  {"output-file", required_argument, 0, 's'},  {"chunk-size",  required_argument, 0, 'h'},  {"tolerance",   required_argument, 0, 't'},  {"epsilon",     required_argument, 0, 'e'},  {"rep",         required_argument, 0, 'r'},  {"degree",      required_argument, 0, 'd'},  {"offset",      required_argument, 0, 'b'},  {"sigma",       required_argument, 0, 'g'},  {"normalizer",  required_argument, 0, 'n'},  {"kernel",      required_argument, 0, 'k'},  {"use-ascii",   no_argument, 0, 'a'},  {"elt-type",    required_argument, 0, 'D'},  {"kern-type",   required_argument, 0, 'K'},  {"cache-rows",  required_argument, 0, 'c'},  {"version",     no_argument, 0, 'v'},  {"help",        no_argument, 0, '?'},  {0, 0, 0, 0}};void usage(char *basename) {  cerr<<"usage: " << basename << " ..." << endl;  cerr<<"  -l, --compute-loo=filename   compute leave-one-out\n";  cerr<<"  -C, --c=value                value for 'C'\n";  cerr<<"  -N, --neg-c=value            value for negative 'C'\n";  cerr<<"  -f, --input-file=filename    training set input file\n";  cerr<<"  -s, --output-file=filename   where to save the svm data\n";  cerr<<"  -h, --chunk-size=value       \n";  cerr<<"  -t, --tolerance=value        \n";  cerr<<"  -e, --epsilon=value          \n";  cerr<<"  -d, --degree=value           degree (for polynomial kernels)\n";  cerr<<"  -b, --offset=value           offset/bias (for polynomial)\n";  cerr<<"  -g, --sigma=value            sigma (for gaussian kernels)\n";  cerr<<"  -n, --normalizer=value       normalizer (for all kernels)\n";  cerr<<"  -r, --rep=[01,N]             sparse01 or sparseN (default dense)\n";  cerr<<"  -k, --kernel=[polynomial,gaussian,linear]\n";  cerr<<"  -a, --use-ascii              use portable ascii for save file\n";  cerr<<"  -c, --cache-rows=value       cache rows\n";  cerr<<"  -D, --elt-type=type          type for data elements (see below)\n";  cerr<<"  -K, --kern-type=type         type for kernel values (see below)\n";  cerr<<"  -v, --version                print SvmFu version ("<<VERSION<<")\n";  cerr<<"  -?, --help                   this message\n";  cout << "Supported types:" << endl;#define IterateTypes(datatype,kerntype) \  cout << "  --elt-type " << #datatype \       << ", --kern-type " << #kerntype << endl;#include "SvmFuSvmTypes.h"  exit(0);}/*! Parses command line arguments into global variables. * \todo use getoptlong */void parseArgs(int argc, char **argv) {  char c;  int opti=0;  char *tmp_datatype=strdup("double");  char *tmp_kerntype=strdup("double");  while ((c = getopt_long(argc, argv,			  "l:C:N:f:i:s:t:e:d:b:g:n:k:r:h:aD:K:c:?",			  optionNames, &opti)) != -1) {	switch(c) {    case 'l':      computeLOO = true;      strcpy(looFile,optarg);      break;    case 'C':      posC = atof(optarg);      break;    case 'N':      negC = atof(optarg);      usePosNegC = true;      break;    case 'f':      strcpy(inputFile, optarg);      break;    case 'i':      strcpy(kernelFile, optarg);      useKernelFileP = true;      break;    case 's':      useOutputFileP = true;      strcpy(outputFile, optarg);      break;    case 'h':      chunkSize = atoi(optarg);      break;    case 't':      tol = atof(optarg);      break;    case 'e':      eps = atof(optarg);      break;    case 'r':      if (!strcmp(optarg, "01")) {		machType = sparse01;		break;      } else if (!strcmp(optarg, "N")) {		machType = sparseN;		break;      } else {	cout << "Argument to -r option must be '01' or 'N'.  Aborting." << endl;	exit(1);      }    case 'd':      degree = atoi(optarg);      break;    case 'b':      offset = atof(optarg);      break;    case 'g':      sigma = atof(optarg);      break;    case 'n':      normalizer = atof(optarg);      break;    case 'k':      if (!strcmp(optarg, "linear")) {		kernType = linear;      } else if (!strcmp(optarg, "polynomial")) {		kernType = polynomial;      } else if (!strcmp(optarg, "gaussian")) {		kernType = gaussian;      }  else {		cout << "Error: Unknown kernel function " << optarg << 		  ".  Aborting."  << endl;		exit(-1);      }      break;    case 'a':      useAsciiP = true;      break;    case 'D':      free(tmp_datatype);      tmp_datatype = strdup(optarg);      break;    case 'K':      free(tmp_kerntype);      tmp_kerntype = strdup(optarg);      break;    case 'c':      extraCacheRows = atoi(optarg);      break;    case 'v':      cout << "svmfu " << VERSION << endl;      exit(0);    case '?':    default:      usage(*argv);    }  }  // Iterate over the known types and figure out which one they wanted#define IterateTypes(datatype,kerntype) \  if((strcmp(tmp_datatype,#datatype)==0) \  && (strcmp(tmp_kerntype,#kerntype)==0)) { \    g_eltType = type_##datatype; \    g_kernType = type_##kerntype; \    goto gotType; \  }#include "SvmFuSvmTypes.h"  cout << "Error: dataElt type '" << tmp_datatype;  cout << "' and kernVal type '" << tmp_kerntype << "' are not" << endl;  cout << "supported by this compilation of SvmFu.  Supported types:" << endl;#define IterateTypes(datatype,kerntype) \  cout << "  dataElt " << #datatype << ", kernVal " << #kerntype << endl;#include "SvmFuSvmTypes.h"  exit(1);gotType:  cout << "inputFile: " << inputFile << endl;    if (usePosNegC) {    cout <<"\tposC: " << posC << ", negC: " << negC << endl;  } else {    cout << "\tC: " << posC << endl;  }  if (useOutputFileP) {    cout << "\toutputFile: " << outputFile << endl;  }  cout << "chunkSize: " << chunkSize << endl;    cout << "\ttol: " << tol << endl;  cout << "\teps: " << eps << endl;  cout << "Machine Type: ";  switch (machType) {  case dense:    cout << "Dense" << endl;    break;  case sparse01:     cout << "Sparse01" << endl;    break;  case sparseN:    cout << "SparseN" << endl;    break;  default:    cout << "Unknown machine type.  Aborting." << endl;    exit(1);  }    cout << "Kernel Type: ";  switch (kernType) {  case linear:    cout << "Linear" << endl;    break;  case polynomial:    cout << "Polynomial, Degree=" << degree << ", Offset=" << offset << endl;    break;  case gaussian:    cout << "Gaussian, Sigma=" << sigma << endl;    break;  default:    cout << "Unknown Kernel Type, Aborting." << endl;    exit(1);  }  cout << "Kernel Value Type: " << tmp_kerntype << endl;  cout << "Data Element Value Type: " << tmp_datatype << endl;  cout << "Kernel Normalization Term: " << normalizer << endl;  free(tmp_kerntype);  free(tmp_datatype);}//! Templatized wrapper for main() functiontemplate<class DataElt, class KernVal> class templateMain{ public:  static int main () {    // local vars    int *y;        FileIO<DataElt, KernVal> fileIO;    DataSet<DataElt> data = fileIO.readDataSet(inputFile);        KernelFuncs<KernVal, DataElt>      kernFunctions(kernType, machType);            // default is to use    if (chunkSize > data.size) { chunkSize = data.size; }    if (extraCacheRows == -1)      extraCacheRows = data.size - chunkSize;    cout << "Extra cache rows: " << extraCacheRows << endl;        SvmLargeOpt<DataPoint<DataElt>, KernVal> svm(data.size,   // number of points						 data.y,      // the y values						 data.points, // the points						 extraCacheRows,						 kernFunctions.pfunc,						 chunkSize, posC,						 tol, eps, 5);        if (kernType == linear) {      svm.useLinearKernel(data.dim, kernFunctions.afunc, kernFunctions.mfunc);    }    if (usePosNegC) {      for (int i = 0; i < data.size; i++) {	if (data.y[i] == -1) {	  svm.setC(i, negC);	}      }    }        svm.optimize();        cout << "Optimization Complete, Obj=" << svm.dualObjFunc() << endl;    cout << "# SV's == " << svm.getNumSupVecs() << ", #USV's == "	 << svm.getNumUnbndSupVecs() << endl;    if (computeLOO) {      double *LOOVals = new double[data.size];      double *funcVals = new double[data.size];      int i;      for (i = 0; i < data.size; i++) {	funcVals[i] = svm.outputAtTrainingExample(i);      }      ofstream os(looFile);      if (!os) {  	cerr << "Error: Cannot open " << looFile << " for writing." << endl;  	exit(1);      }      cout << "LOO errors: "	   << svm.computeLeaveOneOutErrors(LOOVals) << endl;            for (int i = 0; i < data.size; i++) {	int y = data.y[i];  	os << i << ", y=" << y << ", a=" << svm.getAlpha(i) << ", f(x)=" <<   	  funcVals[i] << ", f_L(x)=" << LOOVals[i];  	if (y*LOOVals[i] < 0) { os << "  LOO-ERROR"; }  	os << endl;      }            os.close();    }    // This can actually be quite SLOW in the current implementation.    // This is unfortunate.    // svm->computeTrainingSetPerf(true);            if (useOutputFileP) {      fileIO.saveSvm(&svm, outputFile, useAsciiP);    }    // cleanup    for (int i = 0; i < data.size; i++)      {	if(machType==sparse01 || machType==sparseN)	  delete data.points[i].index;	if(machType==dense || machType==sparseN)	  delete data.points[i].value;      }    delete data.points;    delete data.y;    return 0;   // success  }};/*! \todo attempt to autodetect data types from input file */int main(int argc, char **argv){  set_new_handler(&out_of_store);  parseArgs(argc, argv);#define IterateTypes(datatype,kerntype) \  if(g_eltType==type_##datatype && g_kernType==type_##kerntype) { \    templateMain<datatype,kerntype> m; \    return m.main(); \  }#include "SvmFuSvmTypes.h"}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -