📄 crossvalidation.java
字号:
/* * Created on 23/10/2004 * * TODO To change the template for this generated file go to * Window - Preferences - Java - Code Generation - Code and Comments */package neuralnetworktoolkit.validation;import neuralnetworktoolkit.*;import neuralnetworktoolkit.methods.*;import neuralnetworktoolkit.neuralnetwork.*;import neuralnetworktoolkit.neuralnetwork.weightinitialization.WeightInitialization;import neuralnetworktoolkit.normalization.*;/** * @author iver * * TODO To change the template for this generated type comment go to * Window - Preferences - Java - Code Generation - Code and Comments */public class CrossValidation extends Validation { /* (non-Javadoc) * @see neuralnetworktoolkit.validation.Validation#validate(neuralnetworktoolkit.neuralnetwork.INeuralNetwork) */ public ValidationStatistics validate(INeuralNetwork nn, WeightInitialization initialization, INormalization inputNormalization, INormalization outputNormalization, ITrainingMethod trainingMethod, TrainingParameters param) { ValidationStatistics result = new ValidationStatistics(); StatisticalResults trainStat; double error = 0; double trainError; double totalError = 0; double[][] inputs = param.getInputs(); double[][] outputs = param.getOutputs(); double[][] instantInputs = new double[inputs.length-1][inputs[0].length]; double[][] instantOutputs = new double[inputs.length-1][inputs[0].length]; double[][] results = new double[1][outputs[0].length]; double[][] instantOutput = new double[1][outputs[0].length]; int index = 0; if ( inputNormalization != null ) { System.out.println("Normalizei inputs!!!"); inputs = inputNormalization.normalize(inputs); if ( outputNormalization != null ) { System.out.println("Normalizei outputs!!!"); outputs = outputNormalization.normalize(outputs); } } for (int i = 0; i < inputs.length; i++) { System.out.println("Round " + (i+1)+ "---------------------------------------------------"); for (int j = 0; j < inputs.length; j++) { if (i!=j) { //System.out.println("Index: " + index); instantInputs[index] = inputs[j]; instantOutputs[index] = outputs[j]; index++; } } System.out.println(">>> Treinando"); initialization.initialize(nn); param.setInputs(inputs); param.setOutputs(outputs); trainStat = trainingMethod.train(nn, param); System.out.println(">>> Erro: "+trainStat.getError()); System.out.println(">>>>>>>>>>>>>"); nn.inputLayerSetup(inputs[i]); nn.propagateInput(); //System.out.println("entrada 1: "+ inputs[i][0] + " entrada 2: "+inputs[i][1]); results[0] = nn.retrieveFinalResults(); if ( outputNormalization != null ) { results = outputNormalization.unnormalize(results); instantOutput[0] = outputs[i]; instantOutput = outputNormalization.unnormalize(instantOutput); } for (int k = 0; k < nn.retrieveFinalResults().length; k++) { System.out.println("Saída rede: " + results[0][k] + " Original: " + instantOutput[0][k] + " Erro " + Math.abs( instantOutput[0][k] - results[0][k] ) ); error = error + ( Math.abs( instantOutput[0][k] - results[0][k] ) ); } totalError = totalError +error; error = 0; index = 0; System.out.println("----------------------------------------------------------"); } result.setTotalError(totalError/inputs.length); return result; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -