⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classify.cpp

📁 SMO工具箱
💻 CPP
字号:
#include <StdAfx.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>
#include "classify.h"
#include "initialize.h"


static double dotProduct(FeaturePtr *x,int sizeX, FeaturePtr *y,int sizeY);
static double power(double x ,int n);
static double wtDotProduct(double *w,int sizeX,FeaturePtr *y,int sizeY);
static void freeMemory(void);

/*****private function*****/
static double dotProduct(FeaturePtr *x,int sizeX,FeaturePtr *y,int sizeY)
{
	int num1,num2,a1,a2;
	int p1=0; 
	int p2=0;
	double dot =0;

	if(sizeX==0 || sizeY==0)
		return 0;
	num1 = sizeX; num2 = sizeY;
	while(p1 <num1 && p2<num2){
			a1=x[p1]->id ;
			a2=y[p2]->id ;
			if(a1==a2){
				dot +=(x[p1]->value)*(y[p2]->value);
				p1++;p2++;
			}
			else if(a1>a2)
				p2++;
			else
				p1++;
		}

	return dot;
}


static double wtDotProduct(double *w,int sizeX,FeaturePtr *y,int sizeY)
{
	int num1,num2,a2;
	int p1=1;   //weight 从1开始
	int p2=0;
	double dot =0;

	if(sizeX==0 ||sizeY==0)
		return 0;
	num1 =sizeX;
	num2 =sizeY;
	while( p1<=num1 && p2<num2){
		a2 = y[p2]->id;
		if( p1 ==a2){
			dot += (w[p1])*(y[p2]->value);
			p1++;
			p2++;
		}
		else if( p1>a2)
			p2++;
		else p1++;
	}
	return dot;
}


static double power(double x ,int n)
{
	int i;
	double p;
	p =1.0;
	for (i =1;i<=n;i++)
		p =p*x;

	return p;
}

static void freeMemory(void)
{
	free(lambda);
	free(svNonZeroFeature);
	free(nonZeroFeature);
	free(target);
    free(weight);
	free(output);
}



/******public function*****************/
int writeResult(FILE *out)
{
	int result;

	if( kernelType ==0)
		result =classifyLinear(out);
	else if(kernelType ==1)
		result =classifyPoly(out);
	else if(kernelType ==2)
		result =classifyRbf(out);
	if(result)
		return 1;
	
	return 0;
}

int classifyLinear(FILE *out)
{
	int i,numCorrect=0;
	double startTime,svmOutput;
	
	printf("Start classifying ...\n");
	startTime =clock()/CLOCKS_PER_SEC;
	for(i =1;i<=numExample;i++)
		output[i] =wtDotProduct(weight,maxFeature,example[i],nonZeroFeature[i]);

	printf("Classifying time is %f seconds\n",clock()/CLOCKS_PER_SEC -startTime);
	printf("Finishing classifying.\n");
	for(i=1;i<=numExample;i++){
		fprintf(out,"target[%d] =%d; SVM output is Example[%d] =%6.5f\n",i,target[i],i,output[i]-b);
	}

	//compute the accuracy 
	for(i=1;i<=numExample;i++){
		   svmOutput =output[i]-b;
          if ( (svmOutput>=0 && target[i]>=0) || (svmOutput<0 && target[i]<0))
			  numCorrect++;
		  }
	fprintf(out,"test example number is %d, correct classification number is %d.\n",numExample,numCorrect);
	fprintf(out,"accuracy is %f\n",(float)numCorrect/(float)numExample);

	//
	return 1;
}


int classifyPoly(FILE *out)
{
	int i,j,numCorrect=0;
    double startTime,svmOutput;
		
	printf("Start classifying ...\n");
	startTime =clock()/CLOCKS_PER_SEC;
	for(i =1;i<=numExample;i++){
		output[i]=0;
		for(j=1;j<=numSv;j++)
			output[i] +=lambda[j]*power(1+dotProduct(sv[j],svNonZeroFeature[j],
				example[i],nonZeroFeature[i]),degree);
	}
	printf("Classifying time is %f seconds\n",clock()/CLOCKS_PER_SEC -startTime);
	printf("Finishing classifying.\n");
	for(i=1;i<=numExample;i++)
		fprintf(out,"target[%d] =%d; SVM output is Example[%d] =%6.5f\n",i,target[i],i,output[i]-b);

	for(i=1;i<=numExample;i++){
		   svmOutput =output[i]-b;
          if ( (svmOutput>=0 && target[i]>=0) || (svmOutput<0 && target[i]<0))
			  numCorrect++;
		  }
	fprintf(out,"test example number is %d, correct classification number is %d.\n",numExample,numCorrect);
	fprintf(out,"accuracy is %f\n",(float)numCorrect/(float)numExample);
	return 1;

}

int classifyRbf(FILE *out)
{
	int i,j,numCorrect=0;
	double devSqr;
	double startTime,svmOutput;

	printf("Start classifying ...\n");
	startTime =clock()/CLOCKS_PER_SEC;
	for(i =1;i<=numExample;i++){
		output[i]=0;
		for(j=1;j<=numSv;j++){
			devSqr = dotProduct(sv[j],svNonZeroFeature[j],sv[j],svNonZeroFeature[j])
				-2*dotProduct(sv[j],svNonZeroFeature[j],example[i],nonZeroFeature[i])
				+dotProduct(example[i],nonZeroFeature[i],example[i],nonZeroFeature[i]);
			output[i] +=lambda[j]*exp(-devSqr*rbfConstant);
		}
	}

	printf("Classifying time is %f seconds\n",clock()/CLOCKS_PER_SEC -startTime);
	printf("Finishing classifying.\n");
	for(i=1;i<=numExample;i++)
		fprintf(out,"target[%d] =%d; SVM output is Example[%d] =%6.5f\n",i,target[i],i,output[i]-b);

		for(i=1;i<=numExample;i++){
		   svmOutput =output[i]-b;
          if ( (svmOutput>=0 && target[i]>=0) || (svmOutput<0 && target[i]<0))
			  numCorrect++;
		  }
	fprintf(out,"test example number is %d, correct classification number is %d.\n",numExample,numCorrect);
	fprintf(out,"accuracy is %f\n",(float)numCorrect/(float)numExample);
	return 1;


}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -