svm.cpp

来自「OpenSVM was developped under Visual C++ 」· C++ 代码 · 共 308 行

CPP
308
字号
// SVM.cpp: implementation of the CSVM class.
//
//////////////////////////////////////////////////////////////////////

#include "stdafx.h"
#include "OpenSVM.h"

#include "SVM.h"

#ifdef _DEBUG
#undef THIS_FILE
static char THIS_FILE[]=__FILE__;
#define new DEBUG_NEW
#endif

#define Malloc(type,n) (type *)malloc((n)*sizeof(type))

//#include <svm.h>

//////////////////////////////////////////////////////////////////////
// Construction/Destruction
//////////////////////////////////////////////////////////////////////

CSVM::CSVM()
{
	
                   
}

CSVM::~CSVM()
{
	svm_destroy_param(&param);
	free(prob.y);
	free(prob.x);
	free(x_space);
}


/**
	ReadProblem 
		
	Note: 
	2007-10-13: Change the error output. use AfxMessageBox instead of stderr.(Byron)

  */
void CSVM::ReadProblem(CString filename)
{
	CString msg;
	int elements, max_index, i, j;     
	FILE *fp = fopen(filename,"r");    
	
	if(fp == NULL)                      
	{
		//fprintf(stderr,"can't open input file %s\n",filename);
		msg.Format("can't open input file %s\n",filename);
		AfxMessageBox(msg, MB_ICONERROR | MB_OK);
		exit(1);
	}
	
	prob.l = 0;                         
	elements = 0;                       
	while(1)
	{
		int c = fgetc(fp);                  
		switch(c)                      
		{
		case '\n':
			++prob.l;
			// fall through,
			// count the '-1' element
		case ':':
			++elements;
			break;
		case EOF:                            
			goto out;
		default:
			;
		}
	}
out:
	rewind(fp);                                      
	
	prob.y = Malloc(double,prob.l);                      
	prob.x = Malloc(struct svm_node *,prob.l);
	x_space = Malloc(struct svm_node,elements);
	
	max_index = 0;                        
	j=0;                                       
	for(i=0;i<prob.l;i++)                     
	{                        
		double label;                            
		prob.x[i] = &x_space[j];
		fscanf(fp,"%lf",&label);                 
		prob.y[i] = label; 
		
		while(1)
		{
			int c;                              
			do {
				c = getc(fp);            
				if(c=='\n') goto out2; 
			} while(isspace(c));                
			ungetc(c,fp);                     
			if (fscanf(fp,"%d:%lf",&(x_space[j].index),&(x_space[j].value)) < 2)
			{
				
				//fprintf(stderr,"Wrong input format at line %d\n", i+1);
				msg.Format("Wrong input format at line %d\n", i+1);
				AfxMessageBox(msg,MB_ICONERROR | MB_OK);				
				exit(1);
			}
			++j;
		}	
out2:
		if(j>=1 && x_space[j-1].index > max_index) 
			max_index = x_space[j-1].index;   
		x_space[j++].index = -1;            
	}
	
	if(param.gamma == 0)
		param.gamma = 1.0/max_index;

	if(param.kernel_type == PRECOMPUTED)
		for(i=0;i<prob.l;i++)
		{
			if (prob.x[i][0].index != 0)
			{
				//fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
				msg.Format("Wrong input format: first column must be 0:sample_serial_number\n");
				AfxMessageBox(msg, MB_ICONERROR | MB_OK);	
				exit(1);
			}
			if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)//
			{
				//fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
				msg.Format("Wrong input format: sample_serial_number out of range\n");
				AfxMessageBox(msg, MB_ICONERROR | MB_OK);	
				exit(1);
			}
		}
	
	fclose(fp);
}

void CSVM::DoTrain()
{
	model = svm_train(&prob,&param);
}

void CSVM::StoreModel(CString filename)
{
	filename +=".model";
	svm_save_model(filename,model);
	svm_destroy_model(model);
}

void CSVM::DoCrossValidation(int nr_fold)
{
	int i;
	int total_correct = 0;
	double total_error = 0;
	double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
	double *target = Malloc(double,prob.l);

	svm_cross_validation(&prob,&param,nr_fold,target);

	if(param.svm_type == EPSILON_SVR ||
		param.svm_type == NU_SVR)
	{
		for(i=0;i<prob.l;i++)
		{
			double y = prob.y[i];
			double v = target[i];
			total_error += (v-y)*(v-y);
			sumv += v;
			sumy += y;
			sumvv += v*v;
			sumyy += y*y;
			sumvy += v*y;
		}
		printf("Cross Validation Mean squared error = %g\n",total_error/prob.l);
		printf("Cross Validation Squared correlation coefficient = %g\n",
			((prob.l*sumvy-sumv*sumy)*(prob.l*sumvy-sumv*sumy))/
			((prob.l*sumvv-sumv*sumv)*(prob.l*sumyy-sumy*sumy))
			);
	}
	else
	{
		for(i=0;i<prob.l;i++)
			if(target[i] == prob.y[i])
				++total_correct;
			printf("Cross Validation Accuracy = %g%%\n",100.0*total_correct/prob.l);
	}
	free(target);
}

void CSVM::GetParam(svm_parameter _pass)
{
	const char *error_msg;
	error_msg = svm_check_parameter(&prob,&_pass);
	
	
	if(error_msg)
	{
		AfxMessageBox(error_msg);
	}
	else
	{
		param.svm_type		= _pass.svm_type;
		param.kernel_type	= _pass.kernel_type;
		param.degree		= _pass.degree;
		param.gamma			= _pass.gamma;	// 1/k
		param.coef0			= _pass.coef0;
		param.nu			= _pass.nu;
		param.cache_size	= _pass.cache_size;
		param.C				= _pass.C;
		param.eps			= _pass.eps;
		param.p				= _pass.p;
		param.shrinking		= _pass.shrinking;
		param.probability	= _pass.probability;
		param.nr_weight		= _pass.nr_weight;
		param.weight_label	= _pass.weight_label;
		param.weight		= _pass.weight;
	}

}

void CSVM::StoreModelInfo(CString modelfilename)
{
	modelfilename +=".info";
	FILE *fp = fopen(modelfilename,"w");

	if (fp == NULL)
	{
		AfxMessageBox("Cannot creat info file.");
	}
	else
	{
		if (param.probability)
		{
			fprintf(fp,"DidProbability\n");
		}

		//////////////////////////////////////////////////////////////////////////
		fprintf(fp,"SVM type:");
		if(param.svm_type == C_SVC)
		{
			fprintf(fp,"C_SVC\n");
			fprintf(fp,"C:%lf\n",param.C);
		}

		if(param.svm_type == NU_SVC)
		{
			fprintf(fp,"NU_SVR\n");
			fprintf(fp,"NU:%lf\n",param.nu);
		}

		if (param.svm_type == EPSILON_SVR)
		{
			fprintf(fp,"EPSILON_SVR\n");
			fprintf(fp,"C:%lf\n",param.C);
			fprintf(fp,"EPSILON:%lf\n",param.p);
		}

		if(param.svm_type == NU_SVC)
		{
			fprintf(fp,"NU_SVR\n");
			fprintf(fp,"C :%lf\n",param.C);
			fprintf(fp,"NU:%lf\n",param.nu);
		}

		fprintf(fp,"\n");

		//////////////////////////////////////////////////////////////////////////
		fprintf(fp,"Kernel Type:");
		if(param.kernel_type == LINEAR)
		{
			fprintf(fp,"LINEAR\n");
		}

		if(param.kernel_type == POLY)
		{
			fprintf(fp,"POLY\n");
			fprintf(fp,"Gamma :%lf\n",param.gamma);
			fprintf(fp,"Coef0 :%lf\n",param.coef0);
			fprintf(fp,"Degree:%d\n",param.degree);
		}

		if(param.kernel_type == RBF)
		{
			fprintf(fp,"RBF\n");
			fprintf(fp,"Gamma :%lf\n",param.gamma);
		}

		if(param.kernel_type == SIGMOID)
		{
			fprintf(fp,"SIGMOID\n");
			fprintf(fp,"Gamma :%lf\n",param.gamma);
			fprintf(fp,"Coef0 :%lf\n",param.coef0);
		}


		//////////////////////////////////////////////////////////////////////////

		fclose(fp);
	}
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?