⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_c.cpp

📁 支持向量机(SVM)的VC源代码
💻 CPP
📖 第 1 页 / 共 5 页
字号:

  // equality constraint
  qp.b[0]=0;
  for(i=0;i<working_set_size;i++){
    qp.b[0] += all_alphas[working_set[i]];
  };

  // set initial optimization parameters
  SVMFLOAT new_target=0;
  SVMFLOAT old_target=0;
  SVMFLOAT target_tmp;
  for(i=0;i<working_set_size;i++){
    target_tmp = primal[i]*qp.H[i*working_set_size+i]/2;
    for(j=0;j<i;j++){
      target_tmp+=primal[j]*qp.H[j*working_set_size+i];
    };
    target_tmp+=qp.c[i];
    old_target+=target_tmp*primal[i];
  };

  SVMFLOAT new_constraint_sum=0;
  SVMFLOAT my_is_zero = is_zero;
  SVMINT sv_count=working_set_size;

  qp.n = working_set_size;
  // optimize
  int KKTerror=1;
  int convError=0;

  smo.set_max_allowed_error(feasible_epsilon);

  // loop while some KKT condition is not valid (alpha=0)
  int result = smo.smo_solve(&qp,primal);
  lambda_WS = smo.get_lambda_eq();

  /////////// new
  SVMINT it=3;
  if(! is_pattern){
    SVMFLOAT lambda_lo;
    while(KKTerror && (it>0)){
      KKTerror = 0;
      it--;
      for(SVMINT i=0;i<working_set_size;i++){
	if(primal[i]<is_zero){
	  lambda_lo =  epsilon_neg + epsilon_pos - qp.c[i];
	  for(SVMINT j=0;j<working_set_size;j++){
	    lambda_lo -= primal[j]*qp.H[i*working_set_size+j];
	  };
	  if(qp.A[i] > 0){
	    lambda_lo -= lambda_WS;
	  }
	  else{
	    lambda_lo += lambda_WS;
	  };

	  //cout<<"primal["<<i<<"] = "<<primal[i]<<", lambda_lo = "<<lambda_lo<<endl;

	  if(lambda_lo<-convergence_epsilon){
	    // change sign of i
	    KKTerror=1;
	    qp.A[i] = -qp.A[i];
	    which_alpha[i] = -which_alpha[i];
	    primal[i] = -primal[i];
	    qp.c[i] = epsilon_neg + epsilon_pos - qp.c[i];
	    if(qp.A[i]>0){
	      qp.u[i] = Cneg;
	    }
	    else{
	      qp.u[i] = Cpos;
	    };
	    for(SVMINT j=0;j<working_set_size;j++){
	      qp.H[i*working_set_size+j] = -qp.H[i*working_set_size+j];
	      qp.H[j*working_set_size+i] = -qp.H[j*working_set_size+i];
	    };
	    if(parameters->quadraticLossNeg){
	      if(which_alpha[i]>0){
		(qp.H)[i*(working_set_size+1)] += 1/Cneg;
		(qp.u)[i] = infinity;
	      }
	      else{
		// previous was neg
		(qp.H)[i*(working_set_size+1)] -= 1/Cneg;
	      };
	    };
	    if(parameters->quadraticLossPos){
	      if(which_alpha[i]<0){
		(qp.H)[i*(working_set_size+1)] += 1/Cpos;
		(qp.u)[i] = infinity;
	      }
	      else{
		//previous was pos
		(qp.H)[i*(working_set_size+1)] -= 1/Cpos;
	      };
	    };
	  };
	};
      };
      result = smo.smo_solve(&qp,primal);
      lambda_WS = smo.get_lambda_eq();
      
    };
  };



  KKTerror = 1;
  //////////////////////



  if(parameters->verbosity>=5){
    cout<<"smo ended with result "<<result<<endl;
    cout<<"lambda_WS = "<<lambda_WS<<endl;
    cout<<"smo: Resulting values:"<<endl;
    for(i=0;i<working_set_size;i++){
      cout<<i<<": "<<primal[i]<<endl; 
    };
  };

  while(KKTerror){
    // clip
    sv_count=working_set_size;
    new_constraint_sum=qp.b[0];
    for(i=0;i<working_set_size;i++){
      // check if at bound
      if(primal[i] <= my_is_zero){
	// at lower bound
	primal[i] = qp.l[i];
	sv_count--;
      }
      else if(qp.u[i]-primal[i] <= my_is_zero){
	// at upper bound
	primal[i] = qp.u[i];
	sv_count--;
      };
      new_constraint_sum -= qp.A[i]*primal[i];
    };

    // enforce equality constraint
    if(sv_count>0){
      new_constraint_sum /= (SVMFLOAT)sv_count;
      if(parameters->verbosity>=5){
	cout<<"adjusting "<<sv_count<<" alphas by "<<new_constraint_sum<<endl;
      };
      for(i=0;i<working_set_size;i++){
	if((primal[i] > qp.l[i]) && 
	   (primal[i] < qp.u[i])){
	  // real sv
	  primal[i] += qp.A[i]*new_constraint_sum;
	};
      };
    }
    else if(abs(new_constraint_sum)>(SVMFLOAT)working_set_size*is_zero){
      // error, can't get feasible point
      if(parameters->verbosity>=5){
	cout<<"WARNING: No SVs, constraint_sum = "<<new_constraint_sum<<endl;
      };
      old_target = -infinity; 
      //is_ok=0;
      convError=1;
    };
    // test descend
    new_target=0;
    for(i=0;i<working_set_size;i++){
      // attention: loqo changes one triangle of H!
      target_tmp = primal[i]*qp.H[i*working_set_size+i]/2;
      for(j=0;j<i;j++){
	target_tmp+=primal[j]*qp.H[j*working_set_size+i];
      };
      target_tmp+=qp.c[i];
      new_target+=target_tmp*primal[i];
    };

    if(new_target < old_target){
      KKTerror = 0;
      if(parameters->descend < old_target - new_target){
	target_count=0;
      }
      else{
	convError=1;
      };
      if(parameters->verbosity>=5){
	cout<<"descend = "<<old_target-new_target<<endl;
      };
    }
    else if(sv_count > 0){
      // less SVs
      // set my_is_zero to min_i(primal[i]-qp.l[i], qp.u[i]-primal[i])
      my_is_zero = 1e20;
      for(i=0;i<working_set_size;i++){
	if((primal[i] > qp.l[i]) && (primal[i] < qp.u[i])){
	  if(primal[i] - qp.l[i] < my_is_zero){
	    my_is_zero = primal[i]-qp.l[i];
	  };
	  if(qp.u[i]  - primal[i]  < my_is_zero){
	    my_is_zero = qp.u[i] - primal[i];
	  };
	};
      };
      if(target_count == 0){
      	my_is_zero *= 2;
      };
      if(parameters->verbosity>=5){
	cout<<"WARNING: no descend ("<<old_target-new_target
	    <<" <= "<<parameters->descend
	  //	    <<", alpha_diff = "<<alpha_diff
	    <<"), adjusting is_zero to "<<my_is_zero<<endl;
	cout<<"new_target = "<<new_target<<endl;
      };
    }
    else{
      // nothing we can do
      if(parameters->verbosity>=5){
	cout<<"WARNING: no descend ("<<old_target-new_target
	    <<" <= "<<parameters->descend<<"), stopping."<<endl;
      };
      KKTerror=0;
      convError=1;
    };
  };

  if(1 == convError){
    target_count++;
    //    sigfig_max+=0.05;
    if(old_target < new_target){
      for(i=0;i<working_set_size;i++){
	primal[i] = qp.A[i]*all_alphas[working_set[i]];
      };                              
      if(parameters->verbosity>=5){	
	cout<<"WARNING: Convergence error, restoring old primals"<<endl; //, setting sigfig = "<<sigfig_max<<endl;
      };
    };                                          
  };

  if(target_count>50){
    // non-recoverable numerical error
    feasible_epsilon=1;
    convergence_epsilon*=2;
    //    sigfig_max=-log10(is_zero);
    if(parameters->verbosity>=1)
      cout<<"WARNING: reducing KKT precision to "<<convergence_epsilon<<endl;
    target_count=0;
  };

  time_optimize += get_time() - time_start;
};


SVMFLOAT svm_c::predict(svm_example example){ 
  SVMINT i;
  svm_example sv;
  SVMFLOAT the_sum=examples->get_b();

  for(i=0;i<examples_total;i++){
    if(all_alphas[i] != 0){
      sv = examples->get_example(i);
      the_sum += all_alphas[i]*kernel->calculate_K(sv,example);
    };
  };

  return the_sum;
};


SVMFLOAT svm_c::predict(SVMINT i){
  // return (sum[i]+examples->get_b());
  // numerically mor stable:
  return predict(examples->get_example(i));
};

SVMFLOAT svm_c::loss(SVMINT i){
  return loss(predict(i),all_ys[i]);
};


SVMFLOAT svm_c::loss(SVMFLOAT prediction, SVMFLOAT value){
  SVMFLOAT theloss = prediction - value;
  if(is_pattern){
    if(((value > 0) && (prediction > 0)) ||
       ((value <= 0) && (prediction <= 0))){
      theloss = 0;
    }
  };
  if(theloss > parameters->epsilon_pos){ 
    if(parameters->quadraticLossPos){
      theloss = parameters->Lpos*(theloss-parameters->epsilon_pos)
	*(theloss-parameters->epsilon_pos); 
    }
    else{
      theloss =  parameters->Lpos*(theloss-parameters->epsilon_pos); 
    };
  }
  else if(theloss >= -parameters->epsilon_neg){ theloss = 0; }
  else{ 
    if(parameters->quadraticLossNeg){
      theloss = parameters->Lneg*(-theloss-parameters->epsilon_neg)
	*(-theloss-parameters->epsilon_neg);
    }
    else{
      theloss = parameters->Lneg*(-theloss-parameters->epsilon_neg);
    };
  };
  return theloss;
};


void svm_c::print_special_statistics(){
  // nothing special here!
};


svm_result svm_c::print_statistics(){
  // # SV, # BSV, pos&neg, Loss, VCdim
  // Pattern: Acc, Rec, Pred

  if(parameters->verbosity>=2){
    cout<<"----------------------------------------"<<endl;
  };

  svm_result the_result;
  if(test_set->size() <= 0){
    if(parameters->verbosity>= 0){
      cout << "No training set given" << endl;
    };
    the_result.MAE = 0;
    the_result.MSE = 0;
    the_result.loss = 0;
    the_result.loss_pos = 0;
    the_result.loss_neg = 0;
    the_result.number_svs = 0;
    the_result.number_bsv = 0;
    the_result.accuracy = 0;
    the_result.precision = 0;
    the_result.recall = 0;
    return the_result;
  };
  SVMINT i;
  SVMINT svs = 0;
  SVMINT bsv = 0;
  SVMFLOAT actloss = 0;
  SVMFLOAT theloss = 0;
  SVMFLOAT theloss_pos=0;
  SVMFLOAT theloss_neg=0;
  SVMINT countpos=0;
  SVMINT countneg=0;
  SVMFLOAT min_alpha=infinity;
  SVMFLOAT max_alpha=-infinity;
  SVMFLOAT norm_w=0;
  SVMFLOAT max_norm_x=0;
  SVMFLOAT min_norm_x=1e20;
  SVMFLOAT norm_x=0;
  SVMFLOAT loo_loss_estim=0;
  // for pattern:
  SVMFLOAT correct_pos=0;
  SVMINT correct_neg=0;
  SVMINT total_pos=0;
  SVMINT total_neg=0;
  SVMINT estim_pos=0;
  SVMINT estim_neg=0;
  SVMFLOAT MSE=0;
  SVMFLOAT MAE=0;
  SVMFLOAT alpha;
  SVMFLOAT prediction;
  SVMFLOAT y;
  SVMFLOAT xi;

  for(i=0;i<examples_total;i++){
    // needed before test-loop for performance estimators
    norm_w+=all_alphas[i]*sum[i];

    alpha=all_alphas[i];
    if(alpha!=0){
      norm_x = kernel->calculate_K(i,i);
      if(norm_x>max_norm_x){
	max_norm_x = norm_x;
      };
      if(norm_x<min_norm_x){
	min_norm_x = norm_x;
      };
    };
  };

  for(i=0;i<examples_total;i++){
    alpha=all_alphas[i];
    if(alpha<min_alpha) min_alpha = alpha;
    if(alpha>max_alpha) max_alpha = alpha;
    prediction = predict(i);
    y = all_ys[i];
    actloss=loss(prediction,y);
    theloss+=actloss;
    MAE += abs(prediction-y);
    MSE += (prediction-y)*(prediction-y);
    if(y < prediction-parameters->epsilon_pos){
      theloss_pos += actloss;
      countpos++;
    }
    else if(y > prediction+parameters->epsilon_neg){
      theloss_neg += actloss;
      countneg++;
    };
    if(abs(alpha)>is_zero){ 
      if(is_alpha_neg(i)>=0){
	loo_loss_estim += loss(prediction-(abs(alpha)*(2*kernel->calculate_K(i,i)+max_norm_x)+2*epsilon_neg),y);
      }
      else{
	loo_loss_estim += loss(prediction+(abs(alpha)*(2*kernel->calculate_K(i,i)+max_norm_x)+2*epsilon_pos),y);
      };

      // a support vector
      svs++; 
      if((alpha-Cneg >= -is_zero) || (alpha+Cpos <= is_zero)){ 
	bsv++; 
      };
    }
    else{
      // loss doesn't change if non-SV is omitted
      loo_loss_estim += actloss;
    };
    if(is_pattern){
      if(y>0){
	if(prediction>0){
	  correct_pos++;
	};
	if(prediction>1){
	  xi=0;
	}
	else{
	  xi=1-prediction;
	};
	if(2*alpha*(max_norm_x-min_norm_x)+xi >= 1){
	  estim_pos++;
	};
	total_pos++;
      }
      else{
	if(prediction<=0){
	  correct_neg++;
	};
	if(prediction<-1){
	  xi=0;
	}
	else{
	  xi=1+prediction;
	};
	if(2*(-alpha)*(max_norm_x-min_norm_x)+xi >= 1){
	  estim_neg++;
	};
	total_neg++;
      };
    };    
  };
  if(countpos != 0){
    theloss_pos /= (SVMFLOAT)countpos;
  };
  if(countneg != 0){
    theloss_neg /= (SVMFLOAT)countneg;
  };

  the_result.MAE = MAE / (SVMFLOAT)examples_total;
  the_result.MSE = MSE / (SVMFLOAT)examples_total;
  the_result.VCdim = 1+norm_w*max_norm_x;
  the_result.loss = theloss/((SVMFLOAT)examples_total);
  the_result.pred_loss = loo_loss_estim/((SVMFLOAT)examples_total);
  the_result.loss_pos = theloss_pos;
  the_result.loss_neg = theloss_neg;
  the_result.number_svs = svs;
  the_result.number_bsv = bsv;
  if(is_pattern){
    the_result.accuracy = ((SVMFLOAT)(correct_pos+correct_neg))/((SVMFLOAT)(total_pos+total_neg));
    the_result.precision = ((SVMFLOAT)correct_pos/((SVMFLOAT)(correct_pos+total_neg-correct_neg)));
    the_result.recall = ((SVMFLOAT)correct_pos/(SVMFLOAT)total_pos);
    the_result.pred_accuracy = (1-((SVMFLOAT)(estim_pos+estim_neg))/((SVMFLOAT)(total_pos+total_neg)));
    the_result.pred_precision = ((SVMFLOAT)(total_pos-estim_pos))/((SVMFLOAT)(total_pos-estim_pos+estim_neg));
    the_result.pred_recall = (1-(SVMFLOAT)estim_pos/((SVMFLOAT)total_pos));
  }
  else{
    the_result.accuracy = -1;
    the_result.precision = -1;
    the_result.recall = -1;
    the_result.pred_accuracy = -1;
    the_result.pred_precision = -1;
    the_result.pred_recall = -1;
  };
  if(convergence_epsilon > parameters->convergence_epsilon){
    cout<<"WARNING: The results were obtained using a relaxed epsilon of "<<convergence_epsilon<<" on the KKT conditions!"<<endl;
  }
  else if(parameters->verbosity>=2){
    cout<<"The results are valid with an epsilon of "<<convergence_epsilon<<" on the KKT conditions."<<endl;
  };
  if(parameters->verbosity >= 2){
    cout << "Average loss  : "<<the_result.loss<<" (loo-estim: "<< the_result.pred_loss<<")"<<endl;
    cout << "Avg. loss pos : "<<theloss_pos<<"\t ("<<countpos<<" occurences)"<<endl;
    cout << "Avg. loss neg : "<<theloss_neg<<"\t ("<<countneg<<" occurences)"<<endl;
    cout << "Mean absolute error : "<<the_result.MAE<<endl;
    cout << "Mean squared error  : "<<the_result.MSE<<endl;
    cout << "Support Vectors : "<<svs<<endl;
    cout << "Bounded SVs     : "<<bsv<<endl;
    cout<<"min SV: "<<min_alpha<<endl
	<<"max SV: "<<max_alpha<<endl;
    cout<<"|w| = "<<sqrt(norm_w)<<endl;
    cout<<"max |x| = "<<sqrt(max_norm_x)<<endl;
    cout<<"VCdim <= "<<the_result.VCdim<<endl;

    print_special_statistics();

    if((is_pattern) && (! parameters->is_distribution)){
      // output precision, recall and accuracy
      cout<<"performance (+estimators):"<<endl;
      cout<<"Accuracy  : "<<the_result.accuracy<<" ("<<the_result.pred_accuracy<<")"<<endl;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -