⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 network.java

📁 神经网络源代码,实现了一个BP神经网络,可以完成基于BP的神经网络算法.
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
package net.openai.ai.nn.network;import java.util.*;import java.io.*;import net.openai.ai.nn.architecture.*;import net.openai.ai.nn.error.*;import net.openai.ai.nn.learning.*;import net.openai.ai.nn.input.*;import net.openai.ai.nn.transfer.*;import net.openai.ai.nn.training.*;import net.openai.ai.nn.parser.*;/** *  This class encapsulates all global info on the neural network  *  and contains information on the architecture type and error type.  *  All layers for the network are contained here as well... * */public class Network implements Serializable {    //  The set of all layers for this network.    private Vector layers = null;    //  The ordered list of layers for learning    private Vector learningOrder =null;    //  The architecture type used for this network    private Architecture architecture;    //  The error type used for this network.    private ErrorType errorType = null;     //  the training set that this network is to be trained on    private TrainingSet trainingSet = null;    //  the current training element being processed    private TrainingElement trainingElement = null;    //  A flag to determine whether the network is in learning mode    private boolean learning = false;    //  a flag to tell whether the network is connected or not    private boolean connected = false;    //  a flag to tell whether this network will use a bias    private boolean useBias = false;    //  a debug flag     private static boolean debug = true;    //  the error of the network    private double error = 0;    //  the error criterion to train until    private double errorCriterion = 0;    public Network() {	learningOrder = new Vector();    }    /**     * Constructor use to create a network from a configuration file.     * @param networkFileName Name of the file to load.     */    public Network(String networkFileName) {	//  initialize the network	this();	//  load the configuration for the network	try {	    loadConfiguration(networkFileName);	} catch (NetworkConfigurationException nce) {	    db("Could not load configuration: " + nce.getMessage());	    nce.printStackTrace();	}    }    /**     * Creates a network by loading a configuration file by name.     * @param networkFileName The name of the configuration file to load.     */    public final void loadConfiguration(String networkFileName)     	throws NetworkConfigurationException {	File networkFile = new File(networkFileName);	if(!networkFile.exists()) {	    throw new NetworkConfigurationException("Configuration file " 						    + networkFileName 						    + " not found.");	}	//  create the network based on this file.	loadConfiguration(networkFile);    }    /**     * Creates a network by loading a configuration file.     * @param networkFile The configuration file to load.     */    public final void loadConfiguration(File networkFile)	throws NetworkConfigurationException {    }    /**     * Loads the data from the specified file.     *     * @param inputFileName The name of the input data file to load.     * @param outputFileName The name of the output data file to load.     */    public final void loadTrainingData(String inputFileName, 				 String outputFileName) {	//  load the data	Data inputData = DataParser.parseData(inputFileName);	Data outputData = DataParser.parseData(outputFileName);	if(inputData == null || outputData == null) {	    db("Training data is corrupt or does not exist.");	    return;	}	//  get the column names	Vector inputColumns = inputData.getColumnNames();	Vector outputColumns = outputData.getColumnNames();	//  get the data rows	Vector inputRows = inputData.getRows();	Vector outputRows = outputData.getRows();	//  create the training set	trainingSet = new TrainingSet();	trainingSet.setInputCategories(inputColumns);	trainingSet.setOutputCategories(outputColumns);	//  check that the size of the rows match	if(inputRows.size() != outputRows.size())	    db("The number of input rows does not match the number "	       + "of output rows. Some rows will be lost.");	int size = Math.min(inputRows.size(), outputRows.size());	for(int i = 0; i < size; i++) {	    TrainingElement element 		= new TrainingElement((Vector) inputRows.elementAt(i), 				      (Vector) outputRows.elementAt(i));	    trainingSet.addElement(element);	}    }    /**     * Iterates the network for a specified number of iterations     * @param iterations The number of times to iterate the network.     */    public final void iterate(int iterations) {	if(trainingSet == null || trainingSet.isEmpty()) {	    db("No training set has been specified...");	    return;	}	//  check that the network is connected	if(!connected)	    try {		connect();	    } catch (NetworkConfigurationException nce) {		db("Network could not be connected for iteration.");		nce.printStackTrace();	    }	for(int i = 0; i < iterations; i++) {	    iterate();	    //db("Completed iteration " + (i + 1));	    System.err.println("Completed iteration " + (i + 1));	    System.err.println("Error: " + error);	}	System.err.println("Error: " + error);    }    /**     * Iterates the network for a specified number of iterations     * @param iterations The number of times to iterate the network.     */    public final void iterateToCriterion() {	if(trainingSet == null || trainingSet.isEmpty()) {	    db("No training set has been specified...");	    return;	}	//  check that the network is connected	if(!connected)	    try {		connect();	    } catch (NetworkConfigurationException nce) {		db("Network could not be connected for iteration.");		nce.printStackTrace();	    }	int counter = 0;	do {	    iterate();	    System.err.println("Completed iteration " + ++counter);	    System.err.println("Error: " + error);	} while(error >= errorCriterion);	db("errorcriterion: " + errorCriterion);	db("Reached error criterion...");    }    /**     * Iterates the network once.     */    public final void iterate() {	//  get the current training set	Enumeration e = trainingSet.getElements();	while(e.hasMoreElements()) {	    trainingElement = (TrainingElement) e.nextElement();	    //  get the input layer and seed it with inputs	    Layer inputLayer = getInputLayer();	    inputLayer.seedNeurons(trainingElement.getInput());	    //db("processing training element: " 	    //   + trainingElement.toString());	    //  for each layer	    int size = layers.size();	    for(int i = 0; i < size; i++) {		//db("Processing layer " + (i + 1));		Layer layer = (Layer) layers.elementAt(i);		//  call the calculate method for this layer		layer.calculate();	    }	    	    //  set the output results for this training element	    Layer outputLayer = getOutputLayer();	    Vector neurons = outputLayer.getNeurons();	    Vector outputs = new Vector();	    size = neurons.size();	    for(int i = 0; i < size; i++) {		Neuron neuron = (Neuron) neurons.elementAt(i);		outputs.addElement(new Double(neuron.getOutput()));	    }	    //db("setting the output: " + outputs.toString());	    trainingElement.setOutput(outputs);	    //  calculate the error for the network	    //db("calculating error...");	    errorType.calculateError(this);	    	    //  if learning is turned on...	    if(learning) {		//  check if we have ordered learning		if(learningOrder.isEmpty()) 		    orderLearning();		else		    learning();	    }	}    }    /**     * Iterates through the layers setting the type of layer.     */    private final void setLayerTypes() {	int size = layers.size();	for(int i = 0; i < size; i++) {	    Layer layer = (Layer) layers.elementAt(i);	    if(layer.equals(layers.firstElement()))		layer.setLayerType(Layer.INPUT_LAYER);	    else if(layer.equals(layers.lastElement()))		layer.setLayerType(Layer.OUTPUT_LAYER);	    else		layer.setLayerType(Layer.HIDDEN_LAYER);	}    }    /**     * Iterates through the layers ordering them as they are trained.     */    private final void orderLearning() {	// copy all the layers over	Vector layersCopy = new Vector();	int size = layers.size(); 	for(int i = 0; i < size; i++) { 	    layersCopy.addElement(layers.elementAt(i)); 	}	db("ordering/learning...");	while(!layersCopy.isEmpty()) {	    Layer layer = (Layer) layersCopy.firstElement();	    if(!layer.readyToLearn()) {		db("layer not ready to learn: " + layer.toString());		layersCopy.addElement(layersCopy.remove(0));		continue;	    } else {		db("found a layer ready to learn: " + layer.toString());		//  call the learn method		try {		    layer.learn(trainingElement);		} catch (NetworkConfigurationException nce) {		    db("Aborting learning: " + nce.getMessage());		    nce.printStackTrace();		    return;		}		learningOrder.addElement(layersCopy.remove(0));	    }	}	db("Ordered Layers: " + learningOrder.toString());    }    /**     * Iterates through the layers training them.     */    private final void learning() {	//db("learning...");	//  for each layer	int size = learningOrder.size();	for(int i = 0; i < size; i++) {	    Layer layer = (Layer) learningOrder.elementAt(i);// 	    if(layer.getLayerType() == Layer.INPUT_LAYER)// 		db("working on Input Layer");// 	    if(layer.getLayerType() == Layer.HIDDEN_LAYER)// 		db("working on Hidden Layer");// 	    if(layer.getLayerType() == Layer.OUTPUT_LAYER)// 		db("working on Output Layer");	    	    //  call the learn method	    try {		layer.learn(trainingElement);	    } catch (NetworkConfigurationException nce) {		db("Aborting learning: " + nce.getMessage());		nce.printStackTrace();		return;	    }	}    }    /**     * Returns whether the network has been connected.     * @return boolean Whether the network is connected.     */    public final boolean isConnected() {	return connected;    }    /**     * Set the connected flag.     * @param sets whether this network is connected.     */    public final void setConnected(boolean connected) {	this.connected = connected;    }    /**     * Connect the network using the networks' architecture rule.     */    public final void connect() 	throws NetworkConfigurationException {	//  set the layer types	setLayerTypes();	if(useBias) { 	    int size = layers.size(); 	    for(int i = 0; i < size; i++) { 		Layer layer = (Layer) layers.elementAt(i);		if(layer.getLayerType() == Layer.OUTPUT_LAYER) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -