⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bp.java

📁 数据挖掘。数据仓库
💻 JAVA
字号:
package org.scut.DataMining.Algorithm.NeuralNetwork.BP;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;

import org.scut.DataMining.Algorithm.NeuralNetwork.Core.*;
import org.scut.DataMining.Core.MiningData;
import org.scut.DataMining.Core.MiningException;
import org.scut.DataMining.Core.MiningMetaData;
import org.scut.DataMining.Input.File.MiningArffStream;

public final class BP 
{
	/** Parameters for BP algorithm */
	private int inputCount;
	private int hiddenCount;
	private int outputCount;
	private int maxEpoch;
	private double learningRate;
	private double randomRange;
	
	/** Input layer of the BP algorithm */
	private Layer inputLayer;
	/** Hidden layer of the BP algorithm */
	private Layer hiddenLayer;
	/** Output layer of the BP algorithm */
	private Layer outputLayer;
	/** Input train data set */
	private ArrayList<double[]> inputSet;
	/** Input target train data set */
	private ArrayList<double[]> targetSet;
	
	public BP() 
	{
		super();
	}
	public double[] work(double[] input)
	{
		double[] pass = input;
		if(input.length == this.inputCount-1)
		{
			pass = new double[this.inputCount];
			pass[0] = 1;
			for(int i=0;i<input.length;i++)
				pass[i+1] = input[i];
		}	
		try
		{
			this.forward(pass);
		} 
		catch (MiningException e)
		{
			e.printStackTrace();
		}
		return this.outputLayer.getOutActivity();
	}
	/**
	 * Sets the parameter for the BP algorithm
	 * @param params
	 */
	public void setParameter(Parameter params)
	{
		this.inputCount = params.input + 1; //: 1 neuron for bias
		this.hiddenCount = params.hidden + 1 ; //: 1 neuron for bias
		this.outputCount = params.output;
		this.maxEpoch = params.maxEpoch;
		this.learningRate = params.learningRate;
		this.randomRange = params.randomRange;
	}
	/**
	 * Sets the input traning set
	 * @param inputSet
	 * @throws MiningException
	 */
	public void setInputSet(ArrayList<double[]> inputSet) throws MiningException
	{
		if(inputSet == null)
			throw new MiningException("inputSet passed into is null");
		this.checkParameter();
		this.inputSet = new ArrayList<double[]>();
		for(double[] input : inputSet)
		{
			if(input.length != this.inputCount -1)
				throw new MiningException("Training data input section data size not equal to the network input size");
			double[] pass = new double[this.inputCount];
			pass[0] = 1; //: bias input forbbiden, any value is ok
			for(int i=0;i<input.length;i++)
				pass[i+1] = input[i];
			this.inputSet.add(pass);
		}
		if(this.targetSet != null && this.targetSet.size() != this.inputSet.size())
			throw new MiningException("Traning data input and target count not equal");
	}
	/**
	 * Sets the target training set
	 * @param targetSet
	 * @throws MiningException
	 */
	public void setTargetSet(ArrayList<double[]> targetSet) throws MiningException
	{
		if(targetSet == null)
			throw new MiningException("targetSet passed into is null");
		this.checkParameter();
		this.targetSet = new ArrayList<double[]>();
		for(double[] target : targetSet)
		{
			if(target.length != this.outputCount)
				throw new MiningException("Training data input section data size not equal to the network input size");
			this.targetSet.add(target);
		}
		if(this.inputSet != null && this.targetSet.size() != this.inputSet.size())
			throw new MiningException("Traning data target and input count not equal");
	}
	/**
	 * Checks the validation of the parameter
	 * @throws MiningException
	 */
    private void checkParameter() throws MiningException
    {
    	
    	if(this.inputCount <= 0)
    		throw new MiningException("BP, input layer count<=0");
    	if(this.hiddenCount<=0)
    		throw new MiningException("BP, hidden layer count<=0");
    	if(this.outputCount <= 0)
    		throw new MiningException("BP, output layer count<=0");
    	if(this.learningRate <= 0)
    		throw new MiningException("BP, learning rate<=0");
    }
    /**
     * Initializes the BP network
     * @throws MiningException
     */
    private void initialize() throws MiningException
    {
    	this.checkParameter();

    	this.inputLayer = new Layer(this.inputCount,Layer.LayerType.Input);
    	this.inputLayer.setBiasNeuron(0);
    	
    	this.hiddenLayer = new Layer(this.hiddenCount,Layer.LayerType.Hidden);
    	this.hiddenLayer.setBiasNeuron(0);
    	
    	this.outputLayer = new Layer(this.outputCount,Layer.LayerType.Output);
    	
    	Layer.linkLayer(this.inputLayer,this.hiddenLayer);
    	Layer.linkLayer(this.hiddenLayer,this.outputLayer);
    	
    	this.inputLayer.setOutSynapseInitRandomRange(this.randomRange);
    	this.inputLayer.setOutSynapseLearningRate(this.learningRate);
    	this.hiddenLayer.setOutSynapseInitRandomRange(this.randomRange);
    	this.hiddenLayer.setOutSynapseLearningRate(this.learningRate);
    }
    /**
     * Trains the BP network
     * @throws MiningException
     */
    public void train() throws MiningException
    {
    	this.initialize();
    	if(this.inputSet == null || this.targetSet == null)
    		throw new MiningException("Input and target set not already set");
    	
    	int epoch = 0;
    	int size = this.inputSet.size();
    	while(epoch++<this.maxEpoch)
    	{
    		for(int i=0;i<size;i++)
    		{
    			this.forward(this.inputSet.get(i));
    			this.backward(this.targetSet.get(i));
    		}
    		this.updateWeights();
    	}
    }
    /**
     * Forward propagates the input throw the network
     * @param input
     * @throws MiningException
     */
    private void forward(double[] input) throws MiningException
    {
    	this.inputLayer.lockOutActivity(input);
    	this.hiddenLayer.activate();
    	this.outputLayer.activate();
    }
    /**
     * Backward propagates the target throw the network
     * @param target
     * @throws MiningException
     */
    private void backward(double[] target) throws MiningException
    {
    	double[] fd = new double[target.length];
    	double[] out = this.outputLayer.getOutActivity();
    	for(int i=0;i<fd.length;i++)
    		fd[i] = target[i]-out[i];
    	
    	this.outputLayer.lockInFeedback(fd);
    
    	this.outputLayer.feedback();
    	this.outputLayer.updateBackwardDeltaWeights();
    	this.hiddenLayer.feedback();
    	this.hiddenLayer.updateBackwardDeltaWeights(); 
    }
    /**
     * Updates the weights of all the synapses
     */
    private void updateWeights()
    {
    	this.outputLayer.updateBackwardWeights();
    	this.hiddenLayer.updateBackwardWeights();
    }
    
    public void save(String fileName)
    {
    	try
		{
			BufferedWriter bw = new BufferedWriter(new FileWriter(fileName));
			String iho =  this.inputCount+","+this.hiddenCount+","+this.outputCount+"\n";
			StringBuilder sbih = new StringBuilder();
			StringBuilder sbho = new StringBuilder();
			int nwih = this.inputLayer.getOutSynapseCount();
			for(int i=0;i<nwih;i++)
			{
				Synapse syn = this.inputLayer.getOutSynapse(i);
				if( i != nwih-1)
					sbih.append(syn.getWeight()+",");
				else
					sbih.append(syn.getWeight()+"\n");
			}
			
			int nwho = this.hiddenLayer.getOutSynapseCount();
			for(int i=0;i<nwho;i++)
			{
				Synapse syn = this.hiddenLayer.getOutSynapse(i);
				if( i != nwho-1)
					sbho.append(syn.getWeight()+",");
				else
					sbho.append(syn.getWeight()+"\n");
			}
			bw.write(iho);
			bw.write(sbih.toString());
			bw.write(sbho.toString());
			bw.close();
		} 
    	catch (IOException e)
		{
			e.printStackTrace();
		}
    }
    public static BP load(String fileName)
    {
    	BP bp =  new BP();
    	try
		{
			BufferedReader br = new BufferedReader(new FileReader(fileName));
			String iho = br.readLine();
			String wih = br.readLine();
			String who = br.readLine();
			
			String[] siho = iho.split("[,]");
			String[] swih = wih.split("[,]");
			String[] swho = who.split("[,]");
			
			bp.inputCount = Integer.valueOf(siho[0]);
			bp.hiddenCount = Integer.valueOf(siho[1]);
			bp.outputCount = Integer.valueOf(siho[2]);
			
			bp.inputLayer = new Layer(bp.inputCount,Layer.LayerType.Input);
	    	bp.inputLayer.setBiasNeuron(0);
	    	
	    	bp.hiddenLayer = new Layer(bp.hiddenCount,Layer.LayerType.Hidden);
	    	bp.hiddenLayer.setBiasNeuron(0);
	    	
	    	bp.outputLayer = new Layer(bp.outputCount,Layer.LayerType.Output);
	    	Layer.linkLayer(bp.inputLayer,bp.hiddenLayer);
	    	Layer.linkLayer(bp.hiddenLayer,bp.outputLayer);
	    	
	    	for(int i=0;i<bp.inputLayer.getOutSynapseCount();i++)
	    	{
	    		Synapse syn = bp.inputLayer.getOutSynapse(i);
	    		double weight = Double.valueOf(swih[i]);
	    		syn.setWeight(weight);
	    	}
	    	for(int i=0;i<bp.hiddenLayer.getOutSynapseCount();i++)
	    	{
	    		Synapse syn = bp.hiddenLayer.getOutSynapse(i);
	    		double weight = Double.valueOf(swho[i]);
	    		syn.setWeight(weight);
	    	}
	    	br.close();
		} 
    	catch (Exception e)
		{
			e.printStackTrace();
		}
    	return bp;
    }
	/*********************************************************************/
	public static void main(String[] args)
	{
		long start = 0,end = 0; 
		start = new Date().getTime();
		try 
		{
			ArrayList<MiningData> data = new ArrayList<MiningData>();
			
			MiningArffStream arff = new MiningArffStream("arff//vowel.arff");
			while(arff.next())
			{
				MiningData d = new MiningData(arff.getData());
				data.add(d);
				d.normalize();
				//System.out.println(d.toString());
			}
			MiningMetaData meta = arff.getMetaData();
			meta.addTarget("'class'");
			meta.addInput("'feld4'");
			meta.addInput("'feld5'");
			meta.addInput("'feld6'");
			meta.addInput("'feld7'");
			meta.addInput("'feld8'");
			meta.addInput("'feld9'");
			meta.addInput("'feld10'");
			meta.addInput("'feld11'");
			meta.addInput("'feld12'");
			meta.addInput("'feld13'");
			ArrayList<double[]> inputSet = new ArrayList<double[]>();
			ArrayList<double[]> targetSet = new ArrayList<double[]>();
			for(MiningData d : data)
			{
				inputSet.add(d.getInput());
				targetSet.add(d.getTarget());
			}
			
			Parameter param = new Parameter();
			param.input = meta.getInputCount();
			param.output = meta.getTargetCount();
			param.hidden = (param.input+param.output)/2;
			
			param.maxEpoch = 200;
			param.learningRate = 0.5;
			param.randomRange = 0.05;
			BP bp = new BP();
			bp.setParameter(param);
			bp.setInputSet(inputSet);
			bp.setTargetSet(targetSet);
			
			bp.train();
			bp.save("tmp.txt");
		} 
		catch (MiningException e) 
		{
			e.printStackTrace();
		}
		end = new Date().getTime();
		System.out.println("Time eclipsed[s]: " + (end-start)/1000.0);
	}
	/*********************************************************************/

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -