📄 logitboost.java
字号:
data.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
data.setClassIndex(classIndex);
m_NumericClassData = new Instances(data, 0);
// Perform iterations
double[][] probs = initialProbs(numInstances);
double logLikelihood = logLikelihood(trainYs, probs);
m_NumGenerated = 0;
if (m_Debug) {
System.err.println("Avg. log-likelihood: " + logLikelihood);
}
double sumOfWeights = data.sumOfWeights();
for (int j = 0; j < bestNumIterations; j++) {
double previousLoglikelihood = logLikelihood;
performIteration(trainYs, trainFs, probs, data, sumOfWeights);
logLikelihood = logLikelihood(trainYs, probs);
if (m_Debug) {
System.err.println("Avg. log-likelihood: " + logLikelihood);
}
if (Math.abs(previousLoglikelihood - logLikelihood) < m_Precision) {
return;
}
}
}
/**
* Gets the intial class probabilities.
*/
private double[][] initialProbs(int numInstances) {
double[][] probs = new double[numInstances][m_NumClasses];
for (int i = 0; i < numInstances; i++) {
for (int j = 0 ; j < m_NumClasses; j++) {
probs[i][j] = 1.0 / m_NumClasses;
}
}
return probs;
}
/**
* Computes loglikelihood given class values
* and estimated probablities.
*/
private double logLikelihood(double[][] trainYs, double[][] probs) {
double logLikelihood = 0;
for (int i = 0; i < trainYs.length; i++) {
for (int j = 0; j < m_NumClasses; j++) {
if (trainYs[i][j] == 1.0 - m_Offset) {
logLikelihood -= Math.log(probs[i][j]);
}
}
}
return logLikelihood / (double)trainYs.length;
}
/**
* Performs one boosting iteration.
*/
private void performIteration(double[][] trainYs,
double[][] trainFs,
double[][] probs,
Instances data,
double origSumOfWeights) throws Exception {
if (m_Debug) {
System.err.println("Training classifier " + (m_NumGenerated + 1));
}
// Build the new models
for (int j = 0; j < m_NumClasses; j++) {
if (m_Debug) {
System.err.println("\t...for class " + (j + 1)
+ " (" + m_ClassAttribute.name()
+ "=" + m_ClassAttribute.value(j) + ")");
}
// Make copy because we want to save the weights
Instances boostData = new Instances(data);
// Set instance pseudoclass and weights
for (int i = 0; i < probs.length; i++) {
// Compute response and weight
double p = probs[i][j];
double z, actual = trainYs[i][j];
if (actual == 1 - m_Offset) {
z = 1.0 / p;
if (z > Z_MAX) { // threshold
z = Z_MAX;
}
} else {
z = -1.0 / (1.0 - p);
if (z < -Z_MAX) { // threshold
z = -Z_MAX;
}
}
double w = (actual - p) / z;
// Set values for instance
Instance current = boostData.instance(i);
current.setValue(boostData.classIndex(), z);
current.setWeight(current.weight() * w);
}
// Scale the weights (helps with some base learners)
double sumOfWeights = boostData.sumOfWeights();
double scalingFactor = (double)origSumOfWeights / sumOfWeights;
for (int i = 0; i < probs.length; i++) {
Instance current = boostData.instance(i);
current.setWeight(current.weight() * scalingFactor);
}
// Select instances to train the classifier on
Instances trainData = boostData;
if (m_WeightThreshold < 100) {
trainData = selectWeightQuantile(boostData,
(double)m_WeightThreshold / 100);
} else {
if (m_UseResampling) {
double[] weights = new double[boostData.numInstances()];
for (int kk = 0; kk < weights.length; kk++) {
weights[kk] = boostData.instance(kk).weight();
}
trainData = boostData.resampleWithWeights(m_RandomInstance,
weights);
}
}
// Build the classifier
m_Classifiers[j][m_NumGenerated].buildClassifier(trainData);
}
// Evaluate / increment trainFs from the classifier
for (int i = 0; i < trainFs.length; i++) {
double [] pred = new double [m_NumClasses];
double predSum = 0;
for (int j = 0; j < m_NumClasses; j++) {
pred[j] = m_Shrinkage * m_Classifiers[j][m_NumGenerated]
.classifyInstance(data.instance(i));
predSum += pred[j];
}
predSum /= m_NumClasses;
for (int j = 0; j < m_NumClasses; j++) {
trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses - 1)
/ m_NumClasses;
}
}
m_NumGenerated++;
// Compute the current probability estimates
for (int i = 0; i < trainYs.length; i++) {
probs[i] = probs(trainFs[i]);
}
}
/**
* Returns the array of classifiers that have been built.
*/
public Classifier[][] classifiers() {
Classifier[][] classifiers =
new Classifier[m_NumClasses][m_NumGenerated];
for (int j = 0; j < m_NumClasses; j++) {
for (int i = 0; i < m_NumGenerated; i++) {
classifiers[j][i] = m_Classifiers[j][i];
}
}
return classifiers;
}
/**
* Computes probabilities from F scores
*/
private double[] probs(double[] Fs) {
double maxF = -Double.MAX_VALUE;
for (int i = 0; i < Fs.length; i++) {
if (Fs[i] > maxF) {
maxF = Fs[i];
}
}
double sum = 0;
double[] probs = new double[Fs.length];
for (int i = 0; i < Fs.length; i++) {
probs[i] = Math.exp(Fs[i] - maxF);
sum += probs[i];
}
Utils.normalize(probs, sum);
return probs;
}
/**
* Calculates the class membership probabilities for the given test instance.
*
* @param instance the instance to be classified
* @return predicted class probability distribution
* @exception Exception if instance could not be classified
* successfully
*/
public double [] distributionForInstance(Instance instance)
throws Exception {
instance = (Instance)instance.copy();
instance.setDataset(m_NumericClassData);
double [] pred = new double [m_NumClasses];
double [] Fs = new double [m_NumClasses];
for (int i = 0; i < m_NumGenerated; i++) {
double predSum = 0;
for (int j = 0; j < m_NumClasses; j++) {
pred[j] = m_Classifiers[j][i].classifyInstance(instance);
predSum += pred[j];
}
predSum /= m_NumClasses;
for (int j = 0; j < m_NumClasses; j++) {
Fs[j] += (pred[j] - predSum) * (m_NumClasses - 1)
/ m_NumClasses;
}
}
return probs(Fs);
}
/**
* Returns the boosted model as Java source code.
*
* @return the tree as Java source code
* @exception Exception if something goes wrong
*/
public String toSource(String className) throws Exception {
if (m_NumGenerated == 0) {
throw new Exception("No model built yet");
}
if (!(m_Classifiers[0][0] instanceof Sourcable)) {
throw new Exception("Base learner " + m_Classifier.getClass().getName()
+ " is not Sourcable");
}
StringBuffer text = new StringBuffer("class ");
text.append(className).append(" {\n\n");
text.append(" private static double RtoP(double []R, int j) {\n"+
" double Rcenter = 0;\n"+
" for (int i = 0; i < R.length; i++) {\n"+
" Rcenter += R[i];\n"+
" }\n"+
" Rcenter /= R.length;\n"+
" double Rsum = 0;\n"+
" for (int i = 0; i < R.length; i++) {\n"+
" Rsum += Math.exp(R[i] - Rcenter);\n"+
" }\n"+
" return Math.exp(R[j]) / Rsum;\n"+
" }\n\n");
text.append(" public static double classify(Object [] i) {\n" +
" double [] d = distribution(i);\n" +
" double maxV = d[0];\n" +
" int maxI = 0;\n"+
" for (int j = 1; j < " + m_NumClasses + "; j++) {\n"+
" if (d[j] > maxV) { maxV = d[j]; maxI = j; }\n"+
" }\n return (double) maxI;\n }\n\n");
text.append(" public static double [] distribution(Object [] i) {\n");
text.append(" double [] Fs = new double [" + m_NumClasses + "];\n");
text.append(" double [] Fi = new double [" + m_NumClasses + "];\n");
text.append(" double Fsum;\n");
for (int i = 0; i < m_NumGenerated; i++) {
text.append(" Fsum = 0;\n");
for (int j = 0; j < m_NumClasses; j++) {
text.append(" Fi[" + j + "] = " + className + '_' +j + '_' + i
+ ".classify(i); Fsum += Fi[" + j + "];\n");
}
text.append(" Fsum /= " + m_NumClasses + ";\n");
text.append(" for (int j = 0; j < " + m_NumClasses + "; j++) {");
text.append(" Fs[j] += (Fi[j] - Fsum) * "
+ (m_NumClasses - 1) + " / " + m_NumClasses + "; }\n");
}
text.append(" double [] dist = new double [" + m_NumClasses + "];\n" +
" for (int j = 0; j < " + m_NumClasses + "; j++) {\n"+
" dist[j] = RtoP(Fs, j);\n"+
" }\n return dist;\n");
text.append(" }\n}\n");
for (int i = 0; i < m_Classifiers.length; i++) {
for (int j = 0; j < m_Classifiers[i].length; j++) {
text.append(((Sourcable)m_Classifiers[i][j])
.toSource(className + '_' + i + '_' + j));
}
}
return text.toString();
}
/**
* Returns description of the boosted classifier.
*
* @return description of the boosted classifier as a string
*/
public String toString() {
StringBuffer text = new StringBuffer();
if (m_NumGenerated == 0) {
text.append("LogitBoost: No model built yet.");
// text.append(m_Classifiers[0].toString()+"\n");
} else {
text.append("LogitBoost: Base classifiers and their weights: \n");
for (int i = 0; i < m_NumGenerated; i++) {
text.append("\nIteration "+(i+1));
for (int j = 0; j < m_NumClasses; j++) {
text.append("\n\tClass " + (j + 1)
+ " (" + m_ClassAttribute.name()
+ "=" + m_ClassAttribute.value(j) + ")\n\n"
+ m_Classifiers[j][i].toString() + "\n");
}
}
text.append("Number of performed iterations: " +
m_NumGenerated + "\n");
}
return text.toString();
}
/**
* Main method for testing this class.
*
* @param argv the options
*/
public static void main(String [] argv) {
try {
System.out.println(Evaluation.evaluateModel(new LogitBoost(), argv));
} catch (Exception e) {
e.printStackTrace();
System.err.println(e.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -