📄 svmfutest.cpp
字号:
// Copyright (C) 2000 Ryan M. Rifkin <rif@mit.edu>// // This program is free software; you can redistribute it and/or// modify it under the terms of the GNU General Public License as// published by the Free Software Foundation; either version 2 of the// License, or (at your option) any later version.// // This program is distributed in the hope that it will be useful, but// WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU// General Public License for more details.// // You should have received a copy of the GNU General Public License// along with this program; if not, write to the Free Software// Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA// 02111-1307, USA.// #include "clientincludes.h"// Prototypesvoid out_of_store();void parseArgs(int argc, char **argv);// const int ptSizeFunc(DataPt);// Globalschar *testingFile=NULL;char *svmFile=NULL;bool printPointByPointResultsP = false; bool useAsciiP = false;void out_of_store() { cerr << "op new failed: out of store.\n"; exit(1);}// long optionsstruct option optionNames[] ={ {"test-file", required_argument, 0, 'f'}, {"svm-file", required_argument, 0, 's'}, {"point-by-point", no_argument, 0, 'p'}, {"version", no_argument, 0, 'v'}, {"help", no_argument, 0, '?'}, {0, 0, 0, 0}};void usage(char *basename) { cerr<<"usage: " << basename << " -f testfile -s svmfile [-p]" << endl; cerr<<" -f, --test-file=filename file with testing points\n"; cerr<<" -s, --svm-file=filename save file from svmfutrain\n"; cerr<<" -p, --point-by-point display point-by-point results\n"; cerr<<" -v, --version print SvmFu version ("<<VERSION<<")\n"; cerr<<" -?, --help this message\n"; exit(0);}void parseArgs(int argc, char **argv) { char c; int opti=0; while ((c = getopt_long(argc, argv, "f:s:pv?",optionNames, &opti))!=-1) { switch(c) { case 'f': if(testingFile) free(testingFile); testingFile=strdup(optarg); break; case 's': if(svmFile) free(svmFile); svmFile=strdup(optarg); break; case 'p': printPointByPointResultsP = true; break; case 'v': cout << "svmfu " << VERSION << endl; exit(0); case '?': default: usage(*argv); } } if (!testingFile || !svmFile) { cout << "Error: testing-file and svm-file are required arguments." << endl; exit(1); } // cout << "\ttestingFile: " << testingFile << endl; // cout << "\tsvmFile: " << svmFile << endl; // Now things start to get fun // Read parameters from the SVM file ifstream in(svmFile); if(!in.good()) { cout << "Error: can't open " << svmFile << " for reading." << endl; exit(1); } char line[1024]; in.getline(line,1024); if(strcmp(line,SVM_HEADER)!=0) { cout << "Error: " << svmFile << " is not a valid SvmFu save." << endl; exit(1); } char *arg; char *tmp_datatype=strdup("double"); char *tmp_kerntype=strdup("double"); int datasize=-1, kernsize=-1; int linecount=1; for(;;) { in.getline(line,1024); linecount++; if(in.eof() || strcmp(line,"data")==0) break; if(strncmp(line,"machtype ",9)==0) { arg=line+9; if(strcmp(arg,"sparseN")==0) machType=sparseN; else if(strcmp(arg,"sparse01")==0) machType=sparse01; else if(strcmp(arg,"dense")==0) machType=dense; else { cout << "Error: unknown data type " << arg << endl; exit(1); } continue; } if(strncmp(line,"datatype ",9)==0) { arg=line+9; free(tmp_datatype); tmp_datatype=strdup(arg); continue; } if(strncmp(line,"kerntype ",9)==0) { arg=line+9; free(tmp_kerntype); tmp_kerntype=strdup(arg); continue; } if(strncmp(line,"kernfunc ",9)==0) { arg=line+9; if(strcmp(arg,"linear")==0) { kernType=linear; } else if(strcmp(arg,"polynomial")==0) { kernType=polynomial; } else if(strcmp(arg,"gaussian")==0) { kernType=gaussian; } else { cout << "Error: unknown kernel function " << arg << endl; exit(1); } continue; } if(strncmp(line,"degree ",7)==0) { degree=atoi(line+7); continue; } if(strncmp(line,"bias ",5)==0) { offset=atof(line+5); continue; } if(strncmp(line,"sigma ",6)==0) { sigma=atof(line+6); continue; } if(strncmp(line,"normalizer ",11)==0) { normalizer=atof(line+11); continue; } if(strncmp(line,"format ",7)==0) { arg=line+7; if(strcmp(arg,"ascii")==0) { useAsciiP=true; } else if(strcmp(arg,"binary")==0) { useAsciiP=false; } else { cout << "Error: unknown format " << arg << endl; exit(1); } continue; } if(strncmp(line,"endian ",7)==0) { arg=line+7; int test=0x11223344; bool big=false; if(*(char *)&test==0x11) big=true; if(strcmp(arg,"big")==0 && big) continue; if(strcmp(arg,"little")==0 && !big) continue; cout << "Error: endian-ness '" << arg << "' doesn't match this machine" << endl; exit(1); } if(strncmp(line,"int ",4)==0) { if(atoi(line+4)==sizeof(int)) continue; cout << "Error: size of int doesn't match this machine" << endl; exit(1); } if(strncmp(line,"double ",7)==0) { if(atoi(line+7)==sizeof(double)) continue; cout << "Error: size of int doesn't match this machine" << endl; exit(1); } if(strncmp(line,"datasize ",9)==0) { datasize=atoi(line+9); continue; } if(strncmp(line,"kernsize ",9)==0) { kernsize=atoi(line+9); continue; } cout << "Error: unrecognized line " << linecount << " in svm save file:" << endl << line << endl; exit(1); } if(strcmp(line,"data")!=0) { cout << "Error: svm save file ended with no data at line " << linecount << endl; exit(1); } // Iterate over the known types and figure out which one they wanted#define IterateTypes(datatype,kerntype) \ if((strcmp(tmp_datatype,#datatype)==0) \ && (strcmp(tmp_kerntype,#kerntype)==0)) { \ if(datasize>0 && sizeof(datatype)!=datasize) goto badSize; \ if(kernsize>0 && sizeof(kerntype)!=kernsize) goto badSize; \ g_eltType = type_##datatype; \ g_kernType = type_##kerntype; \ goto gotType; \ }#include "SvmFuSvmTypes.h" cout << "Error: dataElt type '" << tmp_datatype; cout << "' and kernVal type '" << tmp_kerntype << "' are not" << endl; cout << "supported by this compilation of SvmFu. Supported types:" << endl;#define IterateTypes(datatype,kerntype) \ cout << " dataElt " << #datatype << ", kernVal " << #kerntype << endl;#include "SvmFuSvmTypes.h" exit(1);badSize: cout << "Error: kernVal and/or dataElt sizes do not match the" << endl; cout << "machine that generated this svm save file." << endl; exit(1);gotType: // cout << "Machine Type: "; switch (machType) { case dense: //<< "Dense" << endl; break; case sparse01: //<< "Sparse01" << endl; break; case sparseN: //<< "SparseN" << endl; break; default: //<< "Unknown machine type. Aborting." << endl; exit(1); } //<< "Kernel Type: "; switch (kernType) { case linear: //<< "Linear" << endl; break; case polynomial: //<< "Polynomial, Degree=" << degree << ", Offset=" << offset << endl; break; case gaussian: //<< "Gaussian, Sigma=" << sigma << endl; break; default: //<< "Unknown Kernel Type, Aborting." << endl; exit(1); } //<< "Kernel Value Type: " << tmp_kerntype << endl; //<< "Data Element Value Type: " << tmp_datatype << endl; //<< "Kernel Normalization Term: " << normalizer << endl; free(tmp_kerntype); free(tmp_datatype);}//! Templatized wrapper for main() functiontemplate<class DataElt, class KernVal> class templateMain { public: static int main () { ifstream ifs(testingFile); if (!ifs) { cerr << "Error: Cannot open file " << testingFile << " for reading." << endl; exit(-1); } int *y; int trainingSetSize; FileIO<DataElt, KernVal> fileIO; TrainedSvm<DataElt> svm = fileIO.loadSvm(svmFile, useAsciiP); KernelFuncs<KernVal, DataElt> kernFunctions(kernType, machType); SvmTest<DataPoint<DataElt>, KernVal> testSvm(svm.size, svm.alphas, svm.svs, svm.b, svm.y, kernFunctions.pfunc); DataSet<DataElt> testingSet = fileIO.readDataSet(testingFile); int correctlyClassifiedPoints = 0; for (int i = 0; i < testingSet.size; i++) { int yt = testSvm.classifyDataPt(testingSet[i]); if (yt == testingSet.y[i]) correctlyClassifiedPoints++; if (printPointByPointResultsP) { // cout << i << ":" << yt << "(" << testSvm.outputAtDataPt(testingSet[i]) << "), TC= " << testingSet.y[i] << " " << endl; cout << testSvm.outputAtDataPt(testingSet[i]) << " " << testingSet.y[i] << endl; } else if (((i+1)%500)==0) { // cout << i+1 << ": " << correctlyClassifiedPoints << endl; } } cout << "Correctly Classified Points: " << correctlyClassifiedPoints << " out of " << testingSet.size << " (" << correctlyClassifiedPoints/(double)testingSet.size << ")" << endl; // cleanup for (int i = 0; i < testingSet.size; i++) { if(machType==sparse01 || machType==sparseN) delete testingSet.points[i].index; if(machType==dense || machType==sparseN) delete testingSet.points[i].value; } delete testingSet.points; delete testingSet.y; return 0; // success }};/*! \todo attempt to autodetect data types from input file */int main(int argc, char **argv){ set_new_handler(&out_of_store); parseArgs(argc, argv);#define IterateTypes(datatype,kerntype) \ if(g_eltType==type_##datatype && g_kernType==type_##kerntype) { \ templateMain<datatype,kerntype> m; \ return m.main(); \ }#include "SvmFuSvmTypes.h" free(testingFile); free(svmFile);}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -