📄 logitboost.java
字号:
"not allowed.");
}
super.setOptions(options);
}
/**
* Gets the current settings of the Classifier.
*
* @return an array of strings suitable for passing to setOptions
*/
public String [] getOptions() {
String [] superOptions = super.getOptions();
String [] options = new String [superOptions.length + 10];
int current = 0;
if (getUseResampling()) {
options[current++] = "-Q";
} else {
options[current++] = "-P";
options[current++] = "" + getWeightThreshold();
}
options[current++] = "-F"; options[current++] = "" + getNumFolds();
options[current++] = "-R"; options[current++] = "" + getNumRuns();
options[current++] = "-L"; options[current++] = "" + getLikelihoodThreshold();
options[current++] = "-H"; options[current++] = "" + getShrinkage();
System.arraycopy(superOptions, 0, options, current,
superOptions.length);
current += superOptions.length;
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String shrinkageTipText() {
return "Shrinkage parameter (use small value like 0.1 to reduce "
+ "overfitting).";
}
/**
* Get the value of Shrinkage.
*
* @return Value of Shrinkage.
*/
public double getShrinkage() {
return m_Shrinkage;
}
/**
* Set the value of Shrinkage.
*
* @param newShrinkage Value to assign to Shrinkage.
*/
public void setShrinkage(double newShrinkage) {
m_Shrinkage = newShrinkage;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String likelihoodThresholdTipText() {
return "Threshold on improvement in likelihood.";
}
/**
* Get the value of Precision.
*
* @return Value of Precision.
*/
public double getLikelihoodThreshold() {
return m_Precision;
}
/**
* Set the value of Precision.
*
* @param newPrecision Value to assign to Precision.
*/
public void setLikelihoodThreshold(double newPrecision) {
m_Precision = newPrecision;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numRunsTipText() {
return "Number of runs for internal cross-validation.";
}
/**
* Get the value of NumRuns.
*
* @return Value of NumRuns.
*/
public int getNumRuns() {
return m_NumRuns;
}
/**
* Set the value of NumRuns.
*
* @param newNumRuns Value to assign to NumRuns.
*/
public void setNumRuns(int newNumRuns) {
m_NumRuns = newNumRuns;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numFoldsTipText() {
return "Number of folds for internal cross-validation (default 0 "
+ "means no cross-validation is performed).";
}
/**
* Get the value of NumFolds.
*
* @return Value of NumFolds.
*/
public int getNumFolds() {
return m_NumFolds;
}
/**
* Set the value of NumFolds.
*
* @param newNumFolds Value to assign to NumFolds.
*/
public void setNumFolds(int newNumFolds) {
m_NumFolds = newNumFolds;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String useResamplingTipText() {
return "Whether resampling is used instead of reweighting.";
}
/**
* Set resampling mode
*
* @param resampling true if resampling should be done
*/
public void setUseResampling(boolean r) {
m_UseResampling = r;
}
/**
* Get whether resampling is turned on
*
* @return true if resampling output is on
*/
public boolean getUseResampling() {
return m_UseResampling;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String weightThresholdTipText() {
return "Weight threshold for weight pruning (reduce to 90 "
+ "for speeding up learning process).";
}
/**
* Set weight thresholding
*
* @param thresholding the percentage of weight mass used for training
*/
public void setWeightThreshold(int threshold) {
m_WeightThreshold = threshold;
}
/**
* Get the degree of weight thresholding
*
* @return the percentage of weight mass used for training
*/
public int getWeightThreshold() {
return m_WeightThreshold;
}
/**
* Builds the boosted classifier
*/
public void buildClassifier(Instances data) throws Exception {
m_RandomInstance = new Random(m_Seed);
Instances boostData, trainData;
int classIndex = data.classIndex();
if (data.classAttribute().isNumeric()) {
throw new UnsupportedClassTypeException("LogitBoost can't handle a numeric class!");
}
if (m_Classifier == null) {
throw new Exception("A base classifier has not been specified!");
}
if (!(m_Classifier instanceof WeightedInstancesHandler) &&
!m_UseResampling) {
m_UseResampling = true;
}
if (data.checkForStringAttributes()) {
throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
}
if (m_Debug) {
System.err.println("Creating copy of the training data");
}
m_NumClasses = data.numClasses();
m_ClassAttribute = data.classAttribute();
// Create a copy of the data
data = new Instances(data);
data.deleteWithMissingClass();
// Create the base classifiers
if (m_Debug) {
System.err.println("Creating base classifiers");
}
m_Classifiers = new Classifier [m_NumClasses][];
for (int j = 0; j < m_NumClasses; j++) {
m_Classifiers[j] = Classifier.makeCopies(m_Classifier,
getNumIterations());
}
// Do we want to select the appropriate number of iterations
// using cross-validation?
int bestNumIterations = getNumIterations();
if (m_NumFolds > 1) {
if (m_Debug) {
System.err.println("Processing first fold.");
}
// Array for storing the results
double[] results = new double[getNumIterations()];
// Iterate throught the cv-runs
for (int r = 0; r < m_NumRuns; r++) {
// Stratify the data
data.randomize(m_RandomInstance);
data.stratify(m_NumFolds);
// Perform the cross-validation
for (int i = 0; i < m_NumFolds; i++) {
// Get train and test folds
Instances train = data.trainCV(m_NumFolds, i, m_RandomInstance);
Instances test = data.testCV(m_NumFolds, i);
// Make class numeric
Instances trainN = new Instances(train);
trainN.setClassIndex(-1);
trainN.deleteAttributeAt(classIndex);
trainN.insertAttributeAt(new Attribute("'pseudo class'"), classIndex);
trainN.setClassIndex(classIndex);
m_NumericClassData = new Instances(trainN, 0);
// Get class values
int numInstances = train.numInstances();
double [][] trainFs = new double [numInstances][m_NumClasses];
double [][] trainYs = new double [numInstances][m_NumClasses];
for (int j = 0; j < m_NumClasses; j++) {
for (int k = 0; k < numInstances; k++) {
trainYs[k][j] = (train.instance(k).classValue() == j) ?
1.0 - m_Offset: 0.0 + (m_Offset / (double)m_NumClasses);
}
}
// Perform iterations
double[][] probs = initialProbs(numInstances);
m_NumGenerated = 0;
double sumOfWeights = train.sumOfWeights();
for (int j = 0; j < getNumIterations(); j++) {
performIteration(trainYs, trainFs, probs, trainN, sumOfWeights);
Evaluation eval = new Evaluation(train);
eval.evaluateModel(this, test);
results[j] += eval.correct();
}
}
}
// Find the number of iterations with the lowest error
double bestResult = -Double.MAX_VALUE;
for (int j = 0; j < getNumIterations(); j++) {
if (results[j] > bestResult) {
bestResult = results[j];
bestNumIterations = j;
}
}
if (m_Debug) {
System.err.println("Best result for " +
bestNumIterations + " iterations: " +
bestResult);
}
}
// Build classifier on all the data
int numInstances = data.numInstances();
double [][] trainFs = new double [numInstances][m_NumClasses];
double [][] trainYs = new double [numInstances][m_NumClasses];
for (int j = 0; j < m_NumClasses; j++) {
for (int i = 0, k = 0; i < numInstances; i++, k++) {
trainYs[i][j] = (data.instance(k).classValue() == j) ?
1.0 - m_Offset: 0.0 + (m_Offset / (double)m_NumClasses);
}
}
// Make class numeric
data.setClassIndex(-1);
data.deleteAttributeAt(classIndex);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -