📄 logitboost.java
字号:
if (numFolds.length() != 0) { setNumFolds(Integer.parseInt(numFolds)); } else { setNumFolds(0); } String numRuns = Utils.getOption('R', options); if (numRuns.length() != 0) { setNumRuns(Integer.parseInt(numRuns)); } else { setNumRuns(1); } String thresholdString = Utils.getOption('P', options); if (thresholdString.length() != 0) { setWeightThreshold(Integer.parseInt(thresholdString)); } else { setWeightThreshold(100); } String precisionString = Utils.getOption('L', options); if (precisionString.length() != 0) { setLikelihoodThreshold(new Double(precisionString). doubleValue()); } else { setLikelihoodThreshold(-Double.MAX_VALUE); } String shrinkageString = Utils.getOption('H', options); if (shrinkageString.length() != 0) { setShrinkage(new Double(shrinkageString). doubleValue()); } else { setShrinkage(1.0); } setUseResampling(Utils.getFlag('Q', options)); if (m_UseResampling && (thresholdString.length() != 0)) { throw new Exception("Weight pruning with resampling"+ "not allowed."); } super.setOptions(options); } /** * Gets the current settings of the Classifier. * * @return an array of strings suitable for passing to setOptions */ public String [] getOptions() { String [] superOptions = super.getOptions(); String [] options = new String [superOptions.length + 10]; int current = 0; if (getUseResampling()) { options[current++] = "-Q"; } else { options[current++] = "-P"; options[current++] = "" + getWeightThreshold(); } options[current++] = "-F"; options[current++] = "" + getNumFolds(); options[current++] = "-R"; options[current++] = "" + getNumRuns(); options[current++] = "-L"; options[current++] = "" + getLikelihoodThreshold(); options[current++] = "-H"; options[current++] = "" + getShrinkage(); System.arraycopy(superOptions, 0, options, current, superOptions.length); current += superOptions.length; while (current < options.length) { options[current++] = ""; } return options; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String shrinkageTipText() { return "Shrinkage parameter (use small value like 0.1 to reduce " + "overfitting)."; } /** * Get the value of Shrinkage. * * @return Value of Shrinkage. */ public double getShrinkage() { return m_Shrinkage; } /** * Set the value of Shrinkage. * * @param newShrinkage Value to assign to Shrinkage. */ public void setShrinkage(double newShrinkage) { m_Shrinkage = newShrinkage; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String likelihoodThresholdTipText() { return "Threshold on improvement in likelihood."; } /** * Get the value of Precision. * * @return Value of Precision. */ public double getLikelihoodThreshold() { return m_Precision; } /** * Set the value of Precision. * * @param newPrecision Value to assign to Precision. */ public void setLikelihoodThreshold(double newPrecision) { m_Precision = newPrecision; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numRunsTipText() { return "Number of runs for internal cross-validation."; } /** * Get the value of NumRuns. * * @return Value of NumRuns. */ public int getNumRuns() { return m_NumRuns; } /** * Set the value of NumRuns. * * @param newNumRuns Value to assign to NumRuns. */ public void setNumRuns(int newNumRuns) { m_NumRuns = newNumRuns; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String numFoldsTipText() { return "Number of folds for internal cross-validation (default 0 " + "means no cross-validation is performed)."; } /** * Get the value of NumFolds. * * @return Value of NumFolds. */ public int getNumFolds() { return m_NumFolds; } /** * Set the value of NumFolds. * * @param newNumFolds Value to assign to NumFolds. */ public void setNumFolds(int newNumFolds) { m_NumFolds = newNumFolds; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String useResamplingTipText() { return "Whether resampling is used instead of reweighting."; } /** * Set resampling mode * * @param r true if resampling should be done */ public void setUseResampling(boolean r) { m_UseResampling = r; } /** * Get whether resampling is turned on * * @return true if resampling output is on */ public boolean getUseResampling() { return m_UseResampling; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String weightThresholdTipText() { return "Weight threshold for weight pruning (reduce to 90 " + "for speeding up learning process)."; } /** * Set weight thresholding * * @param threshold the percentage of weight mass used for training */ public void setWeightThreshold(int threshold) { m_WeightThreshold = threshold; } /** * Get the degree of weight thresholding * * @return the percentage of weight mass used for training */ public int getWeightThreshold() { return m_WeightThreshold; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // class result.disableAllClasses(); result.disableAllClassDependencies(); result.enable(Capability.NOMINAL_CLASS); return result; } /** * Builds the boosted classifier * * @param data the data to train the classifier with * @throws Exception if building fails, e.g., can't handle data */ public void buildClassifier(Instances data) throws Exception { m_RandomInstance = new Random(m_Seed); int classIndex = data.classIndex(); if (m_Classifier == null) { throw new Exception("A base classifier has not been specified!"); } if (!(m_Classifier instanceof WeightedInstancesHandler) && !m_UseResampling) { m_UseResampling = true; } // can classifier handle the data? getCapabilities().testWithFail(data); if (m_Debug) { System.err.println("Creating copy of the training data"); } // remove instances with missing class data = new Instances(data); data.deleteWithMissingClass(); m_NumClasses = data.numClasses(); m_ClassAttribute = data.classAttribute(); // Create the base classifiers if (m_Debug) { System.err.println("Creating base classifiers"); } m_Classifiers = new Classifier [m_NumClasses][]; for (int j = 0; j < m_NumClasses; j++) { m_Classifiers[j] = Classifier.makeCopies(m_Classifier, getNumIterations()); } // Do we want to select the appropriate number of iterations // using cross-validation? int bestNumIterations = getNumIterations(); if (m_NumFolds > 1) { if (m_Debug) { System.err.println("Processing first fold."); } // Array for storing the results double[] results = new double[getNumIterations()]; // Iterate throught the cv-runs for (int r = 0; r < m_NumRuns; r++) { // Stratify the data data.randomize(m_RandomInstance); data.stratify(m_NumFolds); // Perform the cross-validation for (int i = 0; i < m_NumFolds; i++) { // Get train and test folds Instances train = data.trainCV(m_NumFolds, i, m_RandomInstance); Instances test = data.testCV(m_NumFolds, i); // Make class numeric Instances trainN = new Instances(train); trainN.setClassIndex(-1); trainN.deleteAttributeAt(classIndex); trainN.insertAttributeAt(new Attribute("'pseudo class'"), classIndex); trainN.setClassIndex(classIndex); m_NumericClassData = new Instances(trainN, 0); // Get class values int numInstances = train.numInstances(); double [][] trainFs = new double [numInstances][m_NumClasses]; double [][] trainYs = new double [numInstances][m_NumClasses]; for (int j = 0; j < m_NumClasses; j++) { for (int k = 0; k < numInstances; k++) { trainYs[k][j] = (train.instance(k).classValue() == j) ? 1.0 - m_Offset: 0.0 + (m_Offset / (double)m_NumClasses); } } // Perform iterations double[][] probs = initialProbs(numInstances); m_NumGenerated = 0; double sumOfWeights = train.sumOfWeights(); for (int j = 0; j < getNumIterations(); j++) { performIteration(trainYs, trainFs, probs, trainN, sumOfWeights); Evaluation eval = new Evaluation(train); eval.evaluateModel(this, test); results[j] += eval.correct(); } } } // Find the number of iterations with the lowest error double bestResult = -Double.MAX_VALUE; for (int j = 0; j < getNumIterations(); j++) { if (results[j] > bestResult) { bestResult = results[j]; bestNumIterations = j; } } if (m_Debug) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -