⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 trainmlp.cc

📁 torch tracking code, it is a good code
💻 CC
字号:
const char *help = "\progname: trainMLP.cc\n\code2html: This program trains a MLP with sigmoid outputs for 2 class classification.\n\version: Torch3 vision2.0, 2003-2005\n\(c) Sebastien Marcel (marcel@idiap.ch)\n";/** Torch*/#include "Random.h"#include "DiskXFile.h"// criterions#include "ClassNLLCriterion.h"#include "TwoClassNLLCriterion.h"#include "MSECriterion.h"// class formats#include "TwoClassFormat.h"// measurers#include "ClassMeasurer.h"#include "MSEMeasurer.h"#include "NLLMeasurer.h"// trainers#include "StochasticGradient.h"#include "KFold.h"// command-lines#include "CmdLine.h"#include "FileListCmdOption.h"/** Torch3vision*/// datasets#include "FileBinDataSet.h"// custommachines#include "MyMLP.h"#include "MeanVarNorm.h"// image processing#include "ipHistoEqual.h"#include "ipSmoothGaussian3.h"using namespace Torch;int main(int argc, char **argv){   	//  	int n_inputs;	//	real the_target;	//  	int n_hu;		//	int width_pattern;	int height_pattern;	//  	int max_load;  	int the_seed;	//  	real accuracy;  	real learning_rate;  	real decay;  	int max_iter;  	int k_fold;  	real weight_decay;	//	bool use_mse;	bool use_nll;	bool use_linear_output;	bool image_normalize;	//  	char *dir_name;  	char *model_file;	//  	Allocator *allocator = new Allocator;  	DiskXFile::setLittleEndianMode();  	//=================== The command-line ==========================        FileListCmdOption filelist_class1("file name", "the list files or one data file of positive patterns");        filelist_class1.isArgument(true);          FileListCmdOption filelist_class0("file name", "the list files or one data file of negative patterns");        filelist_class0.isArgument(true);  	// Construct the command line  	CmdLine cmd;	cmd.setBOption("write log", false);  	// Put the help line at the beginning  	cmd.info(help);  	// Train mode  	cmd.addText("\nArguments:");	cmd.addCmdOption(&filelist_class1);        cmd.addCmdOption(&filelist_class0);  	cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data");  	cmd.addText("\nModel Options:");  	cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden units");  	cmd.addText("\nLearning Options:");  	cmd.addICmdOption("-iter", &max_iter, 25, "max number of iterations");  	cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning rate");  	cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy");  	cmd.addRCmdOption("-lrd", &decay, 0, "learning rate decay");  	cmd.addICmdOption("-kfold", &k_fold, -1, "number of folds, if you want to do cross-validation");  	cmd.addRCmdOption("-wd", &weight_decay, 0, "weight decay");  	cmd.addBCmdOption("-mse", &use_mse, false, "use MSE criterion");  	cmd.addBCmdOption("-nll", &use_nll, false, "use NLL criterion");  	cmd.addBCmdOption("-linear", &use_linear_output, false, "use linear output (tanh otherwise)");  	cmd.addRCmdOption("-target", &the_target, 0.6, "the target value (overided if NLL)");  	cmd.addText("\nMisc Options:");  	cmd.addICmdOption("-seed", &the_seed, -1, "the random seed");  	cmd.addICmdOption("-load", &max_load, -1, "max number of examples to load for train");  	cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");  	cmd.addSCmdOption("-save", &model_file, "", "the model file");  	cmd.addText("\nImage Options:");  	cmd.addICmdOption("-width", &width_pattern, 19, "the width of the pattern");  	cmd.addICmdOption("-height", &height_pattern, 19, "the height of the pattern");  	cmd.addBCmdOption("-imagenorm", &image_normalize, false, "considers the input pattern as an image and performs a photometric normalization");  	// Read the command line  	cmd.read(argc, argv);	//	if(use_mse == use_nll) error("choose between MSE or NLL criterion.");	//	if(use_mse) print("Using MSE criterion ...\n");	if(use_nll) print("Using NLL criterion ...\n");	if(image_normalize)	{		print("Perform photometric normalization ...\n");		if(width_pattern * height_pattern != n_inputs) error("incorrect image size.");				print("The input pattern is an %dx%d image.\n", width_pattern, height_pattern);	}	//    	if(the_seed == -1) Random::seed();    	else Random::manualSeed((long)the_seed);  	//	cmd.setWorkingDirectory(dir_name);		//        print(" + class 1:\n");        print("   n_filenames = %d\n", filelist_class1.n_files);        for(int i = 0 ; i < filelist_class1.n_files ; i++)                print("   filename[%d] = %s\n", i, filelist_class1.file_names[i]);        print(" + class 0:\n");        print("   n_filenames = %d\n", filelist_class0.n_files);        for(int i = 0 ; i < filelist_class0.n_files ; i++)                print("   filename[%d] = %s\n", i, filelist_class0.file_names[i]);  	// Create the MLP	MyMLP *mlp = NULL;	if(n_hu != 0)	{		if(use_linear_output) mlp = new(allocator) MyMLP(2, n_inputs, "tanh", n_hu, "linear", 1);		else mlp = new(allocator) MyMLP(2, n_inputs, "tanh", n_hu, "tanh", 1);	}	else	{		if(use_linear_output) mlp = new(allocator) MyMLP(n_inputs, "linear", 1);		else mlp = new(allocator) MyMLP(n_inputs, "tanh", 1);	}	mlp->setWeightDecay(weight_decay);  	mlp->setPartialBackprop();	mlp->info();  	//  	// Create the training dataset (normalize inputs)  	MeanVarNorm *mv_norm = NULL;	 	if(use_nll) the_target = 1.0;    	FileBinDataSet *bindata = NULL;	bindata = new(allocator) FileBinDataSet(filelist_class1.file_names, filelist_class1.n_files, the_target, 			filelist_class0.file_names, filelist_class0.n_files, -the_target, n_inputs);		bindata->info(false);	if(image_normalize)	{		ipHistoEqual *enhancing = new(allocator) ipHistoEqual(width_pattern, height_pattern, "float");		ipCore *smoothing = new(allocator) ipSmoothGaussian3(width_pattern, height_pattern, "gray", 0.25);				for(int i=0; i< bindata->n_examples; i++)		{	   		bindata->setExample(i);			enhancing->process(bindata->inputs);			smoothing->process(enhancing->seq_out);			for(int j = 0 ; j < width_pattern * height_pattern ; j++)				bindata->inputs->frames[0][j] = smoothing->seq_out->frames[0][j];		}	}    	mv_norm = new(allocator) MeanVarNorm(bindata);    	bindata->preProcess(mv_norm);     	// The list of measurers...  	MeasurerList measurers;  	// The class format  	TwoClassFormat *class_format = NULL;    	class_format = new(allocator) TwoClassFormat(bindata);  	// Measurers on the training dataset    	ClassMeasurer *class_meas = NULL;    	class_meas = new(allocator) ClassMeasurer(mlp->outputs, bindata, class_format, cmd.getXFile("classerror.measure"));    	measurers.addNode(class_meas);	// the measurer    	MSEMeasurer *mse_meas = NULL;    	NLLMeasurer *nll_meas = NULL;    		mse_meas = new(allocator) MSEMeasurer(mlp->outputs, bindata, cmd.getXFile("mse.measure"));    	measurers.addNode(mse_meas);    	if(use_nll) 	{		nll_meas = new(allocator) NLLMeasurer(mlp->outputs, bindata, cmd.getXFile("nll.measure"));		measurers.addNode(nll_meas);	}    	//=================== The Trainer ===============================    	// The criterion for the StochasticGradient (MSE criterion or NLL criterion)  	Criterion *criterion = NULL;	if(use_mse) criterion = new(allocator) MSECriterion(1);    	//if(use_nll) criterion = new(allocator) ClassNLLCriterion(class_format);    	if(use_nll) criterion = new(allocator) TwoClassNLLCriterion(0.0);  	// The Gradient Machine Trainer  	StochasticGradient trainer(mlp, criterion);    	trainer.setIOption("max iter", max_iter);    	trainer.setROption("end accuracy", accuracy);    	trainer.setROption("learning rate", learning_rate);    	trainer.setROption("learning rate decay", decay);	//  	// Print the number of parameter of the MLP (just for fun)	message("Number of parameters: %d", mlp->params->n_params);    	if(k_fold <= 0)    	{      		trainer.train(bindata, &measurers);          		if(strcmp(model_file, "")) mlp->save(model_file, mv_norm);    	}    	else    	{	   	print("Go go KFold.\n");		      		KFold k(&trainer, k_fold);      		k.crossValidate(bindata, NULL, &measurers);    	}  	delete allocator;	return(0);}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -