📄 logitboost.java
字号:
/** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] classifierOptions = new String [0]; if ((m_Classifier != null) && (m_Classifier instanceof OptionHandler)) { classifierOptions = ((OptionHandler)m_Classifier).getOptions(); } String [] options = new String [classifierOptions.length + 17]; int current = 0; if (getDebug()) { options[current++] = "-D"; } if (getUseResampling()) { options[current++] = "-Q"; } else { options[current++] = "-P"; options[current++] = "" + getWeightThreshold(); } if (getSeed() != 1) { options[current++] = "-S"; options[current++] = "" + getSeed(); } options[current++] = "-I"; options[current++] = "" + getMaxIterations(); options[current++] = "-F"; options[current++] = "" + getNumFolds(); options[current++] = "-R"; options[current++] = "" + getNumRuns(); options[current++] = "-L"; options[current++] = "" + getLikelihoodThreshold(); options[current++] = "-H"; options[current++] = "" + getShrinkage(); if (getClassifier() != null) { options[current++] = "-W"; options[current++] = getClassifier().getClass().getName(); } options[current++] = "--"; System.arraycopy(classifierOptions, 0, options, current, classifierOptions.length); current += classifierOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Get the value of Shrinkage. * * @return Value of Shrinkage. */ public double getShrinkage() { return m_Shrinkage; } /** * Set the value of Shrinkage. * * @param newShrinkage Value to assign to Shrinkage. */ public void setShrinkage(double newShrinkage) { m_Shrinkage = newShrinkage; } /** * Get the value of Precision. * * @return Value of Precision. */ public double getLikelihoodThreshold() { return m_Precision; } /** * Set the value of Precision. * * @param newPrecision Value to assign to Precision. */ public void setLikelihoodThreshold(double newPrecision) { m_Precision = newPrecision; } /** * Get the value of NumRuns. * * @return Value of NumRuns. */ public int getNumRuns() { return m_NumRuns; } /** * Set the value of NumRuns. * * @param newNumRuns Value to assign to NumRuns. */ public void setNumRuns(int newNumRuns) { m_NumRuns = newNumRuns; } /** * Get the value of NumFolds. * * @return Value of NumFolds. */ public int getNumFolds() { return m_NumFolds; } /** * Set the value of NumFolds. * * @param newNumFolds Value to assign to NumFolds. */ public void setNumFolds(int newNumFolds) { m_NumFolds = newNumFolds; } /** * Set resampling mode * * @param resampling true if resampling should be done */ public void setUseResampling(boolean r) { m_UseResampling = r; } /** * Get whether resampling is turned on * * @return true if resampling output is on */ public boolean getUseResampling() { return m_UseResampling; } /** * Set seed for resampling. * * @param seed the seed for resampling */ public void setSeed(int seed) { m_Seed = seed; } /** * Get seed for resampling. * * @return the seed for resampling */ public int getSeed() { return m_Seed; } /** * Set the classifier for boosting. The learner should be able to * handle numeric class attributes. * * @param newClassifier the Classifier to use. */ public void setClassifier(Classifier newClassifier) { m_Classifier = newClassifier; } /** * Get the classifier used as the classifier * * @return the classifier used as the classifier */ public Classifier getClassifier() { return m_Classifier; } /** * Set the maximum number of boost iterations * * @param maxIterations the maximum number of boost iterations */ public void setMaxIterations(int maxIterations) { m_MaxIterations = maxIterations; } /** * Get the maximum number of boost iterations * * @return the maximum number of boost iterations */ public int getMaxIterations() { return m_MaxIterations; } /** * Set weight thresholding * * @param thresholding the percentage of weight mass used for training */ public void setWeightThreshold(int threshold) { m_WeightThreshold = threshold; } /** * Get the degree of weight thresholding * * @return the percentage of weight mass used for training */ public int getWeightThreshold() { return m_WeightThreshold; } /** * Set debugging mode * * @param debug true if debug output should be printed */ public void setDebug(boolean debug) { m_Debug = debug; } /** * Get whether debugging is turned on * * @return true if debugging output is on */ public boolean getDebug() { return m_Debug; } /** * Builds the boosted classifier */ public void buildClassifier(Instances data) throws Exception { m_RandomInstance = new Random(m_Seed); Instances boostData, trainData; int classIndex = data.classIndex(); if (data.classAttribute().isNumeric()) { throw new UnsupportedClassTypeException("LogitBoost can't handle a numeric class!"); } if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } if (!(m_Classifier instanceof WeightedInstancesHandler) && !m_UseResampling) { m_UseResampling = true; } if (data.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } if (m_Debug) { System.err.println("Creating copy of the training data"); } m_NumClasses = data.numClasses(); m_ClassAttribute = data.classAttribute(); // Create a copy of the data data = new Instances(data); data.deleteWithMissingClass(); // Create the base classifiers if (m_Debug) { System.err.println("Creating base classifiers"); } m_Classifiers = new Classifier [m_NumClasses][]; for (int j = 0; j < m_NumClasses; j++) { m_Classifiers[j] = Classifier.makeCopies(m_Classifier, getMaxIterations()); } // Do we want to select the appropriate number of iterations // using cross-validation? int bestNumIterations = getMaxIterations(); if (m_NumFolds > 1) { if (m_Debug) { System.err.println("Processing first fold."); } // Array for storing the results double[] results = new double[getMaxIterations()]; // Iterate throught the cv-runs for (int r = 0; r < m_NumRuns; r++) { // Stratify the data data.randomize(m_RandomInstance); data.stratify(m_NumFolds); // Perform the cross-validation for (int i = 0; i < m_NumFolds; i++) { // Get train and test folds Instances train = data.trainCV(m_NumFolds, i); Instances test = data.testCV(m_NumFolds, i); // Make class numeric Instances trainN = new Instances(train); trainN.setClassIndex(-1); trainN.deleteAttributeAt(classIndex); trainN.insertAttributeAt(new Attribute("'pseudo class'"), classIndex); trainN.setClassIndex(classIndex); m_NumericClassData = new Instances(trainN, 0); // Get class values int numInstances = train.numInstances(); double [][] trainFs = new double [numInstances][m_NumClasses]; double [][] trainYs = new double [numInstances][m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { for (int k = 0; k < numInstances; k++) { trainYs[k][j] = (train.instance(k).classValue() == j) ? 1.0 - m_Offset: 0.0 + (m_Offset / (double)m_NumClasses); } } // Perform iterations double[][] probs = initialProbs(numInstances); m_NumIterations = 0; double sumOfWeights = train.sumOfWeights(); for (int j = 0; j < getMaxIterations(); j++) { performIteration(trainYs, trainFs, probs, trainN, sumOfWeights); Evaluation eval = new Evaluation(train); eval.evaluateModel(this, test); results[j] += eval.correct(); } } } // Find the number of iterations with the lowest error double bestResult = -Double.MAX_VALUE; for (int j = 0; j < getMaxIterations(); j++) { if (results[j] > bestResult) { bestResult = results[j]; bestNumIterations = j; } } if (m_Debug) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -