📄 logitboost.java
字号:
* * @param newClassifier the Classifier to use. */ public void setClassifier(Classifier newClassifier) { m_Classifier = newClassifier; } /** * Get the classifier used as the classifier * * @return the classifier used as the classifier */ public Classifier getClassifier() { return m_Classifier; } /** * Set the maximum number of boost iterations * * @param maxIterations the maximum number of boost iterations */ public void setMaxIterations(int maxIterations) { m_MaxIterations = maxIterations; } /** * Get the maximum number of boost iterations * * @return the maximum number of boost iterations */ public int getMaxIterations() { return m_MaxIterations; } /** * Set weight thresholding * * @param thresholding the percentage of weight mass used for training */ public void setWeightThreshold(int threshold) { m_WeightThreshold = threshold; } /** * Get the degree of weight thresholding * * @return the percentage of weight mass used for training */ public int getWeightThreshold() { return m_WeightThreshold; } /** * Set debugging mode * * @param debug true if debug output should be printed */ public void setDebug(boolean debug) { m_Debug = debug; } /** * Get whether debugging is turned on * * @return true if debugging output is on */ public boolean getDebug() { return m_Debug; } /** * Boosting method. Boosts any classifier that can handle weighted * instances. * * @param data the training data to be used for generating the * boosted classifier. * @exception Exception if the classifier could not be built successfully */ public void buildClassifier(Instances data) throws Exception { Random randomInstance = new Random(m_Seed); Instances boostData, trainData; int classIndex = data.classIndex(); if (data.classAttribute().isNumeric()) { throw new Exception("LogitBoost can't handle a numeric class!"); } if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } if (!(m_Classifier instanceof WeightedInstancesHandler) && !m_UseResampling) { m_UseResampling = true; } if (data.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } if (m_Debug) { System.err.println("Creating copy of the training data"); } m_NumClasses = data.numClasses(); m_ClassAttribute = data.classAttribute(); // Create a copy of the data with the class transformed into numeric boostData = new Instances(data); boostData.deleteWithMissingClass(); int numInstances = boostData.numInstances(); // Temporarily unset the class index boostData.setClassIndex(-1); boostData.deleteAttributeAt(classIndex); boostData.insertAttributeAt(new Attribute("'pseudo class'"), classIndex); boostData.setClassIndex(classIndex); m_NumericClassData = new Instances(boostData, 0); double [][] trainFs = new double [numInstances][m_NumClasses]; double [][] trainYs = new double [numInstances][m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { for (int i = 0, k = 0; i < numInstances; i++, k++) { while (data.instance(k).classIsMissing()) k++; trainYs[i][j] = (data.instance(k).classValue() == j) ? 1 : 0; } } if (m_Debug) { System.err.println("Creating base classifiers"); } // Create the base classifiers m_Classifiers = new Classifier [m_NumClasses][]; for (int j = 0; j < m_NumClasses; j++) { m_Classifiers[j] = Classifier.makeCopies(m_Classifier, getMaxIterations()); } // Do boostrap iterations for (m_NumIterations = 0; m_NumIterations < getMaxIterations(); m_NumIterations++) { if (m_Debug) { System.err.println("Training classifier " + (m_NumIterations + 1)); } for (int j = 0; j < m_NumClasses; j++) { if (m_Debug) { System.err.println("\t...for class " + (j + 1) + " (" + m_ClassAttribute.name() + "=" + m_ClassAttribute.value(j) + ")"); } // Set instance pseudoclass and weights for (int i = 0; i < numInstances; i++) { double p = RtoP(trainFs[i], j); Instance current = boostData.instance(i); double z, actual = trainYs[i][j]; if (actual == 1) { z = 1.0 / p; if (z > Z_MAX) { // threshold z = Z_MAX; } } else if (actual == 0) { z = -1.0 / (1.0 - p); if (z < -Z_MAX) { // threshold z = -Z_MAX; } } else { z = (actual - p) / (p * (1 - p)); } double w = Math.max(p * (1 - p), VERY_SMALL); current.setValue(classIndex, z); current.setWeight(numInstances * w); } // Select instances to train the classifier on if (m_WeightThreshold < 100) { trainData = selectWeightQuantile(boostData, (double)m_WeightThreshold/100); } else { trainData = new Instances(boostData,0,numInstances); if (m_UseResampling) { double[] weights = new double[boostData.numInstances()]; for (int kk = 0; kk < weights.length; kk++) { weights[kk] = boostData.instance(kk).weight(); } trainData = boostData.resampleWithWeights(randomInstance, weights); } } // Build the classifier m_Classifiers[j][m_NumIterations].buildClassifier(trainData); } // Evaluate / increment trainFs from the classifier for (int i = 0; i < numInstances; i++) { double [] pred = new double [m_NumClasses]; double predSum = 0; for (int j = 0; j < m_NumClasses; j++) { pred[j] = m_Classifiers[j][m_NumIterations] .classifyInstance(boostData.instance(i)); predSum += pred[j]; } predSum /= m_NumClasses; for (int j = 0; j < m_NumClasses; j++) { trainFs[i][j] += (pred[j] - predSum) * (m_NumClasses-1) / m_NumClasses; } } } } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if instance could not be classified * successfully */ public double [] distributionForInstance(Instance instance) throws Exception { instance = (Instance)instance.copy(); instance.setDataset(m_NumericClassData); double [] Fs = new double [m_NumClasses]; for (int i = 0; i < m_NumIterations; i++) { double [] Fi = new double [m_NumClasses]; double Fsum = 0; for (int j = 0; j < m_NumClasses; j++) { Fi[j] = m_Classifiers[j][i].classifyInstance(instance); Fsum += Fi[j]; } Fsum /= m_NumClasses; for (int j = 0; j < m_NumClasses; j++) { Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses; } } double [] distribution = new double [m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { distribution[j] = RtoP(Fs, j); } Utils.normalize(distribution); return distribution; } /** * Returns the boosted model as Java source code. * * @return the tree as Java source code * @exception Exception if something goes wrong */ public String toSource(String className) throws Exception { if (m_NumIterations == 0) { throw new Exception("No model built yet"); } if (!(m_Classifiers[0][0] instanceof Sourcable)) { throw new Exception("Base learner " + m_Classifier.getClass().getName() + " is not Sourcable"); } StringBuffer text = new StringBuffer("class "); text.append(className).append(" {\n\n"); text.append(" private static double RtoP(double []R, int j) {\n"+ " double Rcenter = 0;\n"+ " for (int i = 0; i < R.length; i++) {\n"+ " Rcenter += R[i];\n"+ " }\n"+ " Rcenter /= R.length;\n"+ " double Rsum = 0;\n"+ " for (int i = 0; i < R.length; i++) {\n"+ " Rsum += Math.exp(R[i] - Rcenter);\n"+ " }\n"+ " return Math.exp(R[j]) / Rsum;\n"+ " }\n\n"); text.append(" public static double classify(Object [] i) {\n" + " double [] d = distribution(i);\n" + " double maxV = d[0];\n" + " int maxI = 0;\n"+ " for (int j = 1; j < " + m_NumClasses + "; j++) {\n"+ " if (d[j] > maxV) { maxV = d[j]; maxI = j; }\n"+ " }\n return (double) maxI;\n }\n\n"); text.append(" public static double [] distribution(Object [] i) {\n"); text.append(" double [] Fs = new double [" + m_NumClasses + "];\n"); text.append(" double [] Fi = new double [" + m_NumClasses + "];\n"); text.append(" double Fsum;\n"); for (int i = 0; i < m_NumIterations; i++) { text.append(" Fsum = 0;\n"); for (int j = 0; j < m_NumClasses; j++) { text.append(" Fi[" + j + "] = " + className + '_' +j + '_' + i + ".classify(i); Fsum += Fi[" + j + "];\n"); } text.append(" Fsum /= " + m_NumClasses + ";\n"); text.append(" for (int j = 0; j < " + m_NumClasses + "; j++) {"); text.append(" Fs[j] += (Fi[j] - Fsum) * " + (m_NumClasses - 1) + " / " + m_NumClasses + "; }\n"); } text.append(" double [] dist = new double [" + m_NumClasses + "];\n" + " for (int j = 0; j < " + m_NumClasses + "; j++) {\n"+ " dist[j] = RtoP(Fs, j);\n"+ " }\n return dist;\n"); text.append(" }\n}\n"); for (int i = 0; i < m_Classifiers.length; i++) { for (int j = 0; j < m_Classifiers[i].length; j++) { text.append(((Sourcable)m_Classifiers[i][j]) .toSource(className + '_' + i + '_' + j)); } } return text.toString(); } /** * Returns description of the boosted classifier. * * @return description of the boosted classifier as a string */ public String toString() { StringBuffer text = new StringBuffer(); if (m_NumIterations == 0) { text.append("LogitBoost: No model built yet."); // text.append(m_Classifiers[0].toString()+"\n"); } else { text.append("LogitBoost: Base classifiers and their weights: \n"); for (int i = 0; i < m_NumIterations; i++) { text.append("\nIteration "+(i+1)); for (int j = 0; j < m_NumClasses; j++) { text.append("\n\tClass " + (j + 1) + " (" + m_ClassAttribute.name() + "=" + m_ClassAttribute.value(j) + ")\n\n" + m_Classifiers[j][i].toString() + "\n"); } } text.append("Number of performed iterations: " + m_NumIterations + "\n"); } return text.toString(); } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { try { System.out.println(Evaluation.evaluateModel(new LogitBoost(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -