⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 learn.cpp

📁 介绍支持向量机SVM介绍的参考文献以及程序源代码
💻 CPP
📖 第 1 页 / 共 2 页
字号:
};


svm_result train(){
  svm_result the_result;
  if(parameters->cross_validation > 0){
    the_result = do_cv();
  }
  else{
    kernel->init(parameters->kernel_cache,training_set);
    svm->init(kernel,parameters);

    if(parameters->is_nu || parameters->is_distribution){
      cout<<"Training started with nu = "
	  <<parameters->nu
	  <<"."<<endl;
    }
    else if(parameters->get_Cpos() == parameters->get_Cneg()){
      cout<<"Training started with C = "
	  <<parameters->get_Cpos()
	  <<"."<<endl;
    }
    else{
      cout<<"Training started with C = ("<<parameters->get_Cpos()
	  <<","<<parameters->get_Cneg()<<")."<<endl;
    };
    the_result = svm->train(training_set);
  };
  return the_result;
};


inline
SVMFLOAT to_minimize(svm_result result){
  // which value to minimize in calc_c
  if((parameters->cross_validation <= 0) && (1 == parameters->is_pattern)){
    return result.pred_accuracy;
  }
  else{
    return result.loss;
  };
};

svm_result calc_c(){
  const SVMFLOAT lambda = 0.618033989; // (sqrt(5)-1)/2
  SVMINT verbosity = parameters->verbosity;
  parameters->verbosity -= 2;
  svm_result the_result;
  SVMFLOAT c_min = parameters->c_min;
  SVMFLOAT c_max = parameters->c_max;
  SVMFLOAT c_delta = parameters->c_delta;
  SVMFLOAT oldC;
  SVMINT last_dec=0; // when did loss decrease?
  // setup s,t
  if(verbosity >= 3){
    cout<<"starting search for C"<<endl;
  };
  if((parameters->search_c == 'a') ||(parameters->search_c == 'm')){
    SVMFLOAT minimal_value=infinity;
    SVMFLOAT minimal_C=c_min;
    svm_result minimal_result;
    SVMFLOAT result_value;
    oldC=c_min;
    training_set->clear_alpha();
    while(c_min <= c_max){
      if(verbosity>=3){
	cout<<"C = "<<c_min<<" :"<<endl;
      };
      parameters->realC = c_min;
      training_set->scale_alphas(c_min/oldC);
      // training_set->clear_alpha();
      oldC = c_min;
      if(verbosity >= 4){
	cout<<"C = "<<c_min<<endl;
      }
      the_result = train();
      if(verbosity>=3){
	cout<<"loss = "<<the_result.loss<<endl;
	if(parameters->is_pattern){
	  cout<<"predicted loss = "<<the_result.pred_loss<<endl;
	};
	cout<<"VCdim <= "<<the_result.VCdim<<endl;
      };
      result_value = to_minimize(the_result);
      //      cout<<result_value<<endl;
      last_dec++;
      if((result_value<minimal_value) && (! isnan(result_value))){
	minimal_value=result_value;
	minimal_C=c_min;
	minimal_result = the_result;
	last_dec=0;
      };
      if(parameters->search_c == 'a'){
	c_min += c_delta;
      }
      else{
	c_min *= c_delta;
      };
      if((parameters->search_stop > 0) && (last_dec >= parameters->search_stop)){
	// no decrease in loss, stop
	c_min = 2*c_max;
      };
    };
    parameters->realC = minimal_C;
    the_result = minimal_result;
  }
  else{
    // method of golden ratio
    
    SVMFLOAT s = lambda*c_min+(1-lambda)*c_max;
    SVMFLOAT t = (1-lambda)*c_min+lambda*c_max;
    SVMFLOAT phi_s;
    SVMFLOAT phi_t;
    
    parameters->realC = s;
    training_set->clear_alpha();
    the_result = train();
    phi_s = to_minimize(the_result);
    parameters->realC = t;
    training_set->scale_alphas(t/s);
    oldC = t;
    the_result = train();
    phi_t = to_minimize(the_result);
    while(c_max - c_min > c_delta*c_min){
      if(verbosity >= 3){
	cout<<"C in ["<<c_min<<","<<c_max<<"]"<<endl;
      };
      if(phi_s < phi_t){
	c_max = t;
	t = s;
	phi_t = phi_s;
	// calc s
	s = lambda*c_min+(1-lambda)*c_max;
	parameters->realC = s;
	training_set->scale_alphas(s/oldC);
	oldC=s;
	the_result = train();
	phi_s = to_minimize(the_result);
      }
      else{
	c_min = s;
	s = t;
	phi_s = phi_t;
	// calc t
	t = (1-lambda)*c_min+lambda*c_max;
	parameters->realC = t;
	training_set->scale_alphas(t/oldC);
	oldC=t;
	the_result = train();
	phi_t = to_minimize(the_result);
      };
    };
    // save last results
    if(phi_s < phi_t){
      c_max = t;
    }
    else{
      c_min = s;
    };
    parameters->realC = (c_min+c_max)/2;

  };

  // ouput result
  if(verbosity >= 1){
    cout<<"*** Optimal C is "<<parameters->realC;
    if(parameters->search_c == 'g'){
      cout<<" +/-"<<((c_max-c_min)/2);
    };
    cout<<endl;
  };
  if(verbosity>=2){
    cout<<"result:"<<endl
	<<"Loss: "<<the_result.loss<<endl;
    if(parameters->Lpos != parameters->Lneg){
      cout<<"  Loss+: "<<the_result.loss_pos<<endl;
      cout<<"  Loss-: "<<the_result.loss_neg<<endl;
    };
    if(parameters->is_pattern){
      cout<<"predicted Loss: "<<the_result.pred_loss<<endl;
    };
    cout<<"MAE: "<<the_result.MAE<<endl;
    cout<<"MSE: "<<the_result.MSE<<endl;
    cout<<"VCdim <= "<<the_result.VCdim<<endl;

    if(parameters->is_pattern){
      cout<<"Accuracy  : "<<the_result.accuracy<<endl
	  <<"Precision : "<<the_result.precision<<endl
	  <<"Recall    : "<<the_result.recall<<endl;
      if(parameters->cross_validation == 0){
	cout<<"predicted Accuracy  : "<<the_result.pred_accuracy<<endl
	    <<"predicted Precision : "<<the_result.pred_precision<<endl
	    <<"predicted Recall    : "<<the_result.pred_recall<<endl;
      };
    };
    cout<<"Support Vectors : "<<the_result.number_svs<<endl;
    cout<<"Bounded SVs     : "<<the_result.number_bsv<<endl;
    if(parameters->search_c == 'g'){
      cout<<"(WARNING: this is the last result attained and may slightly differ from the result of the optimal C!)"<<endl;
    };
  };
  parameters->verbosity = verbosity;
  return the_result;
};


///////////////////////////////////////////////////////////////


int main(int argc,char* argv[]){
  cout<<"*** mySVM version "<<mysvmversion<<" ***"<<endl;
  cout.precision(8);
  // read objects
  try{
    if(argc<2){
      cout<<"Reading from STDIN"<<endl;
      // read vom cin
      read_input(cin,"mysvm");
    }
    else{
      char* s = argv[1];
      if((0 == strcmp("-h",s)) || (0==strcmp("-help",s)) || (0==strcmp("--help",s))){
	// print out command-line help
	print_help();
      }
      else{
	// read in all input files
	for(int i=1;i<argc;i++){
	  if(0 == strcmp(argv[i],"-")){
	    cout<<"Reading from STDIN"<<endl;
	    // read vom cin
	    read_input(cin,"mysvm");
	  }
	  else{
	    cout<<"Reading "<<argv[i]<<endl;
	    ifstream input_file(argv[i]);
	    if(input_file.bad()){
	      cout<<"ERROR: Could not read file \""<<argv[i]<<"\", exiting."<<endl;
	      exit(1);
	    };
	    read_input(input_file,argv[i]);
	    input_file.close();
	  };
	};
      };
    };
  }
  catch(general_exception &the_ex){
    cout<<"*** Error while reading input: "<<the_ex.error_msg<<endl;
    exit(1);
  }
  catch(...){
    cout<<"*** Program ended because of unknown error while reading input"<<endl;
    exit(1);
  };

  if(0 == parameters){
    cout << "*** ERROR: You did not enter the svm parameters"<<endl;
    exit(1);
  };
  if(0 == kernel){
    kernel = new kernel_dot_c();
  };
  if(0 == training_set){
    cout << "*** ERROR: You did not enter the training set"<<endl;
    exit(1);
  };
  if(2 > training_set->size()){
    cout << "*** ERROR: Need at least two examples to learn."<<endl;
    exit(1);
  };

  if(parameters->is_distribution){
    svm = new svm_distribution_c();
    cout<<"distribution estimation SVM generated"<<endl;
  }
  else if(parameters->is_nu){
    if(parameters->is_pattern){
      svm = new svm_nu_pattern_c();
      cout<<"nu-PSVM generated"<<endl;
    }
    else{
      svm = new svm_nu_regression_c();
      cout<<"nu-RSVM generated"<<endl;
    };
  }
  else if(parameters->is_pattern){
    svm = new svm_pattern_c();
    cout<<"PSVM generated"<<endl;
  }
  else{
    svm = new svm_regression_c();
    cout<<"RSVM generated"<<endl;
  };

  // scale examples
  if(parameters->do_scale){
    training_set->scale(parameters->do_scale_y);
  };

  // training the svm
  if(parameters->search_c != 'n'){
    calc_c();
    //    cout<<"re-training without CV and C = "<<parameters->realC<<endl; 
    //    parameters->cross_validation = 0;
    //    parameters->verbosity -= 1;
    //    train();
    //    parameters->verbosity += 1;
  }
  else{
    train();
  };


  if(0 == parameters->cross_validation){
    // save results
    if(parameters->verbosity > 1){
      cout<<"Saving trained SVM to "<<(training_set->get_filename())<<".svm"<<endl;
    };
    char* outname = new char[MAXCHAR];
    strcpy(outname,training_set->get_filename());
    strcat(outname,".svm");
    ofstream output_file(outname,
			 ios::out|ios::trunc,filebuf::openprot);
    output_file.precision(16);
    output_file<<*training_set;
    output_file.close();
    delete []outname;
  };

  // testing
  if((parameters->cross_validation > 0) && (0 != test_sets)){
    // test result of cross validation: train new SVM on whole example set
    parameters->cross_validation = 0;
    cout<<"Re-training SVM on whole example set for testing"<<endl;
    train();
  };
  if(0 != test_sets){
    cout<<"----------------------------------------"<<endl;
    cout<<"Starting tests"<<endl;
    example_set_c* next_test;
    SVMINT test_no = 0;
    char* outname = new char[MAXCHAR];
    while(test_sets != 0){
      test_no++;
      next_test = test_sets->the_set;
      if(parameters->do_scale){
	next_test->scale(training_set->get_exp(),
			 training_set->get_var(),
			 training_set->get_dim());
      };
      if(next_test->initialised_y()){
	cout<<"Testing examples from file "<<(next_test->get_filename())<<endl;
	svm->test(next_test,1);
      }
      else{
	cout<<"Predicting examples from file "<<(next_test->get_filename())<<endl;
	svm->predict(next_test);
	// output to file .pred

	strcpy(outname,next_test->get_filename());
	strcat(outname,".pred");
	ofstream output_file(outname,
			     ios::out,filebuf::openprot);
	output_file<<"@examples"<<endl;
	output_file<<(*next_test);
	output_file.close();	
      };
      test_sets = test_sets->next; // skip delete!
    };
    delete []outname;
  };

  if(kernel) delete kernel;
  delete svm;
  if(parameters->verbosity > 1){
    cout << "mysvm ended successfully."<<endl;
  };
  return(0);
};

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -