📄 joonetools.java
字号:
MemoryInputSynapse memTarget = null; if (desired != null) { memTarget = new MemoryInputSynapse(); memTarget.setInputArray(desired); memTarget.setAdvancedColumnSelector("1-"+desired[0].length); } Monitor mon = nnet.getMonitor(); nnet.getMonitor().setValidation(true); if (mon.getValidationPatterns() == 0) mon.setValidationPatterns(input.length); return compare_on_stream(nnet, memInput, memTarget); } /** * Permits to compare the output and target data of a trained neural network using StreamInputSynapses as the input/desired data sources. * * @param nnet The neural network to train * @param input the StreamInputSynapse containing the training data. The advColumnSelector must be set according to the # of input nodes * @param desired the StreamInputSynapse containing the target data. The advColumnSelector must be set according to the # of output nodes * @return a 2D of double containing the output+desired data for each pattern. */ public static double[][] compare_on_stream(NeuralNet nnet, StreamInputSynapse input, StreamInputSynapse desired) { nnet.removeAllInputs(); nnet.removeAllOutputs(); nnet.addInputSynapse(input); ComparingSynapse teacher = new ComparingSynapse(); teacher.setDesired(desired); nnet.addOutputSynapse(teacher); MemoryOutputSynapse outStream = new MemoryOutputSynapse(); teacher.addResultSynapse(outStream); train_complete(nnet, 1, 0, 0, null, false); Vector results = outStream.getAllPatterns(); int rows = results.size(); int columns = ((Pattern)results.get(0)).getArray().length; double[][] output = new double[rows][columns]; for (int i=0; i < rows; ++i) { output[i] = ((Pattern)results.get(i)).getArray(); } return output; } /** * Extracts a subset of data from the StreamInputSynapse passed as parameter. * @return A 2D array of double containing the extracted data * @param dataSet The input StreamInputSynapse. Must be buffered. * @param firstRow The first row (relative to the internal buffer) to extract * @param lastRow The last row (relative to the internal buffer) to extract * @param firstCol The first column (relative to the internal buffer) to extract * @param lastCol The last column (relative to the internal buffer) to extract */ public static double[][] getDataFromStream(StreamInputSynapse dataSet, int firstRow, int lastRow, int firstCol, int lastCol) { // Force the reading of all the input data dataSet.Inspections(); Vector data = dataSet.getInputPatterns(); int rows = lastRow - firstRow + 1; int columns = lastCol - firstCol + 1; double[][] array = new double[rows][columns]; for (int r=0; r < rows; ++r) { double[] temp = ((Pattern)data.get(r + firstRow - 1)).getArray(); for (int c=0; c < columns; ++c) { array[r][c] = temp[c + firstCol - 1]; } } return array; } /** * Saves a neural network to a file * @param nnet The network to save * @param fileName the file name on which the network is saved * @throws java.io.FileNotFoundException if the file name is invalid * @throws java.io.IOException when an IO error occurs */ public static void save(NeuralNet nnet, String fileName) throws FileNotFoundException, IOException { FileOutputStream stream = new FileOutputStream(fileName); save_toStream(nnet, stream); } /** * Saves a neural network to a file * @param nnet The network to save * @param fileName the file on which the network is saved * @throws java.io.FileNotFoundException if the file name is invalid * @throws java.io.IOException when an IO error occurs */ public static void save(NeuralNet nnet, File fileName) throws FileNotFoundException, IOException { FileOutputStream stream = new FileOutputStream(fileName); save_toStream(nnet, stream); } /** * Saves a neural network to an OutputStream * @param nnet The neural network to save * @param stream The OutputStream on which the network is saved * @throws java.io.IOException when an IO error occurs */ public static void save_toStream(NeuralNet nnet, OutputStream stream) throws IOException { ObjectOutput output = new ObjectOutputStream(stream); output.writeObject(nnet); output.close(); } /** * Loads a neural network from a file * @param fileName the name of the file from which the network is loaded * @throws java.io.IOException when an IO error occurs * @throws java.io.FileNotFoundException if the file name is invalid * @throws java.lang.ClassNotFoundException if some neural network's object is not found in the classpath * @return The loaded neural network */ public static NeuralNet load(String fileName) throws FileNotFoundException, IOException, ClassNotFoundException { File NNFile = new File(fileName); FileInputStream fin = new FileInputStream(NNFile); NeuralNet nnet = load_fromStream(fin); fin.close(); return nnet; } /** * Loads a neural network from an InputStream * @param stream The InputStream from which the network is loaded * @throws java.io.IOException when an IO error occurs * @throws java.lang.ClassNotFoundException some neural network's object is not found in the classpath * @return The loaded neural network */ public static NeuralNet load_fromStream(InputStream stream) throws IOException, ClassNotFoundException { ObjectInputStream oin = new ObjectInputStream(stream); NeuralNet nnet = (NeuralNet)oin.readObject(); oin.close(); return nnet; } /** * Connects two layers with the given synapse * @param l1 The source layer * @param syn The synapse to use to connect the two layers * @param l2 The destination layer */ protected static void connect(Layer l1, Synapse syn, Layer l2) { l1.addOutputSynapse(syn); l2.addInputSynapse(syn); } /** * Creates a stop pattern (i.e. a Pattern with counter = -1) * @param size The size of the Pattern's array * @return the created stop Pattern */ protected static Pattern stopPattern(int size) { Pattern stop = new Pattern(new double[size]); stop.setCount(-1); return stop; } /** * Creates a listener for a NeuralNet object. * The listener writes the results to the stdOut object every 'interval' epochs. * If stdOut points to a NeuralNetListener instance, the corresponding methods are invoked. * If stdOut points to a PrintStream instance, a corresponding message is written. * @param nnet The NeuralNetwork to which the listener will be attached * @param stdOut the NeuralNetListener, or the PrintStream instance to which the notifications will be made * @param interval The interval of epochs between two calls to the cyclic events cycleTerminated and errorChanged * @return The created listener */ protected static NeuralNetListener createListener(final NeuralNet nnet, final Object stdOut, final int interval) { NeuralNetListener listener = new NeuralNetListener() { Object output = stdOut; int interv = interval; NeuralNet neuralNet = nnet; public void netStarted(NeuralNetEvent e) { if (output == null) { return; } if (output instanceof PrintStream) { ((PrintStream)output).println("Network started"); } else if (output instanceof NeuralNetListener) { e.setNeuralNet(neuralNet); ((NeuralNetListener)output).netStarted(e); } } public void cicleTerminated(NeuralNetEvent e) { if (output == null) { return; } Monitor mon = (Monitor)e.getSource(); int epoch = mon.getCurrentCicle() - 1; if ((interval == 0) || (epoch % interval > 0)) return; if (output instanceof PrintStream) { ((PrintStream)output).print("Epoch n."+(mon.getTotCicles()-epoch)+" terminated"); if (mon.isSupervised()) { ((PrintStream)output).print(" - rmse: "+mon.getGlobalError()); } ((PrintStream)output).println(""); } else if (output instanceof NeuralNetListener) { e.setNeuralNet(neuralNet); ((NeuralNetListener)output).cicleTerminated(e); } } public void errorChanged(NeuralNetEvent e) { if (output == null) { return; } Monitor mon = (Monitor)e.getSource(); int epoch = mon.getCurrentCicle() - 1; if ((interval == 0) || (epoch % interval > 0)) return; if (output instanceof NeuralNetListener) { e.setNeuralNet(neuralNet); ((NeuralNetListener)output).errorChanged(e); } } public void netStopped(NeuralNetEvent e) { if (output == null) { return; } if (output instanceof PrintStream) { ((PrintStream)output).println("Network stopped"); } else if (output instanceof NeuralNetListener) { e.setNeuralNet(neuralNet); ((NeuralNetListener)output).netStopped(e); } } public void netStoppedError(NeuralNetEvent e,String error) { if (output == null) { return; } if (output instanceof PrintStream) { ((PrintStream)output).println("Network stopped with error:"+error); } else if (output instanceof NeuralNetListener) { e.setNeuralNet(neuralNet); ((NeuralNetListener)output).netStoppedError(e, error); } } }; return listener; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -