📄 crossvalidation.java.svn-base
字号:
package parameter;
import neuralnetworks.*;
public class CrossValidation {
NeuralNetwork myNN;
double[][] allTrainData;
double[] allTrainLabel;
double[][] trainSet, testSet;
double[] trainLabel, testLabel;
double[] trainPredict, testPredict;
int trainTms;
public Measure testMeasure, trainMeasure;
public CrossValidation(NeuralNetwork nn, double[][] atd, double[] tl,
int tms) {
myNN = nn;
allTrainData = atd;
allTrainLabel = tl;
trainTms = tms;
}
/**
* overiable
*
* @param crsooRate
*/
public void divide(double crossRate) {
int len = allTrainData.length;
int testLen = ((Double) (len * crossRate)).intValue();
int trainLen = len - testLen;
trainSet = new double[trainLen][];
trainLabel = new double[trainLen];
int i;
for (i = 0; i < trainLen; i++) {
trainSet[i] = allTrainData[i];
trainLabel[i] = allTrainLabel[i];
}
testSet = new double[testLen][];
testLabel = new double[testLen];
for (; i < len; i++) {
testSet[i - trainLen] = allTrainData[i];
testLabel[i - trainLen] = allTrainLabel[i];
}
}
public void run() {
for (int i = 0; i < trainTms; i++) {
for (int j = 0; j < trainSet.length; j++) {
myNN.train(trainSet[j], trainLabel[j]);
}
}
double[] ptl = new double[trainSet.length];
for (int j = 0; j < trainSet.length; j++)
ptl[j] = myNN.classify(trainSet[j]);
trainPredict = ptl;
double[] pcl = new double[testSet.length];
for (int j = 0; j < testSet.length; j++)
pcl[j] = myNN.classify(testSet[j]);
testPredict = pcl;
trainMeasure = new Measure(trainLabel, ptl);
testMeasure = new Measure(testLabel, pcl);
}
public Measure getMeasure() {
return testMeasure;
}
public static void main(String args[]) {
double trainSet[][] = { { 1, 0, 1 }, { 0, 0, 1 }, { 0, 1, 1 },
{ 1, 1, 1 } };
double d[] = { 0, 1, 0, 1 };
MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork();
nn.loadNet();
CrossValidation cv = new CrossValidation(nn, trainSet, d, 100);
System.out.println("old:" + nn.classify(trainSet[0]));
System.out.println("old:" + nn.classify(trainSet[1]));
System.out.println("old:" + nn.classify(trainSet[2]));
System.out.println("old:" + nn.classify(trainSet[3]));
cv.divide(0.25);
cv.run();
System.out.println("ac: " + cv.trainMeasure.accuracy);
System.out.println("ac: " + cv.testMeasure.accuracy);
System.out.println(nn.classify(trainSet[0]));
System.out.println(nn.classify(trainSet[1]));
System.out.println(nn.classify(trainSet[2]));
System.out.println(nn.classify(trainSet[3]));
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -