📄 networkcontroller.java
字号:
inferedData = new double[inferenceData.length][neuralNetwork.getOutputSize()]; if ((isNormalizer) && (!isInfered)) { inferenceData = inputNormalization.normalize(inferenceData); } for (int i = 0; i < inferenceData.length; i++) { neuralNetwork.inputLayerSetup(inferenceData[i]); neuralNetwork.propagateInput(); inferedData[i] = neuralNetwork.retrieveFinalResults(); } if ( isFullNormalizer ) { inferedData = outputNormalization.unnormalize(inferedData); } isInfered = true; /* double error = 0; double rightValues = 0; for (int i = 0; i < inferenceData.length; i++) { System.out.println(inferedData[i][0]); } */ } } } // infereWithNetwork() /** * Validate a network using the available data. * One must first use the setValidatationData() first. * * @return A class containing statistical data about the network validation. See TODO k * @throws NetworkControllerException */ public ValidationStatistics validateNetwork() throws NetworkControllerException { if (!isValidationDataLoaded) { throw new NetworkControllerException(resources.getString("validationDataNotLoaded")); } else { if ( isDataSplit ) { if ( inputs[0].length != neuralNetwork.getInputSize() ) { throw new NetworkControllerException(resources.getString("dataIncompatibleWithNetwork")); } else { ValidationStatistics result = new ValidationStatistics(); double rightValues = 0; inferedData = new double[outputs.length][outputs[0].length]; if ( isNormalizer ) { inputs = inputNormalization.normalize(inputs); if ( isFullNormalizer ) { outputs = outputNormalization.normalize(outputs); } } for (int i = 0; i < inputs.length; i++) { neuralNetwork.inputLayerSetup(inputs[i]); neuralNetwork.propagateInput(); inferedData[i] = neuralNetwork.retrieveFinalResults(); } if ( isFullNormalizer ) { inferedData = outputNormalization.unnormalize(inferedData); outputs = outputNormalization.unnormalize(outputs); } // TODO Implement more techniques double error = 0; for (int i = 0; i < outputs.length; i++) { if ( Math.abs(inferedData[i][0] - outputs[i][0]) < Math.sqrt(neuralNetwork.getError()) ) { rightValues++; } error = error + Math.abs(inferedData[i][0] - outputs[i][0]); System.out.println(inferedData[i][0] + " " +outputs[i][0]); } error = error / outputs.length; System.out.println(">>> Erro quadratico m??dio TR : " + neuralNetwork.getError() + "Sqrt erro: " + Math.sqrt(neuralNetwork.getError())); System.out.println(">>> Erro medio da valida????o : " + error); System.out.println(">>> Acertos : " + rightValues/outputs.length); return result; } } else { throw new NetworkControllerException(resources.getString("dataNotSplit")); } } } // validateNetwork() /** * Implements validation techniques. * * @param validationType The validation technique to be used. * * @param param Paremeters for the training method. */ public ValidationStatistics validateNetwork(String validationType, TrainingParameters param) throws NetworkControllerException { if (completeData == null) { throw new NetworkControllerException(resources.getString("trainingDataNotLoaded")); } else { Validation validation; try { validation = (Validation) Class .forName(VALIDATION_PACKAGE + validationType) .newInstance(); param.setInputs(inputs); param.setOutputs(outputs); return validation.validate(neuralNetwork, initialization, inputNormalization, outputNormalization, trainingMethod, param); } catch (InstantiationException e1) { throw new NetworkControllerException(e1); } catch (IllegalAccessException e1) { throw new NetworkControllerException(e1); } catch (ClassNotFoundException e1) { throw new NetworkControllerException(e1); } } } // validateNetwork() /** * * @return The neural model this controller represents. See TODO */ public NeuralModel getNeuralModel() { NeuralModel nm = new NeuralModel(); nm.setDate(date); nm.setInputNormalization(inputNormalization); nm.setOutputNormalization(outputNormalization); nm.setNeuralNetwork(neuralNetwork); nm.setMethodInfo(methodInfo); nm.setNetworkArchitecture(networkArchitecture); nm.setStatisticalResults(statisticalResults); return nm; } // getNeuralModel() /** * Sets the training data to be used by the NetworkController. * It's use is mandatory to train a neural network. * * @param data The training data. */ public void setTrainingData(double[][] data) { completeData = data; } // setTrainingData() /** * Sets the inference data to be used by the NetworkController. * It's use is mandatory to infere with a neural network. * * @param data The inference data. */ public void setInferenceData(double[][] data) { isInfered = false; if ((!isFromModel) && (data[0].length == inputs[0].length)) { inferenceData = data; } else if((!isFromModel) && (data[0].length == completeData[0].length)) { inferenceData = new double[data.length][completeData[0].length - outputIndex.length]; int index = -1, columnIndex = 0; int[] indexes = new int[data[0].length - outputIndex.length]; for (int m = 0; m < data[0].length; m++) { boolean isIn = false; for (int k = 0; k < outputIndex.length; k++) { index = outputIndex[k]; if (m == index) { isIn = true; } } if (isIn == false) { indexes[columnIndex] = m; columnIndex++; } } for (int k = 0; k < indexes.length; k++) { index = indexes[k]; for (int i = 0; i < data.length; i++) { inferenceData[i][k] = data[i][index]; } } } else { inferenceData = data; } } // setInferenceData() /** * Sets the validationdata to be used by the NetworkController. * It's use is mandatory to validate a neural network. * * @param data The validation data. */ public void setValidationData(double[][] data) { completeData = data; isValidationDataLoaded = true; isDataSplit = false; } // setValidationData() /** * Splits the training or the validation data into the desired inputs and outputs. */ private void splitDataValues(double[][] data) throws NetworkControllerException { if (this.isOutputIndexSet) { if ( data != null ) { outputs = new double[data.length][outputIndex.length]; inputs = new double[data.length][completeData[0].length - outputIndex.length]; outputsDataHeader = new String[outputs[0].length]; inferenceDataHeader = new String[inputs[0].length]; inputsDataHeader = new String[inputs[0].length]; inferedDataHeader = new String[outputs[0].length]; int index = -1, columnIndex = 0; int[] indexes = new int[data[0].length - outputIndex.length]; for (int m=0; m < data[0].length; m++) { boolean isIn = false; for (int k = 0; k < outputIndex.length; k++ ) { index = outputIndex[k]; if ( m == index ) { isIn=true; } } if (isIn == false) { indexes[columnIndex] = m; columnIndex++; } } for (int k = 0; k < indexes.length; k++ ) { index = indexes[k]; inputsDataHeader[k] = completeDataHeader[index]; for(int i = 0; i < data.length; i++) { inputs[i][k] = data[i][index]; } } for (int k = 0; k < outputIndex.length; k++ ) { index = outputIndex[k]; outputsDataHeader[k] = completeDataHeader[index]; for(int i = 0; i < data.length; i++) { outputs[i][k] = data[i][index]; } } inferenceDataHeader = inputsDataHeader; inferedDataHeader = outputsDataHeader; isDataSplit = true; } else { throw new NetworkControllerException(resources.getString("dataIsNull")); } } else { throw new NetworkControllerException(resources.getString("outputIndexNotSet")); } } //splitDataValues() /** * Splits the training or the validation data into the desired inputs and outputs. * One must first set the output indexes or an exception will be thrown. * * @throws NetworkControllerException */ public void splitDataValues() throws NetworkControllerException { splitDataValues(this.completeData); } // splitDataValues() /** * @return Returns the inferenceData. */ public double[][] getInferenceData() { double[][] data; if(isInfered && isNormalizer) { data = inputNormalization.unnormalize(inferenceData); } else { data = inferenceData; } return data; } // getInferenceData() /** * @return Returns the inputs. */ public double[][] getInputs() { return inputs; } // getInputs() /** * * @return */ public int getInputSize() { return neuralNetwork.getInputSize(); } //getInputSize() /** * @return Returns the otputs. */ public double[][] getOutputs() { return outputs; } // getOutputs() /** * @return Returns the trainingData. */ public double[][] getTrainingData() { return completeData; } // getTrainingData() /** * @return Returns the inferedData. */ public double[][] getInferedData() { return inferedData; } // getInferedData() /** * @return Returns the network training error. */ public double getError() { return neuralNetwork.getError(); } // getError() /** * @return Returns the networkArchitecture. */ public String getNetworkType() { return networkArchitecture; } // getNetworkType() /** * @return Returns the date. */ public Date getDate() { return date; } // getDate() /** * @return Returns the completeDataHeader. */ public String[] getCompleteDataHeader() { return completeDataHeader; } //getCompleteDataHeader() /** * @param completeDataHeader The completeDataHeader to set. */ public void setCompleteDataHeader(String[] completeDataHeader) { this.completeDataHeader = completeDataHeader; } //setCompleteDataHeader() /** * @return Returns the inferedDataHeader. */ public String[] getInferedDataHeader() { String[] header; if(!isFromModel) { header = inferedDataHeader; } else { header = new String[inferedData[0].length]; } return header; } //getInferedDataHeader() /** * @param inferedDataHeader The inferedDataHeader to set. */ public void setInferedDataHeader(String[] inferedDataHeader) { this.inferedDataHeader = inferedDataHeader; } //setInferedDataHeader() /** * @return Returns the inferenceDataHeader. */ public String[] getInferenceDataHeader() { String[] header; if(!isFromModel) { header = inferenceDataHeader; } else { header = new String[neuralNetwork.getInputSize()]; } return header; } //getInferenceDataHeader() /** * @param inferenceDataHeader The inferenceDataHeader to set. */ public void setInferenceDataHeader(String[] inferenceDataHeader) { this.inferenceDataHeader = inferenceDataHeader; } //setInferenceDataHeader() /** * @return Returns the completeData. */ public double[][] getCompleteData() { return completeData; } //getCompleteData() /** * @return Returns the statisticalResults. */ public StatisticalResults getStatisticalResults() { return statisticalResults; } //getStatisticalResults() } // NetworkController()
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -