📄 ginisvm.cpp
字号:
GINI_bool GINI_SVMBlock::_kktcondition( GINI_u32 dataind, GINI_u32* decision, GINI_double* grad, GINI_status* direc ){ GINI_u32 i,j,startpoint,count; GINI_double curre,betaval,z; GINI_Set *currptr; GINI_bool found = GINI_FALSE; z = 0.0; count = 0; for ( i = 0; i < classes; i++ ) { if ( (currptr = svmap[dataind][i]) != (GINI_Set*) GINI_NULL ) { if ( currptr->alpha < C[dataind]*Y[dataind][i] - alphaeps ) { z += - currptr->E - bias[i]; count++; } } else { if ( Y[dataind][i] > alphaeps ) { curre = evaluateCache( dataind, i ); z += - curre - bias[i]; count++; } } } if (count == 0) { z = 2*rdist/classes; } else { z /= count; } // Randomly select one of the classes that violates the kkt // condition. This prevents unneccessary bias when there // exist classes which have no corresponding data points // //GINI_double minbeta = -1; found = GINI_FALSE; startpoint = rand()%classes; for ( j = 0; j < classes; j++ ) { i = (startpoint+j)%classes; if ( (currptr = svmap[dataind][i]) != (GINI_Set*) GINI_NULL ) { if ( currptr->alpha < C[dataind]*Y[dataind][i] - alphaeps ) { betaval = - currptr->E - bias[i] - z; if (( betaval < -kkteps ) || ( betaval > kkteps )) { //printf("KKT Violation for (%d,%d), alpha=%3.7f, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,currptr->alpha,Y[dataind][i],betaval,z); *decision = i; *grad = currptr->E; *direc = GINI_UP_DOWN; return GINI_TRUE; } } else { currptr->E = evaluateCache( dataind, i ); betaval = - currptr->E - bias[i] - z; *direc = GINI_DOWN; if ( betaval < -kkteps ) { //printf("KKT Violation for (%d,%d), alpha=%3.7f, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,currptr->alpha,Y[dataind][i],betaval,z); *decision = i; *grad = currptr->E; // Set its shrink level = 0. currptr->shrinklevel = 0; return GINI_TRUE; } else { // The variable has reached the bound and its lagrangian is +ve // set its shrink level = 1 currptr->shrinklevel = 1; //cachefull = kernel->ResetActivity(dataind); } } } else { curre = evaluateCache( dataind, i ); betaval = 2*rdist*Y[dataind][i] - curre - bias[i] - z; if ( Y[dataind][i] > alphaeps ) { if (( betaval < -kkteps ) || ( betaval > kkteps )) { //printf("KKT Violation for (%d,%d), alpha=0, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,Y[dataind][i],betaval,z); *decision = i; *grad = curre; *direc = GINI_UP_DOWN; return GINI_TRUE; } } else { if ( betaval < -kkteps ) { //printf("KKT Violation for (%d,%d), alpha=0, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,Y[dataind][i],betaval,z); *decision = i; *grad = curre; *direc = GINI_DOWN; return GINI_TRUE; } } } } return found;} /*****************************************************************************/// FUNCTION : _kktcondition// // DESCRIPTION : Computes the KKT condition for a data point and returns// true if it satisfies otherwise false.// We are going to use the theorem that for GiniSVM, for// each data point there exists one unbounded variable.// One has to be careful here because if there exist// a class where there are no corresponding labels then// this algorithm might not progress or the progress// becomes very slow. Therefore the point and the class// is chosen which violates the KKT condition the most.//// INPUT ://// OUTPUT ://///*****************************************************************************/GINI_bool GINI_SVMBlock::_kktsvcondition( GINI_u32 dataind, GINI_u32* decision, GINI_double* grad, GINI_status* direc ){ GINI_u32 i,j,startpoint,count; GINI_double betaval,z; GINI_Set *currptr; GINI_bool found = GINI_FALSE; z = 0.0; count = 0; for ( i = 0; i < classes; i++ ) { if ( (currptr = svmap[dataind][i]) != (GINI_Set*) GINI_NULL ) { if ( currptr->alpha < C[dataind]*Y[dataind][i] - alphaeps ) { z += - currptr->E - bias[i]; count++; } } } if (count == 0) { return GINI_FALSE; } else { z /= count; } // Randomly select one of the classes that violates the kkt // condition. This prevents unneccessary bias when there // exist classes which have no corresponding data points // //GINI_double minbeta = -1; found = GINI_FALSE; startpoint = rand()%classes; for ( j = 0; j < classes; j++ ) { i = (startpoint+j)%classes; if ( (currptr = svmap[dataind][i]) != (GINI_Set*) GINI_NULL ) { if ( currptr->alpha < C[dataind]*Y[dataind][i] - alphaeps ) { betaval = - currptr->E - bias[i] - z; if (( betaval < -kkteps ) || ( betaval > kkteps )) { //printf("KKT Violation for (%d,%d), alpha=%3.7f, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,currptr->alpha,Y[dataind][i],betaval,z); *decision = i; *grad = currptr->E; *direc = GINI_UP_DOWN; return GINI_TRUE; } } } } return found;} /*****************************************************************************/// FUNCTION :// // DESCRIPTION ://// INPUT ://// OUTPUT ://///*****************************************************************************/GINI_bool GINI_SVMBlock::StartTraining( GINI_bool precomp , GINI_u32 iter, GINI_bool verbose ){ // flag the training flag mode = GINISVM_BLKTRN; GINI_u32 numchanged = 0; GINI_u32 examineAll = 1; GINI_u32 i,classcount; GINI_double currentcost,previouscost; GINI_bool costflag = GINI_FALSE; GINI_u32 totalphase; struct timeval tv; struct timezone tz; struct timeval tv1; struct timezone tz1; (void) gettimeofday(&tv1,&tz1); // Number of truncations floorcount = 0; classcount = 0; globalptr = (GINI_Set*) GINI_NULL; // If precompute is true then the block computes all the // kernel values for making the implementation faster. if ( precomp == GINI_TRUE ) { kernel->ComputeAll(traindata, dimension); } previouscost = CostFunction(); currentcost = previouscost; printf("Starting Cost = %f\n",previouscost); fflush(stdout); // This is the outer loop for selecting the // first coefficient. // While number of iterations is less than specified // or the total number of decrease in steps is greater // or we have to examine all the data points. while (( iterations < iter ) && (( numchanged > 0 ) || (examineAll)) ) { numchanged = 0; if ( examineAll ) { printf("\nIteration = %d\n",iterations); numofdel = 0; numofadd = 0; phase1 = 0; phase2 = 0; phase3 = 0; phase4 = 0; for ( i = 0; i < totaldata ; i++ ) { (void) gettimeofday(&tv,&tz); numchanged += _examineExample(i); timer2 += currtimeval(&tv); if ( numchanged%100 == 0 ) { if ( verbose == GINI_TRUE) { totalphase = phase1+phase2+phase3+phase4+1; printf("Number changed = %5d, #add = %5d , #del = %5d, #kswap = %5d, #phase1 = %3.0f, #phase2 = %3.0f, #phase3 = %3.0f, #phase4 = %3.0f\n", numchanged,numofadd,numofdel,kernel->Swapcount(),(GINI_float)phase1/totalphase*100,(GINI_float)phase2/totalphase*100, (GINI_float)phase3/totalphase*100,(GINI_float)phase4/totalphase*100); fflush(stdout); printtimers(); } else { printf("#"); fflush(stdout); } } } totalphase = phase1+phase2+phase3+phase4+1; if ( (GINI_float)phase4/totalphase*100 > (GINI_float)kktiter) { kkteps *= 2; printf("\nIncreasing the KKT tolerance to %f\n",kkteps); } } else { if ( iterations < fpass ) { //printf("Examining only svsets\n"); // Iterate over sv sets for ( i = 0; i < classes; i++ ) { // Now iterate of the set i classcount = numchanged; numofdel = 0; numofadd = 0; globalptr = svset[i]; while ( globalptr != (GINI_Set*) GINI_NULL ) { if ( globalptr->shrinklevel == 0 ) { (void) gettimeofday(&tv,&tz); numchanged += _examineExample(globalptr->dataind); timer3 += currtimeval(&tv); } if (globalptr != (GINI_Set*) GINI_NULL) { globalptr = globalptr->next; } if ( (numchanged+1)%100 == 0 ) { if ( verbose == GINI_FALSE ) { printf("#"); fflush(stdout); } } } //_purgesvlist(); if ( verbose == GINI_TRUE ) { totalphase = phase1+phase2+phase3+phase4+1; printf("Classid = %4d, Number changed = %5d, #sv = %5d #add = %5d , #del = %5d, #kswap = %5d, #phase1 = %3.0f, #phase2 = %3.0f, #phase3 = %3.0f, #phase4 = %3.0f\n", i,classcount,setsize[i],numofadd,numofdel,kernel->Swapcount(),(GINI_float)phase1/totalphase*100,(GINI_float)phase2/totalphase*100, (GINI_float)phase3/totalphase*100,(GINI_float)phase4/totalphase*100); if ( setsize[i] > 0 ) { printf("Max coeff = (%3.4f,%3.4f), Min coeff = (%3.4f,%3.4f)\n",maxE[i]->alpha,Y[maxE[i]->dataind][i],minE[i]->alpha,Y[minE[i]->dataind][i]); } fflush(stdout); phase1 = 0; phase2 = 0; phase3 = 0; phase4 = 0; printtimers(); } else { printf("#"); fflush(stdout); } } //_purgesvlist(); globalptr = (GINI_Set*) GINI_NULL; classcount = numchanged - classcount; if ( verbose == GINI_TRUE ) { totalphase = phase1+phase2+phase3+phase4+1; printf("Number changed = %5d, #add = %5d , #del = %5d, #kswap = %5d, #phase1 = %3.0f, #phase2 = %3.0f, #phase3 = %3.0f, #phase4 = %3.0f\n", classcount,numofadd,numofdel,kernel->Swapcount(),(GINI_float)phase1/totalphase*100,(GINI_float)phase2/totalphase*100, (GINI_float)phase3/totalphase*100,(GINI_float)phase4/totalphase*100); fflush(stdout); phase1 = 0; phase2 = 0; phase3 = 0; phase4 = 0; printtimers(); } else { printf("#"); fflush(stdout); } } else { for ( i = 0; i < classes; i++ ) { // Now iterate of the set i classcount = numchanged; numofdel = 0; numofadd = 0; globalptr = svset[i]; while ( globalptr != (GINI_Set*) GINI_NULL ) { if ( globalptr->shrinklevel == 0 ) { (void) gettimeofday(&tv,&tz); numchanged += _examinesvExample(globalptr->dataind); timer3 += currtimeval(&tv); } if (globalptr != (GINI_Set*) GINI_NULL) { globalptr = globalptr->next; } if ( (numchanged+1)%100 == 0 ) { if ( verbose == GINI_FALSE ) { printf("#"); fflush(stdout); } } } //_purgesvlist(); } //_purgesvlist(); globalptr = (GINI_Set*) GINI_NULL; classcount = numchanged - classcount; if ( verbose == GINI_TRUE ) { totalphase = phase1+phase2+phase3+phase4+1; printf("Number changed = %5d, #add = %5d , #del = %5d, #phase1 = %3.0f, #phase2 = %3.0f, #phase3 = %3.0f, #phase4 = %3.0f\n", classcount,numofadd,numofdel,(GINI_float)phase1/totalphase*100,(GINI_float)phase2/totalphase*100, (GINI_float)phase3/totalphase*100,(GINI_float)phase4/totalphase*100); fflush(stdout); phase1 = 0; phase2 = 0; phase3 = 0; phase4 = 0; printtimers(); } else { printf("#"); fflush(stdout); } } } //printf("Number changed = %d\n",numchanged); if (iterations%costwindow == 0) { currentcost = CostFunction(); // If the current cost function is costeps times les // that the previous cost then all the changes // made during the previous updates dont count. if ( previouscost - currentcost < costeps*previouscost ) { numchanged = 0; if ( costflag == GINI_FALSE ) { costflag = GINI_TRUE; } else { break; } } else { costflag = GINI_FALSE; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -