📄 bvdecompose.java
字号:
} /** * Gets the name of the classifier being analysed * * @return the classifier being analysed. */ public Classifier getClassifier() { return m_Classifier; } /** * Sets debugging mode * * @param debug true if debug output should be printed */ public void setDebug(boolean debug) { m_Debug = debug; } /** * Gets whether debugging is turned on * * @return true if debugging output is on */ public boolean getDebug() { return m_Debug; } /** * Sets the random number seed */ public void setSeed(int seed) { m_Seed = seed; } /** * Gets the random number seed * * @return the random number seed */ public int getSeed() { return m_Seed; } /** * Sets the maximum number of boost iterations */ public void setTrainIterations(int trainIterations) { m_TrainIterations = trainIterations; } /** * Gets the maximum number of boost iterations * * @return the maximum number of boost iterations */ public int getTrainIterations() { return m_TrainIterations; } /** * Sets the maximum number of boost iterations */ public void setDataFileName(String dataFileName) { m_DataFileName = dataFileName; } /** * Get the name of the data file used for the decomposition * * @return the name of the data file */ public String getDataFileName() { return m_DataFileName; } /** * Get the index (starting from 1) of the attribute used as the class. * * @return the index of the class attribute */ public int getClassIndex() { return m_ClassIndex + 1; } /** * Sets index of attribute to discretize on * * @param index the index (starting from 1) of the class attribute */ public void setClassIndex(int classIndex) { m_ClassIndex = classIndex - 1; } /** * Get the calculated bias squared * * @return the bias squared */ public double getBias() { return m_Bias; } /** * Get the calculated variance * * @return the variance */ public double getVariance() { return m_Variance; } /** * Get the calculated sigma squared * * @return the sigma squared */ public double getSigma() { return m_Sigma; } /** * Get the calculated error rate * * @return the error rate */ public double getError() { return m_Error; } /** * Carry out the bias-variance decomposition * * @exception Exception if the decomposition couldn't be carried out */ public void decompose() throws Exception { Reader dataReader = new BufferedReader(new FileReader(m_DataFileName)); Instances data = new Instances(dataReader); if (m_ClassIndex < 0) { data.setClassIndex(data.numAttributes() - 1); } else { data.setClassIndex(m_ClassIndex); } if (data.classAttribute().type() != Attribute.NOMINAL) { throw new Exception("Class attribute must be nominal"); } int numClasses = data.numClasses(); data.deleteWithMissingClass(); if (data.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } if (data.numInstances() < 2 * m_TrainPoolSize) { throw new Exception("The dataset must contain at least " + (2 * m_TrainPoolSize) + " instances"); } Random random = new Random(m_Seed); data.randomize(random); Instances trainPool = new Instances(data, 0, m_TrainPoolSize); Instances test = new Instances(data, m_TrainPoolSize, data.numInstances() - m_TrainPoolSize); int numTest = test.numInstances(); double [][] instanceProbs = new double [numTest][numClasses]; m_Error = 0; for (int i = 0; i < m_TrainIterations; i++) { if (m_Debug) { System.err.println("Iteration " + (i + 1)); } trainPool.randomize(random); Instances train = new Instances(trainPool, 0, m_TrainPoolSize / 2); m_Classifier.buildClassifier(train); //// Evaluate the classifier on test, updating BVD stats for (int j = 0; j < numTest; j++) { int pred = (int)m_Classifier.classifyInstance(test.instance(j)); if (pred != test.instance(j).classValue()) { m_Error++; } instanceProbs[j][pred]++; } } m_Error /= (m_TrainIterations * numTest); // Average the BV over each instance in test. m_Bias = 0; m_Variance = 0; m_Sigma = 0; for (int i = 0; i < numTest; i++) { Instance current = test.instance(i); double [] predProbs = instanceProbs[i]; double pActual, pPred; double bsum = 0, vsum = 0, ssum = 0; for (int j = 0; j < numClasses; j++) { pActual = (current.classValue() == j) ? 1 : 0; // Or via 1NN from test data? pPred = predProbs[j] / m_TrainIterations; bsum += (pActual - pPred) * (pActual - pPred) - pPred * (1 - pPred) / (m_TrainIterations - 1); vsum += pPred * pPred; ssum += pActual * pActual; } m_Bias += bsum; m_Variance += (1 - vsum); m_Sigma += (1 - ssum); } m_Bias /= (2 * numTest); m_Variance /= (2 * numTest); m_Sigma /= (2 * numTest); if (m_Debug) { System.err.println("Decomposition finished"); } } /** * Returns description of the bias-variance decomposition results. * * @return the bias-variance decomposition results as a string */ public String toString() { String result = "\nBias-Variance Decomposition\n"; if (getClassifier() == null) { return "Invalid setup"; } result += "\nClassifier : " + getClassifier().getClass().getName(); if (getClassifier() instanceof OptionHandler) { result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions()); } result += "\nData File : " + getDataFileName(); result += "\nClass Index : "; if (getClassIndex() == 0) { result += "last"; } else { result += getClassIndex(); } result += "\nTraining Pool: " + getTrainPoolSize(); result += "\nIterations : " + getTrainIterations(); result += "\nSeed : " + getSeed(); result += "\nError : " + Utils.doubleToString(getError(), 6, 4); result += "\nSigma^2 : " + Utils.doubleToString(getSigma(), 6, 4); result += "\nBias^2 : " + Utils.doubleToString(getBias(), 6, 4); result += "\nVariance : " + Utils.doubleToString(getVariance(), 6, 4); return result + "\n"; } /** * Test method for this class * * @param args the command line arguments */ public static void main(String [] args) { try { BVDecompose bvd = new BVDecompose(); try { bvd.setOptions(args); Utils.checkForRemainingOptions(args); } catch (Exception ex) { String result = ex.getMessage() + "\nBVDecompose Options:\n\n"; Enumeration enum = bvd.listOptions(); while (enum.hasMoreElements()) { Option option = (Option) enum.nextElement(); result += option.synopsis() + "\n" + option.description() + "\n"; } throw new Exception(result); } bvd.decompose(); System.out.println(bvd.toString()); } catch (Exception ex) { System.err.println(ex.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -