⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_c.cpp

📁 支持向量机(4)mySVM
💻 CPP
📖 第 1 页 / 共 5 页
字号:
			for(SVMINT i=0;i<working_set_size;i++){
				if(primal[i]<is_zero){
					lambda_lo =  epsilon_neg + epsilon_pos - qp.c[i];
					for(SVMINT j=0;j<working_set_size;j++){
						lambda_lo -= primal[j]*qp.H[i*working_set_size+j];
					};
					if(qp.A[i] > 0){
						lambda_lo -= lambda_WS;
					}
					else{
						lambda_lo += lambda_WS;
					};
					
					//cout<<"primal["<<i<<"] = "<<primal[i]<<", lambda_lo = "<<lambda_lo<<endl;
					
					if(lambda_lo<-convergence_epsilon){
						// change sign of i
						KKTerror=1;
						qp.A[i] = -qp.A[i];
						which_alpha[i] = -which_alpha[i];
						primal[i] = -primal[i];
						qp.c[i] = epsilon_neg + epsilon_pos - qp.c[i];
						if(qp.A[i]>0){
							qp.u[i] = Cneg;
						}
						else{
							qp.u[i] = Cpos;
						};
						for(SVMINT j=0;j<working_set_size;j++){
							qp.H[i*working_set_size+j] = -qp.H[i*working_set_size+j];
							qp.H[j*working_set_size+i] = -qp.H[j*working_set_size+i];
						};
						if(parameters->quadraticLossNeg){
							if(which_alpha[i]>0){
								(qp.H)[i*(working_set_size+1)] += 1/Cneg;
								(qp.u)[i] = infinity;
							}
							else{
								// previous was neg
								(qp.H)[i*(working_set_size+1)] -= 1/Cneg;
							};
						};
						if(parameters->quadraticLossPos){
							if(which_alpha[i]<0){
								(qp.H)[i*(working_set_size+1)] += 1/Cpos;
								(qp.u)[i] = infinity;
							}
							else{
								//previous was pos
								(qp.H)[i*(working_set_size+1)] -= 1/Cpos;
							};
						};
					};
				};
			};
			result = smo.smo_solve(&qp,primal);
			lambda_WS = smo.get_lambda_eq();
			
		};
	};
	
	
	
	KKTerror = 1;
	//////////////////////
	
	
	
	if(parameters->verbosity>=5){
		cout<<"smo ended with result "<<result<<endl;
		cout<<"lambda_WS = "<<lambda_WS<<endl;
		cout<<"smo: Resulting values:"<<endl;
		for(i=0;i<working_set_size;i++){
			cout<<i<<": "<<primal[i]<<endl; 
		};
	};
	
	while(KKTerror){
		// clip
		sv_count=working_set_size;
		new_constraint_sum=qp.b[0];
		for(i=0;i<working_set_size;i++){
			// check if at bound
			if(primal[i] <= my_is_zero){
				// at lower bound
				primal[i] = qp.l[i];
				sv_count--;
			}
			else if(qp.u[i]-primal[i] <= my_is_zero){
				// at upper bound
				primal[i] = qp.u[i];
				sv_count--;
			};
			new_constraint_sum -= qp.A[i]*primal[i];
		};
		
		// enforce equality constraint
		if(sv_count>0){
			new_constraint_sum /= (SVMFLOAT)sv_count;
			if(parameters->verbosity>=5){
				cout<<"adjusting "<<sv_count<<" alphas by "<<new_constraint_sum<<endl;
			};
			for(i=0;i<working_set_size;i++){
				if((primal[i] > qp.l[i]) && 
					(primal[i] < qp.u[i])){
					// real sv
					primal[i] += qp.A[i]*new_constraint_sum;
				};
			};
		}
		else if(abs(new_constraint_sum)>(SVMFLOAT)working_set_size*is_zero){
			// error, can't get feasible point
			if(parameters->verbosity>=5){
				cout<<"WARNING: No SVs, constraint_sum = "<<new_constraint_sum<<endl;
			};
			old_target = -infinity; 
			//is_ok=0;
			convError=1;
		};
		// test descend
		new_target=0;
		for(i=0;i<working_set_size;i++){
			// attention: loqo changes one triangle of H!
			target_tmp = primal[i]*qp.H[i*working_set_size+i]/2;
			for(j=0;j<i;j++){
				target_tmp+=primal[j]*qp.H[j*working_set_size+i];
			};
			target_tmp+=qp.c[i];
			new_target+=target_tmp*primal[i];
		};
		
		if(new_target < old_target){
			KKTerror = 0;
			if(parameters->descend < old_target - new_target){
				target_count=0;
			}
			else{
				convError=1;
			};
			if(parameters->verbosity>=5){
				cout<<"descend = "<<old_target-new_target<<endl;
			};
		}
		else if(sv_count > 0){
			// less SVs
			// set my_is_zero to min_i(primal[i]-qp.l[i], qp.u[i]-primal[i])
			my_is_zero = 1e20;
			for(i=0;i<working_set_size;i++){
				if((primal[i] > qp.l[i]) && (primal[i] < qp.u[i])){
					if(primal[i] - qp.l[i] < my_is_zero){
						my_is_zero = primal[i]-qp.l[i];
					};
					if(qp.u[i]  - primal[i]  < my_is_zero){
						my_is_zero = qp.u[i] - primal[i];
					};
				};
			};
			if(target_count == 0){
				my_is_zero *= 2;
			};
			if(parameters->verbosity>=5){
				cout<<"WARNING: no descend ("<<old_target-new_target
					<<" <= "<<parameters->descend
					//	    <<", alpha_diff = "<<alpha_diff
					<<"), adjusting is_zero to "<<my_is_zero<<endl;
				cout<<"new_target = "<<new_target<<endl;
			};
		}
		else{
			// nothing we can do
			if(parameters->verbosity>=5){
				cout<<"WARNING: no descend ("<<old_target-new_target
					<<" <= "<<parameters->descend<<"), stopping."<<endl;
			};
			KKTerror=0;
			convError=1;
		};
	};
	
	if(1 == convError){
		target_count++;
		//    sigfig_max+=0.05;
		if(old_target < new_target){
			for(i=0;i<working_set_size;i++){
				primal[i] = qp.A[i]*all_alphas[working_set[i]];
			};                              
			if(parameters->verbosity>=5){	
				cout<<"WARNING: Convergence error, restoring old primals"<<endl; //, setting sigfig = "<<sigfig_max<<endl;
			};
		};                                          
	};
	
	if(target_count>50){
		// non-recoverable numerical error
		feasible_epsilon=1;
		convergence_epsilon*=2;
		//    sigfig_max=-log10(is_zero);
		if(parameters->verbosity>=1)
			cout<<"WARNING: reducing KKT precision to "<<convergence_epsilon<<endl;
		target_count=0;
	};
	
	time_optimize += get_time() - time_start;
};


SVMFLOAT svm_c::predict(svm_example example){ 
	SVMINT i;
	svm_example sv;
	SVMFLOAT the_sum=examples->get_b();
	
	for(i=0;i<examples_total;i++){
		if(all_alphas[i] != 0){
			sv = examples->get_example(i);
			the_sum += all_alphas[i]*kernel->calculate_K(sv,example);
		};
	};
	
	return the_sum;
};


SVMFLOAT svm_c::predict(SVMINT i){
	// return (sum[i]+examples->get_b());
	// numerically mor stable:
	return predict(examples->get_example(i));
};

SVMFLOAT svm_c::loss(SVMINT i){
	return loss(predict(i),all_ys[i]);
};


SVMFLOAT svm_c::loss(SVMFLOAT prediction, SVMFLOAT value){
	SVMFLOAT theloss = prediction - value;
	if(is_pattern){
		if(((value > 0) && (prediction > 0)) ||
			((value <= 0) && (prediction <= 0))){
			theloss = 0;
		}
	};
	if(theloss > parameters->epsilon_pos){ 
		if(parameters->quadraticLossPos){
			theloss = parameters->Lpos*(theloss-parameters->epsilon_pos)
				*(theloss-parameters->epsilon_pos); 
		}
		else{
			theloss =  parameters->Lpos*(theloss-parameters->epsilon_pos); 
		};
	}
	else if(theloss >= -parameters->epsilon_neg){ theloss = 0; }
	else{ 
		if(parameters->quadraticLossNeg){
			theloss = parameters->Lneg*(-theloss-parameters->epsilon_neg)
				*(-theloss-parameters->epsilon_neg);
		}
		else{
			theloss = parameters->Lneg*(-theloss-parameters->epsilon_neg);
		};
	};
	return theloss;
};


void svm_c::print_special_statistics(){
	// nothing special here!
};


svm_result svm_c::print_statistics(){
	// # SV, # BSV, pos&neg, Loss, VCdim
	// Pattern: Acc, Rec, Pred
	
	if(parameters->verbosity>=2){
		cout<<"----------------------------------------"<<endl;
	};
	
	svm_result the_result;
	if(test_set->size() <= 0){
		if(parameters->verbosity>= 0){
			cout << "No training set given" << endl;
		};
		the_result.MAE = 0;
		the_result.MSE = 0;
		the_result.loss = 0;
		the_result.loss_pos = 0;
		the_result.loss_neg = 0;
		the_result.number_svs = 0;
		the_result.number_bsv = 0;
		the_result.accuracy = 0;
		the_result.precision = 0;
		the_result.recall = 0;
		return the_result;
	};
	SVMINT i;
	SVMINT svs = 0;
	SVMINT bsv = 0;
	SVMFLOAT actloss = 0;
	SVMFLOAT theloss = 0;
	SVMFLOAT theloss_pos=0;
	SVMFLOAT theloss_neg=0;
	SVMINT countpos=0;
	SVMINT countneg=0;
	SVMFLOAT min_alpha=infinity;
	SVMFLOAT max_alpha=-infinity;
	SVMFLOAT norm_w=0;
	SVMFLOAT max_norm_x=0;
	SVMFLOAT min_norm_x=1e20;
	SVMFLOAT norm_x=0;
	SVMFLOAT loo_loss_estim=0;
	// for pattern:
	SVMFLOAT correct_pos=0;
	SVMINT correct_neg=0;
	SVMINT total_pos=0;
	SVMINT total_neg=0;
	SVMINT estim_pos=0;
	SVMINT estim_neg=0;
	SVMFLOAT MSE=0;
	SVMFLOAT MAE=0;
	SVMFLOAT alpha;
	SVMFLOAT prediction;
	SVMFLOAT y;
	SVMFLOAT xi;
	
	for(i=0;i<examples_total;i++){
		// needed before test-loop for performance estimators
		norm_w+=all_alphas[i]*sum[i];
		
		alpha=all_alphas[i];
		if(alpha!=0){
			norm_x = kernel->calculate_K(i,i);
			if(norm_x>max_norm_x){
				max_norm_x = norm_x;
			};
			if(norm_x<min_norm_x){
				min_norm_x = norm_x;
			};
		};
	};
	
	for(i=0;i<examples_total;i++){
		alpha=all_alphas[i];
		if(alpha<min_alpha) min_alpha = alpha;
		if(alpha>max_alpha) max_alpha = alpha;
		prediction = predict(i);
		y = all_ys[i];
		actloss=loss(prediction,y);
		theloss+=actloss;
		MAE += abs(prediction-y);
		MSE += (prediction-y)*(prediction-y);
		if(y < prediction-parameters->epsilon_pos){
			theloss_pos += actloss;
			countpos++;
		}
		else if(y > prediction+parameters->epsilon_neg){
			theloss_neg += actloss;
			countneg++;
		};
		if(abs(alpha)>is_zero){ 
			if(is_alpha_neg(i)>=0){
				loo_loss_estim += loss(prediction-(abs(alpha)*(2*kernel->calculate_K(i,i)+max_norm_x)+2*epsilon_neg),y);
			}
			else{
				loo_loss_estim += loss(prediction+(abs(alpha)*(2*kernel->calculate_K(i,i)+max_norm_x)+2*epsilon_pos),y);
			};
			
			// a support vector
			svs++; 
			if((alpha-Cneg >= -is_zero) || (alpha+Cpos <= is_zero)){ 
				bsv++; 
			};
		}
		else{
			// loss doesn't change if non-SV is omitted
			loo_loss_estim += actloss;
		};
		if(is_pattern){
			if(y>0){
				if(prediction>0){
					correct_pos++;
				};
				if(prediction>1){
					xi=0;
				}
				else{
					xi=1-prediction;
				};
				if(2*alpha*(max_norm_x-min_norm_x)+xi >= 1){
					estim_pos++;
				};
				total_pos++;
			}
			else{
				if(prediction<=0){
					correct_neg++;
				};
				if(prediction<-1){
					xi=0;
				}
				else{
					xi=1+prediction;
				};
				if(2*(-alpha)*(max_norm_x-min_norm_x)+xi >= 1){
					estim_neg++;
				};
				total_neg++;
			};
		};    
	};
	if(countpos != 0){
		theloss_pos /= (SVMFLOAT)countpos;
	};
	if(countneg != 0){
		theloss_neg /= (SVMFLOAT)countneg;
	};
	
	the_result.MAE = MAE / (SVMFLOAT)examples_total;
	the_result.MSE = MSE / (SVMFLOAT)examples_total;
	the_result.VCdim = 1+norm_w*max_norm_x;
	the_result.loss = theloss/((SVMFLOAT)examples_total);
	the_result.pred_loss = loo_loss_estim/((SVMFLOAT)examples_total);
	the_result.loss_pos = theloss_pos;
	the_result.loss_neg = theloss_neg;
	the_result.number_svs = svs;
	the_result.number_bsv = bsv;
	if(is_pattern){
		the_result.accuracy = ((SVMFLOAT)(correct_pos+correct_neg))/((SVMFLOAT)(total_pos+total_neg));
		the_result.precision = ((SVMFLOAT)correct_pos/((SVMFLOAT)(correct_pos+total_neg-correct_neg)));
		the_result.recall = ((SVMFLOAT)correct_pos/(SVMFLOAT)total_pos);
		the_result.pred_accuracy = (1-((SVMFLOAT)(estim_pos+estim_neg))/((SVMFLOAT)(total_pos+total_neg)));
		the_result.pred_precision = ((SVMFLOAT)(total_pos-estim_pos))/((SVMFLOAT)(total_pos-estim_pos+estim_neg));
		the_result.pred_recall = (1-(SVMFLOAT)estim_pos/((SVMFLOAT)total_pos));
	}
	else{
		the_result.accuracy = -1;
		the_result.precision = -1;
		the_result.recall = -1;
		the_result.pred_accuracy = -1;
		the_result.pred_precision = -1;
		the_result.pred_recall = -1;
	};
	if(convergence_epsilon > parameters->convergence_epsilon){
		cout<<"WARNING: The results were obtained using a relaxed epsilon of "<<convergence_epsilon<<" on the KKT conditions!"<<endl;
	}
	else if(parameters->verbosity>=2){
		cout<<"The results are valid with an epsilon of "<<convergence_epsilon<<" on the KKT conditions."<<endl;
	};
	if(parameters->verbosity >= 2){
		cout << "Average loss  : "<<the_result.loss<<" (loo-estim: "<< the_result.pred_loss<<")"<<endl;
		cout << "Avg. loss pos : "<<theloss_pos<<"\t ("<<countpos<<" occurences)"<<endl;
		cout << "Avg. loss neg : "<<theloss_neg<<"\t ("<<countneg<<" occurences)"<<endl;
		cout << "Mean absolute error : "<<the_result.MAE<<endl;
		cout << "Mean squared error  : "<<the_result.MSE<<endl;
		cout << "Support Vectors : "<<svs<<endl;
		cout << "Bounded SVs     : "<<bsv<<endl;
		cout<<"min SV: "<<min_alpha<<endl
			<<"max SV: "<<max_alpha<<endl;
		cout<<"|w| = "<<sqrt(norm_w)<<endl;
		cout<<"max |x| = "<<sqrt(max_norm_x)<<endl;
		cout<<"VCdim <= "<<the_result.VCdim<<endl;
		
		print_special_statistics();
		
		if((is_pattern) && (! parameters->is_distribution)){
			// output precision, recall and accuracy
			cout<<"performance (+estimators):"<<endl;
			cout<<"Accuracy  : "<<the_result.accuracy<<" ("<<the_result.pred_accuracy<<")"<<endl;
			cout<<"Precision : "<<the_result.precision<<" ("<<the_result.pred_precision<<")"<<endl;
			cout<<"Recall    : "<<the_result.recall<<" ("<<the_result.pred_recall<<")"<<endl;
			if(parameters->verbosity>= 2){
				// nice printout ;-)
				int rows = (int)(1+log10((SVMFLOAT)(total_pos+total_neg)));
				int now_digits = rows+2;
				int i,j;
				cout<<endl;
				cout<<"Predicted values:"<<endl;
				cout<<"   |";
				for(i=0;i<rows;i++){ cout<<" "; };
				cout<<"+  |";
				for(j=0;j<rows;j++){ cout<<" "; };
				cout<<"-"<<endl;
				
				cout<<"---+";
				for(i=0;i<now_digits;i++){ cout<<"-"; };
				cout<<"-+-";
				for(i=0;i<now_digits;i++){ cout<<"-"; };
				cout<<endl;
				
				cout<<" + |  ";
				now_digits=rows-(int)(1+log10((SVMFLOAT)correct_pos))-1;
				for(i=0;i<now_digits;i++){ cout<<" "; };
				cout<<correct_pos<<"  |  ";
				now_digits=rows-(int)(1+log10((SVMFLOAT)(total_pos-correct_pos)))-1;
				for(i=0;i<now_digits;i++){ cout<<" "; };
				cout<<total_pos-correct_pos<<"    (true pos)"<<endl;
				
				cout<<" - |  ";
				now_digits=rows-(int)(1+log10((SVMFLOAT)(total_neg-correct_neg)))-1;
				for(i=0;i<now_digits;i++){ cout<<" "; };
				cout<<(total_neg-correct_neg)<<"  |  ";
				now_digits=rows-(int)(1+log10((SVMFLOAT)correct_neg))-1;
				for(i=0;i<now_digits;i++){ cout<<" "; };
				cout<<correct_neg<<"    (true neg)"<<endl;
				cout<<endl;
			};
		};
	};
	
	SVMINT dim = examples->get_dim();
	if((dim<30) && (parameters->verbosity>= 2)){ // && (parameters->is_linear != 0)){
		// print hyperplane
		SVMINT j;
		svm_example example;
		SVMFLOAT* w = new SVMFLOAT[dim];
		SVMFLOAT b = examples->get_b();
		for(j=0;j<dim;j++) w[j] = 0;
		for(i=0;i<examples_total;i++){
			example = examples->get_example(i);
			alpha = examples->get_alpha(i);
			for(j=0;j<example.length;j++){
				w[((example.example)[j]).index] += alpha*((example.example)[j]).att;
			};
		};

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -