⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 learn.cpp

📁 svm的cpp源代码
💻 CPP
📖 第 1 页 / 共 2 页
字号:
      cout<<"Training started with nu = "	  <<parameters->nu	  <<"."<<endl;    }    else if(parameters->get_Cpos() == parameters->get_Cneg())
	{      cout<<"Training started with C = "	  <<parameters->get_Cpos()	  <<"."<<endl;    }    else
	{      cout<<"Training started with C = ("<<parameters->get_Cpos()	  <<","<<parameters->get_Cneg()<<")."<<endl;    };    the_result = svm->train(training_set);  };  return the_result;};inlineSVMFLOAT to_minimize(svm_result result)
{  // which value to minimize in calc_c  if((parameters->cross_validation <= 0) && (1 == parameters->is_pattern))
  {    return (1.0-result.pred_accuracy);  }  else
  {    if(1 == parameters->is_pattern)
	{      return (1.0-result.accuracy);    }    else
	{      return result.loss;    };  };};svm_result calc_c()
{  const SVMFLOAT lambda = 0.618033989; // (sqrt(5)-1)/2  SVMINT verbosity = parameters->verbosity;  parameters->verbosity -= 2;  svm_result the_result;  SVMFLOAT c_min = parameters->c_min;  SVMFLOAT c_max = parameters->c_max;  SVMFLOAT c_delta = parameters->c_delta;  SVMFLOAT oldC;  SVMINT last_dec=0; // when did loss decrease?  // setup s,t  if(verbosity >= 3)
  {    cout<<"starting search for C"<<endl;  };  if((parameters->search_c == 'a') ||(parameters->search_c == 'm'))
  {    SVMFLOAT minimal_value=infinity;    SVMFLOAT minimal_C=c_min;    svm_result minimal_result;    SVMFLOAT result_value;    oldC=c_min;    training_set->clear_alpha();    while(c_min <= c_max)
	{      if(verbosity>=3)
	  {		cout<<"C = "<<c_min<<" :"<<endl;      };      parameters->realC = c_min;      training_set->scale_alphas(c_min/oldC);      // training_set->clear_alpha();      oldC = c_min;      if(verbosity >= 4)
	  {		cout<<"C = "<<c_min<<endl;      }      the_result = train();      if(verbosity>=3)
	  {		cout<<"loss = "<<the_result.loss<<endl;		if(parameters->is_pattern)
		{			cout<<"accuracy = "<<the_result.accuracy<<endl;			// cout<<"predicted loss = "<<the_result.pred_loss<<endl;			cout<<"predicted accuracy = "<<the_result.pred_accuracy<<endl;		};		//	cout<<"VCdim <= "<<the_result.VCdim<<endl;      };      result_value = to_minimize(the_result);      //      cout<<result_value<<endl;      last_dec++;      if((result_value<minimal_value) && (! isnan(result_value)))
	  {		minimal_value=result_value;		minimal_C=c_min;		minimal_result = the_result;		last_dec=0;      };      if(parameters->search_c == 'a')
	  {		c_min += c_delta;      }      else
	  {		c_min *= c_delta;      };      if((parameters->search_stop > 0) && (last_dec >= parameters->search_stop))
	  {		// no decrease in loss, stop		c_min = 2*c_max;      };    };    parameters->realC = minimal_C;    the_result = minimal_result;  }  else
  {    // method of golden ratio    SVMFLOAT s = lambda*c_min+(1-lambda)*c_max;    SVMFLOAT t = (1-lambda)*c_min+lambda*c_max;    SVMFLOAT phi_s;    SVMFLOAT phi_t;        parameters->realC = s;    training_set->clear_alpha();    the_result = train();    phi_s = to_minimize(the_result);    parameters->realC = t;    training_set->scale_alphas(t/s);    oldC = t;    the_result = train();    phi_t = to_minimize(the_result);    while(c_max - c_min > c_delta*c_min)
	{      if(verbosity >= 3)
	  {		cout<<"C in ["<<c_min<<","<<c_max<<"]"<<endl;      };      if(phi_s < phi_t)
	  {		c_max = t;		t = s;		phi_t = phi_s;		// calc s		s = lambda*c_min+(1-lambda)*c_max;		parameters->realC = s;		training_set->scale_alphas(s/oldC);		oldC=s;		the_result = train();		phi_s = to_minimize(the_result);      }      else
	  {		c_min = s;		s = t;		phi_s = phi_t;		// calc t		t = (1-lambda)*c_min+lambda*c_max;		parameters->realC = t;		training_set->scale_alphas(t/oldC);		oldC=t;		the_result = train();		phi_t = to_minimize(the_result);      };    };    // save last results    if(phi_s < phi_t)
	{      c_max = t;    }    else
	{      c_min = s;    };    parameters->realC = (c_min+c_max)/2;  };  // ouput result  if(verbosity >= 1)
  {    cout<<"*** Optimal C is "<<parameters->realC;    if(parameters->search_c == 'g')
	{      cout<<" +/-"<<((c_max-c_min)/2);    };    cout<<endl;  };  if(verbosity>=2)
  {    cout<<"result:"<<endl	<<"Loss: "<<the_result.loss<<endl;    if(parameters->Lpos != parameters->Lneg)
	{      cout<<"  Loss+: "<<the_result.loss_pos<<endl;      cout<<"  Loss-: "<<the_result.loss_neg<<endl;    };    if(parameters->is_pattern)
	{      cout<<"predicted Loss: "<<the_result.pred_loss<<endl;    };    cout<<"MAE: "<<the_result.MAE<<endl;    cout<<"MSE: "<<the_result.MSE<<endl;    cout<<"VCdim <= "<<the_result.VCdim<<endl;    if(parameters->is_pattern)
	{      cout<<"Accuracy  : "<<the_result.accuracy<<endl	  <<"Precision : "<<the_result.precision<<endl	  <<"Recall    : "<<the_result.recall<<endl;      if(parameters->cross_validation == 0)
	  {		cout<<"predicted Accuracy  : "<<the_result.pred_accuracy<<endl	    <<"predicted Precision : "<<the_result.pred_precision<<endl	    <<"predicted Recall    : "<<the_result.pred_recall<<endl;      };    };    cout<<"Support Vectors : "<<the_result.number_svs<<endl;    cout<<"Bounded SVs     : "<<the_result.number_bsv<<endl;    if(parameters->search_c == 'g')
	{      cout<<"(WARNING: this is the last result attained and may slightly differ from the result of the optimal C!)"<<endl;    };  };  parameters->verbosity = verbosity;  return the_result;};///////////////////////////////////////////////////////////////int main(int argc,char* argv[])
{  cout<<"*** mySVM version "<<mysvmversion<<" ***"<<endl;  cout.precision(8);  // read objects
  try
  {    if(argc<2)
	{      cout<<"Reading from STDIN"<<endl;      // read vom cin      read_input(cin,"mysvm");    }    else
	{      char* s = argv[1];      if((0 == strcmp("-h",s)) || (0==strcmp("-help",s)) || (0==strcmp("--help",s)))
	  {		// print out command-line help			print_help();      }      else
	  {		// read in all input files		for(int i=1;i<argc;i++)
		{			if(0 == strcmp(argv[i],"-"))
			{				cout<<"Reading from STDIN"<<endl;				// read vom cin				read_input(cin,"mysvm");			}			else
			{				cout<<"Reading "<<argv[i]<<endl;				ifstream input_file(argv[i]);				if(input_file.bad())
				{					cout<<"ERROR: Could not read file \""<<argv[i]<<"\", exiting."<<endl;					exit(1);				};				read_input(input_file,argv[i]);				input_file.close();			}		}      }    }  }  catch(general_exception &the_ex)
  {    cout<<"*** Error while reading input: "<<the_ex.error_msg<<endl;    exit(1);  }  catch(...)
  {    cout<<"*** Program ended because of unknown error while reading input"<<endl;    exit(1);  };  if(0 == parameters)
  {    parameters = new parameters_c();    if(training_set->initialised_pattern_y())
	{      parameters->is_pattern = 1;      parameters->do_scale_y = 0;    };  };  parameters->is_linear = is_linear;  if(0 == kernel)
  {    kernel = new kernel_dot_c();  };  if(0 == training_set)
  {    cout << "*** ERROR: You did not enter the training set"<<endl;    exit(1);  };  if(2 > training_set->size())
  {    cout << "*** ERROR: Need at least two examples to learn."<<endl;    exit(1);  };  if(parameters->is_distribution)
  {    svm = new svm_distribution_c();    cout<<"distribution estimation SVM generated"<<endl;  }  else if(parameters->is_nu)
  {    if(parameters->is_pattern)
	{      svm = new svm_nu_pattern_c();      if(! training_set->initialised_pattern_y())
	  {		cout<<"WARNING: Parameters set a pattern SVM, but the training ys are not in {-1,1}."<<endl;      }      cout<<"nu-PSVM generated"<<endl;    }    else
	{      svm = new svm_nu_regression_c();      cout<<"nu-RSVM generated"<<endl;    };  }  else if(parameters->is_pattern)
  {    svm = new svm_pattern_c();    if(! training_set->initialised_pattern_y())
	{      cout<<"WARNING: Parameters set a pattern SVM, but the training ys are not in {-1,1}."<<endl;    }    cout<<"PSVM generated"<<endl;  }  else
  {    svm = new svm_regression_c();    cout<<"RSVM generated"<<endl;  };  // scale examples  if(parameters->do_scale)
  {    training_set->scale(parameters->do_scale_y);  };  // training the svm  if(parameters->search_c != 'n')
  {    calc_c();    cout<<"re-training without CV and C = "<<parameters->realC<<endl;     parameters->cross_validation = 0;    parameters->verbosity -= 1;    train();    parameters->verbosity += 1;  }  else
  {    train();  };  if(0 == parameters->cross_validation)
  {    // save results    if(parameters->verbosity > 1)
	{      cout<<"Saving trained SVM to "<<(training_set->get_filename())<<".svm"<<endl;    };    char* outname = new char[MAXCHAR];    strcpy(outname,training_set->get_filename());    strcat(outname,".svm");    ofstream output_file(outname,ios::out|ios::trunc);    output_file.precision(16);    output_file<<*training_set;    output_file.close();    delete []outname;  };  // testing  if((parameters->cross_validation > 0) && (0 != test_sets))
  {    // test result of cross validation: train new SVM on whole example set    parameters->cross_validation = 0;    cout<<"Re-training SVM on whole example set for testing"<<endl;    train();  };  if(0 != test_sets)
  {    cout<<"----------------------------------------"<<endl;    cout<<"Starting tests"<<endl;    example_set_c* next_test;    SVMINT test_no = 0;    char* outname = new char[MAXCHAR];    while(test_sets != 0)
	{      test_no++;      next_test = test_sets->the_set;      if(parameters->do_scale)
	  {		next_test->scale(training_set->get_exp(),			 training_set->get_var(),			 training_set->get_dim());      };      if(next_test->initialised_y())
	  {		cout<<"Testing examples from file "<<(next_test->get_filename())<<endl;		svm->test(next_test,1);      }      else
	  {		cout<<"Predicting examples from file "<<(next_test->get_filename())<<endl;		svm->predict(next_test);	// output to file .pred		strcpy(outname,next_test->get_filename());		strcat(outname,".pred");		ofstream output_file(outname,ios::out);		output_file<<"@examples"<<endl;		output_file<<(*next_test);		output_file.close();	      };      test_sets = test_sets->next; // skip delete!    };    delete []outname;  };  if(kernel) delete kernel;  delete svm;  if(parameters->verbosity > 1)
  {    cout << "mysvm ended successfully."<<endl;  };  return(0);};

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -