⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mlp.cc

📁 torch tracking code, it is a good code
💻 CC
字号:
const char *help = "\progname: mlp.cc\n\code2html: This program trains a MLP for 2 class classification.\n\version: Torch3 vision2.1, 2003-2006\n\(c) Sebastien Marcel (marcel@idiap.ch)\n";/** Torch*/#include "Random.h"#include "DiskXFile.h"// datasets#include "MatDataSet.h"// criterions#include "ClassNLLCriterion.h"#include "TwoClassNLLCriterion.h"#include "MSECriterion.h"// class formats#include "TwoClassFormat.h"// measurers#include "ClassMeasurer.h"#include "MSEMeasurer.h"#include "NLLMeasurer.h"// trainers#include "StochasticGradient.h"#include "KFold.h"// command-lines#include "CmdLine.h"#include "FileListCmdOption.h"// string utils#include "string_utils.h"/** Torch3vision*/// datasets#include "FileBinDataSet.h"// custommachines#include "MyMLP.h"#include "MeanVarNorm.h"#include "MyMeanVarNorm.h"// image processing#include "ipHistoEqual.h"#include "ipSmoothGaussian3.h"using namespace Torch;// check data filesint checkFiles(int n_files, char **file_names);//bool verbose;int main(int argc, char **argv){   	//  	int n_inputs;	//	real the_target;	//  	int n_hu;		//	int width_pattern;	int height_pattern;	//  	int max_load;  	int the_seed;	//  	real accuracy;  	real learning_rate;  	real decay;  	int max_iter;  	int k_fold;  	real weight_decay;	//	bool use_mse;	bool use_nll;	bool use_linear_output;	bool image_normalize;	//  	char *dir_name;  	char *model_file;	char *output_file;	//  	Allocator *allocator = new Allocator;  	DiskXFile::setLittleEndianMode();  	//=================== The command-line ==========================        FileListCmdOption filelist_class1("file name", "the list files or one data file of positive patterns");        filelist_class1.isArgument(true);          FileListCmdOption filelist_class0("file name", "the list files or one data file of negative patterns");        filelist_class0.isArgument(true);        FileListCmdOption filelist_class("file name", "the list files or one data file of patterns");        filelist_class.isArgument(true);  	// Construct the command line  	CmdLine cmd;	cmd.setBOption("write log", false);  	// Put the help line at the beginning  	cmd.info(help);  	// Train mode  	cmd.addText("\nArguments:");	cmd.addCmdOption(&filelist_class1);        cmd.addCmdOption(&filelist_class0);  	cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data");  	cmd.addText("\nModel Options:");  	cmd.addICmdOption("-nhu", &n_hu, 25, "number of hidden units");  	cmd.addText("\nLearning Options:");  	cmd.addICmdOption("-iter", &max_iter, 25, "max number of iterations");  	cmd.addRCmdOption("-lr", &learning_rate, 0.01, "learning rate");  	cmd.addRCmdOption("-e", &accuracy, 0.00001, "end accuracy");  	cmd.addRCmdOption("-lrd", &decay, 0, "learning rate decay");  	cmd.addICmdOption("-kfold", &k_fold, -1, "number of folds, if you want to do cross-validation");  	cmd.addRCmdOption("-wd", &weight_decay, 0, "weight decay");  	cmd.addBCmdOption("-mse", &use_mse, false, "use MSE criterion");  	cmd.addBCmdOption("-nll", &use_nll, false, "use NLL criterion");  	cmd.addBCmdOption("-linear", &use_linear_output, false, "use linear output (tanh otherwise)");  	cmd.addRCmdOption("-target", &the_target, 0.6, "the target value (overrided to 1 if the NLL criterion is chosen)");  	cmd.addText("\nMisc Options:");  	cmd.addICmdOption("-seed", &the_seed, -1, "the random seed");  	cmd.addICmdOption("-load", &max_load, -1, "max number of examples to load for train");  	cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");  	cmd.addSCmdOption("-save", &model_file, "", "the model file");	cmd.addBCmdOption("-verbose", &verbose, false, "verbose");  	cmd.addText("\nImage Options:");  	cmd.addICmdOption("-width", &width_pattern, 64, "the width of the pattern");  	cmd.addICmdOption("-height", &height_pattern, 80, "the height of the pattern");  	cmd.addBCmdOption("-imagenorm", &image_normalize, false, "considers the input pattern as an image and performs a photometric normalization");	// Test mode	cmd.addMasterSwitch("--test");  	cmd.addText("\nArguments:");  	cmd.addSCmdArg("model", &model_file, "the model file");	cmd.addCmdOption(&filelist_class);	cmd.addSCmdArg("output file", &output_file, "the file to save the output of the MLP according to this format: <file name> <mlp output>");  	cmd.addICmdArg("n_inputs", &n_inputs, "input dimension of the data");  	cmd.addText("\nMisc Options:");  	cmd.addSCmdOption("-dir", &dir_name, ".", "directory to save measures");  	cmd.addBCmdOption("-verbose", &verbose, false, "verbose");  	cmd.addText("\nImage Options:");  	cmd.addICmdOption("-width", &width_pattern, 64, "the width of the pattern");  	cmd.addICmdOption("-height", &height_pattern, 80, "the height of the pattern");  	cmd.addBCmdOption("-imagenorm", &image_normalize, false, "considers the input pattern as an image and performs a photometric normalization");		// Read mode	cmd.addMasterSwitch("--read");  	cmd.addText("\nArguments:");  	cmd.addSCmdArg("model", &model_file, "the model file");  	cmd.addText("\nOptions:");  	cmd.addBCmdOption("-verbose", &verbose, false, "verbose");  	// Read the command line  	int mode = cmd.read(argc, argv);	// Training mode	if(mode == 0)	{		if(verbose) print("Training mode\n");			//		if(use_mse == use_nll) error("choose between MSE or NLL criterion.");		//		if(use_mse) if(verbose) print("Using MSE criterion ...\n");		if(use_nll) if(verbose) print("Using NLL criterion ...\n");		if(image_normalize)		{			if(verbose) print("Perform photometric normalization ...\n");			if(width_pattern * height_pattern != n_inputs) error("incorrect image size.");					if(verbose) print("The input pattern is an %dx%d image.\n", width_pattern, height_pattern);		}		//	    	if(the_seed == -1) Random::seed();    		else Random::manualSeed((long)the_seed);  		//		cmd.setWorkingDirectory(dir_name);			// check data files		int n_reminding_files_class1 = filelist_class1.n_files;		n_reminding_files_class1 = checkFiles(filelist_class1.n_files, filelist_class1.file_names);		int n_reminding_files_class0 = filelist_class0.n_files;		n_reminding_files_class0 = checkFiles(filelist_class0.n_files, filelist_class0.file_names);		//        	if(verbose)		{			print(" + class 1:\n");		        print("   n_filenames = %d\n", n_reminding_files_class1);        		for(int i = 0 ; i < n_reminding_files_class1 ; i++)                		print("   filename[%d] = %s\n", i, filelist_class1.file_names[i]);		        print(" + class 0:\n");        		print("   n_filenames = %d\n", n_reminding_files_class0);	        	for(int i = 0 ; i < n_reminding_files_class0 ; i++)        	        	print("   filename[%d] = %s\n", i, filelist_class0.file_names[i]);		}	  	// Create the MLP		MyMLP *mlp = NULL;		if(n_hu != 0)		{			if(use_linear_output) mlp = new(allocator) MyMLP(2, n_inputs, "tanh", n_hu, "linear", 1);			else mlp = new(allocator) MyMLP(2, n_inputs, "tanh", n_hu, "tanh", 1);		}		else		{			if(use_linear_output) mlp = new(allocator) MyMLP(n_inputs, "linear", 1);			else mlp = new(allocator) MyMLP(n_inputs, "tanh", 1);		}		mlp->setWeightDecay(weight_decay);	  	mlp->setPartialBackprop();		if(verbose) mlp->info();	  	//  		// Create the training dataset (normalize inputs)	  	MeanVarNorm *mv_norm = NULL;	 		if(use_nll) the_target = 1.0;	    	FileBinDataSet *bindata = NULL;		bindata = new(allocator) FileBinDataSet(filelist_class1.file_names, n_reminding_files_class1, the_target, 							filelist_class0.file_names, n_reminding_files_class0, -the_target, n_inputs);			if(verbose) bindata->info(false);		if(image_normalize)		{			ipHistoEqual *enhancing = new(allocator) ipHistoEqual(width_pattern, height_pattern, "float");			ipCore *smoothing = new(allocator) ipSmoothGaussian3(width_pattern, height_pattern, "gray", 0.25);					for(int i=0; i< bindata->n_examples; i++)			{				bindata->setExample(i);				enhancing->process(bindata->inputs);				smoothing->process(enhancing->seq_out);				for(int j = 0 ; j < width_pattern * height_pattern ; j++)					bindata->inputs->frames[0][j] = smoothing->seq_out->frames[0][j];			}		}    		mv_norm = new(allocator) MeanVarNorm(bindata);    		bindata->preProcess(mv_norm);     		// The list of measurers...  		MeasurerList measurers;  		// The class format  		TwoClassFormat *class_format = NULL;    		class_format = new(allocator) TwoClassFormat(bindata);  		// Measurers on the training dataset    		ClassMeasurer *class_meas = NULL;    		class_meas = new(allocator) ClassMeasurer(mlp->outputs, bindata, class_format, cmd.getXFile("classerror.measure"));    		measurers.addNode(class_meas);		// the measurer    		MSEMeasurer *mse_meas = NULL;    		NLLMeasurer *nll_meas = NULL;    			mse_meas = new(allocator) MSEMeasurer(mlp->outputs, bindata, cmd.getXFile("mse.measure"));    		measurers.addNode(mse_meas);    		if(use_nll) 		{			nll_meas = new(allocator) NLLMeasurer(mlp->outputs, bindata, cmd.getXFile("nll.measure"));			measurers.addNode(nll_meas);		}    		//=================== The Trainer ===============================    		// The criterion for the StochasticGradient (MSE criterion or NLL criterion)  		Criterion *criterion = NULL;		if(use_mse) criterion = new(allocator) MSECriterion(1);    		//if(use_nll) criterion = new(allocator) ClassNLLCriterion(class_format);    		if(use_nll) criterion = new(allocator) TwoClassNLLCriterion(0.0);	  	// The Gradient Machine Trainer  		StochasticGradient trainer(mlp, criterion);    		trainer.setIOption("max iter", max_iter);    		trainer.setROption("end accuracy", accuracy);	    	trainer.setROption("learning rate", learning_rate);    		trainer.setROption("learning rate decay", decay);		//  		// Print the number of parameter of the MLP (just for fun)		if(verbose) message("Number of parameters: %d", mlp->params->n_params);    		if(k_fold <= 0)	    	{      			trainer.train(bindata, &measurers);    	      		if(strcmp(model_file, "")) mlp->save(model_file, mv_norm);    		}	    	else    		{		   	if(verbose) print("Go go KFold.\n");		      			KFold k(&trainer, k_fold);      			k.crossValidate(bindata, NULL, &measurers);	    	}	}	// Test mode	if(mode == 1)	{		if(verbose) print("Test mode\n");			//		ipHistoEqual *enhancing = NULL;		ipCore *smoothing = NULL;			if(image_normalize)		{			if(verbose) print("Perform photometric normalization ...\n");			if(width_pattern * height_pattern != n_inputs) error("incorrect image size.");					if(verbose) print("The input pattern is an %dx%d image.\n", width_pattern, height_pattern);			enhancing = new(allocator) ipHistoEqual(width_pattern, height_pattern, "float");			smoothing = new(allocator) ipSmoothGaussian3(width_pattern, height_pattern, "gray", 0.25);		}		// check data files		int n_reminding_files = filelist_class.n_files;		n_reminding_files = checkFiles(filelist_class.n_files, filelist_class.file_names);		//        	if(verbose)		{			print("n_filenames = %d\n", n_reminding_files);		        for(int i = 0 ; i < n_reminding_files ; i++)        		        print("   filename[%d] = %s\n", i, filelist_class.file_names[i]);		}	  	//		MyMLP mlp;	  	// Create the training dataset (normalize inputs)  		MyMeanVarNorm *mv_norm = NULL;  	    	mv_norm = new(allocator) MyMeanVarNorm(n_inputs, 1);   		mlp.load(model_file, mv_norm);		if(verbose) mlp.info();		DiskXFile output_xfile(output_file,"w");		DataSet *bindata;		for(int i = 0 ; i < n_reminding_files ; i++)		{		   	if(verbose) print("testing file %s\n", filelist_class.file_names[i]);						real sum = 0.0;						FileBinDataSet *bindata_ = new(allocator) FileBinDataSet(filelist_class.file_names[i], n_inputs);			if(verbose) bindata_->info(false);				bindata = bindata_;					for(int t = 0; t < bindata->n_examples; t++)			{				bindata->setExample(t);				if(image_normalize)				{					enhancing->process(bindata->inputs);					smoothing->process(enhancing->seq_out);					mv_norm->preProcessInputs(smoothing->seq_out);							mlp.forward(smoothing->seq_out);				}				else				{					mv_norm->preProcessInputs(bindata->inputs);							mlp.forward(bindata->inputs);				}						real output = mlp.outputs->frames[0][0];				//print("  -> %g\n", output);				sum += output;			}			sum /= (real) bindata->n_examples;			char *temp = strBaseName(filelist_class.file_names[i]);			char *basename = strRemoveSuffix(temp);			allocator->retain(basename);		   				if(verbose) print("> %s %g\n", basename, sum);						output_xfile.printf("%s %g\n", basename, sum);			allocator->free(bindata_);		}	}	// Read mode	if(mode == 2)	{		if(verbose) print("Read mode\n");			MyMLP mlp;		mlp.load(model_file, NULL);		mlp.info();	}  	delete allocator;	return(0);}#include <sys/stat.h>int checkFiles(int n_files, char **file_names){	int reminding_files = n_files;		struct stat st;	if(verbose) print("Checking files:\n");	int i = 0;	//for(int i = 0 ; i < n_files ; i++)	while(i < reminding_files)	{		if(verbose) print("Checking %s\n", file_names[i]);		if(stat(file_names[i], &st) == -1)		{			warning("Couldn't stat file %s.", file_names[i]);			for(int j = i ; j < reminding_files-1 ; j++) file_names[j] = file_names[j+1];			file_names[reminding_files-1] = NULL;			reminding_files--;		}		else		{			if(!S_ISREG (st.st_mode))			{				warning("not regular file %s.", file_names[i]);				for(int j = i ; j < reminding_files-1 ; j++) file_names[j] = file_names[j+1];				file_names[reminding_files-1] = NULL;				reminding_files--;			}			else i++;		}	}		if(verbose)	{		print("Checked files (%d):\n", reminding_files);		for(int i = 0 ; i < reminding_files ; i++)			print("-> %s\n", file_names[i]);	}	return reminding_files;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -