⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 adaboost.java

📁 aDABOOST This package contains the following files: learner.jar - is a platform independent java
💻 JAVA
字号:
package learner;

import java.util.Arrays;

public class Adaboost implements Classifier {

	public Data data;

	Linear[] strong;

	// ============================================================ Constructor
	Adaboost(Data data, int boostiteration) {

		this.data = data;
		adaboost(boostiteration);
	}

	// =============================================================== Adaboost
	public Linear[] adaboost(int boostiterations) {
		int i, j;
		double sumofweights = 0, a;
		strong = new Linear[boostiterations];
		Arrays.sort(data.training);
		for (i = 0; i < boostiterations; i++) {
			// -------------------------------- apply weak classifier
			strong[i] = new Linear(data, true);
			// ------------------------------------------ check error
			if (strong[i].error == 0)
				break;

			// --------------------------------------- update weights
			sumofweights = 0;
			a = Math.log((1 - strong[i].error) / strong[i].error) / 2;
			for (j = 0; j < data.training.length; j++) {
				data.training[j].weight = data.training[j].weight
						* Math.exp(-a
								* strong[i].classify(data.training[j].data)
								* data.training[j].label);
				sumofweights += data.training[j].weight;
			}

			// ------------------------------------ normalize weights
			for (j = 0; j < data.training.length; j++)
				data.training[j].weight = data.training[j].weight
						/ sumofweights;
		}

		return strong;
	}

	public int classify(double data) {
		double majority = 0;
		for (int i = 0; i < strong.length; i++) {
			double a = (Math.log((1 - strong[i].error) / strong[i].error)) / 2;
			if ((data > strong[i].threshold && strong[i].sign == +1)
					|| (data < strong[i].threshold && strong[i].sign == -1))
				majority += a;
			else
				majority -= a;
		}
		if (majority > 0)
			return 1;
		else
			return -1;
	}

	public double test(Datastructure[] testdata) {
		int good = 0;
		for (int i = 0; i < testdata.length; i++)
			if (classify(testdata[i].data) == testdata[i].label)
				good++;
		return (good * 100.0) / testdata.length;
	}

	public double[] crossvalidate(int folds) {
		double average = 0;
		double[] results = new double[folds];
		for (int i = 0; i < folds; i++) {
			data.split(i, folds);
			results[i] = test(data.test);
		}
		average = average / folds;
		return results;
	}

	public double findparameter() {
		int k = 0, same = 0;
		double performance, bestperformance = 0;
		data.split(0, 0);
		for (int i = 0;; i++) {
			Adaboost boost = new Adaboost(data, i);
			performance = boost.test(data.test);
			if (performance == bestperformance)
				same++;
			else
				same = 0;
			if (performance > bestperformance) {
				bestperformance = performance;
				k = i;
			}
			if (same == 5)
				break;
		}
		System.out.println("Best parameter found: " + k);
		return k;

	}

	public Data getdata() {
		return this.data;
	}

	// ======================= USEFUL FUNCTIONS ===============================

	public void printthresholds() {
		System.out.println("====================");
		for (int i = 0; i < strong.length; i++)
			System.out.print(strong[i].threshold + "|");
		System.out.println("\n====================");
	}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -