📄 .#networkcontroller.java.1.28
字号:
} else { //NeuralMath.printMatrix(completeData); System.out.println("Numero de sinapses: " + neuralNetwork.numberOfSynapses()); if ( isNormalizer ) { inputNormalization.setupParameters(inputs); inputs = inputNormalization.normalize(inputs); if ( isFullNormalizer ) { outputNormalization.setupParameters(outputs); outputs = outputNormalization.normalize(outputs); } System.out.println("normalizei"); } //NeuralMath.printMatrix(inputs); //NeuralMath.printMatrix(outputs); param.setInputs(inputs); param.setOutputs(outputs); results = trainingMethod.train(neuralNetwork, param); date = new Date(); isNNTrained = true; statisticalResults = results; return results; } } else { throw new NetworkControllerException(resources.getString("dataNotSplit")); } } } // trainNeuralNetwork() /** * Uses the existent neural network to infere the results using the available inputs. * One must first use the setInferenceData() first. * * @throws NetworkControllerException */ public void infereWithNetwork() throws NetworkControllerException { if (inferenceData == null) { throw new NetworkControllerException(resources.getString("inferenceDataNotLoaded")); } else { if ( inferenceData[0].length != neuralNetwork.getInputSize() ) { throw new NetworkControllerException(resources.getString("dataIncompatibleWithNetwork")); } else { inferedData = new double[inferenceData.length][neuralNetwork.getOutputSize()]; if ( isNormalizer ) { inferenceData = inputNormalization.normalize(inferenceData); } for (int i = 0; i < inferenceData.length; i++) { neuralNetwork.inputLayerSetup(inferenceData[i]); neuralNetwork.propagateInput(); inferedData[i] = neuralNetwork.retrieveFinalResults(); } if ( isFullNormalizer ) { inferedData = outputNormalization.unnormalize(inferedData); } double error = 0; double rightValues = 0; for (int i = 0; i < inferenceData.length; i++) { System.out.println(inferedData[i][0]); } } } } // infereWithNetwork() /** * Validate a network using the available data. * One must first use the setValidatationData() first. * * @return A class containing statistical data about the network validation. See TODO k * @throws NetworkControllerException */ public ValidationStatistics validateNetwork() throws NetworkControllerException { if (!isValidationDataLoaded) { throw new NetworkControllerException(resources.getString("validationDataNotLoaded")); } else { if ( isDataSplit ) { if ( inputs[0].length != neuralNetwork.getInputSize() ) { throw new NetworkControllerException(resources.getString("dataIncompatibleWithNetwork")); } else { ValidationStatistics result = new ValidationStatistics(); double rightValues = 0; inferedData = new double[outputs.length][outputs[0].length]; if ( isNormalizer ) { inputs = inputNormalization.normalize(inputs); if ( isFullNormalizer ) { outputs = outputNormalization.normalize(outputs); } } for (int i = 0; i < inputs.length; i++) { neuralNetwork.inputLayerSetup(inputs[i]); neuralNetwork.propagateInput(); inferedData[i] = neuralNetwork.retrieveFinalResults(); } if ( isFullNormalizer ) { inferedData = outputNormalization.unnormalize(inferedData); outputs = outputNormalization.unnormalize(outputs); } // TODO Implement more techniques double error = 0; for (int i = 0; i < outputs.length; i++) { if ( Math.abs(inferedData[i][0] - outputs[i][0]) < Math.sqrt(neuralNetwork.getError()) ) { rightValues++; } error = error + Math.abs(inferedData[i][0] - outputs[i][0]); System.out.println(inferedData[i][0] + " " +outputs[i][0]); } error = error / outputs.length; System.out.println(">>> Erro quadratico m??dio TR : " + neuralNetwork.getError() + "Sqrt erro: " + Math.sqrt(neuralNetwork.getError())); System.out.println(">>> Erro medio da valida????o : " + error); System.out.println(">>> Acertos : " + rightValues/outputs.length); return result; } } else { throw new NetworkControllerException(resources.getString("dataNotSplit")); } } } // validateNetwork() /** * Implements validation techniques. * * @param validationType The validation technique to be used. * * @param param Paremeters for the training method. */ public ValidationStatistics validateNetwork(String validationType, TrainingParameters param) throws NetworkControllerException { if (completeData == null) { throw new NetworkControllerException(resources.getString("trainingDataNotLoaded")); } else { Validation validation; try { validation = (Validation) Class .forName(VALIDATION_PACKAGE + validationType) .newInstance(); param.setInputs(inputs); param.setOutputs(outputs); return validation.validate(neuralNetwork, initialization, inputNormalization, outputNormalization, trainingMethod, param); } catch (InstantiationException e1) { throw new NetworkControllerException(e1); } catch (IllegalAccessException e1) { throw new NetworkControllerException(e1); } catch (ClassNotFoundException e1) { throw new NetworkControllerException(e1); } } } // validateNetwork() /** * * @return The neural model this controller represents. See TODO */ public NeuralModel getNeuralModel() { NeuralModel nm = new NeuralModel(); nm.setDate(date); nm.setInputNormalization(inputNormalization); nm.setOutputNormalization(outputNormalization); nm.setNeuralNetwork(neuralNetwork); nm.setMethodInfo(methodInfo); nm.setNetworkArchitecture(networkArchitecture); nm.setStatisticalResults(statisticalResults); return nm; } // getNeuralModel() /** * Sets the training data to be used by the NetworkController. * It's use is mandatory to train a neural network. * * @param data The training data. */ public void setTrainingData(double[][] data) { completeData = data; } // setTrainingData() /** * Sets the inference data to be used by the NetworkController. * It's use is mandatory to infere with a neural network. * * @param data The inference data. */ public void setInferenceData(double[][] data) { inferenceData = data; } // setInferenceData() /** * Sets the validationdata to be used by the NetworkController. * It's use is mandatory to validate a neural network. * * @param data The validation data. */ public void setValidationData(double[][] data) { completeData = data; isValidationDataLoaded = true; isDataSplit = false; } // setValidationData() /** * Splits the training or the validation data into the desired inputs and outputs. */ private void splitDataValues(double[][] data) throws NetworkControllerException { if (this.isOutputIndexSet) { if ( data != null ) { outputs = new double[data.length][outputIndex.length]; inputs = new double[data.length][completeData[0].length - outputIndex.length]; int index = -1, columnIndex = 0; int[] indexes = new int[data[0].length - outputIndex.length]; for (int m=0; m < data[0].length; m++) { boolean isIn = false; for (int k = 0; k < outputIndex.length; k++ ) { index = outputIndex[k]; if ( m == index ) { isIn=true; } } if (isIn == false) { indexes[columnIndex] = m; columnIndex++; } } for (int k = 0; k < indexes.length; k++ ) { index = indexes[k]; for(int i = 0; i < data.length; i++) { inputs[i][k] = data[i][index]; } } for (int k = 0; k < outputIndex.length; k++ ) { index = outputIndex[k]; for(int i = 0; i < data.length; i++) { outputs[i][k] = data[i][index]; } } isDataSplit = true; } else { throw new NetworkControllerException(resources.getString("dataIsNull")); } } else { throw new NetworkControllerException(resources.getString("outputIndexNotSet")); } } //splitDataValues() /** * Splits the training or the validation data into the desired inputs and outputs. * One must first set the output indexes or an exception will be thrown. * * @throws NetworkControllerException */ public void splitDataValues() throws NetworkControllerException { splitDataValues(this.completeData); } // splitDataValues() /** * @return Returns the inferenceData. */ public double[][] getInferenceData() { return inferenceData; } // getInferenceData() /** * @return Returns the inputs. */ public double[][] getInputs() { return inputs; } // getInputs() /** * @return Returns the otputs. */ public double[][] getOutputs() { return outputs; } // getOutputs() /** * @return Returns the trainingData. */ public double[][] getTrainingData() { return completeData; } // getTrainingData() /** * @return Returns the inferedData. */ public double[][] getInferedData() { return inferedData; } // getInferedData() /** * @return Returns the network training error. */ public double getError() { return neuralNetwork.getError(); } // getError() /** * @return Returns the networkArchitecture. */ public String getNetworkType() { return networkArchitecture; } // getNetworkType() /** * @return Returns the date. */ public Date getDate() { return date; } // getDate() } // NetworkController()
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -