📄 example2.java
字号:
package neuralNetwork;
/*
Coded by Aydin Gurel, 2003
The code is free, but please contact me if you wish to use the code entirely or partially in any kind of project so that I can reference it and please don't delete these lines so that other people can reach this information. Also, please inform me if you encounter a bug.
aydingurel@hotmail.com
http://aydingurel.brinkster.com/neural
The purpose of this example is to show you how this package should be used, rather than to give you clues about how to train and use a neural network in general.
In this example, we:
-aim to resolve a classification problem,
-create a multilayer perceptron,
-save the configuration of the net,
-create three patternsets;
-one for training, one for cross validation and one for testing.
-train the net using mini batch, until crossvalerror < 0.02, so that it learns how to distinguish a female crab by using some measurements,
-save its weights,
-recreate the net using previously saved configuration and weigts,
-test it
training data was taken from a tutorial of the "NeuroSolutions" software
http://www.neurosolutions.com/
*/
import java.io.*;
public class example2 {
public static void main(String args[]) {
Randomizer randomizer = new Randomizer();
// create a multilayer perceptron with four layers:
// one input layer with seven units; two hidden layers
// each with ten units, using tanh function;
// one output layer with one unit using tanh function.
// except for noofneurons, all parameters for the input layer
// are ineffectual.
int[] noofneurons = {7,10,10,1};
double[] learnratecoeff = {1, 1, 1, 1};
char[] axonfamily = {'t', 't', 't', 't'};
double[] momentumrate = {0, .6, .5, .4};
double[] flatness = {1, 1.2, 1.1, 1};
System.out.println("Creating the net");
NeuralNet mynet = new NeuralNet(noofneurons, learnratecoeff, axonfamily, momentumrate, flatness, randomizer);
// Save the configuration to a file
System.out.println("Saving the configuration");
try{mynet.SaveConfig("example2.nnc");}catch(IOException e){}
// create three pattern sets with 7 input and 1 output values.
System.out.println("Loading patterns");
// first create a pattern set for training
PatternSet trainingpatterns = new PatternSet("example2_training.csv", 7, 1, 1, 0, 0, randomizer);
// then create a pattern set for cross validation
PatternSet crossvalpatterns = new PatternSet("example2_crossval.csv", 7, 1, 0, 1, 0, randomizer);
// and then create a pattern set for testing
PatternSet testpatterns = new PatternSet("example2_test.csv", 7, 1, 0, 0, 1, randomizer);
// show the error ratio before training
System.out.println("Error ratio before training: " + mynet.CrossValErrorRatio(crossvalpatterns) );
// train the net using mini batch training
System.out.println("Beginning mini batch training");
double temp_err;
temp_err = mynet.CrossValErrorRatio(crossvalpatterns);
while (temp_err > .02) {
System.out.println("Training the net. Error ratio: " + temp_err );
mynet.MinibatchTrainPatterns(trainingpatterns.trainingpatterns, .1, 20);
temp_err = mynet.CrossValErrorRatio(crossvalpatterns);
}
/*
// or you can use incremental training:
System.out.println("Beginning incremental training");
double temp_err;
temp_err = mynet.CrossValErrorRatio(crossvalpatterns);
while (temp_err > .02) {
System.out.println("Training the net. Error ratio: " + temp_err );
mynet.IncrementalTrainPatterns(trainingpatterns.trainingpatterns, .01);
temp_err = mynet.CrossValErrorRatio(crossvalpatterns);
}
*/
// finally, check the error using test data
System.out.println("Error ratio of the test data: " + mynet.TestErrorRatio(testpatterns) );
System.out.println("Training is over");
// now that the training is over, save the weights of the net.
System.out.println("Saving the weights\n");
try{mynet.SaveWeights("example2.nnw");}catch(IOException e){}
// clean up the objects
trainingpatterns = null;
crossvalpatterns = null;
testpatterns = null;
mynet = null;
randomizer = null;
///痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/
// now recreate the net using previously saved data and
// test it.
// recreate the net
randomizer = new Randomizer();
System.out.println("Recreating the net");
mynet = new NeuralNet("example2.nnc", randomizer);
mynet.LoadWeights("example2.nnw");
// and test it
double[] inputs = {-19.5,0.5,-0.683,0.4615,-2.0055,-0.8145,-2.0305};
System.out.println("Feeding the net using data taken from a female crab");
double[] outputs = mynet.Output(inputs);
if ( outputs[0] < 0 ) {
System.out.println("Result: Male");
}
else {
System.out.println("Result: Female");
}
mynet = null;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -