⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 neural_2h.java

📁 Practical AI in Java 人工智能编程
💻 JAVA
字号:
/**
 * Title:        Neural_2H<p>
 * Description:  Two hidden layer back propagtion neural network model<p>
 * Copyright:    Copyright (c) 2000 Mark Watson. Artisitic Open Source License<p>
 * @author Mark Watson
 * @version 1.2
 */

import java.util.*;
import java.io.*;

class Neural_2H implements Serializable {

    protected int numInputs;
    protected int numHidden1;
    protected int numHidden2;
    protected int numOutputs;

    protected int numTraining;

    public float inputs[];
    protected float hidden1[];
    protected float hidden2[];
    public float outputs[];

    protected float W1[][];
    protected float W2[][];
    protected float W3[][];

    protected float output_errors[];
    protected float hidden1_errors[];
    protected float hidden2_errors[];

    transient protected Vector inputTraining = new Vector();
    transient protected Vector outputTraining = new Vector();

    Neural_2H(int num_in, int num_hidden1, int num_hidden2, int num_output) {

        numInputs = num_in;
        numHidden1 = num_hidden1;
        numHidden2 = num_hidden2;
        numOutputs = num_output;
        inputs = new float[numInputs];
        hidden1 = new float[numHidden1];
        hidden2 = new float[numHidden2];
        outputs = new float[numOutputs];
        W1 = new float[numInputs][numHidden1];
        W2 = new float[numHidden1][numHidden2];
        W3 = new float[numHidden2][numOutputs];
        randomizeWeights();

        output_errors = new float[numOutputs];
        hidden1_errors = new float[numHidden1];
        hidden2_errors = new float[numHidden2];
    }

    public void addTrainingExample(float[] inputs, float[] outputs) {
        if (inputs.length != numInputs || outputs.length != numOutputs) {
            System.out.println("addTrainingExample(): array size is wrong");
            return;
        }
        inputTraining.addElement(inputs);
        outputTraining.addElement(outputs);
    }

    public static Neural_2H Factory(String serialized_file_name) {
        Neural_2H nn = null;
        try {
            InputStream ins = ClassLoader.getSystemResourceAsStream(serialized_file_name);
            if (ins == null) {
                System.out.println("CachedExamples(): failed to open 'cache.dat' in JAR file");
                System.exit(1);
            } else {
                ObjectInputStream p = new ObjectInputStream(ins);
                nn = (Neural_2H) p.readObject();
                nn.inputTraining = new Vector();
                nn.outputTraining = new Vector();
                ins.close();
            }
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
        return nn;
    }

    public void save(String file_name) {
        try {
            FileOutputStream ostream = new FileOutputStream(file_name);
            ObjectOutputStream p = new ObjectOutputStream(ostream);
            p.writeObject(this);
            p.flush();
            ostream.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void randomizeWeights() {
        // Randomize weights here:
        for (int ii = 0; ii < numInputs; ii++)
            for (int hh = 0; hh < numHidden1; hh++)
                W1[ii][hh] =
                        0.1f * (float) Math.random() - 0.05f;
        for (int ii = 0; ii < numHidden1; ii++)
            for (int hh = 0; hh < numHidden2; hh++)
                W2[ii][hh] =
                        0.1f * (float) Math.random() - 0.05f;
        for (int hh = 0; hh < numHidden2; hh++)
            for (int oo = 0; oo < numOutputs; oo++)
                W3[hh][oo] =
                        0.1f * (float) Math.random() - 0.05f;
    }

    public float[] recall(float[] in) {
        for (int i = 0; i < numInputs; i++) inputs[i] = in[i];
        forwardPass();
        float[] ret = new float[numOutputs];
        for (int i = 0; i < numOutputs; i++) ret[i] = outputs[i];
        return ret;
    }

    public void forwardPass() {
        int i, h, o;
        for (h = 0; h < numHidden1; h++) {
            hidden1[h] = 0.0f;
        }
        for (h = 0; h < numHidden2; h++) {
            hidden2[h] = 0.0f;
        }
        for (i = 0; i < numInputs; i++) {
            for (h = 0; h < numHidden1; h++) {
                hidden1[h] +=
                        inputs[i] * W1[i][h];
            }
        }
        for (i = 0; i < numHidden1; i++) {
            for (h = 0; h < numHidden2; h++) {
                hidden2[h] +=
                        hidden1[i] * W2[i][h];
            }
        }
        for (o = 0; o < numOutputs; o++)
            outputs[o] = 0.0f;
        for (h = 0; h < numHidden2; h++) {
            for (o = 0; o < numOutputs; o++) {
                outputs[o] +=
                        sigmoid(hidden2[h]) * W3[h][o];
            }
        }
        for (o = 0; o < numOutputs; o++)
            outputs[o] = sigmoid(outputs[o]);
    }

    public float train() {
        return train(inputTraining, outputTraining);
    }

    private int current_example = 0;

    public float train(Vector ins, Vector v_outs) {
        int i, h, o;
        float error = 0.0f;
        int num_cases = ins.size();
        //for (int example=0; example<num_cases; example++) {
        // zero out error arrays:
        for (h = 0; h < numHidden1; h++)
            hidden1_errors[h] = 0.0f;
        for (h = 0; h < numHidden2; h++)
            hidden2_errors[h] = 0.0f;
        for (o = 0; o < numOutputs; o++)
            output_errors[o] = 0.0f;
        // copy the input values:
        for (i = 0; i < numInputs; i++) {
            inputs[i] = ((float[]) ins.elementAt(current_example))[i];
        }
        // copy the output values:
        float[] outs = (float[]) v_outs.elementAt(current_example);

        // perform a forward pass through the network:

        forwardPass();

        for (o = 0; o < numOutputs; o++) {
            output_errors[o] =
                    (outs[o] -
                    outputs[o])
                    * sigmoidP(outputs[o]);
        }
        for (h = 0; h < numHidden2; h++) {
            hidden2_errors[h] = 0.0f;
            for (o = 0; o < numOutputs; o++) {
                hidden2_errors[h] +=
                        output_errors[o] * W3[h][o];
            }
        }
        for (h = 0; h < numHidden1; h++) {
            hidden1_errors[h] = 0.0f;
            for (o = 0; o < numHidden2; o++) {
                hidden1_errors[h] +=
                        hidden2_errors[o] * W2[h][o];
            }
        }
        for (h = 0; h < numHidden2; h++) {
            hidden2_errors[h] =
                    hidden2_errors[h] * sigmoidP(hidden2[h]);
        }
        for (h = 0; h < numHidden1; h++) {
            hidden1_errors[h] =
                    hidden1_errors[h] * sigmoidP(hidden1[h]);
        }
        // update the hidden2 to output weights:
        for (o = 0; o < numOutputs; o++) {
            for (h = 0; h < numHidden2; h++) {
                W3[h][o] +=
                        0.5 * output_errors[o] * hidden2[h];
            }
        }
        // update the hidden1 to hidden2 weights:
        for (o = 0; o < numHidden2; o++) {
            for (h = 0; h < numHidden1; h++) {
                W2[h][o] +=
                        0.5 * hidden2_errors[o] * hidden1[h];
            }
        }
        // update the input to hidden1 weights:
        for (h = 0; h < numHidden1; h++) {
            for (i = 0; i < numInputs; i++) {
                W1[i][h] +=
                        0.5 * hidden1_errors[h] * inputs[i];
            }
        }
        for (o = 0; o < numOutputs; o++) {
            error += Math.abs(outs[o] - outputs[o]);
            //error += Math.abs(output_errors[o]);
        }
        current_example++;
        if (current_example >= num_cases) current_example = 0;
        return error;
    }

    protected float sigmoid(float x) {
        return
                (float) (1.0f / (1.0f + Math.exp((double) (-x))));
    }

    protected float sigmoidP(float x) {
        double z = sigmoid(x); //  + 0.5f;
        return (float) (z * (1.0f - z));
    }

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -