📄 example1.java
字号:
package neuralNetwork;
/*
Coded by Aydin Gurel, 2003
The code is free, but please contact me if you wish to use the code entirely or partially in any kind of project so that I can reference it and please don't delete these lines so that other people can reach this information. Also, please inform me if you encounter a bug.
aydingurel@hotmail.com
http://aydingurel.brinkster.com/neural
The purpose of this example is to show you how this package should be used, rather than to give you clues about how to train and use a neural network in general.
In this example, we:
-create a generalized feed forward net using a previously saved configuration,
-create a patternset, reserve 10% of it for cross validation, 10% for testing and the rest for training
-train it using batch training until crossvalerror < 0.02 so that it learns how to compute a simple function such as y = x1 + x2,
-save its weights,
-recreate the net using previously saved configuration and weigts,
-test it
*/
import java.io.*;
public class example1 {
public static void main(String args[]) {
// construct a Randomizer object using a seed
Randomizer randomizer = new Randomizer(4);
// create a net using a configuration file
System.out.println("Creating the net");
NeuralNet mynet = new NeuralNet("example1.nnc", randomizer);
// create a pattern set with 2 input and 1 output values.
// randomly choose 80% of data for training, 10% for cross validation, 10% for testing.
// the function to be learned is: y = x1 + x2
PatternSet mypatterns = new PatternSet("example1.csv", 2, 1, .8, .1, .1, randomizer);
// display the error rate before training
System.out.println("\n\nError ratio before training: " + mynet.CrossValErrorRatio(mypatterns) );
// train the net using batch training, until error ratio < 0.02
while ( mynet.CrossValErrorRatio(mypatterns) > 0.002 ) {
mynet.BatchTrainPatterns(mypatterns.trainingpatterns, .8);
System.out.println("Training the net. Error ratio: " + mynet.CrossValErrorRatio(mypatterns) );
}
// check the error using test data
System.out.println("Error ratio of the test data: " + mynet.TestErrorRatio(mypatterns) );
System.out.println("Training is over");
// note that this was the easiest function to learn. if we had chosen another function
// instead of y = x1 + x2, we would need much more training.
// now that the training is over, save the weights of the net.
System.out.println("Saving the weights\n");
try{mynet.SaveWeights("example1.nnw");}catch(IOException e){}
// clean up the objects
mypatterns = null;
mynet = null;
randomizer = null;
///痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/
// now recreate the net using previously saved data and
// use the trained net to calculate some numbers.
// You could use this part in a separate java class.
// recreate the net
randomizer = new Randomizer();
System.out.println("Recreating the net");
mynet = new NeuralNet("example1.nnc", randomizer);
mynet.LoadWeights("example1.nnw");
// and test it
double[] inputs = {-.5, .25};
double[] outputs = mynet.Output(inputs); // Although there will be only one output.
System.out.println("-0.5 + 0.25 = " + outputs[0]);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -