📄 coforest.java
字号:
package coforest;
/**
* Description: CoForest is a semi-supervised algorithm, which exploits the power of ensemble learning and available
* large amount of unlabeled data to produce hypothesis with better performance.
*
* Reference: M. Li, Z.-H. Zhou. Improve computer-aided diagnosis with machine learning techniques using undiagnosed
* samples. IEEE Transactions on Systems, Man and Cybernetics - Part A: Systems and Humans, 2007, 37(6).
*
* ATTN: This package is free for academic usage. You can run it at your own risk.
* For other purposes, please contact Prof. Zhi-Hua Zhou (zhouzh@nju.edu.cn).
*
* Requirement: To use this package, the whole WEKA environment (ver 3.4) must be available.
* refer: I.H. Witten and E. Frank. Data Mining: Practical Machine Learning
* Tools and Techniques with Java Implementations. Morgan Kaufmann,
* San Francisco, CA, 2000.
*
* Data format: Both the input and output formats are the same as those used by WEKA.
*
* ATTN2: This package was developed by Mr. Ming Li (lim@lamda.nju.edu.cn). There
* is a ReadMe file provided for roughly explaining the codes. But for any
* problem concerning the code, please feel free to contact with Mr. Li.
*
*/
import java.io.*;
import java.text.*;
import java.util.*;
import weka.core.*;
import weka.classifiers.*;
import weka.classifiers.trees.*;
public class CoForest
{
/** Random Forest */
protected Classifier[] m_classifiers = null;
/** The number component */
protected int m_numClassifiers = 10;
/** The random seed */
protected int m_seed = 1;
/** Number of features to consider in random feature selection.
If less than 1 will use int(logM+1) ) */
protected int m_numFeatures = 0;
/** Final number of features that were considered in last build. */
protected int m_KValue = 0;
/** confidence threshold */
protected double m_threshold = 0.75;
private int m_numOriginalLabeledInsts = 0;
/**
* The constructor
*/
public CoForest()
{
}
/**
* Set the seed for initiating the random object used inside this class
*
* @param s int -- The seed
*/
public void setSeed(int s)
{
m_seed = s;
}
/**
* Set the number of trees used in Random Forest.
*
* @param s int -- Value to assign to numTrees.
*/
public void setNumClassifiers(int n)
{
m_numClassifiers = n;
}
/**
* Get the number of trees used in Random Forest
*
* @return int -- The number of trees.
*/
public int getNumClassifiers()
{
return m_numClassifiers;
}
/**
* Set the number of features to use in random selection.
*
* @param n int -- Value to assign to m_numFeatures.
*/
public void setNumFeatures(int n)
{
m_numFeatures = n;
}
/**
* Get the number of featrues to use in random selection.
*
* @return int -- The number of features
*/
public int getNumFeatures()
{
return m_numFeatures;
}
/**
* Resample instances w.r.t the weight
*
* @param data Instances -- the original data set
* @param random Random -- the random object
* @param sampled boolean[] -- the output parameter, indicating whether the instance is sampled
* @return Instances
*/
public final Instances resampleWithWeights(Instances data,
Random random,
boolean[] sampled)
{
double[] weights = new double[data.numInstances()];
for (int i = 0; i < weights.length; i++) {
weights[i] = data.instance(i).weight();
}
Instances newData = new Instances(data, data.numInstances());
if (data.numInstances() == 0) {
return newData;
}
double[] probabilities = new double[data.numInstances()];
double sumProbs = 0, sumOfWeights = Utils.sum(weights);
for (int i = 0; i < data.numInstances(); i++) {
sumProbs += random.nextDouble();
probabilities[i] = sumProbs;
}
Utils.normalize(probabilities, sumProbs / sumOfWeights);
// Make sure that rounding errors don't mess things up
probabilities[data.numInstances() - 1] = sumOfWeights;
int k = 0; int l = 0;
sumProbs = 0;
while ((k < data.numInstances() && (l < data.numInstances()))) {
if (weights[l] < 0) {
throw new IllegalArgumentException("Weights have to be positive.");
}
sumProbs += weights[l];
while ((k < data.numInstances()) &&
(probabilities[k] <= sumProbs)) {
newData.add(data.instance(l));
sampled[l] = true;
newData.instance(k).setWeight(1);
k++;
}
l++;
}
return newData;
}
/**
* Returns the probability label of a given instance
*
* @param inst Instance -- The instance
* @return double[] -- The probability label
* @throws Exception -- Some exception
*/
public double[] distributionForInstance(Instance inst) throws Exception
{
double[] res = new double[inst.numClasses()];
for(int i = 0; i < m_classifiers.length; i++)
{
double[] distr = m_classifiers[i].distributionForInstance(inst);
for(int j = 0; j < res.length; j++)
res[j] += distr[j];
}
Utils.normalize(res);
return res;
}
/**
* Classifies a given instance
*
* @param inst Instance -- The instance
* @return double -- The class value
* @throws Exception -- Some Exception
*/
public double classifyInstance(Instance inst) throws Exception
{
double[] distr = distributionForInstance(inst);
return Utils.maxIndex(distr);
}
/**
* Build the classifiers using Co-Forest algorithm
*
* @param labeled Instances -- Labeled training set
* @param unlabeled Instances -- unlabeled training set
* @throws Exception -- certain exception
*/
public void buildClassifier(Instances labeled, Instances unlabeled) throws Exception
{
double[] err = new double[m_numClassifiers];
double[] err_prime = new double[m_numClassifiers];
double[] s_prime = new double[m_numClassifiers];
boolean[][] inbags = new boolean[m_numClassifiers][];
Random rand = new Random(m_seed);
m_numOriginalLabeledInsts = labeled.numInstances();
RandomTree rTree = new RandomTree();
// set up the random tree options
m_KValue = m_numFeatures;
if (m_KValue < 1) m_KValue = (int) Utils.log2(labeled.numAttributes())+1;
rTree.setKValue(m_KValue);
m_classifiers = Classifier.makeCopies(rTree, m_numClassifiers);
Instances[] labeleds = new Instances[m_numClassifiers];
int[] randSeeds = new int[m_numClassifiers];
for(int i = 0; i < m_numClassifiers; i++)
{
randSeeds[i] = rand.nextInt();
((RandomTree)m_classifiers[i]).setSeed(randSeeds[i]);
inbags[i] = new boolean[labeled.numInstances()];
labeleds[i] = resampleWithWeights(labeled, rand, inbags[i]);
m_classifiers[i].buildClassifier(labeleds[i]);
err_prime[i] = 0.5;
s_prime[i] = 0;
}
boolean bChanged = true;
while(bChanged)
{
bChanged = false;
boolean[] bUpdate = new boolean[m_classifiers.length];
Instances[] Li = new Instances[m_numClassifiers];
for(int i = 0; i < m_numClassifiers; i++)
{
err[i] = measureError(labeled, inbags, i);
Li[i] = new Instances(labeled, 0);
/** if (e_i < e'_i) */
if(err[i] < err_prime[i])
{
if(s_prime[i] == 0)
s_prime[i] = Math.min(unlabeled.sumOfWeights() / 10, 100);
/** Subsample U for each hi */
double weight = 0;
unlabeled.randomize(rand);
int numWeightsAfterSubsample = (int) Math.ceil(err_prime[i] * s_prime[i] / err[i] - 1);
for(int k = 0; k < unlabeled.numInstances(); k++)
{
weight += unlabeled.instance(k).weight();
if (weight > numWeightsAfterSubsample)
break;
Li[i].add((Instance)unlabeled.instance(k).copy());
}
/** for every x in U' do */
for(int j = Li[i].numInstances() - 1; j > 0; j--)
{
Instance curInst = Li[i].instance(j);
if(!isHighConfidence(curInst, i)) //in which the label is assigned
Li[i].delete(j);
}//end of j
if(s_prime[i] < Li[i].numInstances())
{
if(err[i] * Li[i].sumOfWeights() < err_prime[i] * s_prime[i])
bUpdate[i] = true;
}
}
}//end of for i
//update
Classifier[] newClassifier = Classifier.makeCopies(rTree, m_numClassifiers);
for(int i = 0; i < m_numClassifiers; i++)
{
if(bUpdate[i])
{
double size = Li[i].sumOfWeights();
bChanged = true;
m_classifiers[i] = newClassifier[i];
((RandomTree)m_classifiers[i]).setSeed(randSeeds[i]);
m_classifiers[i].buildClassifier(combine(labeled, Li[i]));
err_prime[i] = err[i];
s_prime[i] = size;
}
}
}//end of while
}
/**
* To judege whether the confidence for a given instance of H* is high enough,
* which is affected by the onfidence threshold. Meanwhile, if the example is
* the confident one, assign label to it and weigh the example with the confidence
*
* @param inst Instance -- The instance
* @param idExcluded int -- the index of the individual should be excluded from H*
* @return boolean -- true for high
* @throws Exception - some exception
*/
protected boolean isHighConfidence(Instance inst, int idExcluded) throws Exception
{
double[] distr = distributionForInstanceExcluded(inst, idExcluded);
double confidence = getConfidence(distr);
if(confidence > m_threshold)
{
double classval = Utils.maxIndex(distr);
inst.setClassValue(classval); //assign label
inst.setWeight(confidence); //set instance weight
return true;
}
else return false;
}
private Instances combine(Instances L, Instances Li)
{
for(int i = 0; i < L.numInstances(); i++)
Li.add(L.instance(i));
return Li;
}
private double measureError(Instances data, boolean[][] inbags, int id) throws Exception
{
double err = 0;
double count = 0;
for(int i = 0; i < data.numInstances() && i < m_numOriginalLabeledInsts; i++)
{
Instance inst = data.instance(i);
double[] distr = outOfBagDistributionForInstanceExcluded(inst, i, inbags, id);
if(getConfidence(distr) > m_threshold)
{
count += inst.weight();
if(Utils.maxIndex(distr) != inst.classValue())
err += inst.weight();
}
}
err /= count;
return err;
}
private double getConfidence(double[] p)
{
int maxIndex = Utils.maxIndex(p);
return p[maxIndex];
}
private double[] distributionForInstanceExcluded(Instance inst, int idExcluded) throws Exception
{
double[] distr = new double[inst.numClasses()];
for(int i = 0; i < m_numClassifiers; i++)
{
if(i == idExcluded)
continue;
double[] d = m_classifiers[i].distributionForInstance(inst);
for(int iClass = 0; iClass < inst.numClasses(); iClass++)
distr[iClass] += d[iClass];
}
Utils.normalize(distr);
return distr;
}
private double[] outOfBagDistributionForInstanceExcluded(Instance inst, int idxInst, boolean[][] inbags, int idExcluded) throws Exception
{
double[] distr = new double[inst.numClasses()];
for(int i = 0; i < m_numClassifiers; i++)
{
if(inbags[i][idxInst] == true || i == idExcluded)
continue;
double[] d = m_classifiers[i].distributionForInstance(inst);
for(int iClass = 0; iClass < inst.numClasses(); iClass++)
distr[iClass] += d[iClass];
}
if(Utils.sum(distr) != 0)
Utils.normalize(distr);
return distr;
}
/**
* The main method only for demonstrating the simple use of this package
*
* @param args String[]
*/
public static void main(String[] args)
{
try
{
int seed = 0;
int numFeatures = 0;
Random rand = new Random(seed);
final int NUM_CLASSIFIERS = 6;
BufferedReader r = new BufferedReader(new FileReader("labeled.arff"));
Instances labeled = new Instances(r);
labeled.setClassIndex(labeled.numAttributes()-1);
r.close();
r = new BufferedReader(new FileReader("unlabeled.arff"));
Instances unlabeled = new Instances(r);
unlabeled.setClassIndex(labeled.numAttributes()-1);
r.close();
r = new BufferedReader(new FileReader("test.arff"));
Instances test = new Instances(r);
test.setClassIndex(labeled.numAttributes()-1);
r.close();
CoForest forest = new CoForest();
forest.setNumClassifiers(NUM_CLASSIFIERS);
forest.setNumFeatures(numFeatures);
forest.setSeed(rand.nextInt());
forest.buildClassifier(labeled, unlabeled);
double err = 0;
for(int i = 0; i < test.numInstances(); i++)
{
if(forest.classifyInstance(test.instance(i)) != test.instance(i).classValue())
err++;
}
System.out.println("Error Rate = " + (err/test.numInstances()));
}
catch(Exception e)
{
e.printStackTrace();
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -