⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crossvalidation.java

📁 利用Java实现的神经网络工具箱
💻 JAVA
字号:
/* * Created on 23/10/2004 * * TODO To change the template for this generated file go to * Window - Preferences - Java - Code Generation - Code and Comments */package neuralnetworktoolkit.validation;import neuralnetworktoolkit.*;import neuralnetworktoolkit.methods.*;import neuralnetworktoolkit.neuralnetwork.*;import neuralnetworktoolkit.neuralnetwork.weightinitialization.WeightInitialization;import neuralnetworktoolkit.normalization.*;/** * @author iver * * TODO To change the template for this generated type comment go to * Window - Preferences - Java - Code Generation - Code and Comments */public class CrossValidation extends Validation {	/* (non-Javadoc)	 * @see neuralnetworktoolkit.validation.Validation#validate(neuralnetworktoolkit.neuralnetwork.INeuralNetwork)	 */	public ValidationStatistics validate(INeuralNetwork nn, WeightInitialization initialization, INormalization inputNormalization, INormalization outputNormalization, ITrainingMethod trainingMethod, TrainingParameters param) {		ValidationStatistics result = new ValidationStatistics();		StatisticalResults trainStat;		double error = 0;		double trainError;		double totalError = 0;		double[][] inputs = param.getInputs();		double[][] outputs = param.getOutputs();		double[][] instantInputs = new double[inputs.length-1][inputs[0].length];		double[][] instantOutputs = new double[inputs.length-1][inputs[0].length];		double[][] results = new double[1][outputs[0].length]; 		double[][] instantOutput = new double[1][outputs[0].length];		int index = 0;				if ( inputNormalization != null ) {			System.out.println("Normalizei inputs!!!");			inputs = inputNormalization.normalize(inputs);			if ( outputNormalization != null ) {				System.out.println("Normalizei outputs!!!");				outputs = outputNormalization.normalize(outputs);			}		}				for (int i = 0; i < inputs.length; i++) {			System.out.println("Round " + (i+1)+ "---------------------------------------------------");						for (int j = 0; j < inputs.length; j++) {				if (i!=j) {					//System.out.println("Index: " + index);					instantInputs[index] = inputs[j];					instantOutputs[index] = outputs[j];					index++;				}			}			System.out.println(">>> Treinando");			initialization.initialize(nn);			param.setInputs(inputs);			param.setOutputs(outputs);			trainStat = trainingMethod.train(nn, param);			System.out.println(">>> Erro: "+trainStat.getError());			System.out.println(">>>>>>>>>>>>>");			nn.inputLayerSetup(inputs[i]);			nn.propagateInput();			//System.out.println("entrada 1: "+ inputs[i][0] + " entrada 2: "+inputs[i][1]);						results[0] = nn.retrieveFinalResults();						if ( outputNormalization != null ) {				results = outputNormalization.unnormalize(results);				instantOutput[0] = outputs[i];				instantOutput = outputNormalization.unnormalize(instantOutput);			}						for (int k = 0; k < nn.retrieveFinalResults().length; k++) {				System.out.println("Saída rede: " + results[0][k] + " Original: " + instantOutput[0][k] + " Erro " + Math.abs( instantOutput[0][k] - results[0][k] ) );				error = error + ( Math.abs( instantOutput[0][k] - results[0][k] ) );			}						totalError = totalError +error;			error = 0;			index = 0;									System.out.println("----------------------------------------------------------");		}				result.setTotalError(totalError/inputs.length);				return result;	}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -