kfold.c
来自「ADaM is a data mining and image processi」· C语言 代码 · 共 256 行
C
256 行
/* Logistic Regression using Truncated Iteratively Re-weighted Least Squares (includes several programs) Copyright (C) 2005 Paul Komarek This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA Author: Paul Komarek, komarek@cmu.edu Alternate contact: Andrew Moore, awm@cs.cmu.edu*//* File: kfold.c Author: Paul Komarek Created: Thu Jun 12 03:30:15 EDT 2003 Description: k-fold cross-validation. Copyright 2003, The Auton Lab, CMU*/#include <stdio.h>#include <string.h>#include <time.h>#include "amiv.h"#include "amdyv.h"#include "amdym.h"#include "spardat.h"#include "lrutils.h"#include "lr.h"#include "train.h"#include "predict.h"#include "kfoldsumm.h"#include "kfold.h"/**************************************************************************//* KFOLD *//**************************************************************************/dyv *mk_fold_predictions( const spardat *sp, const dym *factors, const dyv *outputs, ivec *train_rows, ivec *test_rows, lr_options *opts){ int numtrain, numtest; dyv *preds, *trainout, *testout; spardat *trainsp, *testsp; dym *trainfactors, *testfactors; lr_predict *lrp; numtrain = ivec_size( train_rows); numtest = ivec_size( test_rows); trainsp = NULL; testsp = NULL; trainfactors = NULL; testfactors = NULL; trainout = NULL; testout = NULL; if (sp != NULL) { trainsp = mk_spardat_from_subset_of_rows( sp, train_rows); testsp = mk_spardat_from_subset_of_rows( sp, test_rows); } else { trainfactors = mk_dym_subset( factors, train_rows, NULL); testfactors = mk_dym_subset( factors, test_rows, NULL); trainout = mk_dyv_subset( outputs, train_rows); testout = mk_dyv_subset( outputs, test_rows); } lrp = mk_train_lr_predict( trainsp, trainfactors, trainout, opts); preds = mk_predicts( testsp, testfactors, testout, lrp); free_lr_predict( lrp); if (trainout != NULL) free_dyv( trainout); if (testout != NULL) free_dyv( testout); if (trainsp != NULL) free_spardat( trainsp); if (testsp != NULL) free_spardat( testsp); if (trainfactors != NULL) free_dym( trainfactors); if (testfactors != NULL) free_dym( testfactors); return preds;}void kfold_run_folds( kfoldsumm *kfs, const spardat *sp, const dym *factors, const dyv *outputs, int folds, lr_options *opts){ int fold, size; time_t start, stop; ivec *train_rows, *test_rows; dyv *predictions; if (sp != NULL) size = spardat_num_rows( sp); else size = dym_rows( factors); for (fold=0; fold<folds; ++fold) { /* Get list of train rows and test rows. */ make_kfold_rows( NULL, size, folds, fold, &train_rows, &test_rows); /* Make predictions. */ start =time( NULL); predictions = mk_fold_predictions( sp, factors, outputs, train_rows, test_rows, opts); stop = time( NULL); /* Done with training rows. */ free_ivec(train_rows); /* Write predictions into data structure. */ kfoldsumm_set_fold_time( kfs, fold, stop-start); kfoldsumm_set_subfoldnums( kfs, test_rows, fold); kfoldsumm_set_subpredicts( kfs, test_rows, predictions); /* Free remaining per-iteration stuff. */ free_ivec(test_rows); free_dyv(predictions); } return;}void run_kfold( char *inname, int folds, char *pout, char *fout, char *rout, int argc, char **argv){ int csv, numrows; ivec *outputs; dym *factors; dyv *dvoutputs; spardat *sp; lr_options *opts; kfoldsumm *kfs; csv = string_has_suffix( inname, ".csv"); csv |= string_has_suffix( inname, ".csv.gz"); /* Parse lr options. */ opts = mk_lr_options(); parse_lr_options( opts, argc, argv); check_lr_options( opts, argc, argv); /* Load full data file. */ sp = NULL; factors = NULL; dvoutputs = NULL; if (!csv) { sp = mk_spardat_from_pfilename( inname, argc, argv); numrows = spardat_num_rows( sp); outputs = mk_copy_ivec( sp->row_to_outval); } else { mk_read_dym_for_csv( inname, &factors, &dvoutputs); if (!dyv_is_binary( dvoutputs)) { my_error( "run_kfold: Error: csv output column is not binary.\n"); } numrows = dym_rows( factors); outputs = mk_ivec_from_dyv( dvoutputs); } /* Run folds. */ kfs = mk_kfoldsumm( folds, numrows); kfold_run_folds( kfs, sp, factors, dvoutputs, folds, opts); if (sp != NULL) free_spardat( sp); if (factors != NULL) free_dym( factors); if (dvoutputs != NULL) free_dyv( dvoutputs); free_lr_options( opts); /* Done. */ kfoldsumm_update_stats( kfs, outputs); free_ivec( outputs); fprintf_kfoldsumm_stats( stdout, "", kfs); if (fout != NULL) kfoldsumm_save_foldnums( kfs, fout); if (pout != NULL) kfoldsumm_save_predictions( kfs, pout); if (rout != NULL) kfoldsumm_save_roc( kfs, rout, inname); free_kfoldsumm( kfs); return;}/**************************************************************************//* MAIN *//**************************************************************************/static void usage( char *progname){ printf( "\n"); printf( "Usage:\n"); printf( "%s in <train_datafile> folds <int> [options]\n", progname); printf( "\n"); printf( "Options:\n"); printf( " fout <filename> Save fold assignments.\n"); printf( " pout <filename> Save aggregate predictions.\n"); printf( " rout <filename> Save ROC curve.\n"); printf( "\n"); printf( "This program runs a k-fold cross-validation on the specified\n"); printf( "data file. Statistics are reported at the end. Use the\n"); printf( "option verbosity <int> to increase the amount of information\n"); printf( "printed at the end.\n"); printf( "\n");}void kfold_main( int argc, char **argv){ char *inname, *poutname, *foutname, *routname; int folds; inname = string_from_args( "in", argc, argv, ""); folds = int_from_args( "folds", argc, argv, -1); foutname = string_from_args( "fout", argc, argv, ""); poutname = string_from_args( "pout", argc, argv, ""); routname = string_from_args( "rout", argc, argv, ""); if (!strcmp( inname, "")) { fprintf( stderr, "\nkfold_main: no datafile was specified with 'in'.\n\n"); usage( argv[0]); exit(-1); } if (folds<0) { fprintf( stderr, "\nkfold_main: number of folds not specified, or specified but\n" "negative.\n\n"); usage( argv[0]); exit(-1); } if (!strcmp( foutname, "")) foutname = NULL; if (!strcmp( poutname, "")) poutname = NULL; if (!strcmp( routname, "")) routname = NULL; run_kfold( inname, folds, poutname, foutname, routname, argc, argv); return;}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?