⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ensembleclassifier.java

📁 ensemble classifier example
💻 JAVA
字号:

import mslab.kddcup2008.roc.*;

import weka.core.Instances;
import weka.core.Instance;
import weka.core.Attribute;

import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Remove;

import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.meta.FilteredClassifier;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.FileOutputStream;
import java.io.PrintWriter;

import java.util.*;


public class EnsembleClassifier {

	/**
	 * @param args
	 * @throws Exception 
	 */
	public static void main(String[] args) throws Exception {
	
		PrintWriter pw_score=new PrintWriter( new FileOutputStream ("score.txt"));
		PrintWriter pw_label=new PrintWriter(new FileOutputStream ("label.txt"));
		PrintWriter pw_pid=new PrintWriter(new FileOutputStream ("pid.txt"));
		PrintWriter pw_serial=new PrintWriter(new FileOutputStream ("serial.txt"));

		// Prepare scores to be ensembled
		double[] label;
		double[] pid;
		double[] serial;
		PointArray pa;
		
		double[] score_1 = ROC.loadDataFromFile("score_1.txt");
		label = ROC.loadDataFromFile("label_1.txt");
		pid = ROC.loadDataFromFile("pid_1.txt");
		serial = ROC.loadDataFromFile("serial_1.txt");
		pa = new PointArray(score_1, label, pid, serial, 0.2, 0.3);
		pa.sortBySerialNumber();
		score_1 = pa.getSN();

		double[] score_2 = ROC.loadDataFromFile("score_2.txt");
		label = ROC.loadDataFromFile("label_2.txt");
		pid = ROC.loadDataFromFile("pid_2.txt");
		serial = ROC.loadDataFromFile("serial_2.txt");
		pa = new PointArray(score_2, label, pid, serial, 0.2, 0.3);
		pa.sortBySerialNumber();
		score_2 = pa.getSN();

		double[] score_3 = ROC.loadDataFromFile("score_3.txt");
		label = ROC.loadDataFromFile("label_3.txt");
		pid = ROC.loadDataFromFile("pid_3.txt");
		serial = ROC.loadDataFromFile("serial_3.txt");
		pa = new PointArray(score_3, label, pid, serial, 0.2, 0.3);
		pa.sortBySerialNumber();
		score_3 = pa.getSN();
		
		// contains the full dataset we wann create train/test sets from
		Instances data = new Instances(new BufferedReader(new FileReader("TrainSet.arff")));

		// Add the scores as data for ensemble
		data.insertAttributeAt(new Attribute("numeric_label"), data.numAttributes());
		data.insertAttributeAt(new Attribute("score_1"), data.numAttributes());
		data.insertAttributeAt(new Attribute("score_2"), data.numAttributes());
		data.insertAttributeAt(new Attribute("score_3"), data.numAttributes());

		for(int t=0; t<data.numInstances(); t++) {
			data.instance(t).setValue(129, label[t]);
			data.instance(t).setValue(130, score_1[t]);
			data.instance(t).setValue(131, score_2[t]);
			data.instance(t).setValue(132, score_3[t]);
		}
		System.out.println("Number of Attributes = " + data.numAttributes());
		
		Remove remove = new Remove();                         
		remove.setOptions(weka.core.Utils.splitOptions("-R 1-129")); // if we want to use original data, replace with "-R 1-11,129"
		remove.setInputFormat(data);                          
		data.setClassIndex(129); // numeric label
				
		Vector v = new Vector();
		Date d = new Date();
		int patientNumber = 1712;	// Should be 1602 when testing data
		for (double j=0; j<patientNumber; j++) {	// Should be 10000 to 11601 when testing data
			v.add(new Double(j));
		}
		Collections.shuffle(v, new Random(d.getTime()));

		Instances randomData;
		d = new Date();
		randomData = new Instances(data);
		randomData.randomize(new Random(d.getTime()));
		
		int folds = 10;	
		int patientPerFold = (int)Math.floor((double)patientNumber/folds);
		int restPatients = patientNumber - folds * patientPerFold;
		int testLower = 0;
		int testUpper = patientPerFold;
		int totalTest = 0;
		
		for(int n = 0; n < folds; n++){
			
			if(restPatients > 0) {
				testUpper++;
				restPatients--;
			}
			
			Instances train = new Instances(new BufferedReader(new FileReader("TrainSet_Ensemble_Empty.arff")));
			Instances test = new Instances(new BufferedReader(new FileReader("TrainSet_Ensemble_Empty.arff")));
			
			for(int k = 0; k < randomData.numInstances(); k++) {
				Instance ins = randomData.instance(k);
				double sid = ins.value(4);
				if(sid >= testLower && sid < testUpper) {	//test set
					test.add(ins);
				}
				else {										//train set
					train.add(ins);
				}
			}

			System.out.println("Fold = " + n + ", Train = " + train.numInstances() + ", Test = "+test.numInstances() + " ( " + testLower + " - " + testUpper + " )");
			
			totalTest = totalTest + testUpper - testLower;
			testLower = testUpper;
			testUpper += patientPerFold;
			
			test.setClassIndex(129); // numeric label
			train.setClassIndex(129); // numeric label
			
			// Linear Regression
			String[] options = weka.core.Utils.splitOptions("-S 0 -R 1.0E-8");
			LinearRegression classifier = new LinearRegression();
			
			classifier.setOptions(options);
			FilteredClassifier fc = new FilteredClassifier();
			fc.setFilter(remove);
			fc.setClassifier(classifier);
			fc.buildClassifier(train);
			
			for(int i=0;i<test.numInstances();i++) {
				double[] result=(double[])fc.distributionForInstance(test.instance(i));
				pw_label.println(test.instance(i).attribute(0).value((int)test.instance(i).value(0)));
				pw_score.println(result[0]);
				pw_pid.println((int)test.instance(i).value(4));
				pw_serial.println((int)test.instance(i).value(test.attribute("serialNumber")));
			}
		}
		
		System.out.println("Total test = " + totalTest);
				
		pw_label.close();
		pw_score.close();
		pw_pid.close();
		pw_serial.close();
		
	}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -