📄 learn.cpp
字号:
} else{ kernel->init(parameters->kernel_cache,training_set); svm->init(kernel,parameters); if(parameters->is_nu || parameters->is_distribution){ cout<<"Training started with nu = " <<parameters->nu <<"."<<endl; } else if(parameters->get_Cpos() == parameters->get_Cneg()){ cout<<"Training started with C = " <<parameters->get_Cpos() <<"."<<endl; } else{ cout<<"Training started with C = ("<<parameters->get_Cpos() <<","<<parameters->get_Cneg()<<")."<<endl; }; the_result = svm->train(training_set); }; return the_result;};inlineSVMFLOAT to_minimize(svm_result result){ // which value to minimize in calc_c if((parameters->cross_validation <= 0) && (1 == parameters->is_pattern)){ return (1.0-result.pred_accuracy); } else{ if(1 == parameters->is_pattern){ return (1.0-result.accuracy); } else{ return result.loss; }; };};svm_result calc_c(){ const SVMFLOAT lambda = 0.618033989; // (sqrt(5)-1)/2 SVMINT verbosity = parameters->verbosity; parameters->verbosity -= 2; svm_result the_result; SVMFLOAT c_min = parameters->c_min; SVMFLOAT c_max = parameters->c_max; SVMFLOAT c_delta = parameters->c_delta; SVMFLOAT oldC; SVMINT last_dec=0; // when did loss decrease? // setup s,t if(verbosity >= 3){ cout<<"starting search for C"<<endl; }; if((parameters->search_c == 'a') ||(parameters->search_c == 'm')){ SVMFLOAT minimal_value=infinity; SVMFLOAT minimal_C=c_min; svm_result minimal_result; SVMFLOAT result_value; oldC=c_min; training_set->clear_alpha(); while(c_min <= c_max){ if(verbosity>=3){ cout<<"C = "<<c_min<<" :"<<endl; }; parameters->realC = c_min; training_set->scale_alphas(c_min/oldC); // training_set->clear_alpha(); oldC = c_min; if(verbosity >= 4){ cout<<"C = "<<c_min<<endl; } the_result = train(); if(verbosity>=3){ cout<<"loss = "<<the_result.loss<<endl; if(parameters->is_pattern){ cout<<"accuracy = "<<the_result.accuracy<<endl; // cout<<"predicted loss = "<<the_result.pred_loss<<endl; cout<<"predicted accuracy = "<<the_result.pred_accuracy<<endl; }; // cout<<"VCdim <= "<<the_result.VCdim<<endl; }; result_value = to_minimize(the_result); // cout<<result_value<<endl; last_dec++; if((result_value<minimal_value) && (! isnan(result_value))){ minimal_value=result_value; minimal_C=c_min; minimal_result = the_result; last_dec=0; }; if(parameters->search_c == 'a'){ c_min += c_delta; } else{ c_min *= c_delta; }; if((parameters->search_stop > 0) && (last_dec >= parameters->search_stop)){ // no decrease in loss, stop c_min = 2*c_max; }; }; parameters->realC = minimal_C; the_result = minimal_result; } else{ // method of golden ratio SVMFLOAT s = lambda*c_min+(1-lambda)*c_max; SVMFLOAT t = (1-lambda)*c_min+lambda*c_max; SVMFLOAT phi_s; SVMFLOAT phi_t; parameters->realC = s; training_set->clear_alpha(); the_result = train(); phi_s = to_minimize(the_result); parameters->realC = t; training_set->scale_alphas(t/s); oldC = t; the_result = train(); phi_t = to_minimize(the_result); while(c_max - c_min > c_delta*c_min){ if(verbosity >= 3){ cout<<"C in ["<<c_min<<","<<c_max<<"]"<<endl; }; if(phi_s < phi_t){ c_max = t; t = s; phi_t = phi_s; // calc s s = lambda*c_min+(1-lambda)*c_max; parameters->realC = s; training_set->scale_alphas(s/oldC); oldC=s; the_result = train(); phi_s = to_minimize(the_result); } else{ c_min = s; s = t; phi_s = phi_t; // calc t t = (1-lambda)*c_min+lambda*c_max; parameters->realC = t; training_set->scale_alphas(t/oldC); oldC=t; the_result = train(); phi_t = to_minimize(the_result); }; }; // save last results if(phi_s < phi_t){ c_max = t; } else{ c_min = s; }; parameters->realC = (c_min+c_max)/2; }; // ouput result if(verbosity >= 1){ cout<<"*** Optimal C is "<<parameters->realC; if(parameters->search_c == 'g'){ cout<<" +/-"<<((c_max-c_min)/2); }; cout<<endl; }; if(verbosity>=2){ cout<<"result:"<<endl <<"Loss: "<<the_result.loss<<endl; if(parameters->Lpos != parameters->Lneg){ cout<<" Loss+: "<<the_result.loss_pos<<endl; cout<<" Loss-: "<<the_result.loss_neg<<endl; }; if(parameters->is_pattern){ cout<<"predicted Loss: "<<the_result.pred_loss<<endl; }; cout<<"MAE: "<<the_result.MAE<<endl; cout<<"MSE: "<<the_result.MSE<<endl; cout<<"VCdim <= "<<the_result.VCdim<<endl; if(parameters->is_pattern){ cout<<"Accuracy : "<<the_result.accuracy<<endl <<"Precision : "<<the_result.precision<<endl <<"Recall : "<<the_result.recall<<endl; if(parameters->cross_validation == 0){ cout<<"predicted Accuracy : "<<the_result.pred_accuracy<<endl <<"predicted Precision : "<<the_result.pred_precision<<endl <<"predicted Recall : "<<the_result.pred_recall<<endl; }; }; cout<<"Support Vectors : "<<the_result.number_svs<<endl; cout<<"Bounded SVs : "<<the_result.number_bsv<<endl; if(parameters->search_c == 'g'){ cout<<"(WARNING: this is the last result attained and may slightly differ from the result of the optimal C!)"<<endl; }; }; parameters->verbosity = verbosity; return the_result;};///////////////////////////////////////////////////////////////int main(int argc,char* argv[]){ cout<<"*** mySVM version "<<mysvmversion<<" ***"<<endl; cout.precision(8); // read objects try{ if(argc<2){ cout<<"Reading from STDIN"<<endl; // read vom cin read_input(cin,"mysvm"); } else{ char* s = argv[1]; if((0 == strcmp("-h",s)) || (0==strcmp("-help",s)) || (0==strcmp("--help",s))){ // print out command-line help print_help(); } else{ // read in all input files for(int i=1;i<argc;i++){ if(0 == strcmp(argv[i],"-")){ cout<<"Reading from STDIN"<<endl; // read vom cin read_input(cin,"mysvm"); } else{ cout<<"Reading "<<argv[i]<<endl; ifstream input_file(argv[i]); if(input_file.bad()){ cout<<"ERROR: Could not read file \""<<argv[i]<<"\", exiting."<<endl; exit(1); }; read_input(input_file,argv[i]); input_file.close(); }; }; }; }; } catch(general_exception &the_ex){ cout<<"*** Error while reading input: "<<the_ex.error_msg<<endl; exit(1); } catch(...){ cout<<"*** Program ended because of unknown error while reading input"<<endl; exit(1); }; if(0 == parameters){ parameters = new parameters_c(); if(training_set->initialised_pattern_y()){ parameters->is_pattern = 1; parameters->do_scale_y = 0; }; }; parameters->is_linear = is_linear; if(0 == kernel){ kernel = new kernel_dot_c(); }; if(0 == training_set){ cout << "*** ERROR: You did not enter the training set"<<endl; exit(1); }; if(2 > training_set->size()){ cout << "*** ERROR: Need at least two examples to learn."<<endl; exit(1); }; if(parameters->is_distribution){ svm = new svm_distribution_c(); cout<<"distribution estimation SVM generated"<<endl; } else if(parameters->is_nu){ if(parameters->is_pattern){ svm = new svm_nu_pattern_c(); if(! training_set->initialised_pattern_y()){ cout<<"WARNING: Parameters set a pattern SVM, but the training ys are not in {-1,1}."<<endl; } cout<<"nu-PSVM generated"<<endl; } else{ svm = new svm_nu_regression_c(); cout<<"nu-RSVM generated"<<endl; }; } else if(parameters->is_pattern){ svm = new svm_pattern_c(); if(! training_set->initialised_pattern_y()){ cout<<"WARNING: Parameters set a pattern SVM, but the training ys are not in {-1,1}."<<endl; } cout<<"PSVM generated"<<endl; } else{ svm = new svm_regression_c(); cout<<"RSVM generated"<<endl; }; // scale examples if(parameters->do_scale){ training_set->scale(parameters->do_scale_y); }; // training the svm if(parameters->search_c != 'n'){ calc_c(); cout<<"re-training without CV and C = "<<parameters->realC<<endl; parameters->cross_validation = 0; parameters->verbosity -= 1; train(); parameters->verbosity += 1; } else{ train(); }; if(0 == parameters->cross_validation){ // save results if(parameters->verbosity > 1){ cout<<"Saving trained SVM to "<<(training_set->get_filename())<<".svm"<<endl; }; char* outname = new char[MAXCHAR]; strcpy(outname,training_set->get_filename()); strcat(outname,".svm"); ofstream output_file(outname,ios::out|ios::trunc); output_file.precision(16); output_file<<*training_set; output_file.close(); delete []outname; }; // testing if((parameters->cross_validation > 0) && (0 != test_sets)){ // test result of cross validation: train new SVM on whole example set parameters->cross_validation = 0; cout<<"Re-training SVM on whole example set for testing"<<endl; train(); }; if(0 != test_sets){ cout<<"----------------------------------------"<<endl; cout<<"Starting tests"<<endl; example_set_c* next_test; SVMINT test_no = 0; char* outname = new char[MAXCHAR]; while(test_sets != 0){ test_no++; next_test = test_sets->the_set; if(parameters->do_scale){ next_test->scale(training_set->get_exp(), training_set->get_var(), training_set->get_dim()); }; if(next_test->initialised_y()){ cout<<"Testing examples from file "<<(next_test->get_filename())<<endl; svm->test(next_test,1); } else{ cout<<"Predicting examples from file "<<(next_test->get_filename())<<endl; svm->predict(next_test); // output to file .pred strcpy(outname,next_test->get_filename()); strcat(outname,".pred"); ofstream output_file(outname,ios::out); output_file<<"@examples"<<endl; output_file<<(*next_test); output_file.close(); }; test_sets = test_sets->next; // skip delete! }; delete []outname; }; if(kernel) delete kernel; delete svm; if(parameters->verbosity > 1){ cout << "mysvm ended successfully."<<endl; }; return(0);};
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -