⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 learn.cpp

📁 SMO工具箱
💻 CPP
📖 第 1 页 / 共 2 页
字号:
#include "stdafx.h"
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <math.h>
#include <time.h>
#include <limits.h>
#include <ctype.h>
#include "initializeTraining.h"
#include "utility.h"
//#include "learn.h"

double C;
double b;
int degree;
int kernelType;
double sigmaSqr;
double rbfConstant;
double iteration;
double totalIteration;
int binaryFeature;

/***private defines and data structure**********/
#define MAX(X,Y) ((X)>(Y)?(X):(Y))
#define MIN(X,Y) ((X)<(Y)?(X):(Y))
#define TOL 0.001
#define EPS DBL_EPSILON
#define MAXNUM DBL_MAX

static double lambda1;
static double lambda2;
static double E1;
static int unBoundPtr =-1;
static int errorPtr =-1;
static int unBound1 =0;
static int unBound2 =0;
static int numNonZeroLambda =0 ;
static int lambdaPtr =-1;

/********declared functions****/
static double dotProduct(FeaturePtr *x, int sizeX,FeaturePtr *y,int sizeY);
static double calculateError(int n);
static int takeStep(int e1,int e2);
static int examineExample(int e1);

static double dotProduct(FeaturePtr *x, int sizeX,FeaturePtr *y,int sizeY)
{
	int num1,num2,a1,a2;
	int p1=0; 
	int p2=0;
	double dot =0;

	if(sizeX==0 || sizeY==0)
		return 0;
	if(binaryFeature ==0){
		num1 = sizeX; num2 = sizeY;
		while(p1 <num1 && p2<num2){
			a1=x[p1]->id ;
			a2=y[p2]->id ;
			if(a1==a2){
				dot +=(x[p1]->value)*(y[p2]->value);
				p1++;p2++;
			}
			else if(a1>a2)
				p2++;
			else
				p1++;
		}
	}
	else{
		num1 =sizeX ;num2=sizeY;
        while(p1 <num1 && p2<num2){
			a1=x[p1]->id ;
			a2=y[p2]->id ; 
			if(a1==a2){
				dot +=1;
				p1++;p2++;
			}
			else if(a1>a2)
				p2++;
			else
				p1++;
		}
	}
	return dot;
}

static double calculateError(int n)
{
	double svmOut ,interValue;
	int i,index;

	svmOut =0;
	if( kernelType ==0){  //linear kernel
		for( i=0;i<numNonZeroLambda;i++){
			index =nonZeroLambda[i];
			svmOut +=lambda[index]*target[index]*dotProduct(example[index],
				nonZeroFeature[index],example[n],nonZeroFeature[n]);
		}
	}
	else if(kernelType ==1){ //polynomial kernel
	    for( i=0;i<numNonZeroLambda;i++){
			index =nonZeroLambda[i];
			svmOut +=lambda[index]*target[index]*power(1+dotProduct(example[index],
				nonZeroFeature[index],example[n],nonZeroFeature[n]),degree);
		}
	}
	else if(kernelType ==2){ //rbf kernel 
         for( i=0;i<numNonZeroLambda;i++){
			index =nonZeroLambda[i];
			/*****calculate the abs(example[index]-example[n]^2***/
            if(binaryFeature ==0){
				interValue =dotProduct(example[n],nonZeroFeature[n],example[n],
					nonZeroFeature[n]) -2*dotProduct(example[index],nonZeroFeature[index],
					example[n],nonZeroFeature[n])+dotProduct(example[index],
					nonZeroFeature[index],example[index],nonZeroFeature[index]);
			}
			else{
                 interValue =nonZeroFeature[n]-2*dotProduct(example[index],nonZeroFeature[index],
					example[n],nonZeroFeature[n])+nonZeroFeature[index];
			}
			svmOut +=lambda[index]*target[index]*exp(-interValue*rbfConstant);
		 }
	}
	return (svmOut -b -target[n]);
}

/********joint optimize the langrange multipliers of e1 and e2****/

int takeStep(int e1,int e2)
{
	double k11,k12,k22,eta;
	double a1,a2,f1,f2,L1,L2,H1,H2,Lobj,Hobj;
	int y1,y2,s,i,index,findPos,temp;
	double interValue1,interValue2;
	double E2,b1,b2,interValue,oldb;

    if( e1==e2)
		return 0;
	lambda1 = lambda[e1];
	lambda2 = lambda[e2];
	y1 =target[e1]; y2 =target[e2];
	s =y1*y2;

	if (nonBound[e2])
		E2 =error[e2];
	else
		E2 =calculateError(e2);

	if( y1 !=y2){
		L2 =MAX(0,lambda2 -lambda1);
		H2 =MIN(C,C+lambda2 -lambda1);
	}
	else{
		L2 =MAX(0,lambda1+lambda2-C);
		H2=MIN(C,lambda1+lambda2);
	}
	if(fabs(L2-H2)<EPS)  //L2 Equals H2
		return 0;
	if (kernelType == 0){ //linear
		if(binaryFeature ==0){
			k11 =dotProduct(example[e1],nonZeroFeature[e1],example[e1],
				nonZeroFeature[e1]);
			k22 =dotProduct(example[e2],nonZeroFeature[e2],example[e2],
				nonZeroFeature[e2]);
		}
		else {
			k11 =nonZeroFeature[e1];
			k22 =nonZeroFeature[e2];
		}
		k12 =dotProduct(example[e1],nonZeroFeature[e1],example[e2],
			nonZeroFeature[e2]);
	}
	else if(kernelType ==1){ //polynomial
         if(binaryFeature ==0){
			k11 =power(1+dotProduct(example[e1],nonZeroFeature[e1],example[e1],
				nonZeroFeature[e1]),degree);
			k22 =power(1+dotProduct(example[e2],nonZeroFeature[e2],example[e2],
				nonZeroFeature[e2]),degree);
		}
		else {
			k11 =power(1+nonZeroFeature[e1],degree);
			k22 =power(1+nonZeroFeature[e2],degree);
		}
		k12 =power(1+dotProduct(example[e1],nonZeroFeature[e1],example[e2],
			nonZeroFeature[e2]),degree);
	}
	else if(kernelType ==2){ //rbf
		k11 =1;
		k22 =1;
		if(binaryFeature ==0){
			interValue =dotProduct(example[e1],nonZeroFeature[e1],example[e1],
				nonZeroFeature[e1]) -2*dotProduct(example[e1],nonZeroFeature[e1],example[e2],
				nonZeroFeature[e2]) +dotProduct(example[e2],nonZeroFeature[e2],example[e2],
				nonZeroFeature[e2]);
		}
		else{
			interValue =nonZeroFeature[e1] -2*dotProduct(example[e1],nonZeroFeature[e1],example[e2]
				,nonZeroFeature[e2])+nonZeroFeature[e2];
		}
		k12 =exp(-interValue*rbfConstant);
	}
	eta =2*k12 -k11-k22;
	if(eta<0){
		a2 =lambda2 -y2*(E1-E2)/eta;

		//constrain a2 to within box
		if(a2<L2)
			a2=L2;
		else if(a2>H2)
			a2 =H2;
	}
	else {
		L1 =lambda1 +s*(lambda2 -L2);
		H1 =lambda1 +s*(lambda2 -H2);
		f1 =y1*(E1+b) -lambda1*k11 -s*lambda2*k12;
		f2 =y2*(E2+b) -lambda2*k22 -s*lambda1*k12;
		Lobj =-0.5*L1*L1*k11 -0.5*L2*L2*k22 -s*L1*L2*k12 -L1*f1 -L2*f2;
	    Hobj =-0.5*H1*H1*k11 -0.5*H2*H2*k22 -s*H1*H2*k12 -H1*f1 -H2*f2;
		if(Lobj>Hobj+EPS)
			a2 =L2;
		else if(Lobj <Hobj -EPS)
			a2 =H2;
		else 
			a2 =lambda2;
	}
	if( fabs(a2 -lambda2)<EPS*(a2+lambda2+EPS))
		return 0;

	/*****find the new lambda1*****/
	a1 =lambda1 +s*(lambda2 -a2);
	if( a1<0)
		a1 =0;
	/*******check e1,e2 for unbound lamdas*******/
	if( a1>0 && a1<C)
		unBound1 =1;
	else 
		unBound1 =0;
    if( a2>0 && a2<C)
		unBound2 =1;
	else 
		unBound2 =0;

	/*********update the number of non-zero lambda****/
	if (a1 >0){
		if( numNonZeroLambda ==0){
			lambdaPtr++;
			nonZeroLambda[lambdaPtr] =e1;
			numNonZeroLambda++;
		}
		else if(numNonZeroLambda ==1 &&nonZeroLambda[0]!=e1){
			lambdaPtr++;
			nonZeroLambda[lambdaPtr] =e1;
			numNonZeroLambda++;
			if(e1 <nonZeroLambda[0]){
				temp =e1;
				nonZeroLambda[1] =nonZeroLambda[0];
				nonZeroLambda[0] =e1;
			}
		}
		else if( numNonZeroLambda >1){
			if( binSearch(e1,nonZeroLambda,numNonZeroLambda) == -1){
				lambdaPtr++;
				nonZeroLambda[lambdaPtr] =e1;
				numNonZeroLambda++;
				quicksort(nonZeroLambda,0,lambdaPtr);
			}
		}
	}
	if (a2>0){
		if(numNonZeroLambda ==0){
            lambdaPtr++;
			nonZeroLambda[lambdaPtr] =e2;
			numNonZeroLambda++;
		}
		else if(numNonZeroLambda ==1 &&nonZeroLambda[0]!=e2){
			lambdaPtr++;
			nonZeroLambda[lambdaPtr] =e2;
			numNonZeroLambda++;
			if(e2 <nonZeroLambda[0]){
				temp =e2;
				nonZeroLambda[1] =nonZeroLambda[0];
				nonZeroLambda[0] =e2;
			}
		}
		else if( numNonZeroLambda >1){
			if( binSearch(e1,nonZeroLambda,numNonZeroLambda) == -1){
				lambdaPtr++;
				nonZeroLambda[lambdaPtr] =e2;
				numNonZeroLambda++;
				quicksort(nonZeroLambda,0,lambdaPtr);
			}
		}
	}
    /*****update the threshold b********/
	oldb =b;
	if(kernelType ==0){
		if(binaryFeature ==0){
			b1 =E1 +y1*(a1-lambda1)*dotProduct(example[e1],nonZeroFeature[e1],
				example[e1],nonZeroFeature[e1])+y2*(a2-lambda2)*dotProduct(example[e1],
				nonZeroFeature[e1],example[e2],nonZeroFeature[e2]) +oldb;
			
			b2 =E2 +y1*(a1-lambda1)*dotProduct(example[e1],nonZeroFeature[e1],
				example[e2],nonZeroFeature[e2])+y2*(a2-lambda2)*dotProduct(example[e2],
				nonZeroFeature[e2],example[e2],nonZeroFeature[e2]) +oldb;
		}
		else{
			b1 =E1 +y1*(a1 -lambda1)*nonZeroFeature[e1]+y2*(a2-lambda2)*dotProduct(example[e1],nonZeroFeature[e1],
				example[e2],nonZeroFeature[e2]) +oldb;
           	b2 =E2 +y1*(a1 -lambda1)*dotProduct(example[e1],nonZeroFeature[e1],
				example[e2],nonZeroFeature[e2])+y2*(a2-lambda2)*nonZeroFeature[e2] +oldb;
		}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -