📄 ensembleclassifier.java
字号:
import mslab.kddcup2008.roc.*;
import weka.core.Instances;
import weka.core.Instance;
import weka.core.Attribute;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.meta.FilteredClassifier;
import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.FileOutputStream;
import java.io.PrintWriter;
import java.util.*;
public class EnsembleClassifier {
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
PrintWriter pw_score=new PrintWriter( new FileOutputStream ("score.txt"));
PrintWriter pw_label=new PrintWriter(new FileOutputStream ("label.txt"));
PrintWriter pw_pid=new PrintWriter(new FileOutputStream ("pid.txt"));
PrintWriter pw_serial=new PrintWriter(new FileOutputStream ("serial.txt"));
// Prepare scores to be ensembled
double[] label;
double[] pid;
double[] serial;
PointArray pa;
double[] score_1 = ROC.loadDataFromFile("score_1.txt");
label = ROC.loadDataFromFile("label_1.txt");
pid = ROC.loadDataFromFile("pid_1.txt");
serial = ROC.loadDataFromFile("serial_1.txt");
pa = new PointArray(score_1, label, pid, serial, 0.2, 0.3);
pa.sortBySerialNumber();
score_1 = pa.getSN();
double[] score_2 = ROC.loadDataFromFile("score_2.txt");
label = ROC.loadDataFromFile("label_2.txt");
pid = ROC.loadDataFromFile("pid_2.txt");
serial = ROC.loadDataFromFile("serial_2.txt");
pa = new PointArray(score_2, label, pid, serial, 0.2, 0.3);
pa.sortBySerialNumber();
score_2 = pa.getSN();
double[] score_3 = ROC.loadDataFromFile("score_3.txt");
label = ROC.loadDataFromFile("label_3.txt");
pid = ROC.loadDataFromFile("pid_3.txt");
serial = ROC.loadDataFromFile("serial_3.txt");
pa = new PointArray(score_3, label, pid, serial, 0.2, 0.3);
pa.sortBySerialNumber();
score_3 = pa.getSN();
// contains the full dataset we wann create train/test sets from
Instances data = new Instances(new BufferedReader(new FileReader("TrainSet.arff")));
// Add the scores as data for ensemble
data.insertAttributeAt(new Attribute("numeric_label"), data.numAttributes());
data.insertAttributeAt(new Attribute("score_1"), data.numAttributes());
data.insertAttributeAt(new Attribute("score_2"), data.numAttributes());
data.insertAttributeAt(new Attribute("score_3"), data.numAttributes());
for(int t=0; t<data.numInstances(); t++) {
data.instance(t).setValue(129, label[t]);
data.instance(t).setValue(130, score_1[t]);
data.instance(t).setValue(131, score_2[t]);
data.instance(t).setValue(132, score_3[t]);
}
System.out.println("Number of Attributes = " + data.numAttributes());
Remove remove = new Remove();
remove.setOptions(weka.core.Utils.splitOptions("-R 1-129")); // if we want to use original data, replace with "-R 1-11,129"
remove.setInputFormat(data);
data.setClassIndex(129); // numeric label
Vector v = new Vector();
Date d = new Date();
int patientNumber = 1712; // Should be 1602 when testing data
for (double j=0; j<patientNumber; j++) { // Should be 10000 to 11601 when testing data
v.add(new Double(j));
}
Collections.shuffle(v, new Random(d.getTime()));
Instances randomData;
d = new Date();
randomData = new Instances(data);
randomData.randomize(new Random(d.getTime()));
int folds = 10;
int patientPerFold = (int)Math.floor((double)patientNumber/folds);
int restPatients = patientNumber - folds * patientPerFold;
int testLower = 0;
int testUpper = patientPerFold;
int totalTest = 0;
for(int n = 0; n < folds; n++){
if(restPatients > 0) {
testUpper++;
restPatients--;
}
Instances train = new Instances(new BufferedReader(new FileReader("TrainSet_Ensemble_Empty.arff")));
Instances test = new Instances(new BufferedReader(new FileReader("TrainSet_Ensemble_Empty.arff")));
for(int k = 0; k < randomData.numInstances(); k++) {
Instance ins = randomData.instance(k);
double sid = ins.value(4);
if(sid >= testLower && sid < testUpper) { //test set
test.add(ins);
}
else { //train set
train.add(ins);
}
}
System.out.println("Fold = " + n + ", Train = " + train.numInstances() + ", Test = "+test.numInstances() + " ( " + testLower + " - " + testUpper + " )");
totalTest = totalTest + testUpper - testLower;
testLower = testUpper;
testUpper += patientPerFold;
test.setClassIndex(129); // numeric label
train.setClassIndex(129); // numeric label
// Linear Regression
String[] options = weka.core.Utils.splitOptions("-S 0 -R 1.0E-8");
LinearRegression classifier = new LinearRegression();
classifier.setOptions(options);
FilteredClassifier fc = new FilteredClassifier();
fc.setFilter(remove);
fc.setClassifier(classifier);
fc.buildClassifier(train);
for(int i=0;i<test.numInstances();i++) {
double[] result=(double[])fc.distributionForInstance(test.instance(i));
pw_label.println(test.instance(i).attribute(0).value((int)test.instance(i).value(0)));
pw_score.println(result[0]);
pw_pid.println((int)test.instance(i).value(4));
pw_serial.println((int)test.instance(i).value(test.attribute("serialNumber")));
}
}
System.out.println("Total test = " + totalTest);
pw_label.close();
pw_score.close();
pw_pid.close();
pw_serial.close();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -