📄 learn.cpp
字号:
#include <stdlib.h>#include <string.h>#include <fstream.h>#include <math.h>#include <float.h>#include "globals.h"#include "example_set.h"#include "svm_c.h"#include "parameters.h"#include "kernel.h"#include "svm_nu.h"#include "version.h"#include <time.h>#include <sys/types.h>#include <unistd.h>// global svm-objectskernel_c* kernel=0;parameters_c* parameters=0;svm_c* svm;example_set_c* training_set=0;int is_linear=1; // linear kernel?struct example_set_list{ example_set_c* the_set; example_set_list* next;};example_set_list* test_sets = 0;void print_help(){ cout<<endl; cout<<"my_svm: train a svm from the given parameters and examples."<<endl<<endl; cout<<"usage: my_svm"<<endl <<" my_svm <FILE>"<<endl <<" my_svm <FILE1> <FILE2> ..."<<endl<<endl; cout<<"The input has to consist of:"<<endl <<"- the svm parameters"<<endl <<"- the kernel definition"<<endl <<"- the training set"<<endl <<"- one or more test sets (optional)"<<endl; cout<<endl<<"See the documentation for the input format. The first example set to be entered is considered to be the training set, all others are test sets. Each input file can consist of one or more definitions. If no input file is specified, the input is read from <stdin>."<<endl<<endl; cout<<endl<<"This software is free only for non-commercial use. It must not be modified and distributed without prior permission of the author. The author is not responsible for implications from the use of this software."<<endl; exit(0);};void read_input(istream& input_stream, char* filename){ // returns number of examples sets read char* s = new char[MAXCHAR]; char next; next = input_stream.peek(); if(next == EOF){ // set stream to eof next = input_stream.get(); }; while(! input_stream.eof()){ if('#' == next){ // ignore comment input_stream.getline(s,MAXCHAR); } else if(('\n' == next) || (' ' == next) || ('\r' == next) || ('\f' == next) || ('\t' == next)){ // ignore next = input_stream.get(); } else if('@' == next){ // new section input_stream >> s; if(0 == strcmp("@parameters",s)){ // read parameters if(parameters == 0){ parameters = new parameters_c(); input_stream >> *parameters; } else{ cout <<"*** ERROR: Parameters multiply defined"<<endl; throw input_exception(); }; } else if(0==strcmp("@examples",s)){ if(0 == training_set){ // input training set training_set = new example_set_c(); if(0 != parameters){ training_set->set_format(parameters->default_example_format); }; input_stream >> *training_set; training_set->set_filename(filename); cout<<" read "<<training_set->size()<<" examples, format "<<training_set->my_format<<", dimension = "<<training_set->get_dim()<<"."<<endl; } else{ // input test sets example_set_list* test_set = new example_set_list; test_set->the_set = new example_set_c(); if(0 != parameters){ (test_set->the_set)->set_format(parameters->default_example_format); }; input_stream >> *(test_set->the_set); (test_set->the_set)->set_filename(filename); test_set->next = test_sets; test_sets = test_set; cout<<" read "<<(test_set->the_set)->size()<<" examples, format "<<(test_set->the_set)->my_format<<", dimension = "<<(test_set->the_set)->get_dim()<<"."<<endl; }; } else if(0==strcmp("@kernel",s)){ if(0 == kernel){ kernel_container_c k_cont; input_stream >> k_cont; kernel = k_cont.get_kernel(); is_linear = k_cont.is_linear; } else{ cout <<"*** ERROR: Kernel multiply defined"<<endl; throw input_exception(); }; }; } else{ // default = "@examples" if(0 == training_set){ // input training set training_set = new example_set_c(); if(0 != parameters){ training_set->set_format(parameters->default_example_format); }; input_stream >> *training_set; training_set->set_filename(filename); cout<<" read "<<training_set->size()<<" examples, format "<<training_set->my_format<<", dimension = "<<training_set->get_dim()<<"."<<endl; } else{ // input test sets example_set_list* test_set = new example_set_list; test_set->the_set = new example_set_c(); if(0 != parameters){ (test_set->the_set)->set_format(parameters->default_example_format); }; input_stream >> *(test_set->the_set); (test_set->the_set)->set_filename(filename); test_set->next = test_sets; test_sets = test_set; cout<<" read "<<(test_set->the_set)->size()<<" examples, format "<<(test_set->the_set)->my_format<<", dimension = "<<(test_set->the_set)->get_dim()<<"."<<endl; }; }; next = input_stream.peek(); if(next == EOF){ // set stream to eof next = input_stream.get(); }; }; delete []s;};svm_result do_cv(){ SVMINT number = parameters->cross_validation; SVMINT size = training_set->size(); int verbosity = parameters->verbosity; if((number > size) || (0 >= number)){ number = size; // leave-one-out testing }; // SVMINT cv_size = size / number; if(! parameters->cv_inorder){ training_set->permute(); }; training_set->clear_alpha(); example_set_c* cv_train=new example_set_c(); //=0; example_set_c* cv_test=new example_set_c(); //=0; svm_result train_result; svm_result test_result; svm_result train_sum; svm_result test_sum; train_sum.VCdim = 0; train_sum.pred_loss=0; train_sum.loss=0; train_sum.loss_pos=0; train_sum.loss_neg=0; train_sum.MAE = 0; train_sum.MSE = 0; train_sum.accuracy = 0; train_sum.precision = 0; train_sum.recall=0; train_sum.number_svs=0; train_sum.number_bsv=0; test_sum.VCdim = 0; test_sum.loss=0; test_sum.loss_pos=0; test_sum.loss_neg=0; test_sum.MAE = 0; test_sum.MSE = 0; test_sum.accuracy = 0; test_sum.precision = 0; test_sum.recall=0; test_sum.number_svs=0; test_sum.number_bsv=0; SVMINT j; if(verbosity>2){ if(parameters->cv_window>0){ cout<<"beginning "<<(number-parameters->cv_window)<<" sliding window steps"<<endl; } else{ cout<<"beginning "<<number<<"-fold crossvalidation"<<endl; }; }; SVMINT i; for(i=parameters->cv_window;i<number;i++){ // do cv if(verbosity >= 3){ cout<<"----------------------------------------"<<endl; cout<<(i+1); if(0 == i%10) cout<<"st"; else if(1==i%10) cout<<"nd"; else if(2==i%10) cout<<"rd"; else cout<<"th"; cout<<" step"<<endl; }; //cout<<"From "<<i*cv_size<<" to "<<(i+1)*cv_size<<endl; cv_train->clear(); cv_test->clear(); cv_train->set_dim(training_set->get_dim()); cv_test->set_dim(training_set->get_dim()); if(training_set->initialised_y()){ cv_train->set_initialised_y(); cv_test->set_initialised_y(); }; cv_train->put_Exp_Var(training_set->get_exp(),training_set->get_var()); cv_test->put_Exp_Var(training_set->get_exp(),training_set->get_var()); if(verbosity>=4){ cout<<"Initing examples sets"<<endl; }; if(parameters->cv_window>0){ // training window for(j=SVMINT((i-parameters->cv_window)*size/number);((j<(SVMINT)(i*size/number))&&(j<size));j++){ cv_train->put_example(training_set->get_example(j)); }; // test window for(j=(SVMINT)(i*size/number);((j<(SVMINT)((i+1)*size/number))&&(j<size));j++){ cv_test->put_example(training_set->get_example(j)); }; } else{ for(j=(SVMINT)(i*size/number);((j<(SVMINT)((i+1)*size/number))&&(j<size));j++){ cv_test->put_example(training_set->get_example(j)); }; for(j=0;j<(SVMINT)(i*size/number);j++){ cv_train->put_example(training_set->get_example(j)); }; for(j=(SVMINT)((i+1)*size/number);j<size;j++){ cv_train->put_example(training_set->get_example(j)); }; }; cv_train->clear_alpha(); cv_test->clear_alpha(); cv_train->compress(); cv_test->compress(); if(verbosity>=4){ cout<<"Setting up the SVM"<<endl; }; kernel->init(parameters->kernel_cache,cv_train); svm->init(kernel,parameters); // train & test the svm if(verbosity>=4){ cout<<"Training"<<endl; }; // cv_train->clear_alpha(); train_result = svm->train(cv_train); if(verbosity>=4){ cout<<"Testing"<<endl; }; test_result = svm->test(cv_test,0); train_sum.VCdim += train_result.VCdim; train_sum.loss += train_result.loss; train_sum.loss_pos += train_result.loss_pos; train_sum.loss_neg += train_result.loss_neg; train_sum.MAE += train_result.MAE; train_sum.MSE += train_result.MSE; train_sum.pred_loss += train_result.pred_loss; train_sum.accuracy += train_result.accuracy; train_sum.precision += train_result.precision; train_sum.recall += train_result.recall; train_sum.number_svs += train_result.number_svs; train_sum.number_bsv += train_result.number_bsv; test_sum.loss += test_result.loss; test_sum.loss_pos += test_result.loss_pos; test_sum.loss_neg += test_result.loss_neg; test_sum.MAE += test_result.MAE; test_sum.MSE += test_result.MSE; test_sum.accuracy += test_result.accuracy; test_sum.precision += test_result.precision; test_sum.recall += test_result.recall; if(verbosity>=4){ cout<<"Training set:"<<endl <<"Loss: "<<train_result.loss<<endl; if(parameters->Lpos != parameters->Lneg){ cout<<" Loss+: "<<train_result.loss_pos<<endl; cout<<" Loss-: "<<train_result.loss_neg<<endl; }; cout<<"MAE: "<<train_result.MAE<<endl; cout<<"MSE: "<<train_result.MSE<<endl; cout<<"VCdim: "<<train_result.VCdim<<endl; if(parameters->is_pattern){ cout<<"Accuracy : "<<train_result.accuracy<<endl <<"Precision : "<<train_result.precision<<endl <<"Recall : "<<train_result.recall<<endl; }; cout<<"Support Vectors : "<<train_result.number_svs<<endl; cout<<"Bounded SVs : "<<train_result.number_bsv<<endl; cout<<"Test set:"<<endl <<"Loss: "<<test_result.loss<<endl; if(parameters->Lpos != parameters->Lneg){ cout<<" Loss+: "<<test_result.loss_pos<<endl; cout<<" Loss-: "<<test_result.loss_neg<<endl; }; if(parameters->is_pattern){ cout<<"Accuracy : "<<test_result.accuracy<<endl <<"Precision : "<<test_result.precision<<endl <<"Recall : "<<test_result.recall<<endl; }; }; }; parameters->verbosity = verbosity; number-=parameters->cv_window; if(verbosity > 1){ cout<<"----------------------------------------"<<endl; cout<<"Results of "<<number<<"-fold cross-validation:"<<endl; cout<<"-- Training set: --"<<endl <<"Loss: "<<train_sum.loss/number<<endl; if(parameters->Lpos != parameters->Lneg){ cout<<" Loss+: "<<train_sum.loss_pos/number<<endl; cout<<" Loss-: "<<train_sum.loss_neg/number<<endl; }; cout<<"MAE: "<<train_sum.MAE/number<<endl; cout<<"MSE: "<<train_sum.MSE/number<<endl; cout<<"VCdim: "<<train_sum.VCdim/number<<endl; if(parameters->is_pattern){ cout<<"Accuracy : "<<train_sum.accuracy/number<<endl <<"Precision : "<<train_sum.precision/number<<endl <<"Recall : "<<train_sum.recall/number<<endl; }; cout<<"Support Vectors : "<<((SVMFLOAT)train_sum.number_svs)/number<<endl; cout<<"Bounded SVs : "<<((SVMFLOAT)train_sum.number_bsv)/number<<endl; cout<<"-- Test set: --"<<endl <<"Loss: "<<test_sum.loss/number<<endl; if(parameters->Lpos != parameters->Lneg){ cout<<" Loss+: "<<test_sum.loss_pos/number<<endl; cout<<" Loss-: "<<test_sum.loss_neg/number<<endl; }; cout<<"MAE: "<<test_sum.MAE/number<<endl; cout<<"MSE: "<<test_sum.MSE/number<<endl; if(parameters->is_pattern){ cout<<"Accuracy : "<<test_sum.accuracy/number<<endl <<"Precision : "<<test_sum.precision/number<<endl <<"Recall : "<<test_sum.recall/number<<endl; }; }; test_sum.VCdim = train_sum.VCdim/number; test_sum.loss /= number; test_sum.loss_pos /= number; test_sum.loss_neg /= number; test_sum.MAE /= number; test_sum.MSE /= number; test_sum.pred_loss = train_sum.pred_loss; test_sum.accuracy /= number; test_sum.precision /= number; test_sum.recall /= number; test_sum.number_svs = train_sum.number_svs/number; test_sum.number_bsv = train_sum.number_bsv/number; delete cv_test; delete cv_train; return test_sum;};svm_result train(){ svm_result the_result; if(parameters->cross_validation > 0){ the_result = do_cv();
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -