findbest.java.svn-base

来自「AI大作业实现神经网络的小程序」· SVN-BASE 代码 · 共 64 行

SVN-BASE
64
字号
package parameter;

import java.io.*;
import java.util.*;

import neuralnetworks.*;

/**
 * For AI;
 * 
 * @author huicongwm
 * 
 */
public class FindBest {

	double[][] trainSet;

	double[] labelSet;

	CrossValidation crossV;

	NeuralNetwork myNN;

	Measure nowMeasure = null;

	String nnName;

	int tms = 0;

	public FindBest(NeuralNetwork nn, double[][] ts, double[] tl, String na) {
		trainSet = ts;
		labelSet = tl;
		myNN = nn;
		nnName = na;
		crossV = new CrossValidation(nn, trainSet, labelSet, 10);
		crossV.divide(0.1);
	}

	boolean stop() {
		tms++;
		return nowMeasure.accuracy > 0.9 || tms > 50000;
	}

	boolean save(Measure m) {
		return nowMeasure == null || m.accuracy > nowMeasure.accuracy;
	}

	public void run() {

		while (true) {
			crossV.run();
			if (save(crossV.testMeasure)) {
				nowMeasure = crossV.testMeasure;
				myNN.saveNet(nnName);
			}
			if (stop()) {
				System.out.println(nowMeasure.getAccuracy());
				return;
			}
		}
	}

}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?