📄 svmfutrain.cpp
字号:
// Copyright (C) 2000 Ryan M. Rifkin <rif@mit.edu>// // This program is free software; you can redistribute it and/or// modify it under the terms of the GNU General Public License as// published by the Free Software Foundation; either version 2 of the// License, or (at your option) any later version.// // This program is distributed in the hope that it will be useful, but// WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU// General Public License for more details.// // You should have received a copy of the GNU General Public License// along with this program; if not, write to the Free Software// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA// 02111-1307, USA.// #include "clientincludes.h"#include <sys/time.h>// #include <unistd.h>// Prototypesvoid out_of_store();void parseArgs(int argc, char **argv);// file scope varsdouble posC=1, negC=1;double tol=10E-4;double eps=10E-12;char inputFile[255] = "inputfile";char kernelFile[255];bool useKernelFileP = false;bool useOutputFileP = false;bool useAsciiP = false;bool usePosNegC = false;char outputFile[255];char svmFile[255];bool computeLOO;char looFile[255];int extraCacheRows = 0;int chunkSize = 500;void out_of_store() { cerr << "op new failed: out of store.\n"; exit(1);}// long optionsstruct option optionNames[] ={ {"compute-loo", required_argument, 0, 'l'}, {"c", required_argument, 0, 'C'}, {"neg-c", required_argument, 0, 'N'}, {"input-file", required_argument, 0, 'f'}, {"kernel-file", required_argument, 0, 'i'}, {"output-file", required_argument, 0, 's'}, {"chunk-size", required_argument, 0, 'h'}, {"tolerance", required_argument, 0, 't'}, {"epsilon", required_argument, 0, 'e'}, {"rep", required_argument, 0, 'r'}, {"degree", required_argument, 0, 'd'}, {"offset", required_argument, 0, 'b'}, {"sigma", required_argument, 0, 'g'}, {"normalizer", required_argument, 0, 'n'}, {"kernel", required_argument, 0, 'k'}, {"use-ascii", no_argument, 0, 'a'}, {"elt-type", required_argument, 0, 'D'}, {"kern-type", required_argument, 0, 'K'}, {"cache-rows", required_argument, 0, 'c'}, {"version", no_argument, 0, 'v'}, {"help", no_argument, 0, '?'}, {0, 0, 0, 0}};void usage(char *basename) { cerr<<"usage: " << basename << " ..." << endl; cerr<<" -l, --compute-loo=filename compute leave-one-out\n"; cerr<<" -C, --c=value value for 'C'\n"; cerr<<" -N, --neg-c=value value for negative 'C'\n"; cerr<<" -f, --input-file=filename training set input file\n"; cerr<<" -s, --output-file=filename where to save the svm data\n"; cerr<<" -h, --chunk-size=value \n"; cerr<<" -t, --tolerance=value \n"; cerr<<" -e, --epsilon=value \n"; cerr<<" -d, --degree=value degree (for polynomial kernels)\n"; cerr<<" -b, --offset=value offset/bias (for polynomial)\n"; cerr<<" -g, --sigma=value sigma (for gaussian kernels)\n"; cerr<<" -n, --normalizer=value normalizer (for all kernels)\n"; cerr<<" -r, --rep=[01,N] sparse01 or sparseN (default dense)\n"; cerr<<" -k, --kernel=[polynomial,gaussian,linear]\n"; cerr<<" -a, --use-ascii use portable ascii for save file\n"; cerr<<" -c, --cache-rows=value cache rows\n"; cerr<<" -D, --elt-type=type type for data elements (see below)\n"; cerr<<" -K, --kern-type=type type for kernel values (see below)\n"; cerr<<" -v, --version print SvmFu version ("<<VERSION<<")\n"; cerr<<" -?, --help this message\n"; cout << "Supported types:" << endl;#define IterateTypes(datatype,kerntype) \ cout << " --elt-type " << #datatype \ << ", --kern-type " << #kerntype << endl;#include "SvmFuSvmTypes.h" exit(0);}/*! Parses command line arguments into global variables. * \todo use getoptlong */void parseArgs(int argc, char **argv) { char c; int opti=0; char *tmp_datatype=strdup("double"); char *tmp_kerntype=strdup("double"); while ((c = getopt_long(argc, argv, "l:C:N:f:i:s:t:e:d:b:g:n:k:r:h:aD:K:c:?", optionNames, &opti)) != -1) { switch(c) { case 'l': computeLOO = true; strcpy(looFile,optarg); break; case 'C': posC = atof(optarg); break; case 'N': negC = atof(optarg); usePosNegC = true; break; case 'f': strcpy(inputFile, optarg); break; case 'i': strcpy(kernelFile, optarg); useKernelFileP = true; break; case 's': useOutputFileP = true; strcpy(outputFile, optarg); break; case 'h': chunkSize = atoi(optarg); break; case 't': tol = atof(optarg); break; case 'e': eps = atof(optarg); break; case 'r': if (!strcmp(optarg, "01")) { machType = sparse01; break; } else if (!strcmp(optarg, "N")) { machType = sparseN; break; } else { cout << "Argument to -r option must be '01' or 'N'. Aborting." << endl; exit(1); } case 'd': degree = atoi(optarg); break; case 'b': offset = atof(optarg); break; case 'g': sigma = atof(optarg); break; case 'n': normalizer = atof(optarg); break; case 'k': if (!strcmp(optarg, "linear")) { kernType = linear; } else if (!strcmp(optarg, "polynomial")) { kernType = polynomial; } else if (!strcmp(optarg, "gaussian")) { kernType = gaussian; } else { cout << "Error: Unknown kernel function " << optarg << ". Aborting." << endl; exit(-1); } break; case 'a': useAsciiP = true; break; case 'D': free(tmp_datatype); tmp_datatype = strdup(optarg); break; case 'K': free(tmp_kerntype); tmp_kerntype = strdup(optarg); break; case 'c': extraCacheRows = atoi(optarg); break; case 'v': cout << "svmfu " << VERSION << endl; exit(0); case '?': default: usage(*argv); } } // Iterate over the known types and figure out which one they wanted#define IterateTypes(datatype,kerntype) \ if((strcmp(tmp_datatype,#datatype)==0) \ && (strcmp(tmp_kerntype,#kerntype)==0)) { \ g_eltType = type_##datatype; \ g_kernType = type_##kerntype; \ goto gotType; \ }#include "SvmFuSvmTypes.h" cout << "Error: dataElt type '" << tmp_datatype; cout << "' and kernVal type '" << tmp_kerntype << "' are not" << endl; cout << "supported by this compilation of SvmFu. Supported types:" << endl;#define IterateTypes(datatype,kerntype) \ cout << " dataElt " << #datatype << ", kernVal " << #kerntype << endl;#include "SvmFuSvmTypes.h" exit(1);gotType: cout << "inputFile: " << inputFile << endl; if (usePosNegC) { cout <<"\tposC: " << posC << ", negC: " << negC << endl; } else { cout << "\tC: " << posC << endl; } if (useOutputFileP) { cout << "\toutputFile: " << outputFile << endl; } cout << "chunkSize: " << chunkSize << endl; cout << "\ttol: " << tol << endl; cout << "\teps: " << eps << endl; cout << "Machine Type: "; switch (machType) { case dense: cout << "Dense" << endl; break; case sparse01: cout << "Sparse01" << endl; break; case sparseN: cout << "SparseN" << endl; break; default: cout << "Unknown machine type. Aborting." << endl; exit(1); } cout << "Kernel Type: "; switch (kernType) { case linear: cout << "Linear" << endl; break; case polynomial: cout << "Polynomial, Degree=" << degree << ", Offset=" << offset << endl; break; case gaussian: cout << "Gaussian, Sigma=" << sigma << endl; break; default: cout << "Unknown Kernel Type, Aborting." << endl; exit(1); } cout << "Kernel Value Type: " << tmp_kerntype << endl; cout << "Data Element Value Type: " << tmp_datatype << endl; cout << "Kernel Normalization Term: " << normalizer << endl; free(tmp_kerntype); free(tmp_datatype);}//! Templatized wrapper for main() functiontemplate<class DataElt, class KernVal> class templateMain{ public: static int main () { // local vars int *y; FileIO<DataElt, KernVal> fileIO; DataSet<DataElt> data = fileIO.readDataSet(inputFile); KernelFuncs<KernVal, DataElt> kernFunctions(kernType, machType); // default is to use if (chunkSize > data.size) { chunkSize = data.size; } if (extraCacheRows == -1) extraCacheRows = data.size - chunkSize; cout << "Extra cache rows: " << extraCacheRows << endl; SvmLargeOpt<DataPoint<DataElt>, KernVal> svm(data.size, // number of points data.y, // the y values data.points, // the points extraCacheRows, kernFunctions.pfunc, chunkSize, posC, tol, eps, 5); if (kernType == linear) { svm.useLinearKernel(data.dim, kernFunctions.afunc, kernFunctions.mfunc); } if (usePosNegC) { for (int i = 0; i < data.size; i++) { if (data.y[i] == -1) { svm.setC(i, negC); } } } svm.optimize(); cout << "Optimization Complete, Obj=" << svm.dualObjFunc() << endl; cout << "# SV's == " << svm.getNumSupVecs() << ", #USV's == " << svm.getNumUnbndSupVecs() << endl; if (computeLOO) { double *LOOVals = new double[data.size]; double *funcVals = new double[data.size]; int i; for (i = 0; i < data.size; i++) { funcVals[i] = svm.outputAtTrainingExample(i); } ofstream os(looFile); if (!os) { cerr << "Error: Cannot open " << looFile << " for writing." << endl; exit(1); } cout << "LOO errors: " << svm.computeLeaveOneOutErrors(LOOVals) << endl; for (int i = 0; i < data.size; i++) { int y = data.y[i]; os << i << ", y=" << y << ", a=" << svm.getAlpha(i) << ", f(x)=" << funcVals[i] << ", f_L(x)=" << LOOVals[i]; if (y*LOOVals[i] < 0) { os << " LOO-ERROR"; } os << endl; } os.close(); } // This can actually be quite SLOW in the current implementation. // This is unfortunate. // svm->computeTrainingSetPerf(true); if (useOutputFileP) { fileIO.saveSvm(&svm, outputFile, useAsciiP); } // cleanup for (int i = 0; i < data.size; i++) { if(machType==sparse01 || machType==sparseN) delete data.points[i].index; if(machType==dense || machType==sparseN) delete data.points[i].value; } delete data.points; delete data.y; return 0; // success }};/*! \todo attempt to autodetect data types from input file */int main(int argc, char **argv){ set_new_handler(&out_of_store); parseArgs(argc, argv);#define IterateTypes(datatype,kerntype) \ if(g_eltType==type_##datatype && g_kernType==type_##kerntype) { \ templateMain<datatype,kerntype> m; \ return m.main(); \ }#include "SvmFuSvmTypes.h"}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -