⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ginitrain.cpp

📁 一种新的SVM算法
💻 CPP
📖 第 1 页 / 共 2 页
字号:
         count++;	 nhits = (GINI_u32)atoi(argv[count++]);	 validarg = GINI_TRUE;      }      if ( strcmp(argv[count],"-srch") == 0 )      {         count++;	 srch = (GINI_u32)atoi(argv[count++]);	 validarg = GINI_TRUE;      }      if ( strcmp(argv[count],"-csize") == 0 )      {         count++;	 csize = (GINI_u32)atoi(argv[count++]);	 validarg = GINI_TRUE;      }      if ( strcmp(argv[count],"-fpass") == 0 )      {         count++;	 fpass = (GINI_u32)atoi(argv[count++]);	 validarg = GINI_TRUE;      }      if ( validarg == GINI_FALSE )      {          PrintHelp();	  printf("Invalid Argument: %s\n",argv[count]);	  return 1;      }   }   if ( count > argc-2 )   {       PrintHelp();       return 1;   }   // Read the input training file and read the header.   // Read in the data from the file   FILE *fp = fopen(argv[argc-2],"r");   if ( fp == (FILE*) GINI_NULL )   {      printf("Training File not Present\n");      return 1;   }   FILE *fout = fopen(argv[argc-1],"w");   if ( fout == (FILE*) GINI_NULL )   {      printf("Cannot Open Configuration File for writing\n");      return 1;   }   GINI_float value;   if ( ktype != GINISVMDTK )   {      // First read in the dimensionality of the input      // vectors.      fscanf(fp,"%f\n",&value);      SVM_DIMENSION = (GINI_u32)value;   }   else   {      SVM_DIMENSION = 0;   }   // Number of classes   fscanf(fp,"%f\n",&value);   SVM_CLASS = (GINI_u32)value;   // Total Number of training points   fscanf(fp,"%f\n",&value);   SVM_DATA = (GINI_u32)value;   if ( srch == 0 )   {      srch = SVM_DATA;   }   if ( precomp == GINI_TRUE )   {      if ( SVM_DATA > csize )      {         printf("Cache size less than the total data points\n");	 return 1;      }   }   // Initialize the kernel depending on the ktype.   switch (ktype)   {	   case GINISVMGAUSSIAN :		  kernel = new GINI_GaussianKernel(p1,sp);		  break;	   case GINISVMPOLY :		  kernel = new GINI_PolyKernel(p1,p2,(GINI_u32)p3,sp);		  break;	   case GINISVMDTK :		  kernel = new GINI_DTKKernel(p1,p2,p3,p4);		  SVM_DIMENSION = 0;		  break;	   case GINISVMTANH :		  kernel = new GINI_TanhKernel(p1,sp);		  break;   }   // Initialize the kernel cache with csize    kernel->InitializeCache(csize,SVM_DATA);     // Now define svm block and initialize it with the kernel   GINI_SVMBlock *svmmachine = new GINI_SVMBlock(kernel);   // Initialize the regularization constant array. For   // this particular case all the values are fixed.   GINI_double *inpC = new (GINI_double)[SVM_DATA];   for ( GINI_u32 i = 0; i < SVM_DATA; i++ )   {      inpC[i] = C;   }    // Initialize the rate distortion factor.  Rate   // distortion factor determines how stable the optimization   // is and also controls sparsity of solution.   if ( B == 0.0 )   {      B = ((GINI_double)SVM_CLASS/(SVM_CLASS-1))*log((GINI_double)SVM_CLASS);    }   // Initialize the training    svmmachine->InitTraining( SVM_DATA,       // Total training data		             SVM_DIMENSION,  // Feature dimension			     SVM_CLASS,      // Total number of classes			     inpC,           // Regularization constant array			     B,              // Rate distortion factor			     aeps,           // coefficient tolerance			     keps,           // KKT tolerance			     win,            // Random search window.			     ceps,           // Tolerance for Cost function decrease			     niter,          // Number of Iterations after which the			     nhits,          // the cost function is computed.			     srch,            // Maximum sv search window			     liter,           // Maximum sv search window			     fpass           // First level passes.			 );   // Data structures to read in the training data.   GINI_double *testvec;   GINI_double *label;   GINI_double cval;   for ( GINI_u32 i = 0; i < SVM_DATA; i++ )   {      if ( pflag == GINI_FALSE )      {         fscanf(fp,"%f\n",&value);         if ( value >= SVM_CLASS)         {            printf("Improper Label: Expected < %d, Got %d\n",SVM_CLASS,(GINI_u32)value);            return 1;         }         // Allocate memory for the label vector         label = new GINI_double[SVM_CLASS];         for ( GINI_u32 j = 0; j< SVM_CLASS; j++ )         {            if ( j != value)            {               label[j]= 0;            }            else            {               label[j] = 1;            }         }      }      else      {         // Allocate memory for the label vector         label = new GINI_double[SVM_CLASS];         for ( GINI_u32 j = 0; j< SVM_CLASS; j++ )         {            fscanf(fp,"%f\n",&value);            if (( value > 1.0) || ( value < 0.0 ))            {               printf("%f is not a valid probability measure\n",value);               return 1;            }	    label[j] = value;         }      }      cval = 1.0;      if ( cflag == GINI_TRUE )      {         // Read in the C value for this data point         fscanf(fp,"%f\n",&value);         cval = (GINI_double)value;      }      if ( sp == GINI_FALSE )      {         // If the data is in a non-sparse format	 if ( ktype == GINISVMDTK )         {            // If this is a DTK kernel then we have to	    // process variable length data.            fscanf(fp,"%f\n",&value);            testvec = new GINI_double[(GINI_u32)value+1];	    testvec[0] = (GINI_double)value;            for ( GINI_u32 j = 0; j< (GINI_u32)testvec[0]; j++ )            {               fscanf(fp,"%f\n",&value);               testvec[j+1] = (GINI_double)value;            }         }	 else         {            // Allocate memory for the training vector            testvec = new GINI_double[SVM_DIMENSION];            for ( GINI_u32 j = 0; j< SVM_DIMENSION; j++ )            {               fscanf(fp,"%f\n",&value);               testvec[j] = (GINI_double)value;            }         }      }      else      {         // If the data is in a sparse format.         fscanf(fp,"%f\n",&value);         testvec = new GINI_double[(GINI_u32)(2*value)+1];	 testvec[0] = (GINI_double)value;         for ( GINI_u32 j = 0; j< (GINI_u32)testvec[0]; j++ )         {            // Read in the Index and then the value.            fscanf(fp,"%f\n",&value);            testvec[2*j+1] = (GINI_double)value;            fscanf(fp,"%f\n",&value);            testvec[2*j+2] = (GINI_double)value;         }      }      // Insert training vector and its corresponding label.      svmmachine->InsertTrainingData(label,testvec,cval);   }   fclose(fp);   // Train the svm machine, with an upper limit of 20000   // optimization steps.   if ( svmmachine->StartTraining(precomp,20000,verbose) == GINI_FALSE )   {      printf(" SVM training Failed \n");    }   // End of training to clean up all the data structures.   GINI_u32 unconverged = svmmachine->StopTraining();   printf("Total KKT Violations = %d\n",unconverged);   // Print out the parameters of SVM   GINI_u32 nsv = svmmachine->GetSize();   printf("Number of SVs=%d\n",nsv);   printf("-------------------------------------------------------------------\n");   printf("    BIAS Information for different classes                         \n");   printf("-------------------------------------------------------------------\n");   for ( GINI_u32 j = 0; j< SVM_CLASS; j++ )   {        printf("value of bias for class %d = %3.7f\n",j,svmmachine->GetThreshold(j));   }   // Write the SVM output    if ( svmmachine->Write(fout) == GINI_FALSE )   {      printf("Failed to save SVM parameter\n");   }   fclose(fout);   // Free all the memory.   delete svmmachine;   delete kernel;   delete [] inpC;   return 0;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -