📄 airwin.java
字号:
/*
Coded by Aydin Gurel, 2003
The code is free, but please contact me if you wish to use the code entirely or partially in any kind of project so that I can reference it and please don't delete these lines so that other people can reach this information. Also, please inform me if you encounter a bug.
aydingurel@hotmail.com
http://aydingurel.brinkster.com/neural
The purpose of this example is to show you how this package should be used, rather than to give you clues about how to train and use a neural network in general.
In this example, we:
-aim to resolve a classification problem,
-create a multilayer perceptron,
-save the configuration of the net,
-create three patternsets;
-one for training, one for cross validation and one for testing.
-train the net using mini batch, until crossvalerror < 0.02, so that it learns how to distinguish a female crab by using some measurements,
-save its weights,
-recreate the net using previously saved configuration and weigts,
-test it
training data was taken from a tutorial of the "NeuroSolutions" software
http://www.neurosolutions.com/
*/
import java.io.IOException;
import jnt.FFT.ComplexDouble2DFFT;
import neuralNetwork.*;
import mouse.*;
public class AirWin {
public static void main(String args[]) {
mouseTracker myMouseTracker = new mouseTracker();
Randomizer randomizer = new Randomizer();
int size=34;
// create a multilayer perceptron with four layers:
// one input layer with seven units; two hidden layers
// each with ten units, using tanh function;
// one output layer with one unit using tanh function.
// except for noofneurons, all parameters for the input layer
// are ineffectual.
int[] noofneurons = { size * size * 2, 3000, 4 };
double[] learnratecoeff = { .8, .8, .8 };
char[] axonfamily = { 't', 't', 't' };
double[] momentumrate = { 0, 1.2, 1.1 };
double[] flatness = { 1, 1.2, 1 };
System.out.println("Creating the net");
NeuralNet mynet = new NeuralNet(noofneurons, learnratecoeff,
axonfamily, momentumrate, flatness, randomizer);
// Save the configuration to a file
System.out.println("Saving the configuration");
try {
mynet.SaveConfig("AirWin.nnc");
} catch (IOException e) {
}
// create three pattern sets with 7 input and 1 output values.
System.out.println("Loading patterns");
// first create a pattern set for training
PatternSet trainingpatterns = new PatternSet("AirWin_training.csv",
size * size * 2, 4, 1,
0, 0, randomizer);
// then create a pattern set for cross validation
PatternSet crossvalpatterns = new PatternSet("AirWin_crossval.csv",
size * size * 2, 4, 0,
1, 0, randomizer);
// and then create a pattern set for testing
PatternSet testpatterns = new PatternSet("AirWin_test.csv",
size * size * 2, 4, 0,
0, 1, randomizer);
// show the error ratio before training
System.out.println("Error ratio before training: "
+ mynet.CrossValErrorRatio(crossvalpatterns));
// train the net using mini batch training
System.out.println("Beginning mini batch training");
double temp_err;
temp_err = mynet.CrossValErrorRatio(crossvalpatterns);
while (temp_err > 0.3) {
System.out.println("Training the net. Error ratio: " + temp_err);
mynet.MinibatchTrainPatterns(trainingpatterns.trainingpatterns, .3,
30);
temp_err = mynet.CrossValErrorRatio(crossvalpatterns);
}
/*
* // or you can use incremental training:
* System.out.println("Beginning incremental training"); double
* temp_err; temp_err = mynet.CrossValErrorRatio(crossvalpatterns);
* while (temp_err > .3) {
* System.out.println("Training the net. Error ratio: " + temp_err );
* mynet.IncrementalTrainPatterns(trainingpatterns.trainingpatterns,
* .01); temp_err = mynet.CrossValErrorRatio(crossvalpatterns); }
*/
// finally, check the error using test data
System.out.println("Error ratio of the test data: "
+ mynet.TestErrorRatio(testpatterns));
System.out.println("Training is over");
// now that the training is over, save the weights of the net.
System.out.println("Saving the weights\n");
try {
mynet.SaveWeights("AirWin.nnw");
} catch (IOException e) {
}
// clean up the objects
trainingpatterns = null;
crossvalpatterns = null;
testpatterns = null;
mynet = null;
randomizer = null;
// /痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/痋_/
// now recreate the net using previously saved data and
// test it.
// recreate the net
randomizer = new Randomizer();
System.out.println("Recreating the net");
mynet = new NeuralNet("AirWin.nnc", randomizer);
mynet.LoadWeights("AirWin.nnw");
int ind;
double[] newPict= new double[myMouseTracker.yResolution*myMouseTracker.xResolution*2];
double[] transPict= new double[size*size*2];
ComplexDouble2DFFT methodes = new ComplexDouble2DFFT(myMouseTracker.xResolution, myMouseTracker.yResolution);
// and test it
while (true) {
myMouseTracker.fill();
ind=0;
for (int j=0; j<myMouseTracker.yResolution; j++){
for (int k=0; k<myMouseTracker.xResolution; k++){
newPict[ind++]=myMouseTracker.binaryMap[k][j];
newPict[ind++]=0;
}
}
methodes.transform(newPict);
ind = 0;
for (int k=0; k<newPict.length ;k++){
if ((k%myMouseTracker.xResolution*2<=size || k%myMouseTracker.xResolution*2>myMouseTracker.xResolution*2-size) && (k<(myMouseTracker.xResolution*2)*(size/2) || k>=(myMouseTracker.xResolution*myMouseTracker.yResolution*2)-(myMouseTracker.xResolution*2)*(size/2))){
transPict[ind++]=newPict[k];
}
}
double[] outputs = mynet.Output(transPict);
if (outputs[0] > 0)
System.out.println("Result: Rond");
if (outputs[1] > 0)
System.out.println("Result: croix");
if (outputs[2] > 0)
System.out.println("Result: barres horizontales");
if (outputs[3] > 0)
System.out.println("Result: barres verticales");
}
// mynet = null;
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -