⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_c.cpp

📁 做回归很好
💻 CPP
📖 第 1 页 / 共 5 页
字号:

    // are all variables at the bound?
    SVMINT pos_abs;
    int bounded_pos=1;
    int bounded_neg=1;
    SVMINT pos=0;
    while((pos<working_set_size) && ((1 == bounded_pos) || (1 == bounded_neg))){
      pos_abs = working_set[pos];
      if(is_alpha_neg(pos_abs) > 0){
	if(all_alphas[pos_abs]-Cneg < -is_zero){
	  bounded_pos = 0;
	};
	if(all_alphas[pos_abs] > is_zero){
	  bounded_neg = 0;
	};
      }
      else{
	if(all_alphas[pos_abs]+Cneg > is_zero){
	  bounded_neg = 0;
	};
	if(all_alphas[pos_abs] < -is_zero){
	  bounded_pos = 0;
	};
      };
      pos++;
    };
    if(0 != bounded_pos){
      // all alphas are at upper bound
      // need alpha that can be moved upward
      // use alpha with smallest lambda
      SVMFLOAT max_lambda = infinity;
      SVMINT max_pos=examples_total;
      for(pos_abs=0;pos_abs<examples_total;pos_abs++){
	if(is_alpha_neg(pos_abs) > 0){
	  if(all_alphas[pos_abs]-Cneg < -is_zero){
	    if(lambda(pos_abs) < max_lambda){
	      max_lambda = lambda(pos_abs);
	      max_pos = pos_abs;
	    };
	  };
	}
	else{
	  if(all_alphas[pos_abs] < -is_zero){
	    if(lambda(pos_abs) < max_lambda){
	      max_lambda = lambda(pos_abs);
	      max_pos = pos_abs;
	    };
	  };
	};
      };
      if(max_pos<examples_total){
	if(working_set_size<parameters->working_set_size){
	  working_set_size++;
	};
	working_set[working_set_size-1] = max_pos;
      };
    }
    else if(0 != bounded_neg){
      // all alphas are at lower bound
      // need alpha that can be moved downward
      // use alpha with smallest lambda
      SVMFLOAT max_lambda = infinity;
      SVMINT max_pos=examples_total;
      for(pos_abs=0;pos_abs<examples_total;pos_abs++){
	if(is_alpha_neg(pos_abs) > 0){
	  if(all_alphas[pos_abs] > is_zero){
	    if(lambda(pos_abs) < max_lambda){
	      max_lambda = lambda(pos_abs);
	      max_pos = pos_abs;
	    };
	  };
	}
	else{
	  if(all_alphas[pos_abs]+Cneg > is_zero){
	    if(lambda(pos_abs) < max_lambda){
	      max_lambda = lambda(pos_abs);
	      max_pos = pos_abs;
	    };
	  };
	};
      };
      if(max_pos<examples_total){
	if(working_set_size<parameters->working_set_size){
	  working_set_size++;
	};
	working_set[working_set_size-1] = max_pos;
      };
    };
  };

  if((working_set_size<parameters->working_set_size) &&
     (working_set_size<examples_total)){
    // use full working set
    SVMINT pos = (SVMINT)((SVMFLOAT)examples_total*rand()/(RAND_MAX+1.0));
    int ok;
    while((working_set_size<parameters->working_set_size) &&
	  (working_set_size<examples_total)){
      // add pos into WS if it isn't already
      ok = 1;
      for(i=0;i<working_set_size;i++){
	if(working_set[i] == pos){
	  ok=0;
	  i = working_set_size;
	};
      };
      if(1 == ok){
	working_set[working_set_size] = pos;
	working_set_size++;
      };
      pos = (pos+1)%examples_total;
    };
  };

  SVMINT ipos;
  for(ipos=0;ipos<working_set_size;ipos++){
    which_alpha[ipos] = is_alpha_neg(working_set[ipos]);
  };

  time_calc += get_time() - time_start;
  return;
};


void svm_c::project_to_constraint(){
  // project alphas to match the constraint
  SVMFLOAT alpha_sum = sum_alpha;
  SVMINT SVcount=0;
  SVMFLOAT alpha;
  SVMINT i;
  for(i=0;i<examples_total;i++){
    alpha = all_alphas[i];
    alpha_sum += alpha;
    if(((alpha>is_zero) && (alpha-Cneg < -is_zero)) ||
       ((alpha<-is_zero) && (alpha+Cpos > is_zero))){
      SVcount++;
    };
  };
  if(SVcount > 0){
    // project
    alpha_sum /= (SVMFLOAT)SVcount;
    for(i=0;i<examples_total;i++){
      alpha = all_alphas[i];
      if(((alpha>is_zero) && (alpha-Cneg < -is_zero)) ||
	 ((alpha<-is_zero) && (alpha+Cpos > is_zero))){
	all_alphas[i] -= alpha_sum;
      };
    };
  };
};


void svm_c::init_working_set(){
  // calculate sum
  SVMINT i,j;

  project_to_constraint();
  // check bounds!
  if(examples->initialised_alpha()){
    if(parameters->verbosity >= 2){
      cout<<"Initialising variables, this may take some time."<<endl;
    };
    for(i=0; i<examples_total;i++){
      sum[i] = 0;
      at_bound[i] = 0;
      for(j=0; j<examples_total;j++){
	sum[i] += all_alphas[j]*kernel->calculate_K(i,j);
      };
    };
  }
  else{
    // skip kernel calculation as all alphas = 0
    for(i=0; i<examples_total;i++){
      sum[i] = 0;
      at_bound[i] = 0;
    };    
  };

  if(examples->initialised_alpha()){
    calculate_working_set();
  }
  else{
    // first working set is random
    j=0;
    i=0;
    while((i<working_set_size) && (j < examples_total)){
      working_set[i] = j;
      if(is_alpha_neg(j) > 0){
	which_alpha[i] = 1;
      }
      else{
	which_alpha[i] = -1;
      };
      i++;
      j++;
    };
  };   
  update_working_set();
};


void svm_c::put_optimizer_values(){
  // update nabla, sum, examples.
  // sum[i] += (primal_j^*-primal_j-alpha_j^*+alpha_j)K(i,j)
  // check for |nabla| < is_zero (nabla <-> nabla*)
  //  cout<<"put_optimizer_values()"<<endl;
  SVMINT i=0; 
  SVMINT j=0;
  SVMINT pos_i;
  SVMFLOAT the_new_alpha;
  SVMFLOAT* kernel_row;
  SVMFLOAT alpha_diff;

  long time_start = get_time();
  pos_i=working_set_size;
  while(pos_i>0){
    pos_i--;
    if(which_alpha[pos_i]>0){
      the_new_alpha = primal[pos_i];
    }
    else{
      the_new_alpha = -primal[pos_i];
    };
    // next three statements: keep this order!
    i = working_set[pos_i];
    alpha_diff = the_new_alpha-all_alphas[i];
    all_alphas[i] = the_new_alpha;

    if(alpha_diff != 0){
      // update sum ( => nabla)
      kernel_row = kernel->get_row(i);
      for(j=0;j<examples_total;j++){
	sum[j] += alpha_diff*kernel_row[j];
      };
    };
  };
  time_update += get_time() - time_start;
};


void svm_c::update_working_set(){
  long time_start = get_time();
  // setup subproblem
  SVMINT i,j;
  SVMINT pos_i, pos_j;
  SVMFLOAT* kernel_row;
  SVMFLOAT sum_WS;

  for(pos_i=0;pos_i<working_set_size;pos_i++){
    i = working_set[pos_i];

    // put row sort_i in hessian 
    kernel_row = kernel->get_row(i);
    sum_WS=0;
    //    for(pos_j=0;pos_j<working_set_size;pos_j++){
    for(pos_j=0;pos_j<pos_i;pos_j++){
      j = working_set[pos_j];
      // put all elements K(i,j) in hessian, where j in WS
      if(((which_alpha[pos_j] < 0) && (which_alpha[pos_i] < 0)) ||
	 ((which_alpha[pos_j] > 0) && (which_alpha[pos_i] > 0))){
	// both i and j positive or negative
	(qp.H)[pos_i*working_set_size+pos_j] = kernel_row[j];
	(qp.H)[pos_j*working_set_size+pos_i] = kernel_row[j];
      }
      else{
	// one of i and j positive, one negative
	(qp.H)[pos_i*working_set_size+pos_j] = -kernel_row[j];
	(qp.H)[pos_j*working_set_size+pos_i] = -kernel_row[j];
      };
    };
    for(pos_j=0;pos_j<working_set_size;pos_j++){
      j = working_set[pos_j];
      sum_WS+=all_alphas[j]*kernel_row[j];
    };
    // set main diagonal 
    (qp.H)[pos_i*working_set_size+pos_i] = kernel_row[i];

    // linear and box constraints
    if(which_alpha[pos_i]<0){
      // alpha
      (qp.A)[pos_i] = -1;
      // lin(alpha) = y_i+eps-sum_{i not in WS} alpha_i K_{ij}
      //            = y_i+eps-sum_i+sum_{i in WS}
      (qp.c)[pos_i] = all_ys[i]+epsilon_pos-sum[i]+sum_WS;
      primal[pos_i] = -all_alphas[i];
      (qp.u)[pos_i] = Cpos;
    }
    else{
      // alpha^*
      (qp.A)[pos_i] = 1;
      (qp.c)[pos_i] = -all_ys[i]+epsilon_neg+sum[i]-sum_WS;
      primal[pos_i] = all_alphas[i];
      (qp.u)[pos_i] = Cneg;
    };
  };
  if(parameters->quadraticLossNeg){
    for(pos_i=0;pos_i<working_set_size;pos_i++){
      if(which_alpha[pos_i]>0){
	(qp.H)[pos_i*(working_set_size+1)] += 1/Cneg;
	(qp.u)[pos_i] = infinity;
      };
    };
  };
  if(parameters->quadraticLossPos){
    for(pos_i=0;pos_i<working_set_size;pos_i++){
      if(which_alpha[pos_i]<0){
	(qp.H)[pos_i*(working_set_size+1)] += 1/Cpos;
	(qp.u)[pos_i] = infinity;
      };
    };
  };

  time_update += get_time() - time_start; 
};


svm_result svm_c::test(example_set_c* test_examples, int verbose){
  svm_result the_result;
  test_set = test_examples;

  SVMINT i;
  SVMFLOAT MAE=0;
  SVMFLOAT MSE=0;
  SVMFLOAT actloss=0;
  SVMFLOAT theloss=0;
  SVMFLOAT theloss_pos=0;
  SVMFLOAT theloss_neg=0;
  SVMINT countpos=0;
  SVMINT countneg=0;
  // for pattern:
  SVMINT correct_pos=0;
  SVMINT correct_neg=0;
  SVMINT total_pos=0;
  SVMINT total_neg=0;

  SVMFLOAT prediction;
  SVMFLOAT y;
  svm_example example;
  for(i=0;i<test_set->size();i++){
    example = test_set->get_example(i);
    prediction = predict(example);
    y = examples->unscale_y(test_set->get_y(i));
    MAE += abs(y-prediction);
    MSE += (y-prediction)*(y-prediction);
    actloss=loss(prediction,y);
    theloss+=actloss;
    if(y < prediction-parameters->epsilon_pos){
      theloss_pos += actloss;
      countpos++;
    }
    else if(y > prediction+parameters->epsilon_neg){
      theloss_neg += actloss;
      countneg++;
    };
    // if pattern!
    if(is_pattern){
      if(y>0){
	if(prediction>0){
	  correct_pos++;
	};
	total_pos++;
      }
      else{
	if(prediction<=0){
	  correct_neg++;
	};
	total_neg++;
      };
    };    
  };
  if(countpos != 0){
    theloss_pos /= (SVMFLOAT)countpos;
  };
  if(countneg != 0){
    theloss_neg /= (SVMFLOAT)countneg;
  };

  the_result.MAE =  MAE / (SVMFLOAT)test_set->size();
  the_result.MSE =  MSE / (SVMFLOAT)test_set->size();
  the_result.loss = theloss/test_set->size();
  the_result.loss_pos = theloss_pos;
  the_result.loss_neg = theloss_neg;
  the_result.number_svs = 0;
  the_result.number_bsv = 0;
  if(is_pattern){
    the_result.accuracy = ((SVMFLOAT)(correct_pos+correct_neg))/((SVMFLOAT)(total_pos+total_neg));
    the_result.precision = ((SVMFLOAT)correct_pos/((SVMFLOAT)(correct_pos+total_neg-correct_neg)));
    the_result.recall = ((SVMFLOAT)correct_pos/(SVMFLOAT)total_pos);
  }
  else{
    the_result.accuracy = -1;
    the_result.precision = -1;
    the_result.recall = -1;
  };

  if(verbose){
    cout << "Average loss  : "<<(theloss/test_set->size())<<endl;
    cout << "Avg. loss pos : "<<theloss_pos<<"\t ("<<countpos<<" occurences)"<<endl;
    cout << "Avg. loss neg : "<<theloss_neg<<"\t ("<<countneg<<" occurences)"<<endl;
    cout << "Mean absolute error : "<<the_result.MAE<<endl;
    cout << "Mean squared error  : "<<the_result.MSE<<endl;

    if(is_pattern){
      // output precision, recall and accuracy
      cout<<"Accuracy  : "<<the_result.accuracy<<endl;
      cout<<"Precision : "<<the_result.precision<<endl;
      cout<<"Recall    : "<<the_result.recall<<endl;
      // nice printout ;-)
      int rows = (int)(1+log10((SVMFLOAT)(total_pos+total_neg)));
      int now_digits = rows+2;
      int i,j;
      cout<<endl;
      cout<<"Predicted values:"<<endl;
      cout<<"   |";
      for(i=0;i<rows;i++){ cout<<" "; };
      cout<<"+  |";
      for(j=0;j<rows;j++){ cout<<" "; };
      cout<<"-"<<endl;
      
      cout<<"---+";
      for(i=0;i<now_digits;i++){ cout<<"-"; };
      cout<<"-+-";
      for(i=0;i<now_digits;i++){ cout<<"-"; };
      cout<<endl;
      
      cout<<" + |  ";
      now_digits=rows-(int)(1+log10((SVMFLOAT)correct_pos))-1;
      for(i=0;i<now_digits;i++){ cout<<" "; };
      cout<<correct_pos<<"  |  ";
      now_digits=rows-(int)(1+log10((SVMFLOAT)(total_pos-correct_pos)))-1;
      for(i=0;i<now_digits;i++){ cout<<" "; };
      cout<<total_pos-correct_pos<<"    (true pos)"<<endl;
      
      cout<<" - |  ";
      now_digits=rows-(int)(1+log10((SVMFLOAT)(total_neg-correct_neg)))-1;
      for(i=0;i<now_digits;i++){ cout<<" "; };
      cout<<(total_neg-correct_neg)<<"  |  ";
      now_digits=rows-(int)(1+log10((SVMFLOAT)correct_neg))-1;
      for(i=0;i<now_digits;i++){ cout<<" "; };
      cout<<correct_neg<<"    (true neg)"<<endl;
      cout<<endl;
    };
  };
  return the_result;
};


void svm_c::optimize(){
  // optimizer-specific call
  // get time
  long time_start = get_time();

  qp.n = working_set_size;

  SVMINT i;
  SVMINT j;

  // equality constraint
  qp.b[0]=0;
  for(i=0;i<working_set_size;i++){
    qp.b[0] += all_alphas[working_set[i]];
  };

  // set initial optimization parameters
  SVMFLOAT new_target=0;
  SVMFLOAT old_target=0;
  SVMFLOAT target_tmp;
  for(i=0;i<working_set_size;i++){
    target_tmp = primal[i]*qp.H[i*working_set_size+i]/2;
    for(j=0;j<i;j++){
      target_tmp+=primal[j]*qp.H[j*working_set_size+i];
    };
    target_tmp+=qp.c[i];
    old_target+=target_tmp*primal[i];
  };

  SVMFLOAT new_constraint_sum=0;
  SVMFLOAT my_is_zero = is_zero;
  SVMINT sv_count=working_set_size;

  qp.n = working_set_size;
  // optimize
  int KKTerror=1;
  int convError=0;

  smo.set_max_allowed_error(convergence_epsilon);

  // loop while some KKT condition is not valid (alpha=0)
  int result = smo.smo_solve(&qp,primal);
  lambda_WS = smo.get_lambda_eq();

  /////////// new
  SVMINT it=3;
  if(! is_pattern){
    SVMFLOAT lambda_lo;
    while(KKTerror && (it>0)){
      KKTerror = 0;
      it--;
      for(i=0;i<working_set_size;i++){
	if(primal[i]<is_zero){
	  lambda_lo =  epsilon_neg + epsilon_pos - qp.c[i];
	  for(j=0;j<working_set_size;j++){
	    lambda_lo -= primal[j]*qp.H[i*working_set_size+j];
	  };
	  if(qp.A[i] > 0){
	    lambda_lo -= lambda_WS;
	  }
	  else{
	    lambda_lo += lambda_WS;
	  };

	  //cout<<"primal["<<i<<"] = "<<primal[i]<<", lambda_lo = "<<lambda_lo<<endl;

	  if(lambda_lo<-convergence_epsilon){

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -