📄 bvdecompose.java
字号:
* @param newClassifier the Classifier to use.
*/
public void setClassifier(Classifier newClassifier) {
m_Classifier = newClassifier;
}
/**
* Gets the name of the classifier being analysed
*
* @return the classifier being analysed.
*/
public Classifier getClassifier() {
return m_Classifier;
}
/**
* Sets debugging mode
*
* @param debug true if debug output should be printed
*/
public void setDebug(boolean debug) {
m_Debug = debug;
}
/**
* Gets whether debugging is turned on
*
* @return true if debugging output is on
*/
public boolean getDebug() {
return m_Debug;
}
/**
* Sets the random number seed
*/
public void setSeed(int seed) {
m_Seed = seed;
}
/**
* Gets the random number seed
*
* @return the random number seed
*/
public int getSeed() {
return m_Seed;
}
/**
* Sets the maximum number of boost iterations
*/
public void setTrainIterations(int trainIterations) {
m_TrainIterations = trainIterations;
}
/**
* Gets the maximum number of boost iterations
*
* @return the maximum number of boost iterations
*/
public int getTrainIterations() {
return m_TrainIterations;
}
/**
* Sets the maximum number of boost iterations
*/
public void setDataFileName(String dataFileName) {
m_DataFileName = dataFileName;
}
/**
* Get the name of the data file used for the decomposition
*
* @return the name of the data file
*/
public String getDataFileName() {
return m_DataFileName;
}
/**
* Get the index (starting from 1) of the attribute used as the class.
*
* @return the index of the class attribute
*/
public int getClassIndex() {
return m_ClassIndex + 1;
}
/**
* Sets index of attribute to discretize on
*
* @param index the index (starting from 1) of the class attribute
*/
public void setClassIndex(int classIndex) {
m_ClassIndex = classIndex - 1;
}
/**
* Get the calculated bias squared
*
* @return the bias squared
*/
public double getBias() {
return m_Bias;
}
/**
* Get the calculated variance
*
* @return the variance
*/
public double getVariance() {
return m_Variance;
}
/**
* Get the calculated sigma squared
*
* @return the sigma squared
*/
public double getSigma() {
return m_Sigma;
}
/**
* Get the calculated error rate
*
* @return the error rate
*/
public double getError() {
return m_Error;
}
/**
* Carry out the bias-variance decomposition
*
* @exception Exception if the decomposition couldn't be carried out
*/
public void decompose() throws Exception {
Reader dataReader = new BufferedReader(new FileReader(m_DataFileName));
Instances data = new Instances(dataReader);
if (m_ClassIndex < 0) {
data.setClassIndex(data.numAttributes() - 1);
} else {
data.setClassIndex(m_ClassIndex);
}
if (data.classAttribute().type() != Attribute.NOMINAL) {
throw new Exception("Class attribute must be nominal");
}
int numClasses = data.numClasses();
data.deleteWithMissingClass();
if (data.checkForStringAttributes()) {
throw new Exception("Can't handle string attributes!");
}
if (data.numInstances() < 2 * m_TrainPoolSize) {
throw new Exception("The dataset must contain at least "
+ (2 * m_TrainPoolSize) + " instances");
}
Random random = new Random(m_Seed);
data.randomize(random);
Instances trainPool = new Instances(data, 0, m_TrainPoolSize);
Instances test = new Instances(data, m_TrainPoolSize,
data.numInstances() - m_TrainPoolSize);
int numTest = test.numInstances();
double [][] instanceProbs = new double [numTest][numClasses];
m_Error = 0;
for (int i = 0; i < m_TrainIterations; i++) {
if (m_Debug) {
System.err.println("Iteration " + (i + 1));
}
trainPool.randomize(random);
Instances train = new Instances(trainPool, 0, m_TrainPoolSize / 2);
m_Classifier.buildClassifier(train);
//// Evaluate the classifier on test, updating BVD stats
for (int j = 0; j < numTest; j++) {
int pred = (int)m_Classifier.classifyInstance(test.instance(j));
if (pred != test.instance(j).classValue()) {
m_Error++;
}
instanceProbs[j][pred]++;
}
}
m_Error /= (m_TrainIterations * numTest);
// Average the BV over each instance in test.
m_Bias = 0;
m_Variance = 0;
m_Sigma = 0;
for (int i = 0; i < numTest; i++) {
Instance current = test.instance(i);
double [] predProbs = instanceProbs[i];
double pActual, pPred;
double bsum = 0, vsum = 0, ssum = 0;
for (int j = 0; j < numClasses; j++) {
pActual = (current.classValue() == j) ? 1 : 0; // Or via 1NN from test data?
pPred = predProbs[j] / m_TrainIterations;
bsum += (pActual - pPred) * (pActual - pPred)
- pPred * (1 - pPred) / (m_TrainIterations - 1);
vsum += pPred * pPred;
ssum += pActual * pActual;
}
m_Bias += bsum;
m_Variance += (1 - vsum);
m_Sigma += (1 - ssum);
}
m_Bias /= (2 * numTest);
m_Variance /= (2 * numTest);
m_Sigma /= (2 * numTest);
if (m_Debug) {
System.err.println("Decomposition finished");
}
}
/**
* Returns description of the bias-variance decomposition results.
*
* @return the bias-variance decomposition results as a string
*/
public String toString() {
String result = "\nBias-Variance Decomposition\n";
if (getClassifier() == null) {
return "Invalid setup";
}
result += "\nClassifier : " + getClassifier().getClass().getName();
if (getClassifier() instanceof OptionHandler) {
result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions());
}
result += "\nData File : " + getDataFileName();
result += "\nClass Index : ";
if (getClassIndex() == 0) {
result += "last";
} else {
result += getClassIndex();
}
result += "\nTraining Pool: " + getTrainPoolSize();
result += "\nIterations : " + getTrainIterations();
result += "\nSeed : " + getSeed();
result += "\nError : " + Utils.doubleToString(getError(), 6, 4);
result += "\nSigma^2 : " + Utils.doubleToString(getSigma(), 6, 4);
result += "\nBias^2 : " + Utils.doubleToString(getBias(), 6, 4);
result += "\nVariance : " + Utils.doubleToString(getVariance(), 6, 4);
return result + "\n";
}
/**
* Test method for this class
*
* @param args the command line arguments
*/
public static void main(String [] args) {
try {
BVDecompose bvd = new BVDecompose();
try {
bvd.setOptions(args);
Utils.checkForRemainingOptions(args);
} catch (Exception ex) {
String result = ex.getMessage() + "\nBVDecompose Options:\n\n";
Enumeration em = bvd.listOptions();
while (em.hasMoreElements()) {
Option option = (Option) em.nextElement();
result += option.synopsis() + "\n" + option.description() + "\n";
}
throw new Exception(result);
}
bvd.decompose();
System.out.println(bvd.toString());
} catch (Exception ex) {
System.err.println(ex.getMessage());
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -