⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ginisvm.cpp

📁 一种新的SVM算法
💻 CPP
📖 第 1 页 / 共 5 页
字号:
GINI_bool GINI_SVMBlock::_kktcondition( GINI_u32 dataind, 		                        GINI_u32* decision,		                        GINI_double* grad,		                        GINI_status* direc                                      ){   GINI_u32 i,j,startpoint,count;   GINI_double curre,betaval,z;   GINI_Set *currptr;   GINI_bool found = GINI_FALSE;    z = 0.0;    count = 0;    for ( i = 0; i < classes; i++ )    {        if ( (currptr = svmap[dataind][i]) != (GINI_Set*) GINI_NULL )        {            if ( currptr->alpha < C[dataind]*Y[dataind][i] - alphaeps )            {                 z +=  - currptr->E - bias[i];                 count++;            }        }       else       { 	  if ( Y[dataind][i] > alphaeps )           {               curre = evaluateCache( dataind, i ); 	      z +=  - curre - bias[i]; 	      count++; 	  }       }    }    if (count == 0)    {         z = 2*rdist/classes;    }    else    {         z /= count;    }   // Randomly select one of the classes that violates the kkt   // condition. This prevents unneccessary bias when there   // exist classes which have no corresponding data points   //   //GINI_double minbeta = -1;   found = GINI_FALSE;    startpoint = rand()%classes;   for ( j = 0; j < classes; j++ )   {      i = (startpoint+j)%classes;      if ( (currptr = svmap[dataind][i]) != (GINI_Set*) GINI_NULL )      {	 if ( currptr->alpha < C[dataind]*Y[dataind][i] - alphaeps )         {            betaval =  - currptr->E - bias[i] - z;            if (( betaval < -kkteps ) || ( betaval > kkteps ))            {               //printf("KKT Violation for (%d,%d), alpha=%3.7f, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,currptr->alpha,Y[dataind][i],betaval,z);               *decision = i;	       *grad = currptr->E;	       *direc = GINI_UP_DOWN;	       return GINI_TRUE;            }         }	 else         {            currptr->E = evaluateCache( dataind, i );	            betaval = - currptr->E - bias[i] - z;	    *direc = GINI_DOWN;            if ( betaval < -kkteps )            {               //printf("KKT Violation for (%d,%d), alpha=%3.7f, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,currptr->alpha,Y[dataind][i],betaval,z);               *decision = i;	       *grad = currptr->E;	       // Set its shrink level = 0.	       currptr->shrinklevel = 0;	       return GINI_TRUE;            }	    else            {               // The variable has reached the bound and its lagrangian is +ve	       // set its shrink level = 1	       currptr->shrinklevel = 1;	       //cachefull = kernel->ResetActivity(dataind);            }         }      }      else      {         curre = evaluateCache( dataind, i );	         betaval = 2*rdist*Y[dataind][i] - curre - bias[i] - z;	 if ( Y[dataind][i] > alphaeps )         {            if (( betaval < -kkteps ) || ( betaval > kkteps ))            {               //printf("KKT Violation for (%d,%d), alpha=0, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,Y[dataind][i],betaval,z);               *decision = i;	       *grad = curre;	       *direc = GINI_UP_DOWN;	       return GINI_TRUE;            }         }	 else         {            if ( betaval < -kkteps )            {               //printf("KKT Violation for (%d,%d), alpha=0, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,Y[dataind][i],betaval,z);               *decision = i;	       *grad = curre;	       *direc = GINI_DOWN;	       return GINI_TRUE;            }         }      }   }   return found;}       /*****************************************************************************/// FUNCTION  :     _kktcondition// // DESCRIPTION :   Computes the KKT condition for a data point and returns//                 true if it satisfies otherwise false.//                 We are going to use the theorem that for GiniSVM, for//                 each data point there exists one unbounded variable.//                 One has to be careful here because if there exist//                 a class where there are no corresponding labels then//                 this algorithm might not progress or the progress//                 becomes very slow. Therefore the point and the class//                 is chosen which violates the KKT condition the most.//// INPUT ://// OUTPUT ://///*****************************************************************************/GINI_bool GINI_SVMBlock::_kktsvcondition( GINI_u32 dataind, 		                        GINI_u32* decision,		                        GINI_double* grad,		                        GINI_status* direc                                      ){   GINI_u32 i,j,startpoint,count;   GINI_double betaval,z;   GINI_Set *currptr;   GINI_bool found = GINI_FALSE;    z = 0.0;    count = 0;    for ( i = 0; i < classes; i++ )    {        if ( (currptr = svmap[dataind][i]) != (GINI_Set*) GINI_NULL )        {            if ( currptr->alpha < C[dataind]*Y[dataind][i] - alphaeps )            {                 z +=  - currptr->E - bias[i];                 count++;            }        }    }    if (count == 0)    {         return GINI_FALSE;    }    else    {         z /= count;    }   // Randomly select one of the classes that violates the kkt   // condition. This prevents unneccessary bias when there   // exist classes which have no corresponding data points   //   //GINI_double minbeta = -1;   found = GINI_FALSE;    startpoint = rand()%classes;   for ( j = 0; j < classes; j++ )   {      i = (startpoint+j)%classes;      if ( (currptr = svmap[dataind][i]) != (GINI_Set*) GINI_NULL )      {	 if ( currptr->alpha < C[dataind]*Y[dataind][i] - alphaeps )         {            betaval =  - currptr->E - bias[i] - z;            if (( betaval < -kkteps ) || ( betaval > kkteps ))            {               //printf("KKT Violation for (%d,%d), alpha=%3.7f, Y=%1.1f, beta=%3.7f, z=%2.4f\n",dataind,i,currptr->alpha,Y[dataind][i],betaval,z);               *decision = i;	       *grad = currptr->E;	       *direc = GINI_UP_DOWN;	       return GINI_TRUE;            }         }      }   }   return found;}       /*****************************************************************************/// FUNCTION  :// // DESCRIPTION ://// INPUT ://// OUTPUT ://///*****************************************************************************/GINI_bool GINI_SVMBlock::StartTraining( GINI_bool precomp , GINI_u32 iter, GINI_bool verbose ){   // flag the training flag   mode = GINISVM_BLKTRN;   GINI_u32 numchanged = 0;   GINI_u32 examineAll = 1;   GINI_u32 i,classcount;    GINI_double currentcost,previouscost;   GINI_bool costflag =  GINI_FALSE;   GINI_u32 totalphase;   struct timeval tv;   struct timezone tz;   struct timeval tv1;   struct timezone tz1;   (void) gettimeofday(&tv1,&tz1);   // Number of truncations   floorcount = 0;   classcount = 0;   globalptr = (GINI_Set*) GINI_NULL;   // If precompute is true then the block computes all the   // kernel values for making the implementation faster.   if ( precomp == GINI_TRUE )   {      kernel->ComputeAll(traindata, dimension);   }   previouscost = CostFunction();   currentcost = previouscost;   printf("Starting Cost = %f\n",previouscost);   fflush(stdout);   // This is the outer loop for selecting the    // first coefficient.   // While number of iterations is less than specified   // or the total number of decrease in steps is greater   // or we have to examine all the data points.   while (( iterations < iter ) && (( numchanged > 0 ) ||          (examineAll))   )   {      numchanged = 0;      if ( examineAll )      {            printf("\nIteration =  %d\n",iterations);	    numofdel = 0;	    numofadd = 0;	    phase1 = 0;	    phase2 = 0;	    phase3 = 0;	    phase4 = 0;	    for ( i = 0; i < totaldata ; i++ )            {               (void) gettimeofday(&tv,&tz);	       numchanged += _examineExample(i);               timer2 += currtimeval(&tv);	       if ( numchanged%100 == 0 )               {                  if ( verbose == GINI_TRUE)                  {	             totalphase = phase1+phase2+phase3+phase4+1;                     printf("Number changed = %5d, #add = %5d , #del = %5d, #kswap = %5d, #phase1 = %3.0f, #phase2 = %3.0f, #phase3 = %3.0f, #phase4 = %3.0f\n",	   	          numchanged,numofadd,numofdel,kernel->Swapcount(),(GINI_float)phase1/totalphase*100,(GINI_float)phase2/totalphase*100,	   	          (GINI_float)phase3/totalphase*100,(GINI_float)phase4/totalphase*100);                     fflush(stdout);	             printtimers();                  }		  else                  {                     printf("#");                     fflush(stdout);                  }               }            }	    totalphase = phase1+phase2+phase3+phase4+1;	    if ( (GINI_float)phase4/totalphase*100 > (GINI_float)kktiter)            {               kkteps *= 2;               printf("\nIncreasing the KKT tolerance to %f\n",kkteps);            }      }      else      {         if ( iterations < fpass )         {            //printf("Examining only svsets\n");            // Iterate over sv sets	    for ( i = 0; i < classes; i++ )            {               // Now iterate of the set i	       classcount = numchanged;	       numofdel = 0;	       numofadd = 0;               globalptr = svset[i];	       while ( globalptr != (GINI_Set*) GINI_NULL )               {                  if ( globalptr->shrinklevel == 0 )                  {                     (void) gettimeofday(&tv,&tz);                     numchanged += _examineExample(globalptr->dataind);                     timer3 += currtimeval(&tv);                  }	          if (globalptr != (GINI_Set*) GINI_NULL)	          { 	             globalptr = globalptr->next;                  }	          if ( (numchanged+1)%100 == 0 )                  {	             if ( verbose == GINI_FALSE )                     {                        printf("#");                        fflush(stdout);                     }                  }               }	       //_purgesvlist();	       if ( verbose == GINI_TRUE )               {	          totalphase = phase1+phase2+phase3+phase4+1;                  printf("Classid = %4d, Number changed = %5d, #sv = %5d #add = %5d , #del = %5d, #kswap = %5d, #phase1 = %3.0f, #phase2 = %3.0f, #phase3 = %3.0f, #phase4 = %3.0f\n",   	          i,classcount,setsize[i],numofadd,numofdel,kernel->Swapcount(),(GINI_float)phase1/totalphase*100,(GINI_float)phase2/totalphase*100,     	          (GINI_float)phase3/totalphase*100,(GINI_float)phase4/totalphase*100);		  if ( setsize[i] > 0 )                  {		     printf("Max coeff = (%3.4f,%3.4f), Min coeff = (%3.4f,%3.4f)\n",maxE[i]->alpha,Y[maxE[i]->dataind][i],minE[i]->alpha,Y[minE[i]->dataind][i]);		  }                   fflush(stdout);	          phase1 = 0;	          phase2 = 0;	          phase3 = 0;	          phase4 = 0;	          printtimers();               }	       else               {                  printf("#");                  fflush(stdout);               }            }	    //_purgesvlist();	    globalptr = (GINI_Set*) GINI_NULL;            classcount = numchanged - classcount;	    if ( verbose == GINI_TRUE )            {	       totalphase = phase1+phase2+phase3+phase4+1;               printf("Number changed = %5d, #add = %5d , #del = %5d, #kswap = %5d, #phase1 = %3.0f, #phase2 = %3.0f, #phase3 = %3.0f, #phase4 = %3.0f\n",   	       classcount,numofadd,numofdel,kernel->Swapcount(),(GINI_float)phase1/totalphase*100,(GINI_float)phase2/totalphase*100,     	       (GINI_float)phase3/totalphase*100,(GINI_float)phase4/totalphase*100);               fflush(stdout);	       phase1 = 0;	       phase2 = 0;	       phase3 = 0;	       phase4 = 0;	       printtimers();            }	    else            {               printf("#");               fflush(stdout);            }	 }	 else         {	    for ( i = 0; i < classes; i++ )            {               // Now iterate of the set i	       classcount = numchanged;	       numofdel = 0;	       numofadd = 0;               globalptr = svset[i];	       while ( globalptr != (GINI_Set*) GINI_NULL )               {                  if ( globalptr->shrinklevel == 0 )                  {                     (void) gettimeofday(&tv,&tz);                     numchanged += _examinesvExample(globalptr->dataind);                     timer3 += currtimeval(&tv);                  }	          if (globalptr != (GINI_Set*) GINI_NULL)	          { 	             globalptr = globalptr->next;                  }	          if ( (numchanged+1)%100 == 0 )                  {	             if ( verbose == GINI_FALSE )                     {                        printf("#");                        fflush(stdout);                     }                  }               }	       //_purgesvlist();            }	    //_purgesvlist();	    globalptr = (GINI_Set*) GINI_NULL;            classcount = numchanged - classcount;	    if ( verbose == GINI_TRUE )            {	       totalphase = phase1+phase2+phase3+phase4+1;               printf("Number changed = %5d, #add = %5d , #del = %5d, #phase1 = %3.0f, #phase2 = %3.0f, #phase3 = %3.0f, #phase4 = %3.0f\n",   	             classcount,numofadd,numofdel,(GINI_float)phase1/totalphase*100,(GINI_float)phase2/totalphase*100,   		 (GINI_float)phase3/totalphase*100,(GINI_float)phase4/totalphase*100);               fflush(stdout);	       phase1 = 0;	       phase2 = 0;	       phase3 = 0;	       phase4 = 0;	       printtimers();            }	    else            {                printf("#");                fflush(stdout);            }         }      }      //printf("Number changed = %d\n",numchanged);      if (iterations%costwindow == 0)      {         currentcost = CostFunction();	 // If the current cost function is costeps times les	 // that the previous cost then all the changes	 // made during the previous updates dont count.	 if ( previouscost - currentcost < costeps*previouscost )         {            numchanged = 0;	    if ( costflag == GINI_FALSE )            {               costflag = GINI_TRUE;            }	    else            {               break;            }         }	 else         {            costflag = GINI_FALSE;         }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -