📄 mlp.cc
字号:
const char *help = "\progname: mlp.cc\n\code2html: This program trains a MLP for 2 class classification.\n\version: Torch3 vision2.1, 2003-2006\n\(c) Sebastien Marcel (marcel@idiap.ch)\n";/** Torch*/#include "Random.h"#include "DiskXFile.h"// datasets#include "MatDataSet.h"// criterions#include "ClassNLLCriterion.h"#include "TwoClassNLLCriterion.h"#include "MSECriterion.h"// class formats#include "TwoClassFormat.h"// measurers#include "ClassMeasurer.h"#include "MSEMeasurer.h"#include "NLLMeasurer.h"// trainers#include "StochasticGradient.h"#include "KFold.h"// command-lines#include "CmdLine.h"#include "FileListCmdOption.h"// string utils#include "string_utils.h"/** Torch3vision*/// datasets#include "FileBinDataSet.h"// custommachines#include "MyMLP.h"#include "MeanVarNorm.h"#include "MyMeanVarNorm.h"// image processing#include "ipHistoEqual.h"#include "ipSmoothGaussian3.h"using namespace Torch;// check data filesint checkFiles(int n_files, char **file_names);//bool verbose;int main(int argc, char **argv){ // int n_inputs; // real the_target; // int n_hu; // int width_pattern; int height_pattern; // int max_load; int the_seed; // real accuracy; real learning_rate; real decay; int max_iter; int k_fold; real weight_decay; // bool use_mse; bool use_nll; bool use_linear_output; bool image_normalize; // char *dir_name; char *model_file; char *output_file; // Allocator *allocator = new Allocator; DiskXFile::setLittleEndianMode(); //=================== The command-line ========================== FileListCmdOption filelist_class1("file name", "the list files or one data file of positive patterns"); filelist_class1.isArgument(true); FileListCmdOption filelist_class0("file name", "the list files or one data file of negative patterns"); filelist_class0.isArgument(true); FileListCmdOption filelist_class("file name", "the list files or one data file of patterns"); filelist_class.isArgument(true); // Construct the command line CmdLine cmd; cmd.setBOption("write log", false); // Put the help line at the beginning cmd.info(help); // Train mode cmd.addText("\nArguments:"); cmd.addCmdOption(&filelist_class1); cmd.addCmdOption(&filelist_class0); cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data"); cmd.addText("\nModel Options:"); cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden units"); cmd.addText("\nLearning Options:"); cmd.addICmdOption("-iter", &max_iter, 25, "max number of iterations"); cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning rate"); cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy"); cmd.addRCmdOption("-lrd", &decay, 0, "learning rate decay"); cmd.addICmdOption("-kfold", &k_fold, -1, "number of folds, if you want to do cross-validation"); cmd.addRCmdOption("-wd", &weight_decay, 0, "weight decay"); cmd.addBCmdOption("-mse", &use_mse, false, "use MSE criterion"); cmd.addBCmdOption("-nll", &use_nll, false, "use NLL criterion"); cmd.addBCmdOption("-linear", &use_linear_output, false, "use linear output (tanh otherwise)"); cmd.addRCmdOption("-target", &the_target, 0.6, "the target value (overrided to 1 if the NLL criterion is chosen)"); cmd.addText("\nMisc Options:"); cmd.addICmdOption("-seed", &the_seed, -1, "the random seed"); cmd.addICmdOption("-load", &max_load, -1, "max number of examples to load for train"); cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures"); cmd.addSCmdOption("-save", &model_file, "", "the model file"); cmd.addBCmdOption("-verbose", &verbose, false, "verbose"); cmd.addText("\nImage Options:"); cmd.addICmdOption("-width", &width_pattern, 64, "the width of the pattern"); cmd.addICmdOption("-height", &height_pattern, 80, "the height of the pattern"); cmd.addBCmdOption("-imagenorm", &image_normalize, false, "considers the input pattern as an image and performs a photometric normalization"); // Test mode cmd.addMasterSwitch("--test"); cmd.addText("\nArguments:"); cmd.addSCmdArg("model", &model_file, "the model file"); cmd.addCmdOption(&filelist_class); cmd.addSCmdArg("output file", &output_file, "the file to save the output of the MLP according to this format: <file name> <mlp output>"); cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data"); cmd.addText("\nMisc Options:"); cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures"); cmd.addBCmdOption("-verbose", &verbose, false, "verbose"); cmd.addText("\nImage Options:"); cmd.addICmdOption("-width", &width_pattern, 64, "the width of the pattern"); cmd.addICmdOption("-height", &height_pattern, 80, "the height of the pattern"); cmd.addBCmdOption("-imagenorm", &image_normalize, false, "considers the input pattern as an image and performs a photometric normalization"); // Read mode cmd.addMasterSwitch("--read"); cmd.addText("\nArguments:"); cmd.addSCmdArg("model", &model_file, "the model file"); cmd.addText("\nOptions:"); cmd.addBCmdOption("-verbose", &verbose, false, "verbose"); // Read the command line int mode = cmd.read(argc, argv); // Training mode if(mode == 0) { if(verbose) print("Training mode\n"); // if(use_mse == use_nll) error("choose between MSE or NLL criterion."); // if(use_mse) if(verbose) print("Using MSE criterion ...\n"); if(use_nll) if(verbose) print("Using NLL criterion ...\n"); if(image_normalize) { if(verbose) print("Perform photometric normalization ...\n"); if(width_pattern * height_pattern != n_inputs) error("incorrect image size."); if(verbose) print("The input pattern is an %dx%d image.\n", width_pattern, height_pattern); } // if(the_seed == -1) Random::seed(); else Random::manualSeed((long)the_seed); // cmd.setWorkingDirectory(dir_name); // check data files int n_reminding_files_class1 = filelist_class1.n_files; n_reminding_files_class1 = checkFiles(filelist_class1.n_files, filelist_class1.file_names); int n_reminding_files_class0 = filelist_class0.n_files; n_reminding_files_class0 = checkFiles(filelist_class0.n_files, filelist_class0.file_names); // if(verbose) { print(" + class 1:\n"); print(" n_filenames = %d\n", n_reminding_files_class1); for(int i = 0 ; i < n_reminding_files_class1 ; i++) print(" filename[%d] = %s\n", i, filelist_class1.file_names[i]); print(" + class 0:\n"); print(" n_filenames = %d\n", n_reminding_files_class0); for(int i = 0 ; i < n_reminding_files_class0 ; i++) print(" filename[%d] = %s\n", i, filelist_class0.file_names[i]); } // Create the MLP MyMLP *mlp = NULL; if(n_hu != 0) { if(use_linear_output) mlp = new(allocator) MyMLP(2, n_inputs, "tanh", n_hu, "linear", 1); else mlp = new(allocator) MyMLP(2, n_inputs, "tanh", n_hu, "tanh", 1); } else { if(use_linear_output) mlp = new(allocator) MyMLP(n_inputs, "linear", 1); else mlp = new(allocator) MyMLP(n_inputs, "tanh", 1); } mlp->setWeightDecay(weight_decay); mlp->setPartialBackprop(); if(verbose) mlp->info(); // // Create the training dataset (normalize inputs) MeanVarNorm *mv_norm = NULL; if(use_nll) the_target = 1.0; FileBinDataSet *bindata = NULL; bindata = new(allocator) FileBinDataSet(filelist_class1.file_names, n_reminding_files_class1, the_target, filelist_class0.file_names, n_reminding_files_class0, -the_target, n_inputs); if(verbose) bindata->info(false); if(image_normalize) { ipHistoEqual *enhancing = new(allocator) ipHistoEqual(width_pattern, height_pattern, "float"); ipCore *smoothing = new(allocator) ipSmoothGaussian3(width_pattern, height_pattern, "gray", 0.25); for(int i=0; i< bindata->n_examples; i++) { bindata->setExample(i); enhancing->process(bindata->inputs); smoothing->process(enhancing->seq_out); for(int j = 0 ; j < width_pattern * height_pattern ; j++) bindata->inputs->frames[0][j] = smoothing->seq_out->frames[0][j]; } } mv_norm = new(allocator) MeanVarNorm(bindata); bindata->preProcess(mv_norm); // The list of measurers... MeasurerList measurers; // The class format TwoClassFormat *class_format = NULL; class_format = new(allocator) TwoClassFormat(bindata); // Measurers on the training dataset ClassMeasurer *class_meas = NULL; class_meas = new(allocator) ClassMeasurer(mlp->outputs, bindata, class_format, cmd.getXFile("classerror.measure")); measurers.addNode(class_meas); // the measurer MSEMeasurer *mse_meas = NULL; NLLMeasurer *nll_meas = NULL; mse_meas = new(allocator) MSEMeasurer(mlp->outputs, bindata, cmd.getXFile("mse.measure")); measurers.addNode(mse_meas); if(use_nll) { nll_meas = new(allocator) NLLMeasurer(mlp->outputs, bindata, cmd.getXFile("nll.measure")); measurers.addNode(nll_meas); } //=================== The Trainer =============================== // The criterion for the StochasticGradient (MSE criterion or NLL criterion) Criterion *criterion = NULL; if(use_mse) criterion = new(allocator) MSECriterion(1); //if(use_nll) criterion = new(allocator) ClassNLLCriterion(class_format); if(use_nll) criterion = new(allocator) TwoClassNLLCriterion(0.0); // The Gradient Machine Trainer StochasticGradient trainer(mlp, criterion); trainer.setIOption("max iter", max_iter); trainer.setROption("end accuracy", accuracy); trainer.setROption("learning rate", learning_rate); trainer.setROption("learning rate decay", decay); // // Print the number of parameter of the MLP (just for fun) if(verbose) message("Number of parameters: %d", mlp->params->n_params); if(k_fold <= 0) { trainer.train(bindata, &measurers); if(strcmp(model_file, "")) mlp->save(model_file, mv_norm); } else { if(verbose) print("Go go KFold.\n"); KFold k(&trainer, k_fold); k.crossValidate(bindata, NULL, &measurers); } } // Test mode if(mode == 1) { if(verbose) print("Test mode\n"); // ipHistoEqual *enhancing = NULL; ipCore *smoothing = NULL; if(image_normalize) { if(verbose) print("Perform photometric normalization ...\n"); if(width_pattern * height_pattern != n_inputs) error("incorrect image size."); if(verbose) print("The input pattern is an %dx%d image.\n", width_pattern, height_pattern); enhancing = new(allocator) ipHistoEqual(width_pattern, height_pattern, "float"); smoothing = new(allocator) ipSmoothGaussian3(width_pattern, height_pattern, "gray", 0.25); } // check data files int n_reminding_files = filelist_class.n_files; n_reminding_files = checkFiles(filelist_class.n_files, filelist_class.file_names); // if(verbose) { print("n_filenames = %d\n", n_reminding_files); for(int i = 0 ; i < n_reminding_files ; i++) print(" filename[%d] = %s\n", i, filelist_class.file_names[i]); } // MyMLP mlp; // Create the training dataset (normalize inputs) MyMeanVarNorm *mv_norm = NULL; mv_norm = new(allocator) MyMeanVarNorm(n_inputs, 1); mlp.load(model_file, mv_norm); if(verbose) mlp.info(); DiskXFile output_xfile(output_file,"w"); DataSet *bindata; for(int i = 0 ; i < n_reminding_files ; i++) { if(verbose) print("testing file %s\n", filelist_class.file_names[i]); real sum = 0.0; FileBinDataSet *bindata_ = new(allocator) FileBinDataSet(filelist_class.file_names[i], n_inputs); if(verbose) bindata_->info(false); bindata = bindata_; for(int t = 0; t < bindata->n_examples; t++) { bindata->setExample(t); if(image_normalize) { enhancing->process(bindata->inputs); smoothing->process(enhancing->seq_out); mv_norm->preProcessInputs(smoothing->seq_out); mlp.forward(smoothing->seq_out); } else { mv_norm->preProcessInputs(bindata->inputs); mlp.forward(bindata->inputs); } real output = mlp.outputs->frames[0][0]; //print(" -> %g\n", output); sum += output; } sum /= (real) bindata->n_examples; char *temp = strBaseName(filelist_class.file_names[i]); char *basename = strRemoveSuffix(temp); allocator->retain(basename); if(verbose) print("> %s %g\n", basename, sum); output_xfile.printf("%s %g\n", basename, sum); allocator->free(bindata_); } } // Read mode if(mode == 2) { if(verbose) print("Read mode\n"); MyMLP mlp; mlp.load(model_file, NULL); mlp.info(); } delete allocator; return(0);}#include <sys/stat.h>int checkFiles(int n_files, char **file_names){ int reminding_files = n_files; struct stat st; if(verbose) print("Checking files:\n"); int i = 0; //for(int i = 0 ; i < n_files ; i++) while(i < reminding_files) { if(verbose) print("Checking %s\n", file_names[i]); if(stat(file_names[i], &st) == -1) { warning("Couldn't stat file %s.", file_names[i]); for(int j = i ; j < reminding_files-1 ; j++) file_names[j] = file_names[j+1]; file_names[reminding_files-1] = NULL; reminding_files--; } else { if(!S_ISREG (st.st_mode)) { warning("not regular file %s.", file_names[i]); for(int j = i ; j < reminding_files-1 ; j++) file_names[j] = file_names[j+1]; file_names[reminding_files-1] = NULL; reminding_files--; } else i++; } } if(verbose) { print("Checked files (%d):\n", reminding_files); for(int i = 0 ; i < reminding_files ; i++) print("-> %s\n", file_names[i]); } return reminding_files;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -