bvdecomposesegcvsub.java
来自「Java 编写的多种数据挖掘算法 包括聚类、分类、预处理等」· Java 代码 · 共 1,109 行 · 第 1/3 页
JAVA
1,109 行
} //roundup tps from double to integer tps = (int) Math.ceil( ((double)m_TrainSize / (double)m_P) + 1 ); k = (int) Math.ceil( tps / (tps - (double) m_TrainSize)); // number of folds cannot be more than the number of instances in the training pool if ( k > tps ) { throw new Exception("The required number of folds is too many." + "Change p or the size of the training set."); } // calculate the number of segments, round down. q = (int) Math.floor( (double) data.numInstances() / (double)tps ); //create confusion matrix, columns = number of instances in data set, as all will be used, by rows = number of classes. double [][] instanceProbs = new double [data.numInstances()][numClasses]; int [][] foldIndex = new int [ k ][ 2 ]; Vector segmentList = new Vector(q + 1); //Set random seed Random random = new Random(m_Seed); data.randomize(random); //create index arrays for different segments int currentDataIndex = 0; for( int count = 1; count <= (q + 1); count++ ){ if( count > q){ int [] segmentIndex = new int [ (data.numInstances() - (q * tps)) ]; for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){ segmentIndex[index] = currentDataIndex; } segmentList.add(segmentIndex); } else { int [] segmentIndex = new int [ tps ]; for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){ segmentIndex[index] = currentDataIndex; } segmentList.add(segmentIndex); } } int remainder = tps % k; // remainder is used to determine when to shrink the fold size by 1. //foldSize = ROUNDUP( tps / k ) (round up, eg 3 -> 3, 3.3->4) int foldSize = (int) Math.ceil( (double)tps /(double) k); //roundup fold size double to integer int index = 0; int currentIndex; for( int count = 0; count < k; count ++){ if( remainder != 0 && count == remainder ){ foldSize -= 1; } foldIndex[count][0] = index; foldIndex[count][1] = foldSize; index += foldSize; } for( int l = 0; l < m_ClassifyIterations; l++) { for(int i = 1; i <= q; i++){ int [] currentSegment = (int[]) segmentList.get(i - 1); randomize(currentSegment, random); //CROSS FOLD VALIDATION for current Segment for( int j = 1; j <= k; j++){ Instances TP = null; for(int foldNum = 1; foldNum <= k; foldNum++){ if( foldNum != j){ int startFoldIndex = foldIndex[ foldNum - 1 ][ 0 ]; //start index foldSize = foldIndex[ foldNum - 1 ][ 1 ]; int endFoldIndex = startFoldIndex + foldSize - 1; for(int currentFoldIndex = startFoldIndex; currentFoldIndex <= endFoldIndex; currentFoldIndex++){ if( TP == null ){ TP = new Instances(data, currentSegment[ currentFoldIndex ], 1); }else{ TP.add( data.instance( currentSegment[ currentFoldIndex ] ) ); } } } } TP.randomize(random); if( getTrainSize() > TP.numInstances() ){ throw new Exception("The training set size of " + getTrainSize() + ", is greater than the training pool " + TP.numInstances() ); } Instances train = new Instances(TP, 0, m_TrainSize); Classifier current = Classifier.makeCopy(m_Classifier); current.buildClassifier(train); // create a clssifier using the instances in train. int currentTestIndex = foldIndex[ j - 1 ][ 0 ]; //start index int testFoldSize = foldIndex[ j - 1 ][ 1 ]; //size int endTestIndex = currentTestIndex + testFoldSize - 1; while( currentTestIndex <= endTestIndex ){ Instance testInst = data.instance( currentSegment[currentTestIndex] ); int pred = (int)current.classifyInstance( testInst ); if(pred != testInst.classValue()) { m_Error++; // add 1 to mis-classifications. } instanceProbs[ currentSegment[ currentTestIndex ] ][ pred ]++; currentTestIndex++; } if( i == 1 && j == 1){ int[] segmentElast = (int[])segmentList.lastElement(); for( currentIndex = 0; currentIndex < segmentElast.length; currentIndex++){ Instance testInst = data.instance( segmentElast[currentIndex] ); int pred = (int)current.classifyInstance( testInst ); if(pred != testInst.classValue()) { m_Error++; // add 1 to mis-classifications. } instanceProbs[ segmentElast[ currentIndex ] ][ pred ]++; } } } } } m_Error /= (double)( m_ClassifyIterations * data.numInstances() ); m_KWBias = 0.0; m_KWVariance = 0.0; m_KWSigma = 0.0; m_WBias = 0.0; m_WVariance = 0.0; for (int i = 0; i < data.numInstances(); i++) { Instance current = data.instance( i ); double [] predProbs = instanceProbs[ i ]; double pActual, pPred; double bsum = 0, vsum = 0, ssum = 0; double wBSum = 0, wVSum = 0; Vector centralTendencies = findCentralTendencies( predProbs ); if( centralTendencies == null ){ throw new Exception("Central tendency was null."); } for (int j = 0; j < numClasses; j++) { pActual = (current.classValue() == j) ? 1 : 0; pPred = predProbs[j] / m_ClassifyIterations; bsum += (pActual - pPred) * (pActual - pPred) - pPred * (1 - pPred) / (m_ClassifyIterations - 1); vsum += pPred * pPred; ssum += pActual * pActual; } m_KWBias += bsum; m_KWVariance += (1 - vsum); m_KWSigma += (1 - ssum); for( int count = 0; count < centralTendencies.size(); count++ ) { int wB = 0, wV = 0; int centralTendency = ((Integer)centralTendencies.get(count)).intValue(); // For a single instance xi, find the bias and variance. for (int j = 0; j < numClasses; j++) { //Webb definition if( j != (int)current.classValue() && j == centralTendency ) { wB += predProbs[j]; } if( j != (int)current.classValue() && j != centralTendency ) { wV += predProbs[j]; } } wBSum += (double) wB; wVSum += (double) wV; } // calculate bais by dividing bSum by the number of central tendencies and // total number of instances. (effectively finding the average and dividing // by the number of instances to get the nominalised probability). m_WBias += ( wBSum / ((double) ( centralTendencies.size() * m_ClassifyIterations ))); // calculate variance by dividing vSum by the total number of interations m_WVariance += ( wVSum / ((double) ( centralTendencies.size() * m_ClassifyIterations ))); } m_KWBias /= (2.0 * (double) data.numInstances()); m_KWVariance /= (2.0 * (double) data.numInstances()); m_KWSigma /= (2.0 * (double) data.numInstances()); // bias = bias / number of data instances m_WBias /= (double) data.numInstances(); // variance = variance / number of data instances. m_WVariance /= (double) data.numInstances(); if (m_Debug) { System.err.println("Decomposition finished"); } } /** Finds the central tendency, given the classifications for an instance. * * Where the central tendency is defined as the class that was most commonly * selected for a given instance.<p> * * For example, instance 'x' may be classified out of 3 classes y = {1, 2, 3}, * so if x is classified 10 times, and is classified as follows, '1' = 2 times, '2' = 5 times * and '3' = 3 times. Then the central tendency is '2'. <p> * * However, it is important to note that this method returns a list of all classes * that have the highest number of classifications. * * In cases where there are several classes with the largest number of classifications, then * all of these classes are returned. For example if 'x' is classified '1' = 4 times, * '2' = 4 times and '3' = 2 times. Then '1' and '2' are returned.<p> * * @param predProbs the array of classifications for a single instance. * * @return a Vector containing Integer objects which store the class(s) which * are the central tendency. */ public Vector findCentralTendencies(double[] predProbs) { int centralTValue = 0; int currentValue = 0; //array to store the list of classes the have the greatest number of classifictions. Vector centralTClasses; centralTClasses = new Vector(); //create an array with size of the number of classes. // Go through array, finding the central tendency. for( int i = 0; i < predProbs.length; i++) { currentValue = (int) predProbs[i]; // if current value is greater than the central tendency value then // clear vector and add new class to vector array. if( currentValue > centralTValue) { centralTClasses.clear(); centralTClasses.addElement( new Integer(i) ); centralTValue = currentValue; } else if( currentValue != 0 && currentValue == centralTValue) { centralTClasses.addElement( new Integer(i) ); } } //return all classes that have the greatest number of classifications. if( centralTValue != 0){ return centralTClasses; } else { return null; } } /** * Returns description of the bias-variance decomposition results. * * @return the bias-variance decomposition results as a string */ public String toString() { String result = "\nBias-Variance Decomposition Segmentation, Cross Validation\n" + "with subsampling.\n"; if (getClassifier() == null) { return "Invalid setup"; } result += "\nClassifier : " + getClassifier().getClass().getName(); if (getClassifier() instanceof OptionHandler) { result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions()); } result += "\nData File : " + getDataFileName(); result += "\nClass Index : "; if (getClassIndex() == 0) { result += "last"; } else { result += getClassIndex(); } result += "\nIterations : " + getClassifyIterations(); result += "\np : " + getP(); result += "\nTraining Size : " + getTrainSize(); result += "\nSeed : " + getSeed(); result += "\n\nDefinition : " +"Kohavi and Wolpert"; result += "\nError :" + Utils.doubleToString(getError(), 4); result += "\nBias^2 :" + Utils.doubleToString(getKWBias(), 4); result += "\nVariance :" + Utils.doubleToString(getKWVariance(), 4); result += "\nSigma^2 :" + Utils.doubleToString(getKWSigma(), 4); result += "\n\nDefinition : " +"Webb"; result += "\nError :" + Utils.doubleToString(getError(), 4); result += "\nBias :" + Utils.doubleToString(getWBias(), 4); result += "\nVariance :" + Utils.doubleToString(getWVariance(), 4); return result; } /** * Test method for this class * * @param args the command line arguments */ public static void main(String [] args) { try { BVDecomposeSegCVSub bvd = new BVDecomposeSegCVSub(); try { bvd.setOptions(args); Utils.checkForRemainingOptions(args); } catch (Exception ex) { String result = ex.getMessage() + "\nBVDecompose Options:\n\n"; Enumeration enu = bvd.listOptions(); while (enu.hasMoreElements()) { Option option = (Option) enu.nextElement(); result += option.synopsis() + "\n" + option.description() + "\n"; } throw new Exception(result); } bvd.decompose(); System.out.println(bvd.toString()); } catch (Exception ex) { System.err.println(ex.getMessage()); } } /** * Accepts an array of ints and randomises the values in the array, using the * random seed. * *@param index is the array of integers *@param random is the Random seed. */ public final void randomize(int[] index, Random random) { for( int j = index.length - 1; j > 0; j-- ){ int k = random.nextInt( j + 1 ); int temp = index[j]; index[j] = index[k]; index[k] = temp; } }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?