⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm_c.cpp

📁 支持向量机(4)mySVM
💻 CPP
📖 第 1 页 / 共 5 页
字号:
		};
		
		if(parameters->verbosity>=5){
			cout<<"ws: feasible_epsilon = "<<feasible_epsilon<<endl;
		};
	};
	
	if((working_set_size<examples_total) && (working_set_size>0)){
		// convergence error on last iteration, 
		// some more tests on WS
		// unlikely to happen, so speed isn't so important
		
		// are all variables at the bound?
		SVMINT pos_abs;
		int bounded_pos=1;
		int bounded_neg=1;
		SVMINT pos=0;
		while((pos<working_set_size) && ((1 == bounded_pos) || (1 == bounded_neg))){
			pos_abs = working_set[pos];
			if(is_alpha_neg(pos_abs) > 0){
				if(all_alphas[pos_abs]-Cneg < -is_zero){
					bounded_pos = 0;
				};
				if(all_alphas[pos_abs] > is_zero){
					bounded_neg = 0;
				};
			}
			else{
				if(all_alphas[pos_abs]+Cneg > is_zero){
					bounded_neg = 0;
				};
				if(all_alphas[pos_abs] < -is_zero){
					bounded_pos = 0;
				};
			};
			pos++;
		};
		if(0 != bounded_pos){
			// need alpha that can be moved upward
			// use alpha with smallest lambda
			SVMFLOAT max_lambda = infinity;
			SVMINT max_pos=examples_total;
			for(pos_abs=0;pos_abs<examples_total;pos_abs++){
				if(is_alpha_neg(pos_abs) > 0){
					if(all_alphas[pos_abs]-Cneg < -is_zero){
						if(lambda(pos_abs) < max_lambda){
							max_lambda = lambda(pos_abs);
							max_pos = pos_abs;
						};
					};
				}
				else{
					if(all_alphas[pos_abs] < -is_zero){
						if(lambda(pos_abs) < max_lambda){
							max_lambda = lambda(pos_abs);
							max_pos = pos_abs;
						};
					};
				};
			};
			if(max_pos<examples_total){
				if(working_set_size<parameters->working_set_size){
					working_set_size++;
				};
				working_set[working_set_size-1] = max_pos;
			};
		}
		else if(0 != bounded_neg){
			// need alpha that can be moved downward
			// use alpha with smallest lambda
			SVMFLOAT max_lambda = infinity;
			SVMINT max_pos=examples_total;
			for(pos_abs=0;pos_abs<examples_total;pos_abs++){
				if(is_alpha_neg(pos_abs) > 0){
					if(all_alphas[pos_abs] > is_zero){
						if(lambda(pos_abs) < max_lambda){
							max_lambda = lambda(pos_abs);
							max_pos = pos_abs;
						};
					};
				}
				else{
					if(all_alphas[pos_abs]+Cneg > is_zero){
						if(lambda(pos_abs) < max_lambda){
							max_lambda = lambda(pos_abs);
							max_pos = pos_abs;
						};
					};
				};
			};
			if(max_pos<examples_total){
				if(working_set_size<parameters->working_set_size){
					working_set_size++;
				};
				working_set[working_set_size-1] = max_pos;
			};
		};
	};
	
	if((working_set_size<parameters->working_set_size) &&
		(working_set_size<examples_total)){
		// use full working set
		SVMINT pos = (SVMINT)((SVMFLOAT)examples_total*rand()/(RAND_MAX+1.0));
		int ok;
		while((working_set_size<parameters->working_set_size) &&
			(working_set_size<examples_total)){
			// add pos into WS if it isn't already
			ok = 1;
			for(i=0;i<working_set_size;i++){
				if(working_set[i] == pos){
					ok=0;
					i = working_set_size;
				};
			};
			if(1 == ok){
				working_set[working_set_size] = pos;
				working_set_size++;
			};
			pos = (pos+1)%examples_total;
		};
	};
	
	for(SVMINT ipos=0;ipos<working_set_size;ipos++){
		which_alpha[ipos] = is_alpha_neg(working_set[ipos]);
	};
	
	time_calc += get_time() - time_start;
	return;
};


void svm_c::project_to_constraint(){
	// project alphas to match the constraint
	SVMFLOAT alpha_sum = sum_alpha;
	SVMINT SVcount=0;
	SVMFLOAT alpha;
	for(SVMINT i=0;i<examples_total;i++){
		alpha = all_alphas[i];
		alpha_sum += alpha;
		if(((alpha>is_zero) && (alpha-Cneg < -is_zero)) ||
			((alpha<-is_zero) && (alpha+Cpos > is_zero))){
			SVcount++;
		};
	};
	if(SVcount > 0){
		// project
		alpha_sum /= (SVMFLOAT)SVcount;
		for(SVMINT i=0;i<examples_total;i++){
			alpha = all_alphas[i];
			if(((alpha>is_zero) && (alpha-Cneg < -is_zero)) ||
				((alpha<-is_zero) && (alpha+Cpos > is_zero))){
				all_alphas[i] -= alpha_sum;
			};
		};
	};
};


void svm_c::init_working_set(){
	// calculate sum
	SVMINT i,j;
	
	project_to_constraint();
	// check bounds!
	if(examples->initialised_alpha()){
		if(parameters->verbosity >= 2){
			cout<<"Initialising variables, this may take some time."<<endl;
		};
		for(i=0; i<examples_total;i++){
			sum[i] = 0;
			at_bound[i] = 0;
			for(j=0; j<examples_total;j++){
				sum[i] += all_alphas[j]*kernel->calculate_K(i,j);
			};
		};
	}
	else{
		// skip kernel calculation as all alphas = 0
		for(i=0; i<examples_total;i++){
			sum[i] = 0;
			at_bound[i] = 0;
		};    
	};
	
	if(examples->initialised_alpha()){
		calculate_working_set();
	}
	else{
		// first working set is random
		j=0;
		i=0;
		while((i<working_set_size) && (j < examples_total)){
			working_set[i] = j;
			if(is_alpha_neg(j) > 0){
				which_alpha[i] = 1;
			}
			else{
				which_alpha[i] = -1;
			};
			i++;
			j++;
		};
	};   
	update_working_set();
};


void svm_c::put_optimizer_values(){
	// update nabla, sum, examples.
	// sum[i] += (primal_j^*-primal_j-alpha_j^*+alpha_j)K(i,j)
	// check for |nabla| < is_zero (nabla <-> nabla*)
	//  cout<<"put_optimizer_values()"<<endl;
	SVMINT i=0; 
	SVMINT j=0;
	SVMINT pos_i;
	SVMFLOAT the_new_alpha;
	SVMFLOAT* kernel_row;
	SVMFLOAT alpha_diff;
	
	long time_start = get_time();
	pos_i=working_set_size;
	while(pos_i>0){
		pos_i--;
		if(which_alpha[pos_i]>0){
			the_new_alpha = primal[pos_i];
		}
		else{
			the_new_alpha = -primal[pos_i];
		};
		// next three statements: keep this order!
		i = working_set[pos_i];
		alpha_diff = the_new_alpha-all_alphas[i];
		all_alphas[i] = the_new_alpha;
		
		if(alpha_diff != 0){
			// update sum ( => nabla)
			kernel_row = kernel->get_row(i);
			for(j=0;j<examples_total;j++){
				sum[j] += alpha_diff*kernel_row[j];
			};
		};
	};
	time_update += get_time() - time_start;
};


void svm_c::update_working_set(){
	long time_start = get_time();
	
	// setup subproblem
	SVMINT i,j;
	SVMINT pos_i, pos_j;
	SVMFLOAT* kernel_row;
	SVMFLOAT sum_WS;
	
	for(pos_i=0;pos_i<working_set_size;pos_i++){
		i = working_set[pos_i];
		
		// put row sort_i in hessian 
		kernel_row = kernel->get_row(i);
		sum_WS=0;
		//    for(pos_j=0;pos_j<working_set_size;pos_j++){
		for(pos_j=0;pos_j<pos_i;pos_j++){
			j = working_set[pos_j];
			// put all elements K(i,j) in hessian, where j in WS
			if(((which_alpha[pos_j] < 0) && (which_alpha[pos_i] < 0)) ||
				((which_alpha[pos_j] > 0) && (which_alpha[pos_i] > 0))){
				// both i and j positive or negative
				(qp.H)[pos_i*working_set_size+pos_j] = kernel_row[j];
				(qp.H)[pos_j*working_set_size+pos_i] = kernel_row[j];
			}
			else{
				// one of i and j positive, one negative
				(qp.H)[pos_i*working_set_size+pos_j] = -kernel_row[j];
				(qp.H)[pos_j*working_set_size+pos_i] = -kernel_row[j];
			};
		};
		for(pos_j=0;pos_j<working_set_size;pos_j++){
			j = working_set[pos_j];
			sum_WS+=all_alphas[j]*kernel_row[j];
		};
		// set main diagonal 
		(qp.H)[pos_i*working_set_size+pos_i] = kernel_row[i];
		
		// linear and box constraints
		if(which_alpha[pos_i]<0){
			// alpha
			(qp.A)[pos_i] = -1;
			// lin(alpha) = y_i+eps-sum_{i not in WS} alpha_i K_{ij}
			//            = y_i+eps-sum_i+sum_{i in WS}
			(qp.c)[pos_i] = all_ys[i]+epsilon_pos-sum[i]+sum_WS;
			primal[pos_i] = -all_alphas[i];
			(qp.u)[pos_i] = Cpos;
		}
		else{
			// alpha^*
			(qp.A)[pos_i] = 1;
			(qp.c)[pos_i] = -all_ys[i]+epsilon_neg+sum[i]-sum_WS;
			primal[pos_i] = all_alphas[i];
			(qp.u)[pos_i] = Cneg;
		};
	};
	if(parameters->quadraticLossNeg){
		for(pos_i=0;pos_i<working_set_size;pos_i++){
			if(which_alpha[pos_i]>0){
				(qp.H)[pos_i*(working_set_size+1)] += 1/Cneg;
				(qp.u)[pos_i] = infinity;
			};
		};
	};
	if(parameters->quadraticLossPos){
		for(pos_i=0;pos_i<working_set_size;pos_i++){
			if(which_alpha[pos_i]<0){
				(qp.H)[pos_i*(working_set_size+1)] += 1/Cpos;
				(qp.u)[pos_i] = infinity;
			};
		};
	};
	
	time_update += get_time() - time_start; 
};


svm_result svm_c::test(example_set_c* test_examples, int verbose){
	svm_result the_result;
	test_set = test_examples;
	
	SVMINT i;
	SVMFLOAT MAE=0;
	SVMFLOAT MSE=0;
	SVMFLOAT actloss=0;
	SVMFLOAT theloss=0;
	SVMFLOAT theloss_pos=0;
	SVMFLOAT theloss_neg=0;
	SVMINT countpos=0;
	SVMINT countneg=0;
	// for pattern:
	SVMINT correct_pos=0;
	SVMINT correct_neg=0;
	SVMINT total_pos=0;
	SVMINT total_neg=0;
	
	SVMFLOAT prediction;
	SVMFLOAT y;
	svm_example example;
	for(i=0;i<test_set->size();i++){
		example = test_set->get_example(i);
		prediction = predict(example);
		y = test_set->get_y(i);
		MAE += abs(y-prediction);
		MSE += (y-prediction)*(y-prediction);
		actloss=loss(prediction,y);
		theloss+=actloss;
		if(y < prediction-parameters->epsilon_pos){
			theloss_pos += actloss;
			countpos++;
		}
		else if(y > prediction+parameters->epsilon_neg){
			theloss_neg += actloss;
			countneg++;
		};
		// if pattern!
		if(is_pattern){
			if(y>0){
				if(prediction>0){
					correct_pos++;
				};
				total_pos++;
			}
			else{
				if(prediction<=0){
					correct_neg++;
				};
				total_neg++;
			};
		};    
	};
	if(countpos != 0){
		theloss_pos /= (SVMFLOAT)countpos;
	};
	if(countneg != 0){
		theloss_neg /= (SVMFLOAT)countneg;
	};
	
	the_result.MAE =  MAE / (SVMFLOAT)test_set->size();
	the_result.MSE =  MSE / (SVMFLOAT)test_set->size();
	the_result.loss = theloss/test_set->size();
	the_result.loss_pos = theloss_pos;
	the_result.loss_neg = theloss_neg;
	the_result.number_svs = 0;
	the_result.number_bsv = 0;
	if(is_pattern){
		the_result.accuracy = ((SVMFLOAT)(correct_pos+correct_neg))/((SVMFLOAT)(total_pos+total_neg));
		the_result.precision = ((SVMFLOAT)correct_pos/((SVMFLOAT)(correct_pos+total_neg-correct_neg)));
		the_result.recall = ((SVMFLOAT)correct_pos/(SVMFLOAT)total_pos);
	}
	else{
		the_result.accuracy = -1;
		the_result.precision = -1;
		the_result.recall = -1;
	};
	if(verbose){
		cout << "Average loss  : "<<(theloss/test_set->size())<<endl;
		cout << "Avg. loss pos : "<<theloss_pos<<"\t ("<<countpos<<" occurences)"<<endl;
		cout << "Avg. loss neg : "<<theloss_neg<<"\t ("<<countneg<<" occurences)"<<endl;
		cout << "Mean absolute error : "<<the_result.MAE<<endl;
		cout << "Mean squared error  : "<<the_result.MSE<<endl;
		
		if(is_pattern){
			// output precision, recall and accuracy
			cout<<"Accuracy  : "<<the_result.accuracy<<endl;
			cout<<"Precision : "<<the_result.precision<<endl;
			cout<<"Recall    : "<<the_result.recall<<endl;
			// nice printout ;-)
			int rows = (int)(1+log10((SVMFLOAT)(total_pos+total_neg)));
			int now_digits = rows+2;
			int i,j;
			cout<<endl;
			cout<<"Predicted values:"<<endl;
			cout<<"   |";
			for(i=0;i<rows;i++){ cout<<" "; };
			cout<<"+  |";
			for(j=0;j<rows;j++){ cout<<" "; };
			cout<<"-"<<endl;
			
			cout<<"---+";
			for(i=0;i<now_digits;i++){ cout<<"-"; };
			cout<<"-+-";
			for(i=0;i<now_digits;i++){ cout<<"-"; };
			cout<<endl;
			
			cout<<" + |  ";
			now_digits=rows-(int)(1+log10((SVMFLOAT)correct_pos))-1;
			for(i=0;i<now_digits;i++){ cout<<" "; };
			cout<<correct_pos<<"  |  ";
			now_digits=rows-(int)(1+log10((SVMFLOAT)(total_pos-correct_pos)))-1;
			for(i=0;i<now_digits;i++){ cout<<" "; };
			cout<<total_pos-correct_pos<<"    (true pos)"<<endl;
			
			cout<<" - |  ";
			now_digits=rows-(int)(1+log10((SVMFLOAT)(total_neg-correct_neg)))-1;
			for(i=0;i<now_digits;i++){ cout<<" "; };
			cout<<(total_neg-correct_neg)<<"  |  ";
			now_digits=rows-(int)(1+log10((SVMFLOAT)correct_neg))-1;
			for(i=0;i<now_digits;i++){ cout<<" "; };
			cout<<correct_neg<<"    (true neg)"<<endl;
			cout<<endl;
		};
	};
	return the_result;
};


void svm_c::predict(example_set_c* test_examples){
	test_set = test_examples;
	SVMINT i;
	SVMFLOAT prediction;
	svm_example example;
	
	for(i=0;i<test_set->size();i++){
		example = test_set->get_example(i);
		prediction = predict(example);
		test_set->put_y(i,prediction);
	};
	test_set->set_initialised_y();
	test_set->put_b(examples->get_b());
	if(parameters->verbosity>=4){
		cout<<"Prediction generated"<<endl;
	};
};


void svm_c::optimize(){
	// optimizer-specific call
	// get time
	long time_start = get_time();
	
	qp.n = working_set_size;
	
	SVMINT i;
	SVMINT j;
	
	// equality constraint
	qp.b[0]=0;
	for(i=0;i<working_set_size;i++){
		qp.b[0] += all_alphas[working_set[i]];
	};
	
	// set initial optimization parameters
	SVMFLOAT new_target=0;
	SVMFLOAT old_target=0;
	SVMFLOAT target_tmp;
	for(i=0;i<working_set_size;i++){
		target_tmp = primal[i]*qp.H[i*working_set_size+i]/2;
		for(j=0;j<i;j++){
			target_tmp+=primal[j]*qp.H[j*working_set_size+i];
		};
		target_tmp+=qp.c[i];
		old_target+=target_tmp*primal[i];
	};
	
	SVMFLOAT new_constraint_sum=0;
	SVMFLOAT my_is_zero = is_zero;
	SVMINT sv_count=working_set_size;
	
	qp.n = working_set_size;
	// optimize
	int KKTerror=1;
	int convError=0;
	
	smo.set_max_allowed_error(feasible_epsilon);
	
	// loop while some KKT condition is not valid (alpha=0)
	int result = smo.smo_solve(&qp,primal);
	lambda_WS = smo.get_lambda_eq();
	
	/////////// new
	SVMINT it=3;
	if(! is_pattern){
		SVMFLOAT lambda_lo;
		while(KKTerror && (it>0)){
			KKTerror = 0;
			it--;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -