📄 ginitrain.cpp
字号:
count++; nhits = (GINI_u32)atoi(argv[count++]); validarg = GINI_TRUE; } if ( strcmp(argv[count],"-srch") == 0 ) { count++; srch = (GINI_u32)atoi(argv[count++]); validarg = GINI_TRUE; } if ( strcmp(argv[count],"-csize") == 0 ) { count++; csize = (GINI_u32)atoi(argv[count++]); validarg = GINI_TRUE; } if ( strcmp(argv[count],"-fpass") == 0 ) { count++; fpass = (GINI_u32)atoi(argv[count++]); validarg = GINI_TRUE; } if ( validarg == GINI_FALSE ) { PrintHelp(); printf("Invalid Argument: %s\n",argv[count]); return 1; } } if ( count > argc-2 ) { PrintHelp(); return 1; } // Read the input training file and read the header. // Read in the data from the file FILE *fp = fopen(argv[argc-2],"r"); if ( fp == (FILE*) GINI_NULL ) { printf("Training File not Present\n"); return 1; } FILE *fout = fopen(argv[argc-1],"w"); if ( fout == (FILE*) GINI_NULL ) { printf("Cannot Open Configuration File for writing\n"); return 1; } GINI_float value; if ( ktype != GINISVMDTK ) { // First read in the dimensionality of the input // vectors. fscanf(fp,"%f\n",&value); SVM_DIMENSION = (GINI_u32)value; } else { SVM_DIMENSION = 0; } // Number of classes fscanf(fp,"%f\n",&value); SVM_CLASS = (GINI_u32)value; // Total Number of training points fscanf(fp,"%f\n",&value); SVM_DATA = (GINI_u32)value; if ( srch == 0 ) { srch = SVM_DATA; } if ( precomp == GINI_TRUE ) { if ( SVM_DATA > csize ) { printf("Cache size less than the total data points\n"); return 1; } } // Initialize the kernel depending on the ktype. switch (ktype) { case GINISVMGAUSSIAN : kernel = new GINI_GaussianKernel(p1,sp); break; case GINISVMPOLY : kernel = new GINI_PolyKernel(p1,p2,(GINI_u32)p3,sp); break; case GINISVMDTK : kernel = new GINI_DTKKernel(p1,p2,p3,p4); SVM_DIMENSION = 0; break; case GINISVMTANH : kernel = new GINI_TanhKernel(p1,sp); break; } // Initialize the kernel cache with csize kernel->InitializeCache(csize,SVM_DATA); // Now define svm block and initialize it with the kernel GINI_SVMBlock *svmmachine = new GINI_SVMBlock(kernel); // Initialize the regularization constant array. For // this particular case all the values are fixed. GINI_double *inpC = new (GINI_double)[SVM_DATA]; for ( GINI_u32 i = 0; i < SVM_DATA; i++ ) { inpC[i] = C; } // Initialize the rate distortion factor. Rate // distortion factor determines how stable the optimization // is and also controls sparsity of solution. if ( B == 0.0 ) { B = ((GINI_double)SVM_CLASS/(SVM_CLASS-1))*log((GINI_double)SVM_CLASS); } // Initialize the training svmmachine->InitTraining( SVM_DATA, // Total training data SVM_DIMENSION, // Feature dimension SVM_CLASS, // Total number of classes inpC, // Regularization constant array B, // Rate distortion factor aeps, // coefficient tolerance keps, // KKT tolerance win, // Random search window. ceps, // Tolerance for Cost function decrease niter, // Number of Iterations after which the nhits, // the cost function is computed. srch, // Maximum sv search window liter, // Maximum sv search window fpass // First level passes. ); // Data structures to read in the training data. GINI_double *testvec; GINI_double *label; GINI_double cval; for ( GINI_u32 i = 0; i < SVM_DATA; i++ ) { if ( pflag == GINI_FALSE ) { fscanf(fp,"%f\n",&value); if ( value >= SVM_CLASS) { printf("Improper Label: Expected < %d, Got %d\n",SVM_CLASS,(GINI_u32)value); return 1; } // Allocate memory for the label vector label = new GINI_double[SVM_CLASS]; for ( GINI_u32 j = 0; j< SVM_CLASS; j++ ) { if ( j != value) { label[j]= 0; } else { label[j] = 1; } } } else { // Allocate memory for the label vector label = new GINI_double[SVM_CLASS]; for ( GINI_u32 j = 0; j< SVM_CLASS; j++ ) { fscanf(fp,"%f\n",&value); if (( value > 1.0) || ( value < 0.0 )) { printf("%f is not a valid probability measure\n",value); return 1; } label[j] = value; } } cval = 1.0; if ( cflag == GINI_TRUE ) { // Read in the C value for this data point fscanf(fp,"%f\n",&value); cval = (GINI_double)value; } if ( sp == GINI_FALSE ) { // If the data is in a non-sparse format if ( ktype == GINISVMDTK ) { // If this is a DTK kernel then we have to // process variable length data. fscanf(fp,"%f\n",&value); testvec = new GINI_double[(GINI_u32)value+1]; testvec[0] = (GINI_double)value; for ( GINI_u32 j = 0; j< (GINI_u32)testvec[0]; j++ ) { fscanf(fp,"%f\n",&value); testvec[j+1] = (GINI_double)value; } } else { // Allocate memory for the training vector testvec = new GINI_double[SVM_DIMENSION]; for ( GINI_u32 j = 0; j< SVM_DIMENSION; j++ ) { fscanf(fp,"%f\n",&value); testvec[j] = (GINI_double)value; } } } else { // If the data is in a sparse format. fscanf(fp,"%f\n",&value); testvec = new GINI_double[(GINI_u32)(2*value)+1]; testvec[0] = (GINI_double)value; for ( GINI_u32 j = 0; j< (GINI_u32)testvec[0]; j++ ) { // Read in the Index and then the value. fscanf(fp,"%f\n",&value); testvec[2*j+1] = (GINI_double)value; fscanf(fp,"%f\n",&value); testvec[2*j+2] = (GINI_double)value; } } // Insert training vector and its corresponding label. svmmachine->InsertTrainingData(label,testvec,cval); } fclose(fp); // Train the svm machine, with an upper limit of 20000 // optimization steps. if ( svmmachine->StartTraining(precomp,20000,verbose) == GINI_FALSE ) { printf(" SVM training Failed \n"); } // End of training to clean up all the data structures. GINI_u32 unconverged = svmmachine->StopTraining(); printf("Total KKT Violations = %d\n",unconverged); // Print out the parameters of SVM GINI_u32 nsv = svmmachine->GetSize(); printf("Number of SVs=%d\n",nsv); printf("-------------------------------------------------------------------\n"); printf(" BIAS Information for different classes \n"); printf("-------------------------------------------------------------------\n"); for ( GINI_u32 j = 0; j< SVM_CLASS; j++ ) { printf("value of bias for class %d = %3.7f\n",j,svmmachine->GetThreshold(j)); } // Write the SVM output if ( svmmachine->Write(fout) == GINI_FALSE ) { printf("Failed to save SVM parameter\n"); } fclose(fout); // Free all the memory. delete svmmachine; delete kernel; delete [] inpC; return 0;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -