📄 adacost.java
字号:
Capabilities capabilities = super.getCapabilities();
capabilities.disableAllClasses();
capabilities.disableAllClassDependencies();
if(super.getCapabilities().handles(weka.core.Capabilities.Capability.NOMINAL_CLASS))
capabilities.enable(weka.core.Capabilities.Capability.NOMINAL_CLASS);
if(super.getCapabilities().handles(weka.core.Capabilities.Capability.BINARY_CLASS))
capabilities.enable(weka.core.Capabilities.Capability.BINARY_CLASS);
return capabilities;
}
public void buildClassifier(Instances instances)
throws Exception
{
getCapabilities().testWithFail(instances);
instances = new Instances(instances);
instances.deleteWithMissingClass();
if(m_Classifier == null)
throw new Exception("No base classifier has been set!");
if(m_MatrixSource == 1)
{
String s = (new StringBuilder()).append(instances.relationName()).append(CostMatrix.FILE_EXTENSION).toString();
File file = new File(getOnDemandDirectory(), s);
if(!file.exists())
throw new Exception((new StringBuilder()).append("On-demand cost file doesn't exist: ").append(file).toString());
setCostMatrix(new CostMatrix(new BufferedReader(new FileReader(file))));
} else
if(m_CostMatrix == null)
{
m_CostMatrix = new CostMatrix(instances.numClasses());
m_CostMatrix.readOldFormat(new BufferedReader(new FileReader(m_CostFile)));
}
if(!m_MinimizeExpectedCost)
{
Random random = null;
if(!(m_Classifier instanceof WeightedInstancesHandler))
random = new Random(m_Seed);
instances = m_CostMatrix.applyCostMatrix(instances, random);
}
startbuildClassifier(instances);
}
public void startbuildClassifier(Instances instances)
throws Exception
{
super.buildClassifier(instances);
getCapabilities().testWithFail(instances);
instances = new Instances(instances);
instances.deleteWithMissingClass();
m_NumClasses = instances.numClasses();
if(!m_UseResampling && (m_Classifier instanceof WeightedInstancesHandler))
buildClassifierWithWeights(instances);
else
buildClassifierUsingResampling(instances);
}
protected void buildClassifierUsingResampling(Instances instances)
throws Exception
{
int i = instances.numInstances();
Random random = new Random(m_Seed);
boolean flag = false;
m_Betas = new double[m_Classifiers.length];
m_NumIterationsPerformed = 0;
Instances instances3 = new Instances(instances, 0, i);
double d2 = instances3.sumOfWeights();
for(int k = 0; k < instances3.numInstances(); k++)
instances3.instance(k).setWeight(instances3.instance(k).weight() / d2);
for(m_NumIterationsPerformed = 0; m_NumIterationsPerformed < m_Classifiers.length; m_NumIterationsPerformed++)
{
if(m_Debug)
System.err.println((new StringBuilder()).append("Training classifier ").append(m_NumIterationsPerformed + 1).toString());
Instances instances1;
if(m_WeightThreshold < 100)
instances1 = selectWeightQuantile(instances3, (double)m_WeightThreshold / 100D);
else
instances1 = new Instances(instances3);
int j = 0;
double ad[] = new double[instances1.numInstances()];
for(int l = 0; l < ad.length; l++)
ad[l] = instances1.instance(l).weight();
double d;
do
{
Instances instances2 = instances1.resampleWithWeights(random, ad);
m_Classifiers[m_NumIterationsPerformed].buildClassifier(instances2);
Evaluation evaluation = new Evaluation(instances);
evaluation.evaluateModel(m_Classifiers[m_NumIterationsPerformed], instances3);
d = evaluation.errorRate();
j++;
} while(Utils.eq(d, 0.0D) && j < MAX_NUM_RESAMPLING_ITERATIONS);
if(Utils.grOrEq(d, 0.5D) || Utils.eq(d, 0.0D))
{
if(m_NumIterationsPerformed == 0)
m_NumIterationsPerformed = 1;
break;
}
m_Betas[m_NumIterationsPerformed] = Math.log((1.0D - d) / d);
double d1 = (1.0D - d) / d;
if(m_Debug)
System.err.println((new StringBuilder()).append("\terror rate = ").append(d).append(" beta = ").append(m_Betas[m_NumIterationsPerformed]).toString());
setWeights(instances3, d1);
}
}
protected void setWeights(Instances instances, double d)
throws Exception
{
double d1 = instances.sumOfWeights();
Enumeration enumeration = instances.enumerateInstances();
do
{
if(!enumeration.hasMoreElements())
break;
Instance instance = (Instance)enumeration.nextElement();
if(!Utils.eq(m_Classifiers[m_NumIterationsPerformed].classifyInstance(instance), instance.classValue()))
instance.setWeight(instance.weight() * d);
} while(true);
double d2 = instances.sumOfWeights();
Instance instance1;
for(Enumeration enumeration1 = instances.enumerateInstances(); enumeration1.hasMoreElements(); instance1.setWeight((instance1.weight() * d1) / d2))
instance1 = (Instance)enumeration1.nextElement();
}
protected void buildClassifierWithWeights(Instances instances)
throws Exception
{
int i = instances.numInstances();
Random random = new Random(m_Seed);
m_Betas = new double[m_Classifiers.length];
m_NumIterationsPerformed = 0;
Instances instances2 = new Instances(instances, 0, i);
for(m_NumIterationsPerformed = 0; m_NumIterationsPerformed < m_Classifiers.length; m_NumIterationsPerformed++)
{
if(m_Debug)
System.err.println((new StringBuilder()).append("Training classifier ").append(m_NumIterationsPerformed + 1).toString());
Instances instances1;
if(m_WeightThreshold < 100)
instances1 = selectWeightQuantile(instances2, (double)m_WeightThreshold / 100D);
else
instances1 = new Instances(instances2, 0, i);
if(m_Classifiers[m_NumIterationsPerformed] instanceof Randomizable)
((Randomizable)m_Classifiers[m_NumIterationsPerformed]).setSeed(random.nextInt());
m_Classifiers[m_NumIterationsPerformed].buildClassifier(instances1);
Evaluation evaluation = new Evaluation(instances);
evaluation.evaluateModel(m_Classifiers[m_NumIterationsPerformed], instances2);
double d = evaluation.errorRate();
if(Utils.grOrEq(d, 0.5D) || Utils.eq(d, 0.0D))
{
if(m_NumIterationsPerformed == 0)
m_NumIterationsPerformed = 1;
break;
}
m_Betas[m_NumIterationsPerformed] = Math.log((1.0D - d) / d);
double d1 = (1.0D - d) / d;
if(m_Debug)
System.err.println((new StringBuilder()).append("\terror rate = ").append(d).append(" beta = ").append(m_Betas[m_NumIterationsPerformed]).toString());
setWeights(instances2, d1);
}
}
public double[] distributionForInstance(Instance instance)
throws Exception
{
if(!m_MinimizeExpectedCost)
return AgaindistributionForInstance(instance);
double ad[] = m_Classifier.distributionForInstance(instance);
double ad1[] = m_CostMatrix.expectedCosts(ad);
int i = Utils.minIndex(ad1);
for(int j = 0; j < ad.length; j++)
if(j == i)
ad[j] = 1.0D;
else
ad[j] = 0.0D;
return ad;
}
public double[] AgaindistributionForInstance(Instance instance)
throws Exception
{
if(m_NumIterationsPerformed == 0)
throw new Exception("No model built");
double ad[] = new double[instance.numClasses()];
if(m_NumIterationsPerformed == 1)
return m_Classifiers[0].distributionForInstance(instance);
for(int i = 0; i < m_NumIterationsPerformed; i++)
ad[(int)m_Classifiers[i].classifyInstance(instance)] += m_Betas[i];
return Utils.logs2probs(ad);
}
public int graphType()
{
if(m_Classifier instanceof Drawable)
return ((Drawable)m_Classifier).graphType();
else
return 0;
}
public String graph()
throws Exception
{
if(m_Classifier instanceof Drawable)
return ((Drawable)m_Classifier).graph();
else
throw new Exception((new StringBuilder()).append("Classifier: ").append(getClassifierSpec()).append(" cannot be graphed").toString());
}
public String toSource(String s)
throws Exception
{
if(m_NumIterationsPerformed == 0)
throw new Exception("No model built yet");
if(!(m_Classifiers[0] instanceof Sourcable))
throw new Exception((new StringBuilder()).append("Base learner ").append(m_Classifier.getClass().getName()).append(" is not Sourcable").toString());
StringBuffer stringbuffer = new StringBuffer("class ");
stringbuffer.append(s).append(" {\n\n");
stringbuffer.append(" public static double classify(Object [] i) {\n");
if(m_NumIterationsPerformed == 1)
{
stringbuffer.append((new StringBuilder()).append(" return ").append(s).append("_0.classify(i);\n").toString());
} else
{
stringbuffer.append((new StringBuilder()).append(" double [] sums = new double [").append(m_NumClasses).append("];\n").toString());
for(int i = 0; i < m_NumIterationsPerformed; i++)
stringbuffer.append((new StringBuilder()).append(" sums[(int) ").append(s).append('_').append(i).append(".classify(i)] += ").append(m_Betas[i]).append(";\n").toString());
stringbuffer.append((new StringBuilder()).append(" double maxV = sums[0];\n int maxI = 0;\n for (int j = 1; j < ").append(m_NumClasses).append("; j++) {\n").append(" if (sums[j] > maxV) { maxV = sums[j]; maxI = j; }\n").append(" }\n return (double) maxI;\n").toString());
}
stringbuffer.append(" }\n}\n");
for(int j = 0; j < m_Classifiers.length; j++)
stringbuffer.append(((Sourcable)m_Classifiers[j]).toSource((new StringBuilder()).append(s).append('_').append(j).toString()));
return stringbuffer.toString();
}
public String toString()
{
StringBuffer stringbuffer = new StringBuffer();
if(m_NumIterationsPerformed == 0)
stringbuffer.append("AdaCost: No model built yet.\n");
else
if(m_NumIterationsPerformed == 1)
{
stringbuffer.append("AdaCost: No boosting possible, one classifier used!\n");
stringbuffer.append((new StringBuilder()).append(m_Classifiers[0].toString()).append("\n").toString());
} else
{
stringbuffer.append("AdaCost: Base classifiers and their weights: \n\n");
for(int i = 0; i < m_NumIterationsPerformed; i++)
{
stringbuffer.append((new StringBuilder()).append(m_Classifiers[i].toString()).append("\n\n").toString());
stringbuffer.append((new StringBuilder()).append("Weight: ").append(Utils.roundDouble(m_Betas[i], 2)).append("\n\n").toString());
}
stringbuffer.append((new StringBuilder()).append("Number of performed Iterations: ").append(m_NumIterationsPerformed).append("\n").toString());
}
stringbuffer.append(new StringBuilder()).append("AdaCost using ");
if(m_MinimizeExpectedCost)
stringbuffer.append(new StringBuilder()).append("minimized expected misclasification cost\n").toString();
else
stringbuffer.append(new StringBuilder()).append("reweighted training instances\n").toString();
stringbuffer.append(new StringBuilder()).append("\n").append(getClassifierSpec()).append("\n\nClassifier Model\n").append(m_Classifier.toString()).append("\n\nCost Matrix\n").append(m_CostMatrix.toString()).toString();
return stringbuffer.toString();
}
private static int MAX_NUM_RESAMPLING_ITERATIONS = 10;
protected double m_Betas[];
protected int m_NumIterationsPerformed;
protected int m_WeightThreshold;
protected boolean m_UseResampling;
protected int m_NumClasses;
public static final int MATRIX_ON_DEMAND = 1;
public static final int MATRIX_SUPPLIED = 2;
public static final Tag TAGS_MATRIX_SOURCE[] = {
new Tag(1, "Load cost matrix on demand"), new Tag(2, "Use explicit cost matrix")
};
protected int m_MatrixSource;
protected File m_OnDemandDirectory;
protected String m_CostFile;
protected CostMatrix m_CostMatrix;
protected boolean m_MinimizeExpectedCost;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -