📄 learn.cpp
字号:
#include "stdafx.h"
#include <stdio.h>
#include <stdlib.h>
#include <float.h>
#include <math.h>
#include <time.h>
#include <limits.h>
#include <ctype.h>
#include "initializeTraining.h"
#include "utility.h"
//#include "learn.h"
double C;
double b;
int degree;
int kernelType;
double sigmaSqr;
double rbfConstant;
double iteration;
double totalIteration;
int binaryFeature;
/***private defines and data structure**********/
#define MAX(X,Y) ((X)>(Y)?(X):(Y))
#define MIN(X,Y) ((X)<(Y)?(X):(Y))
#define TOL 0.001
#define EPS DBL_EPSILON
#define MAXNUM DBL_MAX
static double lambda1;
static double lambda2;
static double E1;
static int unBoundPtr =-1;
static int errorPtr =-1;
static int unBound1 =0;
static int unBound2 =0;
static int numNonZeroLambda =0 ;
static int lambdaPtr =-1;
/********declared functions****/
static double dotProduct(FeaturePtr *x, int sizeX,FeaturePtr *y,int sizeY);
static double calculateError(int n);
static int takeStep(int e1,int e2);
static int examineExample(int e1);
static double dotProduct(FeaturePtr *x, int sizeX,FeaturePtr *y,int sizeY)
{
int num1,num2,a1,a2;
int p1=0;
int p2=0;
double dot =0;
if(sizeX==0 || sizeY==0)
return 0;
if(binaryFeature ==0){
num1 = sizeX; num2 = sizeY;
while(p1 <num1 && p2<num2){
a1=x[p1]->id ;
a2=y[p2]->id ;
if(a1==a2){
dot +=(x[p1]->value)*(y[p2]->value);
p1++;p2++;
}
else if(a1>a2)
p2++;
else
p1++;
}
}
else{
num1 =sizeX ;num2=sizeY;
while(p1 <num1 && p2<num2){
a1=x[p1]->id ;
a2=y[p2]->id ;
if(a1==a2){
dot +=1;
p1++;p2++;
}
else if(a1>a2)
p2++;
else
p1++;
}
}
return dot;
}
static double calculateError(int n)
{
double svmOut ,interValue;
int i,index;
svmOut =0;
if( kernelType ==0){ //linear kernel
for( i=0;i<numNonZeroLambda;i++){
index =nonZeroLambda[i];
svmOut +=lambda[index]*target[index]*dotProduct(example[index],
nonZeroFeature[index],example[n],nonZeroFeature[n]);
}
}
else if(kernelType ==1){ //polynomial kernel
for( i=0;i<numNonZeroLambda;i++){
index =nonZeroLambda[i];
svmOut +=lambda[index]*target[index]*power(1+dotProduct(example[index],
nonZeroFeature[index],example[n],nonZeroFeature[n]),degree);
}
}
else if(kernelType ==2){ //rbf kernel
for( i=0;i<numNonZeroLambda;i++){
index =nonZeroLambda[i];
/*****calculate the abs(example[index]-example[n]^2***/
if(binaryFeature ==0){
interValue =dotProduct(example[n],nonZeroFeature[n],example[n],
nonZeroFeature[n]) -2*dotProduct(example[index],nonZeroFeature[index],
example[n],nonZeroFeature[n])+dotProduct(example[index],
nonZeroFeature[index],example[index],nonZeroFeature[index]);
}
else{
interValue =nonZeroFeature[n]-2*dotProduct(example[index],nonZeroFeature[index],
example[n],nonZeroFeature[n])+nonZeroFeature[index];
}
svmOut +=lambda[index]*target[index]*exp(-interValue*rbfConstant);
}
}
return (svmOut -b -target[n]);
}
/********joint optimize the langrange multipliers of e1 and e2****/
int takeStep(int e1,int e2)
{
double k11,k12,k22,eta;
double a1,a2,f1,f2,L1,L2,H1,H2,Lobj,Hobj;
int y1,y2,s,i,index,findPos,temp;
double interValue1,interValue2;
double E2,b1,b2,interValue,oldb;
if( e1==e2)
return 0;
lambda1 = lambda[e1];
lambda2 = lambda[e2];
y1 =target[e1]; y2 =target[e2];
s =y1*y2;
if (nonBound[e2])
E2 =error[e2];
else
E2 =calculateError(e2);
if( y1 !=y2){
L2 =MAX(0,lambda2 -lambda1);
H2 =MIN(C,C+lambda2 -lambda1);
}
else{
L2 =MAX(0,lambda1+lambda2-C);
H2=MIN(C,lambda1+lambda2);
}
if(fabs(L2-H2)<EPS) //L2 Equals H2
return 0;
if (kernelType == 0){ //linear
if(binaryFeature ==0){
k11 =dotProduct(example[e1],nonZeroFeature[e1],example[e1],
nonZeroFeature[e1]);
k22 =dotProduct(example[e2],nonZeroFeature[e2],example[e2],
nonZeroFeature[e2]);
}
else {
k11 =nonZeroFeature[e1];
k22 =nonZeroFeature[e2];
}
k12 =dotProduct(example[e1],nonZeroFeature[e1],example[e2],
nonZeroFeature[e2]);
}
else if(kernelType ==1){ //polynomial
if(binaryFeature ==0){
k11 =power(1+dotProduct(example[e1],nonZeroFeature[e1],example[e1],
nonZeroFeature[e1]),degree);
k22 =power(1+dotProduct(example[e2],nonZeroFeature[e2],example[e2],
nonZeroFeature[e2]),degree);
}
else {
k11 =power(1+nonZeroFeature[e1],degree);
k22 =power(1+nonZeroFeature[e2],degree);
}
k12 =power(1+dotProduct(example[e1],nonZeroFeature[e1],example[e2],
nonZeroFeature[e2]),degree);
}
else if(kernelType ==2){ //rbf
k11 =1;
k22 =1;
if(binaryFeature ==0){
interValue =dotProduct(example[e1],nonZeroFeature[e1],example[e1],
nonZeroFeature[e1]) -2*dotProduct(example[e1],nonZeroFeature[e1],example[e2],
nonZeroFeature[e2]) +dotProduct(example[e2],nonZeroFeature[e2],example[e2],
nonZeroFeature[e2]);
}
else{
interValue =nonZeroFeature[e1] -2*dotProduct(example[e1],nonZeroFeature[e1],example[e2]
,nonZeroFeature[e2])+nonZeroFeature[e2];
}
k12 =exp(-interValue*rbfConstant);
}
eta =2*k12 -k11-k22;
if(eta<0){
a2 =lambda2 -y2*(E1-E2)/eta;
//constrain a2 to within box
if(a2<L2)
a2=L2;
else if(a2>H2)
a2 =H2;
}
else {
L1 =lambda1 +s*(lambda2 -L2);
H1 =lambda1 +s*(lambda2 -H2);
f1 =y1*(E1+b) -lambda1*k11 -s*lambda2*k12;
f2 =y2*(E2+b) -lambda2*k22 -s*lambda1*k12;
Lobj =-0.5*L1*L1*k11 -0.5*L2*L2*k22 -s*L1*L2*k12 -L1*f1 -L2*f2;
Hobj =-0.5*H1*H1*k11 -0.5*H2*H2*k22 -s*H1*H2*k12 -H1*f1 -H2*f2;
if(Lobj>Hobj+EPS)
a2 =L2;
else if(Lobj <Hobj -EPS)
a2 =H2;
else
a2 =lambda2;
}
if( fabs(a2 -lambda2)<EPS*(a2+lambda2+EPS))
return 0;
/*****find the new lambda1*****/
a1 =lambda1 +s*(lambda2 -a2);
if( a1<0)
a1 =0;
/*******check e1,e2 for unbound lamdas*******/
if( a1>0 && a1<C)
unBound1 =1;
else
unBound1 =0;
if( a2>0 && a2<C)
unBound2 =1;
else
unBound2 =0;
/*********update the number of non-zero lambda****/
if (a1 >0){
if( numNonZeroLambda ==0){
lambdaPtr++;
nonZeroLambda[lambdaPtr] =e1;
numNonZeroLambda++;
}
else if(numNonZeroLambda ==1 &&nonZeroLambda[0]!=e1){
lambdaPtr++;
nonZeroLambda[lambdaPtr] =e1;
numNonZeroLambda++;
if(e1 <nonZeroLambda[0]){
temp =e1;
nonZeroLambda[1] =nonZeroLambda[0];
nonZeroLambda[0] =e1;
}
}
else if( numNonZeroLambda >1){
if( binSearch(e1,nonZeroLambda,numNonZeroLambda) == -1){
lambdaPtr++;
nonZeroLambda[lambdaPtr] =e1;
numNonZeroLambda++;
quicksort(nonZeroLambda,0,lambdaPtr);
}
}
}
if (a2>0){
if(numNonZeroLambda ==0){
lambdaPtr++;
nonZeroLambda[lambdaPtr] =e2;
numNonZeroLambda++;
}
else if(numNonZeroLambda ==1 &&nonZeroLambda[0]!=e2){
lambdaPtr++;
nonZeroLambda[lambdaPtr] =e2;
numNonZeroLambda++;
if(e2 <nonZeroLambda[0]){
temp =e2;
nonZeroLambda[1] =nonZeroLambda[0];
nonZeroLambda[0] =e2;
}
}
else if( numNonZeroLambda >1){
if( binSearch(e1,nonZeroLambda,numNonZeroLambda) == -1){
lambdaPtr++;
nonZeroLambda[lambdaPtr] =e2;
numNonZeroLambda++;
quicksort(nonZeroLambda,0,lambdaPtr);
}
}
}
/*****update the threshold b********/
oldb =b;
if(kernelType ==0){
if(binaryFeature ==0){
b1 =E1 +y1*(a1-lambda1)*dotProduct(example[e1],nonZeroFeature[e1],
example[e1],nonZeroFeature[e1])+y2*(a2-lambda2)*dotProduct(example[e1],
nonZeroFeature[e1],example[e2],nonZeroFeature[e2]) +oldb;
b2 =E2 +y1*(a1-lambda1)*dotProduct(example[e1],nonZeroFeature[e1],
example[e2],nonZeroFeature[e2])+y2*(a2-lambda2)*dotProduct(example[e2],
nonZeroFeature[e2],example[e2],nonZeroFeature[e2]) +oldb;
}
else{
b1 =E1 +y1*(a1 -lambda1)*nonZeroFeature[e1]+y2*(a2-lambda2)*dotProduct(example[e1],nonZeroFeature[e1],
example[e2],nonZeroFeature[e2]) +oldb;
b2 =E2 +y1*(a1 -lambda1)*dotProduct(example[e1],nonZeroFeature[e1],
example[e2],nonZeroFeature[e2])+y2*(a2-lambda2)*nonZeroFeature[e2] +oldb;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -