📄 svmfutrainmulti.cpp
字号:
// Copyright (C) 2000 Ryan M. Rifkin <rif@mit.edu>// // This program is free software; you can redistribute it and/or// modify it under the terms of the GNU General Public License as// published by the Free Software Foundation; either version 2 of the// License, or (at your option) any later version.// // This program is distributed in the hope that it will be useful, but// WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU// General Public License for more details.// // You should have received a copy of the GNU General Public License// along with this program; if not, write to the Free Software// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA// 02111-1307, USA.// #include "clientincludes.h"// Prototypesvoid out_of_store();void parseArgs(int argc, char **argv);// Globalsdouble bigC=1, smC=1;bool useEqualCs = true;double tol=10E-4;double eps=10E-12;int *origY; char inputFile[255];char kernelTrainFile[255];char kernelTestFile[255];bool useKernelFileP = false;bool useKernelTrainP = false;bool useKernelTestP = false;bool useSvmFilesP = false;char testFile[255];char testOutputFile[255] = "test.out";char splitsFile[255];char svmFileBase[255];char looFile[255] = "loo.out";bool computeLOO;int extraCacheRows = 0;int chunkSize = 500;int verbosity = 0;bool testFileSetP = false;bool splitsFileSetP = false;void out_of_store() { cerr << "op new failed: out of store.\n"; exit(1);}// long optionsstruct option optionNames[] ={ {"compute-loo", required_argument, 0, 'l'}, {"verbosity", required_argument, 0, 'v'}, {"c", required_argument, 0, 'C'}, {"use-equal-cs", required_argument, 0, 'u'}, {"input-file", required_argument, 0, 'f'}, {"kernel-file", required_argument, 0, 'i'}, {"test-file", required_argument, 0, 'F'}, {"kernel-test-file", required_argument, 0, 'I'}, {"output-base", required_argument, 0, 's'}, {"test-output-file", required_argument, 0, 'O'}, {"splits-file", required_argument, 0, 'p'}, {"chunk-size", required_argument, 0, 'h'}, {"tolerance", required_argument, 0, 't'}, {"epsilon", required_argument, 0, 'e'}, {"rep", required_argument, 0, 'r'}, {"degree", required_argument, 0, 'd'}, {"offset", required_argument, 0, 'b'}, {"sigma", required_argument, 0, 'g'}, {"normalizer", required_argument, 0, 'n'}, {"kernel", required_argument, 0, 'k'}, {"elt-type", required_argument, 0, 'D'}, {"kern-type", required_argument, 0, 'K'}, {"cache-rows", required_argument, 0, 'c'}, {0, 0, 0, 0}};/*! Parses command line arguments into global variables. * \todo use getoptlong */void parseArgs(int argc, char **argv) { char c; int opti; char *tmp_datatype=strdup("double"); char *tmp_kerntype=strdup("double"); while ((c = getopt_long(argc, argv, "v:C:uf:s:i:F:I:O:p:t:e:d:b:g:n:k:r:h:aD:K:c:l:", optionNames, &opti)) != -1) { switch(c) { case 'v': verbosity = atoi(optarg); break; case 'C': bigC = atof(optarg); break; case 'u': useEqualCs = false; break; case 'f': strcpy(inputFile, optarg); break; case 'i': strcpy(kernelTrainFile, optarg); useKernelFileP = true; useKernelTrainP = true; break; case 'F': strcpy(testFile, optarg); testFileSetP = true; break; case 's': strcpy(svmFileBase, optarg); useSvmFilesP = true; case 'I': strcpy(kernelTestFile, optarg); useKernelFileP = true; useKernelTestP = true; break; case 'l': computeLOO = true; strcpy(looFile, optarg); break; case 'O': strcpy(testOutputFile, optarg); break; case 'p': strcpy(splitsFile, optarg); splitsFileSetP = true; break; case 'h': chunkSize = atoi(optarg); break; case 't': tol = atof(optarg); break; case 'e': eps = atof(optarg); break; case 'r': if (!strcmp(optarg, "01")) { // linearKernelFunction = sparse01BinaryProduct; // polynomialKernelFunction = sparse01PolynomialProduct; // gaussianKernelFunction = sparse01GaussianProduct; machType = sparse01; break; } else if (!strcmp(optarg, "N")) { // linearKernelFunction = sparseNLinearProduct; // polynomialKernelFunction = sparseNPolynomialProduct; // gaussianKernelFunction = sparseNGaussianProduct; machType = sparseN; break; } else { cout << "Argument to -r option must be '01' or 'N'. Aborting." << endl; exit(1); } case 'd': degree = atoi(optarg); break; case 'b': offset = atof(optarg); break; case 'g': sigma = atof(optarg); break; case 'n': normalizer = atof(optarg); break; case 'k': if (!strcmp(optarg, "linear")) { // kernelFunctionPtr = &linearKernelFunction; kernType = linear; } else if (!strcmp(optarg, "polynomial")) { // kernelFunctionPtr = &polynomialKernelFunction; kernType = polynomial; } else if (!strcmp(optarg, "gaussian")) { // kernelFunctionPtr = &gaussianKernelFunction; kernType = gaussian; } else { cout << "Error: Unknown kernel function " << optarg << ". Aborting." << endl; exit(-1); } break; case 'D': free(tmp_datatype); tmp_datatype = strdup(optarg); break; case 'K': free(tmp_kerntype); tmp_kerntype = strdup(optarg); break; case 'c': extraCacheRows = atoi(optarg); break; default: abort(); } } // Iterate over the known types and figure out which one they wanted#define IterateTypes(datatype,kerntype) \ if((strcmp(tmp_datatype,#datatype)==0) \ && (strcmp(tmp_kerntype,#kerntype)==0)) { \ g_eltType = type_##datatype; \ g_kernType = type_##kerntype; \ goto gotType; \ }#include "SvmFuSvmTypes.h" cout << "Error: dataElt type '" << tmp_datatype; cout << "' and kernVal type '" << tmp_kerntype << "' are not" << endl; cout << "supported by this compilation of SvmFu. Supported types:" << endl;#define IterateTypes(datatype,kerntype) \ cout << " dataElt " << #datatype << ", kernVal " << #kerntype << endl;#include "SvmFuSvmTypes.h" exit(1);gotType: cout << "inputFile: " << inputFile << endl; if (testFileSetP) { cout << "testFile: " << testFile << endl; } if (splitsFileSetP) { cout << "splitsFile: " << splitsFile << endl; } else { cerr << "ERROR: splits-file (-p) is a required argument." << endl; exit(-1); } if (useKernelTestP && !useKernelTrainP) { cerr << "ERROR: Cannot use kernel test file without kernel train file." << endl; } cout << "\tC: " << bigC; if (useEqualCs) { cout << " (Using Equal C's.)" << endl; } else { cout << " (Using Class-Size Weighted C's.)" << endl; } cout << "\ttol: " << tol << endl; cout << "\teps: " << eps << endl; cout << "Machine Type: "; switch (machType) { case dense: cout << "Dense" << endl; break; case sparse01: cout << "Sparse01" << endl; break; case sparseN: cout << "SparseN" << endl; break; default: cout << "Unknown machine type. Aborting." << endl; exit(1); } cout << "Kernel Type: "; switch (kernType) { case linear: cout << "Linear" << endl; break; case polynomial: cout << "Polynomial, Degree=" << degree << ", Offset=" << offset << endl; break; case gaussian: cout << "Gaussian, Sigma=" << sigma << endl; break; default: cout << "Unknown Kernel Type, Aborting." << endl; exit(1); } cout << "Kernel Value Type: " << tmp_kerntype << endl; cout << "Data Element Value Type: " << tmp_datatype << endl; cout << "Kernel Normalization Term: " << normalizer << endl; free(tmp_kerntype); free(tmp_datatype);}template<class DataElt, class KernVal> class templateMain{ public: static int main () { int i, j, k; double *cVec; double *alphaVec, *outputVec; int *origY, *svmY; int posEx, negEx; FileIO<DataElt, KernVal> fileIO; DataSet<DataElt> data = fileIO.readDataSet(inputFile); SplitsMatrix splitsMat = fileIO.readSplitsMatrix(splitsFile); int numSplits = splitsMat.numSplits; int *numSupVecs = new int[numSplits]; int **supVecs = new (int *)[numSplits]; double **SValphas = new (double *)[numSplits]; double *Bs = new double[numSplits]; double **LOOvals = new (double *)[numSplits]; //Check chunkSize if ( chunkSize > data.size ) { chunkSize = data.size; } if (useKernelTrainP) { chunkSize = data.size; cout << "Reading kernel matrix from file. Fixing chunksize to data size." << endl; } cout << "chunkSize: " << chunkSize << endl; //Make a copy of the y values origY = new int[data.size]; for (i=0; i < data.size; i++) { origY[i] = data.y[i]; } cout << endl; KernelFuncs<KernVal, DataElt> kernFunctions(kernType, machType); if (extraCacheRows == -1) extraCacheRows = data.size - chunkSize; SvmKernCache<DataPoint<DataElt>, KernVal> *kC = new SvmKernCache<DataPoint<DataElt>, KernVal>(chunkSize, extraCacheRows, data.size, data.points, kernFunctions.pfunc); if (useKernelTrainP) { kC->readFromFile(kernelTrainFile); } //Train SVMs // //splits file: each row is an SVM // each col is if examples of that class are positive, negative, // or not in that SVMs training set. cVec = new double[data.size]; alphaVec = new double[data.size]; outputVec = new double[data.size]; for (i=0; i < splitsMat.numSplits; i++) { posEx = 0; negEx = 0; // Create a new class vector, and working set vectors, // alpha and output vectors. for (j=0; j < data.size; j++) { // use the y values from data as indexes into the current splits row // to get the proper y vector for this svm. data.y[j] = splitsMat.splits[i][origY[j]]; // cout << j << ": " << origY[j] << " " << splitsMat.splits[i][origY[j]] << endl; if ( data.y[j] == 1 ) { posEx++; } if ( data.y[j] == -1 ) { negEx++; } alphaVec[j] = 0; outputVec[j] = 0; } int bigEx = (posEx > negEx ? posEx : negEx); int smEx = (posEx < negEx ? posEx : negEx); double smC = (useEqualCs ? bigC : bigC*smEx/(double)bigEx ); double posC = (posEx > negEx ? smC : bigC); double negC = (posEx > negEx ? bigC : smC); for (j = 0; j < data.size; j++) { if (data.y[j] == 1) { cVec[j] = posC; } else if (data.y[j] == -1) { cVec[j] = negC; } else { cVec[j] = 0; // Set "dummy" y value to get cached values to compute data.y[j] = -1; } } if (verbosity > 0) { cout << "posC = " << posC << ", negC = " << negC << endl; } //Build SVM SvmLargeOpt<DataPoint<DataElt>, KernVal> svm(data.size, data.y, data.points, kernFunctions.pfunc, cVec, alphaVec, outputVec, 0.0, chunkSize, kC, tol, eps, verbosity); if (kernType == linear) { svm.useLinearKernel(data.dim, kernFunctions.afunc, kernFunctions.mfunc); } //Train SVM if (verbosity > 0) { cout << "Pos Examples: " << posEx << ", Neg Examples: " << negEx << endl; } cout << "Training SVM " << i+1 << " of " << splitsMat.numSplits << "..." << endl; svm.optimize(); if (useSvmFilesP) { char svmFile[255]; sprintf(svmFile, "%s.%i", svmFileBase, i); fileIO.saveSvm(&svm, svmFile, true); } numSupVecs[i] = svm.getNumSupVecs(); supVecs[i] = svm.getSupVecIDs(); SValphas[i] = svm.getSupVecAlphas(); Bs[i] = svm.getB(); // Compute leave one out values for this svm if (computeLOO) { cout << "Computing leave one out values..." << endl; LOOvals[i] = new double[data.size]; //loo vals for this svm svm.computeLeaveOneOutErrors(LOOvals[i]); } } delete kC; // Perform Testing, if necessary. if (testFileSetP) { DataSet<DataElt> testData = fileIO.readDataSet(testFile); KernVal *testKerns; ifstream *testKernFile; if (useKernelTestP) { testKernFile = new ifstream(kernelTestFile); testKerns = new KernVal[data.size]; int dummy; *testKernFile >> dummy >> dummy; } // For each test point, compute the output from all SVMs simultaneously // so we can save on kernel products. ofstream to(testOutputFile); if (!to.good()) { cerr << "ERROR: Can't open " << testOutputFile << " for writing. Death." << endl; exit(-1); } bool *prodComped = new bool[data.size]; double *prod = new double[data.size]; for (i = 0; i < testData.size; i++) { to << testData.y[i] << " "; if (useKernelTestP) { for (j = 0; j < data.size; j++){ *testKernFile >> testKerns[j]; } } for (j = 0; j < data.size; j++) { prodComped[j] = false; prod[j] = 0; } for (j = 0; j < numSplits; j++) { double fv = Bs[j]; for (k = 0; k < numSupVecs[j]; k++) { int id = supVecs[j][k]; if (!prodComped[id]) { if (useKernelTestP) { prod[id] = testKerns[id]; } else { prod[id] = kernFunctions.pfunc(data.points[id], testData.points[i]); } prodComped[id] = true; } fv += (splitsMat.splits[j][origY[id]]*SValphas[j][k]*prod[id]); } to << fv << " "; } to << endl; } to.close(); if (useKernelTestP) { delete testKernFile; } delete[] prodComped; delete[] prod; } //write leave one out vals if (computeLOO) { ofstream looout(looFile); if (!looout.good()) { cerr << "ERROR: Can't open " << looFile << " for writing." << endl; exit(-1); } for (i=0; i < data.size; i++) { looout << origY[i] << " "; for (j=0; j < numSplits; j++) { looout << LOOvals[j][i] << " "; } looout << endl; } looout.close(); } return 0; }};int main(int argc, char **argv){ set_new_handler(&out_of_store); parseArgs(argc, argv);#define IterateTypes(datatype,kerntype) \ if(g_eltType==type_##datatype && g_kernType==type_##kerntype) { \ templateMain<datatype,kerntype> m; \ return m.main(); \ }#include "SvmFuSvmTypes.h"}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -