⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 crossvalidation.java.svn-base

📁 AI大作业实现神经网络的小程序
💻 SVN-BASE
字号:
package parameter;

import neuralnetworks.*;

public class CrossValidation {

	NeuralNetwork myNN;

	double[][] allTrainData;

	double[] allTrainLabel;

	double[][] trainSet, testSet;

	double[] trainLabel, testLabel;

	double[] trainPredict, testPredict;

	int trainTms;

	public Measure testMeasure, trainMeasure;

	public CrossValidation(NeuralNetwork nn, double[][] atd, double[] tl,
			int tms) {
		myNN = nn;
		allTrainData = atd;
		allTrainLabel = tl;
		trainTms = tms;
	}

	/**
	 * overiable
	 * 
	 * @param crsooRate
	 */
	public void divide(double crossRate) {
		int len = allTrainData.length;
		int testLen = ((Double) (len * crossRate)).intValue();
		int trainLen = len - testLen;
		trainSet = new double[trainLen][];
		trainLabel = new double[trainLen];
		int i;
		for (i = 0; i < trainLen; i++) {
			trainSet[i] = allTrainData[i];
			trainLabel[i] = allTrainLabel[i];
		}
		testSet = new double[testLen][];
		testLabel = new double[testLen];
		for (; i < len; i++) {
			testSet[i - trainLen] = allTrainData[i];
			testLabel[i - trainLen] = allTrainLabel[i];
		}
	}

	public void run() {
		for (int i = 0; i < trainTms; i++) {
			for (int j = 0; j < trainSet.length; j++) {
				myNN.train(trainSet[j], trainLabel[j]);
			}
		}

		double[] ptl = new double[trainSet.length];
		for (int j = 0; j < trainSet.length; j++)
			ptl[j] = myNN.classify(trainSet[j]);
		trainPredict = ptl;

		double[] pcl = new double[testSet.length];
		for (int j = 0; j < testSet.length; j++)
			pcl[j] = myNN.classify(testSet[j]);
		testPredict = pcl;

		trainMeasure = new Measure(trainLabel, ptl);
		testMeasure = new Measure(testLabel, pcl);
	}

	public Measure getMeasure() {
		return testMeasure;
	}

	public static void main(String args[]) {
		double trainSet[][] = { { 1, 0, 1 }, { 0, 0, 1 }, { 0, 1, 1 },
				{ 1, 1, 1 } };
		double d[] = { 0, 1, 0, 1 };

		MultiClassNeuralNetwork nn = new MultiClassNeuralNetwork();
		nn.loadNet();
		CrossValidation cv = new CrossValidation(nn, trainSet, d, 100);

		System.out.println("old:" + nn.classify(trainSet[0]));
		System.out.println("old:" + nn.classify(trainSet[1]));
		System.out.println("old:" + nn.classify(trainSet[2]));
		System.out.println("old:" + nn.classify(trainSet[3]));

		cv.divide(0.25);
		cv.run();

		System.out.println("ac:    " + cv.trainMeasure.accuracy);
		System.out.println("ac:    " + cv.testMeasure.accuracy);

		System.out.println(nn.classify(trainSet[0]));
		System.out.println(nn.classify(trainSet[1]));
		System.out.println(nn.classify(trainSet[2]));
		System.out.println(nn.classify(trainSet[3]));

	}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -