⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 learn.cpp

📁 cvm的c语言版
💻 CPP
📖 第 1 页 / 共 2 页
字号:
  }  else{    kernel->init(parameters->kernel_cache,training_set);    svm->init(kernel,parameters);    if(parameters->is_nu || parameters->is_distribution){      cout<<"Training started with nu = "	  <<parameters->nu	  <<"."<<endl;    }    else if(parameters->get_Cpos() == parameters->get_Cneg()){      cout<<"Training started with C = "	  <<parameters->get_Cpos()	  <<"."<<endl;    }    else{      cout<<"Training started with C = ("<<parameters->get_Cpos()	  <<","<<parameters->get_Cneg()<<")."<<endl;    };    the_result = svm->train(training_set);  };  return the_result;};inlineSVMFLOAT to_minimize(svm_result result){  // which value to minimize in calc_c  if((parameters->cross_validation <= 0) && (1 == parameters->is_pattern)){    return (1.0-result.pred_accuracy);  }  else{    if(1 == parameters->is_pattern){      return (1.0-result.accuracy);    }    else{      return result.loss;    };  };};svm_result calc_c(){  const SVMFLOAT lambda = 0.618033989; // (sqrt(5)-1)/2  SVMINT verbosity = parameters->verbosity;  parameters->verbosity -= 2;  svm_result the_result;  SVMFLOAT c_min = parameters->c_min;  SVMFLOAT c_max = parameters->c_max;  SVMFLOAT c_delta = parameters->c_delta;  SVMFLOAT oldC;  SVMINT last_dec=0; // when did loss decrease?  // setup s,t  if(verbosity >= 3){    cout<<"starting search for C"<<endl;  };  if((parameters->search_c == 'a') ||(parameters->search_c == 'm')){    SVMFLOAT minimal_value=infinity;    SVMFLOAT minimal_C=c_min;    svm_result minimal_result;    SVMFLOAT result_value;    oldC=c_min;    training_set->clear_alpha();    while(c_min <= c_max){      if(verbosity>=3){	cout<<"C = "<<c_min<<" :"<<endl;      };      parameters->realC = c_min;      training_set->scale_alphas(c_min/oldC);      // training_set->clear_alpha();      oldC = c_min;      if(verbosity >= 4){	cout<<"C = "<<c_min<<endl;      }      the_result = train();      if(verbosity>=3){	cout<<"loss = "<<the_result.loss<<endl;	if(parameters->is_pattern){	  cout<<"accuracy = "<<the_result.accuracy<<endl;	  // cout<<"predicted loss = "<<the_result.pred_loss<<endl;	  cout<<"predicted accuracy = "<<the_result.pred_accuracy<<endl;	};	//	cout<<"VCdim <= "<<the_result.VCdim<<endl;      };      result_value = to_minimize(the_result);      //      cout<<result_value<<endl;      last_dec++;      if((result_value<minimal_value) && (! isnan(result_value))){	minimal_value=result_value;	minimal_C=c_min;	minimal_result = the_result;	last_dec=0;      };      if(parameters->search_c == 'a'){	c_min += c_delta;      }      else{	c_min *= c_delta;      };      if((parameters->search_stop > 0) && (last_dec >= parameters->search_stop)){	// no decrease in loss, stop	c_min = 2*c_max;      };    };    parameters->realC = minimal_C;    the_result = minimal_result;  }  else{    // method of golden ratio        SVMFLOAT s = lambda*c_min+(1-lambda)*c_max;    SVMFLOAT t = (1-lambda)*c_min+lambda*c_max;    SVMFLOAT phi_s;    SVMFLOAT phi_t;        parameters->realC = s;    training_set->clear_alpha();    the_result = train();    phi_s = to_minimize(the_result);    parameters->realC = t;    training_set->scale_alphas(t/s);    oldC = t;    the_result = train();    phi_t = to_minimize(the_result);    while(c_max - c_min > c_delta*c_min){      if(verbosity >= 3){	cout<<"C in ["<<c_min<<","<<c_max<<"]"<<endl;      };      if(phi_s < phi_t){	c_max = t;	t = s;	phi_t = phi_s;	// calc s	s = lambda*c_min+(1-lambda)*c_max;	parameters->realC = s;	training_set->scale_alphas(s/oldC);	oldC=s;	the_result = train();	phi_s = to_minimize(the_result);      }      else{	c_min = s;	s = t;	phi_s = phi_t;	// calc t	t = (1-lambda)*c_min+lambda*c_max;	parameters->realC = t;	training_set->scale_alphas(t/oldC);	oldC=t;	the_result = train();	phi_t = to_minimize(the_result);      };    };    // save last results    if(phi_s < phi_t){      c_max = t;    }    else{      c_min = s;    };    parameters->realC = (c_min+c_max)/2;  };  // ouput result  if(verbosity >= 1){    cout<<"*** Optimal C is "<<parameters->realC;    if(parameters->search_c == 'g'){      cout<<" +/-"<<((c_max-c_min)/2);    };    cout<<endl;  };  if(verbosity>=2){    cout<<"result:"<<endl	<<"Loss: "<<the_result.loss<<endl;    if(parameters->Lpos != parameters->Lneg){      cout<<"  Loss+: "<<the_result.loss_pos<<endl;      cout<<"  Loss-: "<<the_result.loss_neg<<endl;    };    if(parameters->is_pattern){      cout<<"predicted Loss: "<<the_result.pred_loss<<endl;    };    cout<<"MAE: "<<the_result.MAE<<endl;    cout<<"MSE: "<<the_result.MSE<<endl;    cout<<"VCdim <= "<<the_result.VCdim<<endl;    if(parameters->is_pattern){      cout<<"Accuracy  : "<<the_result.accuracy<<endl	  <<"Precision : "<<the_result.precision<<endl	  <<"Recall    : "<<the_result.recall<<endl;      if(parameters->cross_validation == 0){	cout<<"predicted Accuracy  : "<<the_result.pred_accuracy<<endl	    <<"predicted Precision : "<<the_result.pred_precision<<endl	    <<"predicted Recall    : "<<the_result.pred_recall<<endl;      };    };    cout<<"Support Vectors : "<<the_result.number_svs<<endl;    cout<<"Bounded SVs     : "<<the_result.number_bsv<<endl;    if(parameters->search_c == 'g'){      cout<<"(WARNING: this is the last result attained and may slightly differ from the result of the optimal C!)"<<endl;    };  };  parameters->verbosity = verbosity;  return the_result;};///////////////////////////////////////////////////////////////int main(int argc,char* argv[]){  cout<<"*** mySVM version "<<mysvmversion<<" ***"<<endl;  cout.precision(8);  // read objects  try{    if(argc<2){      cout<<"Reading from STDIN"<<endl;      // read vom cin      read_input(cin,"mysvm");    }    else{      char* s = argv[1];      if((0 == strcmp("-h",s)) || (0==strcmp("-help",s)) || (0==strcmp("--help",s))){	// print out command-line help	print_help();      }      else{	// read in all input files	for(int i=1;i<argc;i++){	  if(0 == strcmp(argv[i],"-")){	    cout<<"Reading from STDIN"<<endl;	    // read vom cin	    read_input(cin,"mysvm");	  }	  else{	    cout<<"Reading "<<argv[i]<<endl;	    ifstream input_file(argv[i]);	    if(input_file.bad()){	      cout<<"ERROR: Could not read file \""<<argv[i]<<"\", exiting."<<endl;	      exit(1);	    };	    read_input(input_file,argv[i]);	    input_file.close();	  };	};      };    };  }  catch(general_exception &the_ex){    cout<<"*** Error while reading input: "<<the_ex.error_msg<<endl;    exit(1);  }  catch(...){    cout<<"*** Program ended because of unknown error while reading input"<<endl;    exit(1);  };  if(0 == parameters){    parameters = new parameters_c();    if(training_set->initialised_pattern_y()){      parameters->is_pattern = 1;      parameters->do_scale_y = 0;    };  };  parameters->is_linear = is_linear;  if(0 == kernel){    kernel = new kernel_dot_c();  };  if(0 == training_set){    cout << "*** ERROR: You did not enter the training set"<<endl;    exit(1);  };  if(2 > training_set->size()){    cout << "*** ERROR: Need at least two examples to learn."<<endl;    exit(1);  };  if(parameters->is_distribution){    svm = new svm_distribution_c();    cout<<"distribution estimation SVM generated"<<endl;  }  else if(parameters->is_nu){    if(parameters->is_pattern){      svm = new svm_nu_pattern_c();      if(! training_set->initialised_pattern_y()){	cout<<"WARNING: Parameters set a pattern SVM, but the training ys are not in {-1,1}."<<endl;      }      cout<<"nu-PSVM generated"<<endl;    }    else{      svm = new svm_nu_regression_c();      cout<<"nu-RSVM generated"<<endl;    };  }  else if(parameters->is_pattern){    svm = new svm_pattern_c();    if(! training_set->initialised_pattern_y()){      cout<<"WARNING: Parameters set a pattern SVM, but the training ys are not in {-1,1}."<<endl;    }    cout<<"PSVM generated"<<endl;  }  else{    svm = new svm_regression_c();    cout<<"RSVM generated"<<endl;  };  // scale examples  if(parameters->do_scale){    training_set->scale(parameters->do_scale_y);  };  // training the svm  if(parameters->search_c != 'n'){    calc_c();    cout<<"re-training without CV and C = "<<parameters->realC<<endl;     parameters->cross_validation = 0;    parameters->verbosity -= 1;    train();    parameters->verbosity += 1;  }  else{    train();  };  if(0 == parameters->cross_validation){    // save results    if(parameters->verbosity > 1){      cout<<"Saving trained SVM to "<<(training_set->get_filename())<<".svm"<<endl;    };    char* outname = new char[MAXCHAR];    strcpy(outname,training_set->get_filename());    strcat(outname,".svm");    ofstream output_file(outname,ios::out|ios::trunc);    output_file.precision(16);    output_file<<*training_set;    output_file.close();    delete []outname;  };  // testing  if((parameters->cross_validation > 0) && (0 != test_sets)){    // test result of cross validation: train new SVM on whole example set    parameters->cross_validation = 0;    cout<<"Re-training SVM on whole example set for testing"<<endl;    train();  };  if(0 != test_sets){    cout<<"----------------------------------------"<<endl;    cout<<"Starting tests"<<endl;    example_set_c* next_test;    SVMINT test_no = 0;    char* outname = new char[MAXCHAR];    while(test_sets != 0){      test_no++;      next_test = test_sets->the_set;      if(parameters->do_scale){	next_test->scale(training_set->get_exp(),			 training_set->get_var(),			 training_set->get_dim());      };      if(next_test->initialised_y()){	cout<<"Testing examples from file "<<(next_test->get_filename())<<endl;	svm->test(next_test,1);      }      else{	cout<<"Predicting examples from file "<<(next_test->get_filename())<<endl;	svm->predict(next_test);	// output to file .pred	strcpy(outname,next_test->get_filename());	strcat(outname,".pred");	ofstream output_file(outname,ios::out);	output_file<<"@examples"<<endl;	output_file<<(*next_test);	output_file.close();	      };      test_sets = test_sets->next; // skip delete!    };    delete []outname;  };  if(kernel) delete kernel;  delete svm;  if(parameters->verbosity > 1){    cout << "mysvm ended successfully."<<endl;  };  return(0);};

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -