📄 libs-learner.c
字号:
/*** Copyright (C) 2006 Thai Computational Linguistics Laboratory (TCL)** National Institute of Information and Communications Technology (NICT)** Canasai Kruengkrai <canasai xx gmail yy com, where xx=at and yy=dot>**** This file is part of the `libs' library.**** This library is free software; you can redistribute it and/or modify** it under the terms of the GNU General Public License as published by** the Free Software Foundation; either version 2 of the License, or** (at your option) any later version.**** This program is distributed in the hope that it will be useful,** but WITHOUT ANY WARRANTY; without even the implied warranty of** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the** GNU General Public License for more details.**** You should have received a copy of the GNU General Public License** along with this program; if not, write to the Free Software** Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.*/#include <stdio.h>#include <stdlib.h>#include <string.h>#include <ctype.h>#include <math.h>#include <unistd.h>#include "libsvm-string-2.71/svm.h"#define file_error( msg ) fprintf( stderr, "\nFILE ERROR: Could not open the file: %s\n", msg ), exit( 0 )#define fatal_error( msg ) fprintf( stderr, "\n%s\n", msg ), exit( 0 )#define malloc_error( str1, str2 ) fprintf( stderr, "ERROR: In function `%s': Could not allocate memory for `%s'\n", str1, str2 ), exit( 1 )#define Malloc( type, n ) ( type * )malloc( ( n ) * sizeof( type ) )#define MAX_STRING_LEN 1024struct svm_parameter param;struct svm_problem prob;struct svm_model *model;struct svm_node *x_space;void exit_with_help( char *prog_name ){ fprintf( stderr, "\nUsage:" "\n------" "\n%s config_file\n\n", prog_name ); exit( 0 );}void read_config_file( char *conf_file ){ // Set default values for libsvm param.svm_type = C_SVC; param.kernel_type = STRING; param.degree = 5.0; param.gamma = 2.0; param.coef0 = 1.5; param.nu = 0.5; param.cache_size = 40; param.C = 100; param.eps = 1e-3; param.p = 0.1; param.shrinking = 1; param.probability = 0; param.nr_weight = 0; param.weight_label = NULL; param.weight = NULL; param.mc_method = 0; param.cross_validation = 0; param.nr_fold = 0; int num_tokens = 0; int num_lines = 0; int longest_line = 0; text_scan( conf_file, &num_tokens, &num_lines, &longest_line ); char *line; if( ( line = Malloc( char, longest_line ) ) == NULL ) malloc_error( "read_config_file", "line" ); char name[MAX_STRING_LEN]; char value[MAX_STRING_LEN]; FILE *file_ptr = fopen( conf_file, "r" ); while( fgets( line, longest_line, file_ptr ) != NULL ) { if( text_not_blank( line ) && line[0] != '#' ) { if( line[ strlen( line ) - 1 ] == '\n' ) line[ strlen( line ) - 1 ] = '\0'; sscanf( line, "%[^=]=%[^=]", name, value ); if( strcmp( name, "tagged_file" ) == 0 ) { if( access( value, R_OK ) != 0 ) file_error( value ); param.tagged_file = text_copy( value ); } else if( strcmp( name, "label_file" ) == 0 ) { if( access( value, R_OK ) != 0 ) file_error( value ); param.label_file = text_copy( value ); } else if( strcmp( name, "model_file" ) == 0 ) { param.model_file = text_copy( value ); } else if( strcmp( name, "cv_result_file" ) == 0 ) { param.cv_result_file = text_copy( value ); } else if( strcmp( name, "degree" ) == 0 ) { param.degree = atof( value ); } else if( strcmp( name, "gamma" ) == 0 ) { param.gamma = atof( value ); if( param.gamma < 0 || param.gamma > 3 ) fatal_error( "ERROR: In function `read_config_file': gamma is out of range [0,3]" ); } else if( strcmp( name, "coef0" ) == 0 ) { param.coef0 = atof( value ); } else if( strcmp( name, "mc_method" ) == 0 ) { param.mc_method = atoi( value ); } else if( strcmp( name, "cross_validation" ) == 0 ) { param.cross_validation = atoi( value ); } else if( strcmp( name, "nr_fold" ) == 0 ) { param.nr_fold = atoi( value ); } } } /* Print param values */ fprintf( stderr, "# tagged_file = [%s]\n", param.tagged_file ); fprintf( stderr, "# label_file = [%s]\n", param.label_file ); fprintf( stderr, "# model_file = [%s]\n", param.model_file ); fprintf( stderr, "# cv_result_file = [%s]\n", param.cv_result_file ); fprintf( stderr, "# degree = [%g]\n", param.degree ); fprintf( stderr, "# gamma = [%g]\n", param.gamma ); fprintf( stderr, "# coef0 = [%g]\n", param.coef0 ); fprintf( stderr, "# C = [%g]\n", param.C ); fprintf( stderr, "# mc_method = [%d]\n", param.mc_method ); fprintf( stderr, "# cross_validation = [%d]\n", param.cross_validation ); fprintf( stderr, "# nr_fold = [%d]\n\n", param.nr_fold );}void read_string_problem( char *tagged_file ){ int num_tokens = 0; int num_lines = 0; int longest_line = 0; int elements = 0; text_string_scan( tagged_file, &num_tokens, &num_lines, &longest_line, &elements ); char *line = ( char * )malloc( longest_line * sizeof( char ) ); prob.l = num_lines; if( ( prob.y = Malloc( double, prob.l ) ) == NULL ) malloc_error( "read_string_problem", "prob.y" ); if( ( prob.x = Malloc( struct svm_node *, prob.l ) ) == NULL ) malloc_error( "read_string_problem", "prob.x" ); if( ( x_space = Malloc( struct svm_node, elements ) ) == NULL ) malloc_error( "read_string_problem", "x_space" ); int i = 0; int j = 0; FILE *file_ptr = fopen( tagged_file, "r" ); while( fgets( line, longest_line, file_ptr ) != NULL ) { if( text_not_blank( line ) ) { line[ strlen( line ) - 1 ] = '\0'; char *tmp_ptr = text_copy4( line ); char *tmp_ptr2 = tmp_ptr; tmp_ptr2 = strchr( tmp_ptr2, ' ' ); if( tmp_ptr2 == NULL ) fatal_error( "ERROR: In function `read_string_problem': Every sample must has a class label" ); *tmp_ptr2 = '\0'; ++tmp_ptr2; double label = atof( tmp_ptr ); char *string = text_copy( tmp_ptr2 ); prob.x[i] = &x_space[j]; prob.y[i] = label; int str_length = strlen( string ); int k; for( k = 0; k < str_length; k++ ) { x_space[j].index = k; x_space[j].value = ( unsigned char )string[k]; ++j; } x_space[j++].index = -1; i++; free( string ); free( tmp_ptr ); } } fclose( file_ptr );}void do_cross_validation( char *label_file ){ int i; int total_correct = 0; double *target = Malloc( double, prob.l ); char **label_name = NULL; int max_classname_length = 0; int nr_class; label_name = load_label( label_file, &max_classname_length ); sk_svm_cross_validation( &prob, ¶m, target, label_name, max_classname_length, &nr_class ); int **conf_mat = ( int ** )malloc( nr_class * sizeof( int * ) ); int j, k; for( j = 0; j < nr_class; j++ ) { conf_mat[j] = ( int * )malloc( nr_class * sizeof( int ) ); for( k = 0; k < nr_class; k++ ) conf_mat[j][k]=0; } FILE *file_ptr = fopen( param.cv_result_file, "w" ); for( i = 0; i < prob.l; i++ ) { if( target[i] == prob.y[i] ) ++total_correct; fprintf( file_ptr, "%g %g\n", prob.y[i], target[i] ); conf_mat[( int )prob.y[i]][( int )target[i]]++; } fclose( file_ptr ); double *precision = Malloc( double, nr_class ); double *recall = Malloc( double, nr_class ); double *F_1 = Malloc( double, nr_class ); for( j = 0; j < nr_class; j++ ) recall[j] = precision[j] = F_1[j] = 0.0; for( j = 0; j < nr_class; j++ ) { for( k = 0; k < nr_class; k++ ) { recall[j] += conf_mat[j][k]; precision[k] += conf_mat[j][k]; } } double macro_avg_precision = 0.0; double macro_avg_recall = 0.0; double macro_avg_F_1 = 0.0; for( j = 0; j < nr_class; j++ ) { if( precision[j] != 0.0 ) precision[j] = 100 * ( double )conf_mat[j][j] / precision[j]; if( recall[j] != 0.0 ) recall[j] = 100 * ( double )conf_mat[j][j] / recall[j]; if( ( precision[j] + recall[j] ) != 0.0 ) F_1[j] = ( 2 * precision[j] * recall[j] ) / ( precision[j] + recall[j] ); //printf ("prec=%1.2f, rec=%1.2f, F_1=%1.2f\n", precision[j], recall[j], F_1[j] ); macro_avg_precision += precision[j]; macro_avg_recall += recall[j]; macro_avg_F_1 += F_1[j]; } macro_avg_precision /= nr_class; macro_avg_recall /= nr_class; macro_avg_F_1 /= nr_class; double acc = ( 100.0 * total_correct / prob.l ); printf( "\nSummary\n" ); printf( "===============================================================\n" ); printf (" %*s & %*s & %*s & %*s \\\\\n", 9, "Precision", 9, "Recall", 9, "F_1", 9, "Accuracy" ); printf (" %9.2f & %9.2f & %9.2f & %9.2f \\\\\n\n", macro_avg_precision, macro_avg_recall, macro_avg_F_1, acc ); printf( "Note for accuracy: total correct %d out of %d (%g %%)\n\n", total_correct, prob.l, acc ); free( target ); free( conf_mat ); free( label_name );}void run( int argc, char **argv ){ read_config_file( argv[1] ); fprintf( stderr, "# Read tagged file `%s'...\n", param.tagged_file ); read_string_problem( param.tagged_file ); const char *error_msg = svm_check_parameter( &prob, ¶m ); if( error_msg ) { fprintf( stderr, "ERROR: %s\n", error_msg ); exit( 1 ); } if( param.cross_validation ) { fprintf( stderr, "# Do stratified cross validation...\n" ); do_cross_validation( param.label_file ); } else { fprintf( stderr, "# Train libs...\n" ); model = svm_train( &prob, ¶m ); fprintf( stderr, "# Save model file `%s'\n", param.model_file ); svm_save_model( param.model_file, model ); svm_destroy_model( model ); } svm_destroy_param( ¶m ); free( prob.y ); free( prob.x ); free( x_space );}int main( int argc, char **argv ){ if( argc < 2 ) exit_with_help( argv[0] ); run( argc, argv ); return( 0 );}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -