⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bpa.java

📁 【源码共享系列-JSmartMiner.Eclipse】一个JAVA版本的数据挖掘基本框架体系
💻 JAVA
字号:
package org.scut.DataMining.Algorithm.NeuralNetwork.BP;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.Random;


import org.scut.DataMining.Core.MiningData;
import org.scut.DataMining.Core.MiningException;
import org.scut.DataMining.Core.MiningMetaData;
import org.scut.DataMining.Input.File.MiningArffStream;

public class BPA
{
    private int ni;
    private int nh;
    private int no;
    private double range = 0.01;
    private double enta = 0.5;
    private int maxEpoch = 2000;
    
    private double[][] wih;
    private double[][] dwih;
    private double[][] who;
    private double[][] dwho;
    
    private double[] bh;
    private double[] dbh;
    private double[] bo;
    private double[] dbo;
    
    private double[] oh;
    private double[] oo;
    private double[] so;
    private double[] sh;
	
    private ArrayList<double[]> inputSet;
    private ArrayList<double[]> targetSet;
    
    private Random r = new Random();
   
    public BPA()
	{
		super();
	}
    public void train() throws MiningException
    {
    	this.initialize();
    	if(this.inputSet == null || this.targetSet == null)
    		throw new MiningException("input or target set not set yet!");
    	int size = this.inputSet.size();
    	int epoch = 0;
    	while(epoch++ < this.maxEpoch)
    	{
    		for(int i=0;i<size;i++)
    		{
    			double[] input = this.inputSet.get(i);
    			double[] target = this.targetSet.get(i);
    			this.forward(input);
    			this.backward(input,target);
    		}
    		this.updateWeights();
    	}
    }
    
    public double[] work(double[] input)
    {
    	this.forward(input);
    	return this.oo;
    }
    private void check() throws MiningException
    {	
    	if(this.ni <= 0)
    		throw new MiningException("BP, input layer count<=0");
    	if(this.nh<=0)
    		throw new MiningException("BP, hidden layer count<=0");
    	if(this.no <= 0)
    		throw new MiningException("BP, output layer count<=0");
    	if(this.enta <= 0)
    		throw new MiningException("BP, learning rate<=0");
    }
    private double random()
    {
    	return (r.nextDouble()-0.5)*2*this.range;
    }
    /**
	 * Sets the parameter for the BP algorithm
	 * @param params
	 */
	public void setParameter(Parameter params)
	{
		this.ni = params.input; 
		this.nh = params.hidden; 
		this.no = params.output;
		this.maxEpoch = params.maxEpoch;
		this.enta = params.learningRate;
		this.range = params.randomRange;
	}
    /**
	 * Sets the input traning set
	 * @param inputSet
	 * @throws MiningException
	 */
	public void setInputSet(ArrayList<double[]> inputSet) throws MiningException
	{
		if(inputSet == null)
			throw new MiningException("inputSet passed into is null");
		this.check();
		this.inputSet = new ArrayList<double[]>();
		for(double[] input : inputSet)
		{
			if(input.length != this.ni)
				throw new MiningException("Training data input section data size not equal to the network input size");
			this.inputSet.add(input);
		}
		if(this.targetSet != null && this.targetSet.size() != this.inputSet.size())
			throw new MiningException("Traning data input and target count not equal");
	}
	/**
	 * Sets the target training set
	 * @param targetSet
	 * @throws MiningException
	 */
	public void setTargetSet(ArrayList<double[]> targetSet) throws MiningException
	{
		if(targetSet == null)
			throw new MiningException("targetSet passed into is null");
		this.check();
		this.targetSet = new ArrayList<double[]>();
		for(double[] target : targetSet)
		{
			if(target.length != this.no)
				throw new MiningException("Training data input section data size not equal to the network input size");
			this.targetSet.add(target);
		}
		if(this.inputSet != null && this.targetSet.size() != this.inputSet.size())
			throw new MiningException("Traning data target and input count not equal");
	}
	
    private void initialize() throws MiningException
    {
    	this.check();
    	this.oh = new double[this.nh];
    	this.oo = new double[this.no];
    	
    	this.sh = new double[this.nh];
    	this.so = new double[this.no];
    	
    	this.dwih = new double[this.ni][];
    	for(int i=0;i<this.ni;i++)
    		this.dwih[i] = new double[this.nh];
    	this.dwho = new double[this.nh][];
    	for(int j=0;j<this.nh;j++)
    		this.dwho[j] = new double[this.no];
    	
    	this.dbh = new double[this.nh];
    	this.dbo = new double[this.no];
    	
    	this.wih = new double[this.ni][];
    	for(int i=0;i<this.ni;i++)
    		this.wih[i] = new double[this.nh];
    	this.who = new double[this.nh][];
    	for(int j=0;j<this.nh;j++)
    		this.who[j] = new double[this.no];
    	
    	this.bh = new double[this.nh];
    	this.bo = new double[this.no];
    	for(int i=0;i<this.ni;i++)
    		for(int j=0;j<this.nh;j++)
    			this.wih[i][j] = this.random();
    	for(int j=0;j<this.nh;j++)
    		for(int k=0;k<this.no;k++)
    			this.who[j][k] = this.random();
    	for(int j=0;j<this.nh;j++)
    		this.bh[j] = this.random();
    	for(int k=0;k<this.no;k++)
    		this.bo[k] = this.random();
    	
    }
    private void forward(double[] input)
    {
    	for(int j=0;j<this.nh;j++)
    	{
    		double netj = this.bh[j];
    		for(int i=0;i<this.ni;i++)
    			netj += this.wih[i][j] * input[i];
    		this.oh[j] = this.squash(netj);
    	}
    	for(int k=0;k<this.no;k++)
    	{
    		double netk = this.bo[k];
    		for(int j=0;j<this.nh;j++)
    			netk += this.who[j][k] * this.oh[j];
    		this.oo[k] = this.squash(netk);
    	}
    }
    private void backward(double[] input,double[] target)
    {
    	for(int k=0;k<this.no;k++)
    		this.so[k] = (target[k]-this.oo[k])*this.oo[k]*(1-this.oo[k]);
    	for(int j=0;j<this.nh;j++)
    	{
    		double ss = 0;
    		for(int k=0;k<this.no;k++)
    			ss += this.who[j][k] * this.so[k];
    		this.sh[j] = this.oh[j]*(1-this.oh[j])*ss;
    	}
    	
    	//: computes the delta weights
    	for(int j=0;j<this.nh;j++)
    		for(int k=0;k<this.no;k++)
    			this.dwho[j][k] += this.enta * this.so[k] * this.oh[j];
    	for(int i=0;i<this.ni;i++)
    		for(int j=0;j<this.nh;j++)
    			this.dwih[i][j] += this.enta * this.sh[j] * input[i];
    	for(int k=0;k<this.no;k++)
    		this.dbo[k] += this.enta * this.so[k];
    	for(int j=0;j<this.nh;j++)
    		this.dbh[j] += this.enta * this.sh[j];
    	
    }
    private void updateWeights()
    {
    	//: weights
    	for(int j=0;j<this.nh;j++)
    		for(int k=0;k<this.no;k++)
    			this.who[j][k] += this.dwho[j][k];
    	for(int i=0;i<this.ni;i++)
    		for(int j=0;j<this.nh;j++)
    			this.wih[i][j] += this.dwih[i][j];
    	for(int k=0;k<this.no;k++)
    		this.bo[k] += this.dbo[k];
    	for(int j=0;j<this.nh;j++)
    		this.bh[j] += this.dbh[j];
    	//: delta weights
    	for(int j=0;j<this.nh;j++)
    		for(int k=0;k<this.no;k++)
    			this.dwho[j][k] = 0;
    	for(int i=0;i<this.ni;i++)
    		for(int j=0;j<this.nh;j++)
    			this.dwih[i][j] = 0;
    	for(int k=0;k<this.no;k++)
    		this.dbo[k] = 0;
    	for(int j=0;j<this.nh;j++)
    		this.dbh[j] = 0;
    }
   
    private double squash(double value)
    {
    	return 1.0/(1.0+Math.exp(-value));
    }
    public void save(String fileName)
    {
    	try
		{
			BufferedWriter bw = new BufferedWriter(new FileWriter(fileName));
			String siho =  this.ni+","+this.nh+","+this.no+"\n";
			StringBuilder swih = new StringBuilder();
			StringBuilder swho = new StringBuilder();
			StringBuilder sbh = new StringBuilder();
			StringBuilder sbo = new StringBuilder();
			for(int i=0;i<this.ni;i++)
			{
				for(int j=0;j<this.nh;j++)
					if( i == this.ni-1 && j == this.nh-1)
						swih.append(this.wih[i][j] + "\n");
					else
						swih.append(this.wih[i][j] + ",");
			}
			for(int j=0;j<this.nh;j++)
			{
				for(int k=0;k<this.no;k++)
					if( j == this.nh-1 && k == this.no -1)
						swho.append(this.who[j][k] + "\n");
					else
						swho.append(this.who[j][k] + ",");
			}
			for(int j=0;j<this.nh;j++)
			{
				if(j==this.nh-1)
					sbh.append(this.bh[j] + "\n");
				else
					sbh.append(this.bh[j] + ",");
			}
			for(int k=0;k<this.no;k++)
			{
				if(k==this.no-1)
					sbo.append(this.bo[k] + "\n");
				else
					sbo.append(this.bo[k] + ",");
			}
			bw.write(siho);
			bw.write(swih.toString());
			bw.write(swho.toString());
			bw.write(sbh.toString());
			bw.write(sbo.toString());
			bw.close();
		} 
    	catch (IOException e)
		{
			e.printStackTrace();
		}
    }
    public static BPA load(String fileName)
    {
    	BPA bp =  new BPA();
    	try
		{
			BufferedReader br = new BufferedReader(new FileReader(fileName));
			String iho = br.readLine();
			String wih = br.readLine();
			String who = br.readLine();
			String bh = br.readLine();
			String bo = br.readLine();
			
			String[] siho = iho.split("[,]");
			String[] swih = wih.split("[,]");
			String[] swho = who.split("[,]");
			String[] sbh =  bh.split("[,]");
			String[] sbo =  bo.split("[,]");
			
			bp.ni = Integer.valueOf(siho[0]);
			bp.nh = Integer.valueOf(siho[1]);
			bp.no = Integer.valueOf(siho[2]);
			
			bp.initialize();
	    	
	    	for(int i=0;i<swih.length;i++)
	    		bp.wih[i/bp.nh][i%bp.nh] = Double.valueOf(swih[i]);
	    	for(int i=0;i<swho.length;i++)
	    		bp.who[i/bp.nh][i%bp.nh] = Double.valueOf(swih[i]);
	    	for(int j=0;j<bp.nh;j++)
	    		bp.bh[j] = Double.valueOf(sbh[j]);
	    	for(int k=0;k<bp.no;k++)
	    		bp.bo[k] = Double.valueOf(sbo[k]);
	    	
	    	br.close();
		} 
    	catch (Exception e)
		{
			e.printStackTrace();
		}
    	return bp;
    }
    
	public static void main(String[] args)
	{
		long start = 0,end = 0; 
		start = new Date().getTime();
		try 
		{
			ArrayList<MiningData> data = new ArrayList<MiningData>();
			
			MiningArffStream arff = new MiningArffStream("arff//pm.arff");
			while(arff.next())
			{
				MiningData d = new MiningData(arff.getData());
				data.add(d);
				d.normalize();
				//System.out.println(d.toString());
			}
			MiningMetaData meta = arff.getMetaData();
			
			for(int i=6;i<18;i++) meta.addInput(i);
			for(int i=0;i<6;i++) meta.addTarget(i);
			
			ArrayList<double[]> inputSet = new ArrayList<double[]>();
			ArrayList<double[]> targetSet = new ArrayList<double[]>();
			for(MiningData d : data)
			{
				inputSet.add(d.getInput());
				targetSet.add(d.getTarget());
			}
			
			Parameter param = new Parameter();
			param.input = meta.getInputCount();
			param.output = meta.getTargetCount();
			param.hidden = (param.input+param.output)/2;
			
			param.maxEpoch = 200;
			param.learningRate = 0.1;
			param.randomRange = 0.01;
			BPA bp = new BPA();
			bp.setParameter(param);
			bp.setInputSet(inputSet);
			bp.setTargetSet(targetSet);
			
			bp.train();
			bp.save("tmp.txt");
		} 
		catch (MiningException e) 
		{
			e.printStackTrace();
		}
		end = new Date().getTime();
		System.out.println("Time eclipsed[s]: " + (end-start)/1000.0);
	}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -