⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 measure.java.svn-base

📁 AI大作业实现神经网络的小程序
💻 SVN-BASE
字号:
package parameter;

import java.util.*;

public class Measure {

	Hashtable<Double, Integer> realNum = new Hashtable<Double, Integer>();

	Hashtable<Double, Integer> predictNum = new Hashtable<Double, Integer>();

	Hashtable<Double, Integer> rpMath = new Hashtable<Double, Integer>();

	Hashtable<Double, Double> precition = new Hashtable<Double, Double>();

	Hashtable<Double, Double> recall = new Hashtable<Double, Double>();

	Hashtable<Double, Double> F1 = new Hashtable<Double, Double>();

	Vector<Double> allLabel = new Vector<Double>();

	double[] realLabel;

	double[] predictLabel;

	public double macroP, macroR, macroF1;

	public double microP, microR, microF1;

	public double accuracy;

	public Measure(double[] rl, double[] pl) {
		realLabel = rl;
		predictLabel = pl;
		init();
	}

	private double getNearest(double pl) {
		double bestj = 0;
		int bj = -1;
		for (int i = 0; i < allLabel.size(); i++) {
			double ll = allLabel.get(i);
			double dis = Math.abs(ll - pl);
			if (dis < bestj || bj == -1) {
				bestj = dis;
				bj = i;
			}
		}
		return allLabel.get(bj);
	}

	private void init() {

		for (int i = 0; i < realLabel.length; i++) {
			double rl = realLabel[i];
			if (!allLabel.contains(rl))
				allLabel.add(rl);
		}

		accuracy = 0;
		for (int i = 0; i < realLabel.length; i++) {
			double rl = realLabel[i];
			double pl = predictLabel[i];

			pl = getNearest(pl);

			if (!realNum.containsKey(rl))
				realNum.put(rl, 0);
			realNum.put(rl, realNum.get(rl) + 1);

			if (!predictNum.containsKey(pl))
				predictNum.put(pl, 0);
			predictNum.put(pl, predictNum.get(pl) + 1);

			if (pl == rl) {
				accuracy += 1;
				if (!rpMath.contains(pl))
					rpMath.put(pl, 0);
				rpMath.put(pl, rpMath.get(pl) + 1);
			}
		}

		accuracy /= realLabel.length;

		calc();
	}

	private void calc() {
		macroP = 0;
		macroR = 0;
		microP = 0;
		microR = 0;
		for (int i = 0; i < allLabel.size(); i++) {
			double ll = allLabel.get(i);

			double pre = 0;
			if (predictNum.contains(ll))
				pre = (rpMath.get(ll) * 1.0) / predictNum.get(ll);
			double rec = 0;
			if (realNum.contains(ll))
				rec = (rpMath.get(ll) * 1.0) / realNum.get(ll);
			double f1 = (2 * pre * rec) / (pre + rec);
			precition.put(ll, pre);
			recall.put(ll, rec);
			F1.put(ll, f1);

			macroP += pre;
			macroR += rec;
		}
		macroP /= allLabel.size();
		macroR /= allLabel.size();
		macroF1 = (2 * macroP * macroR) / (macroP + macroR);
	}

	public double getPrecition(double l) {
		return precition.get(l);
	}

	public double getRecall(double l) {
		return recall.get(l);
	}

	public double getF1(double l) {
		return F1.get(l);
	}

	public double getMacroP() {
		return macroP;
	}

	public double getMacroR() {
		return macroR;
	}

	public double getMacroF1() {
		return macroF1;
	}

	public double getAccuracy() {
		return accuracy;
	}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -