⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 perceptron.java.svn-base

📁 AI大作业实现神经网络的小程序
💻 SVN-BASE
字号:
package perceptron;

import java.util.*;
import java.io.*;
import java.math.*;

/**
 * weight=weight+learnRate*delta*input
 * 
 * output=f(s)
 * 
 * @author huicongwm
 * 
 */
public class Perceptron {

	/**
	 * Weight.
	 */
	public double weight[];

	/**
	 * Input and Output.
	 */
	public double input[], output;

	/**
	 * Learn rate.
	 */
	public double learnRate;

	/**
	 * Threshold.
	 */
	public double threshold;

	/**
	 * The net structure.
	 */
	public Perceptron preP[], postP[];

	/**
	 * Now use delta.
	 */
	public double myDelta;

	/**
	 * The hidden layer.
	 */
	public int layer, id;

	public Perceptron(int layer, int id, int nin, double lrate, double threshold) {
		weight = new double[nin];
		for (int i = 0; i < nin; i++)
			weight[i] = (new Random()).nextDouble() - 0.5;

		input = new double[nin];
		output = 0;

		learnRate = lrate;

		this.threshold = threshold;
		this.layer = layer;
		this.id = id;
	}

	public Perceptron(InputStream f) {
		load(f);
	}

	public void initWeight(double w[]) {
		weight = w;
	}

	public void link(Perceptron[] pre, Perceptron[] post) {
		preP = pre;
		postP = post;
	}

	double getS(double a[], double b[]) {
		double c = 0;
		for (int i = 0; i < a.length && i < b.length; i++)
			c += a[i] * b[i];
		return c - threshold;
	}

	/**
	 * Simple function.
	 * 
	 * Overiable.
	 * 
	 * @param s=weight
	 *            dotProduct input.
	 * @return
	 */
	double f(double s) {
		return s;
	}

	/**
	 * f'
	 * 
	 * Overiable.
	 * 
	 * @param s
	 * @return
	 */

	double differential(double s) {
		return 1;
	}

	/**
	 * delta function for updata.
	 * 
	 * Overiable.
	 * 
	 * @param s
	 * @param d
	 * @return
	 */
	double delta(double s, double d) {
		myDelta = (d - f(s)) * differential(s);
		return myDelta;
	}

	double delta(double s) {
		double ud = 0;
		for (int l = 0; l < postP.length; l++) {
			ud += postP[l].myDelta * postP[l].weight[id];
		}
		myDelta = ud * differential(s);
		return myDelta;
	}

	public void updata(double d) {
		if (postP == null) {
			finalUpdata(d);
		} else {
			intermediateUpdata();
		}
	}

	public void finalUpdata(double d) {
		double s = getS(weight, input);
		double crate = learnRate * delta(s, d);
		for (int i = 0; i < weight.length; i++)
			weight[i] += crate * input[i];
	}

	public void intermediateUpdata() {
		double s = getS(weight, input);
		double crate = learnRate * delta(s);
		for (int i = 0; i < weight.length; i++)
			weight[i] += crate * input[i];
	}

	/**
	 * Main input function
	 * 
	 * @param x
	 */
	public void in(double x[]) {
		input = x;
	}

	/**
	 * Main output function
	 * 
	 * @return
	 */
	public double out() {
		double s = getS(weight, input);
		output = f(s);
		return output;
	}

	public String toString() {
		String ss = "Perceptron ";
		ss += "(Layer:" + layer + ")";
		ss += "(index:" + id + ")";
		ss += "(lRate:" + learnRate + ")";
		ss += "(Threshold:" + threshold + ")";
		for (int i = 0; i < weight.length; i++)
			ss += "(w[" + i + "]=" + weight[i] + ")";
		ss += "(NowDelta:" + myDelta + ")";
		return ss;
	}

	/**
	 * File format
	 * 
	 * @return
	 */
	public String toFile() {
		String ss = "";
		ss += layer + " ";
		ss += id + " ";
		ss += learnRate + "";
		ss += threshold + " ";
		ss += weight.length + " ";
		for (int i = 0; i < weight.length; i++)
			ss += weight[i] + " ";
		ss += myDelta;
		ss += "\n";
		return ss;
	}

	/**
	 * Save
	 * 
	 * @param fw
	 */
	public void save(OutputStream fw) {
		try {
			DataOutputStream dos = new DataOutputStream(fw);
			dos.writeInt(layer);
			dos.writeInt(id);
			dos.writeDouble(learnRate);
			dos.writeDouble(threshold);

			dos.writeInt(weight.length);
			for (int i = 0; i < weight.length; i++)
				dos.writeDouble(weight[i]);
			dos.writeDouble(myDelta);
		} catch (Exception e) {
			// TODO 自动生成 catch 块
			e.printStackTrace();
		}
	}

	/**
	 * Load
	 * 
	 * @param fw
	 */
	public void load(InputStream fw) {
		try {
			DataInputStream dis = new DataInputStream(fw);
			layer = dis.readInt();
			id = dis.readInt();
			learnRate = dis.readDouble();
			threshold = dis.readDouble();

			int nin = dis.readInt();
			input = new double[nin];
			output = 0;
			weight = new double[nin];
			for (int i = 0; i < nin; i++)
				weight[i] = (new Random()).nextDouble();
			for (int i = 0; i < nin; i++)
				weight[i] = dis.readDouble();

			myDelta = dis.readDouble();
		} catch (IOException e) {
			e.printStackTrace();
		}

	}
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -