📄 regressionbvdecompose.java
字号:
*/ public void setClassifier(Classifier newClassifier) { m_Classifier = newClassifier; } /** * Gets the name of the classifier being analysed * * @return the classifier being analysed. */ public Classifier getClassifier() { return m_Classifier; } /** * Sets debugging mode * * @param debug true if debug output should be printed */ public void setDebug(boolean debug) { m_Debug = debug; } /** * Gets whether debugging is turned on * * @return true if debugging output is on */ public boolean getDebug() { return m_Debug; } /** * Sets the random number seed */ public void setSeed(int seed) { m_Seed = seed; } /** * Gets the random number seed * * @return the random number seed */ public int getSeed() { return m_Seed; } /** * Sets the maximum number of boost iterations */ public void setTrainIterations(int trainIterations) { m_TrainIterations = trainIterations; } /** * Gets the maximum number of boost iterations * * @return the maximum number of boost iterations */ public int getTrainIterations() { return m_TrainIterations; } /** * Sets the maximum number of boost iterations */ public void setDataFileName(String dataFileName) { m_DataFileName = dataFileName; } /** * Get the name of the data file used for the decomposition * * @return the name of the data file */ public String getDataFileName() { return m_DataFileName; } /** * Get the index (starting from 1) of the attribute used as the class. * * @return the index of the class attribute */ public int getClassIndex() { return m_ClassIndex + 1; } /** * Sets index of attribute to discretize on * * @param index the index (starting from 1) of the class attribute */ public void setClassIndex(int classIndex) { m_ClassIndex = classIndex - 1; } /** * Get the calculated bias squared * * @return the bias squared */ public double getBias() { return m_Bias; } /** * Get the calculated variance * * @return the variance */ public double getVariance() { return m_Variance; } /** * Get the calculated error rate * * @return the error rate */ public double getError() { return m_Error; } /** * Carry out the bias-variance decomposition * * @exception Exception if the decomposition couldn't be carried out */ public void decompose() throws Exception { Reader dataReader = new BufferedReader(new FileReader(m_DataFileName)); Instances data = new Instances(dataReader); if (m_ClassIndex < 0) { data.setClassIndex(data.numAttributes() - 1); } else { data.setClassIndex(m_ClassIndex); } if (data.classAttribute().type() != Attribute.NUMERIC) { throw new Exception("Class attribute must be numeric"); } int numClasses = data.numClasses(); data.deleteWithMissingClass(); if (data.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } if (data.numInstances() < 2 * m_TrainPoolSize) { throw new Exception("The dataset must contain at least " + (2 * m_TrainPoolSize) + " instances"); } Random random = new Random(m_Seed); data.randomize(random); Instances trainPool = new Instances(data, 0, m_TrainPoolSize); Instances test = new Instances(data, m_TrainPoolSize, data.numInstances() - m_TrainPoolSize); int numTest = test.numInstances(); double [][] pred = new double [numTest][m_TrainIterations]; //collect predictions for (int i = 0; i < m_TrainIterations; i++) { if (m_Debug) { System.err.println("Iteration " + (i + 1)); } trainPool.randomize(random); Instances train = new Instances(trainPool, 0, m_TrainPoolSize / 2); m_Classifier.buildClassifier(train); //// Evaluate the classifier on test, updating BVD stats for (int j = 0; j < numTest; j++) { pred[j][i] = m_Classifier.classifyInstance(test.instance(j)); } } //if (pred != test.instance(j).classValue()) { // Average the BV over each instance in test. m_Bias = 0; m_Variance = 0; m_Error = 0; for (int i = 0; i < numTest; i++) { double error=0, bias, variance, sum=0, sumSquared=0; double x, y = test.instance(i).classValue(); //compute stats for single test example for (int j = 0; j < m_TrainIterations; j++) { x = pred[i][j]; error += Math.pow((y - x),2); sum += x; sumSquared += x*x; } error = (error/m_TrainIterations); bias = (y - (sum/m_TrainIterations)); bias = bias * bias; variance = (sumSquared - (sum * sum / m_TrainIterations))/m_TrainIterations; //System.out.println("bias: "+bias+" + var: "+variance+" = "+error+"\tChkSum: "+(error-(bias+variance))); //update test set stats m_Error += error; m_Bias += bias; m_Variance += variance; } m_Error /= numTest; m_Bias /= numTest; m_Variance /= numTest; if (m_Debug) { System.err.println("Decomposition finished"); } } /** * Returns description of the bias-variance decomposition results. * * @return the bias-variance decomposition results as a string */ public String toString() { String result = "\nBias-Variance Decomposition\n"; if (getClassifier() == null) { return "Invalid setup"; } result += "\nClassifier : " + getClassifier().getClass().getName(); if (getClassifier() instanceof OptionHandler) { result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions()); } result += "\nData File : " + getDataFileName(); result += "\nClass Index : "; if (getClassIndex() == 0) { result += "last"; } else { result += getClassIndex(); } result += "\nTraining Pool: " + getTrainPoolSize(); result += "\nIterations : " + getTrainIterations(); result += "\nSeed : " + getSeed(); result += "\nError : " + Utils.doubleToString(getError(), 6, 4); result += "\nBias^2 : " + Utils.doubleToString(getBias(), 6, 4); result += "\nVariance : " + Utils.doubleToString(getVariance(), 6, 4); result += "\nChkSum: "+(getError() - (getBias()+getVariance())); return result + "\n"; } /** * Test method for this class * * @param args the command line arguments */ public static void main(String [] args) { try { RegressionBVDecompose bvd = new RegressionBVDecompose(); try { bvd.setOptions(args); Utils.checkForRemainingOptions(args); } catch (Exception ex) { String result = ex.getMessage() + "\nRegressionBVDecompose Options:\n\n"; Enumeration enum = bvd.listOptions(); while (enum.hasMoreElements()) { Option option = (Option) enum.nextElement(); result += option.synopsis() + "\n" + option.description() + "\n"; } throw new Exception(result); } bvd.decompose(); System.out.println(bvd.toString()); } catch (Exception ex) { System.err.println(ex.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -