⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 multiplevalidationsample.java

📁 一个纯java写的神经网络源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* * ValidationSample.java * * Created on 11 november 2002, 22.59 * @author  pmarrone */package org.joone.samples.engine.validation;import org.joone.engine.*;import org.joone.engine.learning.*;import org.joone.net.*;import org.joone.io.*;import org.joone.util.*;import java.io.*;/** * This example shows how to check the training level of a neural network * using a validation data source. * The training and the validation phases of the created network is executed * many times, showing for each one the resulting RMSE. * This program shows how to build the same kind of neural net as that * contained into the org/joone/samples/editor/scripting/ValidationSample.ser * file using only java code and the core engine's API. Open that net in * the GUI editor to see the architecture of the net built in this example. */public class MultipleValidationSample implements NeuralValidationListener {        NeuralNet nnet;    boolean ready;    int totNets = 10; // Number of neural nets to train & validate    int returnedNets = 0;    double totRMSE = 0;    double minRMSE = 99;        long mStart;    int trainingLCP = 1;    int validationLCP = 16;    int totCycles = 1000;    FileWriter wr = null;        private static String filePath = "org/joone/samples/engine/validation";    // Must point to a trained XOR network without I/O components    String xorNet = filePath+"/trainedXOR.snet";        /** Creates a new instance of SampleScript */    public MultipleValidationSample() {    }        /**     * @param args the command line arguments     */    public static void main(String[] args) {        MultipleValidationSample sampleNet = new MultipleValidationSample();        sampleNet.start();    }        private void start() {        try{            wr = new FileWriter(new File("/tmp/memory.txt"));            while (trainingLCP <= validationLCP){                // Start the LC Calculation                startValidation(trainingLCP,validationLCP);                trainingLCP += 1;                wr.flush();            }            // Draws the error's curve            wr.close();        }        catch (IOException ioe){ioe.printStackTrace();}        System.out.println("Done.");        System.exit(0);            }        private synchronized void startValidation(int trnP, int valP) {        nnet = initializeModularParity(trnP, valP);        //nnet = initializeSimpleParity(trnP, valP);        //nnet = initializeNetworkI(trnP, valP);        nnet.getMonitor().setTrainingPatterns(trnP);        nnet.getMonitor().setValidationPatterns(valP);                try {            mStart = System.currentTimeMillis();            returnedNets = 0;            totRMSE = 0;            minRMSE = 99;            // n = total number of neural networks to create, train and validate            int n = totNets;            // First of all, starts the initial number of            // neural networks that must be trained in parallel            for (int i=0; i < 1; ++i)  {                test(n--);            }            while (n > 0)   {                // Waits for a neural network's validation termination                // before to start another one                while (!ready) {                    try  {                        wait();                    } catch (InterruptedException doNothing) {}                }                ready = false;                test(n--);                long mem = getMemoryUse();                wr.write(mem+"\r\n");            }            while (returnedNets < totNets){                try  {                    wait();                } catch (InterruptedException doNothing) {}            }            // This code is executed when all the neural networks            // have been trained and validated            displayResults();        } catch (IOException ioe) { ioe.printStackTrace(); }    }        // Run a new training & validation phase    private void test(int n)  {        nnet.randomize(0.5);        nnet.setParam("ID", new Integer(n)); // Set its param ID        // Create the trainer object        NeuralNetTrainer trainer = new NeuralNetTrainer(nnet);        //NeuralNetTester trainer = new NeuralNetTester(nnet,true,0);        // Registers itself as a listener of the trainer object        trainer.addValidationListener(this);        // Run the training+validation tasks        trainer.start();    }        /* This method is called by the trainers for each validated neural network     * The param ID is used to recognize the returned net     */    public synchronized void netValidated(NeuralValidationEvent event) {        // Shows the RMSE at the end of the validation phase        NeuralNet NN = (NeuralNet)event.getSource();        int n = ((Integer)NN.getParam("ID")).intValue();        double rmse = NN.getMonitor().getGlobalError();        //System.out.print("Returned NeuralNet #"+n);        //System.out.println(" Validation RMSE: "+rmse);        totRMSE += rmse;        if (minRMSE > rmse)            minRMSE = rmse;        ++returnedNets;        ready = true;        notifyAll();    }        private void displayResults(){        // This code is executed when all the neural networks have been trained and validated        double aveRMSE = totRMSE/totNets;        long mTot = System.currentTimeMillis()-mStart;        System.out.println("---------------------------------------------------------");        System.out.println("Training Patterns: "+trainingLCP);        System.out.println("Average Generalization Error: "+aveRMSE);        System.out.println("Minimum Generalization Error: "+minRMSE);        System.out.println("Elapsed Time: "+mTot+" Miliseconds");        System.out.println("---------------------------------------------------------");    }        /// Garbage collection ////    private static long fSLEEP_INTERVAL = 20;        private static long getMemoryUse(){        //  collectGarbage(); // NOTE: To obtain the memory allocation w/o GC, comment this line.        long totalMemory = Runtime.getRuntime().totalMemory();        long freeMemory = Runtime.getRuntime().freeMemory();        return (totalMemory - freeMemory);    }            private static void collectGarbage() {        try {            System.gc();            Thread.currentThread().sleep(fSLEEP_INTERVAL);            System.runFinalization();            Thread.currentThread().sleep(fSLEEP_INTERVAL);        }        catch (InterruptedException ex){            ex.printStackTrace();        }    }    /** Configures & Starts the SimpleParity Network - Class Method     * @param learningPatternNumber Number of Learning Patterns     * @param testPatternNumber Number of Test Patterns     */    private NeuralNet initializeSimpleParity(int learningPatternNumber,int testPatternNumber){        // Initialize the neural network        NeuralNet network = new NeuralNet();        // Define & Initialize the Learning & Test inputs        double[][] learningData = constructLearningData(learningPatternNumber);        double[][] testData = constructTestData(testPatternNumber);        // Define & Initialize the network layers and Define the layer names        LinearLayer input = new LinearLayer();        SigmoidLayer hidden = new SigmoidLayer();        SigmoidLayer output = new SigmoidLayer();        input.setLayerName("Input Layer");        hidden.setLayerName("Hidden Layer");        output.setLayerName("Output Layer");        // Define the number of neurons for each layer        input.setRows(4);        hidden.setRows(4);        output.setRows(1);        // Define the input -> hidden connection        FullSynapse synapseIH = new FullSynapse();        synapseIH.setName("IH Synapse");        // Define the hidden -> output connection        FullSynapse synapseHO = new FullSynapse();        synapseHO.setName("HO Synapse");        // Connect the Input Layer with the Hidden Layer        NeuralNetFactory.connect(input,synapseIH,hidden);        // Connect the Hidden Layer with the Output Layer        NeuralNetFactory.connect(hidden,synapseHO,output);        // Define & Initialize the Learning Input Synapse        MemoryInputSynapse learningInputSynapse = NeuralNetFactory.createInput("Learning Input Synapse",learningData,1,1,4);        // Define the Test Input Synapse        MemoryInputSynapse testInputSynapse = NeuralNetFactory.createInput("Test Input Synapse",testData,1,1,4);        // Initialize the Input Switch Synapse        LearningSwitch inputSwitch = NeuralNetFactory.createSwitch("Input Switch Synapse",learningInputSynapse,testInputSynapse);        // Connect the Input Switch Synapse to the Input Layer        input.addInputSynapse(inputSwitch);        // Define the Trainer Input Switch        MemoryInputSynapse learningDesiredSynapse = NeuralNetFactory.createInput("Learning Desired Synapse",learningData,1,5,5);        // Define the Test Input Synapse        MemoryInputSynapse testDesiredSynapse = NeuralNetFactory.createInput("Test Desired Synapse",testData,1,5,5);        // Initialize the Input Switch Synapse        LearningSwitch learningSwitch = NeuralNetFactory.createSwitch("Learning Switch Synapse",learningDesiredSynapse,testDesiredSynapse);        // Define the Trainer and link it to the Monitor        TeachingSynapse trainer = new TeachingSynapse();        trainer.setName("Simple Parity Trainer Synapse");        // Connect the Teacher to the Output Layer        output.addOutputSynapse(trainer);        // Connect the Learning Switch Synapse to the Trainer        trainer.setDesired(learningSwitch);        // Define the Output Synapse Memory (Data)        MemoryOutputSynapse outputMemoryData = new MemoryOutputSynapse();        outputMemoryData.setName("Output Data");        // Connect the Output Memory Synapse (Data) to the Output        output.addOutputSynapse(outputMemoryData);        // Incorpore the network components to the NeuralNet object        network.addLayer(input);        network.addLayer(hidden);        network.addLayer(output);        network.setTeacher(trainer);        network.getMonitor().setLearningRate(0.7);        network.getMonitor().setMomentum(0.5);        network.getMonitor().setTotCicles(totCycles);        return network;    }    /** Constructs the network Learning Data based on the GUI options - Class Method     * @param learningPatternNumber The int Number of Training Patterns     * @return The Training Patterns Vector     */

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -