⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 xor_static_rbf.java

📁 一个纯java写的神经网络源代码
💻 JAVA
字号:
package org.joone.samples.engine.xor.rbf;import org.joone.engine.*;import org.joone.engine.learning.*;import org.joone.io.*;import org.joone.net.*;import java.util.Vector;/** * Very simple example of the Gaussian (static and random centers) RBF solving the XOR problem. * * @author Boris Jansen */public class XOR_static_RBF implements NeuralNetListener {        /** The neural network. */    private NeuralNet nnet = null;        /** The RBF hidden layer. */    RbfGaussianLayer hidden = null;        /** Synapses. */    private MemoryInputSynapse inputSynapse, desiredOutputSynapse;    private MemoryOutputSynapse outputSynapse;        /** If the following flag is true, the centers will be chosen randomly.     * Otherwise it will use predefined, fixed centers (able to solve the XOR     * problem.     */    private boolean randomCenters = false;        // XOR input    private double[][] inputArray = new double[][] {        {0.0, 0.0},        {0.0, 1.0},        {1.0, 0.0},        {1.0, 1.0}    };        // XOR desired output    private double[][] desiredOutputArray = new double[][] {        {1.0},        {0.0},        {0.0},        {1.0}    };        /**     * Main     *     * @param args the command line arguments     */    public static void main(String args[]) {        XOR_static_RBF xor = new XOR_static_RBF();                xor.initNeuralNet();        xor.train();        xor.test();    }        /**     * Method declaration     */    public void train() {                // set the inputs        inputSynapse.setInputArray(inputArray);        inputSynapse.setAdvancedColumnSelector("1,2");                // set the desired outputs        desiredOutputSynapse.setInputArray(desiredOutputArray);        desiredOutputSynapse.setAdvancedColumnSelector("1");                // get the monitor object to train or feed forward        Monitor monitor = nnet.getMonitor();                // set the monitor parameters        monitor.setLearningRate(0.3);        monitor.setMomentum(0.8);        monitor.setTrainingPatterns(inputArray.length);        monitor.setTotCicles(200);                // RPROP parameters (uncomment if you want to use the RPROP learning algorithm)        //monitor.getLearners().add(0, "org.joone.engine.RpropLearner");        //monitor.setBatchSize(4);        //monitor.setLearningMode(0);                monitor.setLearning(true);        nnet.addNeuralNetListener(this);        nnet.go(true);    }        /**     * Create and init the neural network.     */    protected void initNeuralNet() {        // First create the three layers        LinearLayer input = new LinearLayer();        hidden = new RbfGaussianLayer();        //SigmoidLayer output = new SigmoidLayer(); // you can try it (not a traditional RBF network)        BiasedLinearLayer output = new BiasedLinearLayer();                // set the dimensions of the layers        input.setRows(2);        hidden.setRows(2);        output.setRows(1);                if(!randomCenters) {            // Use static Gaussian RBFs            RbfGaussianParameters[] myParameters = new RbfGaussianParameters[2];            double[] myMean0 = {0.0, 0.0};            myParameters[0] = new RbfGaussianParameters(myMean0, Math.sqrt(.5));            double[] myMean1 = {1.0, 1.0};            myParameters[1] = new RbfGaussianParameters(myMean1, Math.sqrt(.5));            hidden.setGaussianParameters(myParameters);        }                // Now create the two synapses        RbfInputSynapse synapse_IH = new RbfInputSynapse(); /* input -> hidden conn. */        FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */                // Connect the input layer whit the hidden layer        input.addOutputSynapse(synapse_IH);        hidden.addInputSynapse(synapse_IH);                // Connect the hidden layer whit the output layer        hidden.addOutputSynapse(synapse_HO);        output.addInputSynapse(synapse_HO);                // the input to the neural net        inputSynapse = new MemoryInputSynapse();        input.addInputSynapse(inputSynapse);        if(randomCenters) {            hidden.useRandomCenter(inputSynapse);        }                // The Trainer and its desired output        desiredOutputSynapse = new MemoryInputSynapse();        TeachingSynapse trainer = new TeachingSynapse();        trainer.setDesired(desiredOutputSynapse);                // Now we add this structure to a NeuralNet object        nnet = new NeuralNet();                nnet.addLayer(input, NeuralNet.INPUT_LAYER);        nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);        nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);        nnet.setTeacher(trainer);        output.addOutputSynapse(trainer);    }        public void test() {        // attach a MemoryOutputSynapse to the output of the neural net        outputSynapse = new MemoryOutputSynapse();        nnet.getOutputLayer().addOutputSynapse(outputSynapse);        nnet.getMonitor().setTotCicles(1);        nnet.getMonitor().setTrainingPatterns(4);        nnet.getMonitor().setLearning(false);        nnet.removeAllListeners();        nnet.go();                System.out.println("Outputs");        System.out.println("-------");        for(int i = 0; i < 4; i++) {            double[] myPattern = outputSynapse.getNextPattern();            System.out.println("Output: " + myPattern[0]);        }                System.out.println("Centers RBF neurons: ");        RbfGaussianParameters[] myParams = hidden.getGaussianParameters();        for(int i = 0; i < myParams.length; i++) {            String myText = (i+1) + ": [center: ";            for(int j = 0; j < myParams[i].getMean().length; j++) {                myText += myParams[i].getMean()[j] + ", ";            }            myText += "Std dev: " + myParams[i].getStdDeviation() + "]";            System.out.println(myText);        }    }        public void cicleTerminated(NeuralNetEvent e) {    }        public void errorChanged(NeuralNetEvent e) {        Monitor mon = (Monitor)e.getSource();        if (mon.getCurrentCicle() % 100 == 0)            System.out.println("Epoch: "+(mon.getTotCicles()-mon.getCurrentCicle())+" RMSE:"+mon.getGlobalError());    }        public void netStarted(NeuralNetEvent e) {    }        public void netStopped(NeuralNetEvent e) {    }        public void netStoppedError(NeuralNetEvent e, String error) {    }    }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -