bvdecomposesegcvsub.java

来自「Java 编写的多种数据挖掘算法 包括聚类、分类、预处理等」· Java 代码 · 共 1,109 行 · 第 1/3 页

JAVA
1,109
字号
        }                //roundup tps from double to integer        tps = (int) Math.ceil( ((double)m_TrainSize / (double)m_P) + 1 );        k = (int) Math.ceil( tps / (tps - (double) m_TrainSize));                // number of folds cannot be more than the number of instances in the training pool        if ( k > tps ) {            throw new Exception("The required number of folds is too many."            + "Change p or the size of the training set.");        }                // calculate the number of segments, round down.        q = (int) Math.floor( (double) data.numInstances() / (double)tps );                //create confusion matrix, columns = number of instances in data set, as all will be used,  by rows = number of classes.        double [][] instanceProbs = new double [data.numInstances()][numClasses];        int [][] foldIndex = new int [ k ][ 2 ];        Vector segmentList = new Vector(q + 1);                //Set random seed        Random random = new Random(m_Seed);                data.randomize(random);                //create index arrays for different segments                int currentDataIndex = 0;        for( int count = 1; count <= (q + 1); count++ ){            if( count > q){                int [] segmentIndex = new int [ (data.numInstances() - (q * tps)) ];                for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){                                        segmentIndex[index] = currentDataIndex;                }                segmentList.add(segmentIndex);            } else {                int [] segmentIndex = new int [ tps ];                                for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){                    segmentIndex[index] = currentDataIndex;                }                segmentList.add(segmentIndex);            }        }                int remainder = tps % k; // remainder is used to determine when to shrink the fold size by 1.                //foldSize = ROUNDUP( tps / k ) (round up, eg 3 -> 3,  3.3->4)        int foldSize = (int) Math.ceil( (double)tps /(double) k); //roundup fold size double to integer        int index = 0;        int currentIndex;                for( int count = 0; count < k; count ++){            if( remainder != 0 && count == remainder ){                foldSize -= 1;            }            foldIndex[count][0] = index;            foldIndex[count][1] = foldSize;            index += foldSize;        }                for( int l = 0; l < m_ClassifyIterations; l++) {                        for(int i = 1; i <= q; i++){                                int [] currentSegment = (int[]) segmentList.get(i - 1);                                randomize(currentSegment, random);                                //CROSS FOLD VALIDATION for current Segment                for( int j = 1; j <= k; j++){                                        Instances TP = null;                    for(int foldNum = 1; foldNum <= k; foldNum++){                        if( foldNum != j){                                                        int startFoldIndex = foldIndex[ foldNum - 1 ][ 0 ]; //start index                            foldSize = foldIndex[ foldNum - 1 ][ 1 ];                            int endFoldIndex = startFoldIndex + foldSize - 1;                                                        for(int currentFoldIndex = startFoldIndex; currentFoldIndex <= endFoldIndex; currentFoldIndex++){                                                                if( TP == null ){                                    TP = new Instances(data, currentSegment[ currentFoldIndex ], 1);                                }else{                                    TP.add( data.instance( currentSegment[ currentFoldIndex ] ) );                                }                            }                        }                    }                                        TP.randomize(random);                                        if( getTrainSize() > TP.numInstances() ){                        throw new Exception("The training set size of " + getTrainSize() + ", is greater than the training pool "                        + TP.numInstances() );                    }                                        Instances train = new Instances(TP, 0, m_TrainSize);                    		    Classifier current = Classifier.makeCopy(m_Classifier);                    current.buildClassifier(train); // create a clssifier using the instances in train.                                        int currentTestIndex = foldIndex[ j - 1 ][ 0 ]; //start index                    int testFoldSize = foldIndex[ j - 1 ][ 1 ]; //size                    int endTestIndex = currentTestIndex + testFoldSize - 1;                                        while( currentTestIndex <= endTestIndex ){                                                Instance testInst = data.instance( currentSegment[currentTestIndex] );                        int pred = (int)current.classifyInstance( testInst );                                                                        if(pred != testInst.classValue()) {                            m_Error++; // add 1 to mis-classifications.                        }                        instanceProbs[ currentSegment[ currentTestIndex ] ][ pred ]++;                        currentTestIndex++;                    }                                        if( i == 1 && j == 1){                        int[] segmentElast = (int[])segmentList.lastElement();                        for( currentIndex = 0; currentIndex < segmentElast.length; currentIndex++){                            Instance testInst = data.instance( segmentElast[currentIndex] );                            int pred = (int)current.classifyInstance( testInst );                            if(pred != testInst.classValue()) {                                m_Error++; // add 1 to mis-classifications.                            }                                                        instanceProbs[ segmentElast[ currentIndex ] ][ pred ]++;                        }                    }                }            }        }                m_Error /= (double)( m_ClassifyIterations * data.numInstances() );                m_KWBias = 0.0;        m_KWVariance = 0.0;        m_KWSigma = 0.0;                m_WBias = 0.0;        m_WVariance = 0.0;                for (int i = 0; i < data.numInstances(); i++) {                        Instance current = data.instance( i );                        double [] predProbs = instanceProbs[ i ];            double pActual, pPred;            double bsum = 0, vsum = 0, ssum = 0;            double wBSum = 0, wVSum = 0;                        Vector centralTendencies = findCentralTendencies( predProbs );                        if( centralTendencies == null ){                throw new Exception("Central tendency was null.");            }                        for (int j = 0; j < numClasses; j++) {                pActual = (current.classValue() == j) ? 1 : 0;                pPred = predProbs[j] / m_ClassifyIterations;                bsum += (pActual - pPred) * (pActual - pPred) - pPred * (1 - pPred) / (m_ClassifyIterations - 1);                vsum += pPred * pPred;                ssum += pActual * pActual;            }                        m_KWBias += bsum;            m_KWVariance += (1 - vsum);            m_KWSigma += (1 - ssum);                        for( int count = 0; count < centralTendencies.size(); count++ ) {                                int wB = 0, wV = 0;                int centralTendency = ((Integer)centralTendencies.get(count)).intValue();                                // For a single instance xi, find the bias and variance.                for (int j = 0; j < numClasses; j++) {                                        //Webb definition                    if( j != (int)current.classValue() && j == centralTendency ) {                        wB += predProbs[j];                    }                    if( j != (int)current.classValue() && j != centralTendency ) {                        wV += predProbs[j];                    }                                    }                wBSum += (double) wB;                wVSum += (double) wV;            }                        // calculate bais by dividing bSum by the number of central tendencies and            // total number of instances. (effectively finding the average and dividing            // by the number of instances to get the nominalised probability).                        m_WBias += ( wBSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));            // calculate variance by dividing vSum by the total number of interations            m_WVariance += ( wVSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));                    }                m_KWBias /= (2.0 * (double) data.numInstances());        m_KWVariance /= (2.0 * (double) data.numInstances());        m_KWSigma /= (2.0 * (double) data.numInstances());                // bias = bias / number of data instances        m_WBias /= (double) data.numInstances();        // variance = variance / number of data instances.        m_WVariance /= (double) data.numInstances();                if (m_Debug) {            System.err.println("Decomposition finished");        }            }        /** Finds the central tendency, given the classifications for an instance.     *     * Where the central tendency is defined as the class that was most commonly     * selected for a given instance.<p>     *     * For example, instance 'x' may be classified out of 3 classes y = {1, 2, 3},     * so if x is classified 10 times, and is classified as follows, '1' = 2 times, '2' = 5 times     * and '3' = 3 times. Then the central tendency is '2'. <p>     *     * However, it is important to note that this method returns a list of all classes     * that have the highest number of classifications.     *     * In cases where there are several classes with the largest number of classifications, then     * all of these classes are returned. For example if 'x' is classified '1' = 4 times,     * '2' = 4 times and '3' = 2 times. Then '1' and '2' are returned.<p>     *     * @param predProbs the array of classifications for a single instance.     *     * @return a Vector containing Integer objects which store the class(s) which     * are the central tendency.     */    public Vector findCentralTendencies(double[] predProbs) {                int centralTValue = 0;        int currentValue = 0;        //array to store the list of classes the have the greatest number of classifictions.        Vector centralTClasses;                centralTClasses = new Vector(); //create an array with size of the number of classes.                // Go through array, finding the central tendency.        for( int i = 0; i < predProbs.length; i++) {            currentValue = (int) predProbs[i];            // if current value is greater than the central tendency value then            // clear vector and add new class to vector array.            if( currentValue > centralTValue) {                centralTClasses.clear();                centralTClasses.addElement( new Integer(i) );                centralTValue = currentValue;            } else if( currentValue != 0 && currentValue == centralTValue) {                centralTClasses.addElement( new Integer(i) );            }        }        //return all classes that have the greatest number of classifications.        if( centralTValue != 0){            return centralTClasses;        } else {            return null;        }            }        /**     * Returns description of the bias-variance decomposition results.     *     * @return the bias-variance decomposition results as a string     */    public String toString() {                String result = "\nBias-Variance Decomposition Segmentation, Cross Validation\n" +        "with subsampling.\n";                if (getClassifier() == null) {            return "Invalid setup";        }                result += "\nClassifier    : " + getClassifier().getClass().getName();        if (getClassifier() instanceof OptionHandler) {            result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions());        }        result += "\nData File     : " + getDataFileName();        result += "\nClass Index   : ";        if (getClassIndex() == 0) {            result += "last";        } else {            result += getClassIndex();        }        result += "\nIterations    : " + getClassifyIterations();        result += "\np             : " + getP();        result += "\nTraining Size : " + getTrainSize();        result += "\nSeed          : " + getSeed();                result += "\n\nDefinition   : " +"Kohavi and Wolpert";        result += "\nError         :" + Utils.doubleToString(getError(), 4);        result += "\nBias^2        :" + Utils.doubleToString(getKWBias(), 4);        result += "\nVariance      :" + Utils.doubleToString(getKWVariance(), 4);        result += "\nSigma^2       :" + Utils.doubleToString(getKWSigma(), 4);                result += "\n\nDefinition   : " +"Webb";        result += "\nError         :" + Utils.doubleToString(getError(), 4);        result += "\nBias          :" + Utils.doubleToString(getWBias(), 4);        result += "\nVariance      :" + Utils.doubleToString(getWVariance(), 4);                return result;    }                /**     * Test method for this class     *     * @param args the command line arguments     */    public static void main(String [] args) {                try {            BVDecomposeSegCVSub bvd = new BVDecomposeSegCVSub();                        try {                bvd.setOptions(args);                Utils.checkForRemainingOptions(args);            } catch (Exception ex) {                String result = ex.getMessage() + "\nBVDecompose Options:\n\n";                Enumeration enu = bvd.listOptions();                while (enu.hasMoreElements()) {                    Option option = (Option) enu.nextElement();                    result += option.synopsis() + "\n" + option.description() + "\n";                }                throw new Exception(result);            }                        bvd.decompose();                        System.out.println(bvd.toString());                    } catch (Exception ex) {            System.err.println(ex.getMessage());        }            }        /**     * Accepts an array of ints and randomises the values in the array, using the     * random seed.     *     *@param index is the array of integers     *@param random is the Random seed.     */    public final void randomize(int[] index, Random random) {        for( int j = index.length - 1; j > 0; j-- ){            int k = random.nextInt( j + 1 );            int temp = index[j];            index[j] = index[k];            index[k] = temp;        }    }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?