📄 adaboost.java
字号:
package learner;
import java.util.Arrays;
public class Adaboost implements Classifier {
public Data data;
Linear[] strong;
// ============================================================ Constructor
Adaboost(Data data, int boostiteration) {
this.data = data;
adaboost(boostiteration);
}
// =============================================================== Adaboost
public Linear[] adaboost(int boostiterations) {
int i, j;
double sumofweights = 0, a;
strong = new Linear[boostiterations];
Arrays.sort(data.training);
for (i = 0; i < boostiterations; i++) {
// -------------------------------- apply weak classifier
strong[i] = new Linear(data, true);
// ------------------------------------------ check error
if (strong[i].error == 0)
break;
// --------------------------------------- update weights
sumofweights = 0;
a = Math.log((1 - strong[i].error) / strong[i].error) / 2;
for (j = 0; j < data.training.length; j++) {
data.training[j].weight = data.training[j].weight
* Math.exp(-a
* strong[i].classify(data.training[j].data)
* data.training[j].label);
sumofweights += data.training[j].weight;
}
// ------------------------------------ normalize weights
for (j = 0; j < data.training.length; j++)
data.training[j].weight = data.training[j].weight
/ sumofweights;
}
return strong;
}
public int classify(double data) {
double majority = 0;
for (int i = 0; i < strong.length; i++) {
double a = (Math.log((1 - strong[i].error) / strong[i].error)) / 2;
if ((data > strong[i].threshold && strong[i].sign == +1)
|| (data < strong[i].threshold && strong[i].sign == -1))
majority += a;
else
majority -= a;
}
if (majority > 0)
return 1;
else
return -1;
}
public double test(Datastructure[] testdata) {
int good = 0;
for (int i = 0; i < testdata.length; i++)
if (classify(testdata[i].data) == testdata[i].label)
good++;
return (good * 100.0) / testdata.length;
}
public double[] crossvalidate(int folds) {
double average = 0;
double[] results = new double[folds];
for (int i = 0; i < folds; i++) {
data.split(i, folds);
results[i] = test(data.test);
}
average = average / folds;
return results;
}
public double findparameter() {
int k = 0, same = 0;
double performance, bestperformance = 0;
data.split(0, 0);
for (int i = 0;; i++) {
Adaboost boost = new Adaboost(data, i);
performance = boost.test(data.test);
if (performance == bestperformance)
same++;
else
same = 0;
if (performance > bestperformance) {
bestperformance = performance;
k = i;
}
if (same == 5)
break;
}
System.out.println("Best parameter found: " + k);
return k;
}
public Data getdata() {
return this.data;
}
// ======================= USEFUL FUNCTIONS ===============================
public void printthresholds() {
System.out.println("====================");
for (int i = 0; i < strong.length; i++)
System.out.print(strong[i].threshold + "|");
System.out.println("\n====================");
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -