📄 network.java
字号:
package net.openai.ai.nn.network;import java.util.*;import java.io.*;import net.openai.ai.nn.architecture.*;import net.openai.ai.nn.error.*;import net.openai.ai.nn.learning.*;import net.openai.ai.nn.input.*;import net.openai.ai.nn.transfer.*;import net.openai.ai.nn.training.*;import net.openai.ai.nn.parser.*;/** * This class encapsulates all global info on the neural network * and contains information on the architecture type and error type. * All layers for the network are contained here as well... * */public class Network implements Serializable { // The set of all layers for this network. private Vector layers = null; // The ordered list of layers for learning private Vector learningOrder =null; // The architecture type used for this network private Architecture architecture; // The error type used for this network. private ErrorType errorType = null; // the training set that this network is to be trained on private TrainingSet trainingSet = null; // the current training element being processed private TrainingElement trainingElement = null; // A flag to determine whether the network is in learning mode private boolean learning = false; // a flag to tell whether the network is connected or not private boolean connected = false; // a flag to tell whether this network will use a bias private boolean useBias = false; // a debug flag private static boolean debug = true; // the error of the network private double error = 0; // the error criterion to train until private double errorCriterion = 0; public Network() { learningOrder = new Vector(); } /** * Constructor use to create a network from a configuration file. * @param networkFileName Name of the file to load. */ public Network(String networkFileName) { // initialize the network this(); // load the configuration for the network try { loadConfiguration(networkFileName); } catch (NetworkConfigurationException nce) { db("Could not load configuration: " + nce.getMessage()); nce.printStackTrace(); } } /** * Creates a network by loading a configuration file by name. * @param networkFileName The name of the configuration file to load. */ public final void loadConfiguration(String networkFileName) throws NetworkConfigurationException { File networkFile = new File(networkFileName); if(!networkFile.exists()) { throw new NetworkConfigurationException("Configuration file " + networkFileName + " not found."); } // create the network based on this file. loadConfiguration(networkFile); } /** * Creates a network by loading a configuration file. * @param networkFile The configuration file to load. */ public final void loadConfiguration(File networkFile) throws NetworkConfigurationException { } /** * Loads the data from the specified file. * * @param inputFileName The name of the input data file to load. * @param outputFileName The name of the output data file to load. */ public final void loadTrainingData(String inputFileName, String outputFileName) { // load the data Data inputData = DataParser.parseData(inputFileName); Data outputData = DataParser.parseData(outputFileName); if(inputData == null || outputData == null) { db("Training data is corrupt or does not exist."); return; } // get the column names Vector inputColumns = inputData.getColumnNames(); Vector outputColumns = outputData.getColumnNames(); // get the data rows Vector inputRows = inputData.getRows(); Vector outputRows = outputData.getRows(); // create the training set trainingSet = new TrainingSet(); trainingSet.setInputCategories(inputColumns); trainingSet.setOutputCategories(outputColumns); // check that the size of the rows match if(inputRows.size() != outputRows.size()) db("The number of input rows does not match the number " + "of output rows. Some rows will be lost."); int size = Math.min(inputRows.size(), outputRows.size()); for(int i = 0; i < size; i++) { TrainingElement element = new TrainingElement((Vector) inputRows.elementAt(i), (Vector) outputRows.elementAt(i)); trainingSet.addElement(element); } } /** * Iterates the network for a specified number of iterations * @param iterations The number of times to iterate the network. */ public final void iterate(int iterations) { if(trainingSet == null || trainingSet.isEmpty()) { db("No training set has been specified..."); return; } // check that the network is connected if(!connected) try { connect(); } catch (NetworkConfigurationException nce) { db("Network could not be connected for iteration."); nce.printStackTrace(); } for(int i = 0; i < iterations; i++) { iterate(); //db("Completed iteration " + (i + 1)); System.err.println("Completed iteration " + (i + 1)); System.err.println("Error: " + error); } System.err.println("Error: " + error); } /** * Iterates the network for a specified number of iterations * @param iterations The number of times to iterate the network. */ public final void iterateToCriterion() { if(trainingSet == null || trainingSet.isEmpty()) { db("No training set has been specified..."); return; } // check that the network is connected if(!connected) try { connect(); } catch (NetworkConfigurationException nce) { db("Network could not be connected for iteration."); nce.printStackTrace(); } int counter = 0; do { iterate(); System.err.println("Completed iteration " + ++counter); System.err.println("Error: " + error); } while(error >= errorCriterion); db("errorcriterion: " + errorCriterion); db("Reached error criterion..."); } /** * Iterates the network once. */ public final void iterate() { // get the current training set Enumeration e = trainingSet.getElements(); while(e.hasMoreElements()) { trainingElement = (TrainingElement) e.nextElement(); // get the input layer and seed it with inputs Layer inputLayer = getInputLayer(); inputLayer.seedNeurons(trainingElement.getInput()); //db("processing training element: " // + trainingElement.toString()); // for each layer int size = layers.size(); for(int i = 0; i < size; i++) { //db("Processing layer " + (i + 1)); Layer layer = (Layer) layers.elementAt(i); // call the calculate method for this layer layer.calculate(); } // set the output results for this training element Layer outputLayer = getOutputLayer(); Vector neurons = outputLayer.getNeurons(); Vector outputs = new Vector(); size = neurons.size(); for(int i = 0; i < size; i++) { Neuron neuron = (Neuron) neurons.elementAt(i); outputs.addElement(new Double(neuron.getOutput())); } //db("setting the output: " + outputs.toString()); trainingElement.setOutput(outputs); // calculate the error for the network //db("calculating error..."); errorType.calculateError(this); // if learning is turned on... if(learning) { // check if we have ordered learning if(learningOrder.isEmpty()) orderLearning(); else learning(); } } } /** * Iterates through the layers setting the type of layer. */ private final void setLayerTypes() { int size = layers.size(); for(int i = 0; i < size; i++) { Layer layer = (Layer) layers.elementAt(i); if(layer.equals(layers.firstElement())) layer.setLayerType(Layer.INPUT_LAYER); else if(layer.equals(layers.lastElement())) layer.setLayerType(Layer.OUTPUT_LAYER); else layer.setLayerType(Layer.HIDDEN_LAYER); } } /** * Iterates through the layers ordering them as they are trained. */ private final void orderLearning() { // copy all the layers over Vector layersCopy = new Vector(); int size = layers.size(); for(int i = 0; i < size; i++) { layersCopy.addElement(layers.elementAt(i)); } db("ordering/learning..."); while(!layersCopy.isEmpty()) { Layer layer = (Layer) layersCopy.firstElement(); if(!layer.readyToLearn()) { db("layer not ready to learn: " + layer.toString()); layersCopy.addElement(layersCopy.remove(0)); continue; } else { db("found a layer ready to learn: " + layer.toString()); // call the learn method try { layer.learn(trainingElement); } catch (NetworkConfigurationException nce) { db("Aborting learning: " + nce.getMessage()); nce.printStackTrace(); return; } learningOrder.addElement(layersCopy.remove(0)); } } db("Ordered Layers: " + learningOrder.toString()); } /** * Iterates through the layers training them. */ private final void learning() { //db("learning..."); // for each layer int size = learningOrder.size(); for(int i = 0; i < size; i++) { Layer layer = (Layer) learningOrder.elementAt(i);// if(layer.getLayerType() == Layer.INPUT_LAYER)// db("working on Input Layer");// if(layer.getLayerType() == Layer.HIDDEN_LAYER)// db("working on Hidden Layer");// if(layer.getLayerType() == Layer.OUTPUT_LAYER)// db("working on Output Layer"); // call the learn method try { layer.learn(trainingElement); } catch (NetworkConfigurationException nce) { db("Aborting learning: " + nce.getMessage()); nce.printStackTrace(); return; } } } /** * Returns whether the network has been connected. * @return boolean Whether the network is connected. */ public final boolean isConnected() { return connected; } /** * Set the connected flag. * @param sets whether this network is connected. */ public final void setConnected(boolean connected) { this.connected = connected; } /** * Connect the network using the networks' architecture rule. */ public final void connect() throws NetworkConfigurationException { // set the layer types setLayerTypes(); if(useBias) { int size = layers.size(); for(int i = 0; i < size; i++) { Layer layer = (Layer) layers.elementAt(i); if(layer.getLayerType() == Layer.OUTPUT_LAYER) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -