⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 bvdecomposesegcvsub.java

📁 MacroWeka扩展了著名数据挖掘工具weka
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
                int [] segmentIndex = new int [ tps ];
                
                for(int index = 0; index < segmentIndex.length; index++, currentDataIndex++){
                    segmentIndex[index] = currentDataIndex;
                }
                segmentList.add(segmentIndex);
            }
        }
        
        int remainder = tps % k; // remainder is used to determine when to shrink the fold size by 1.
        
        //foldSize = ROUNDUP( tps / k ) (round up, eg 3 -> 3,  3.3->4)
        int foldSize = (int) Math.ceil( (double)tps /(double) k); //roundup fold size double to integer
        int index = 0;
        int currentIndex, endIndex;
        
        for( int count = 0; count < k; count ++){
            if( remainder != 0 && count == remainder ){
                foldSize -= 1;
            }
            foldIndex[count][0] = index;
            foldIndex[count][1] = foldSize;
            index += foldSize;
        }
        
        for( int l = 0; l < m_ClassifyIterations; l++) {
            
            for(int i = 1; i <= q; i++){
                
                int [] currentSegment = (int[]) segmentList.get(i - 1);
                
                randomize(currentSegment, random);
                
                //CROSS FOLD VALIDATION for current Segment
                for( int j = 1; j <= k; j++){
                    
                    Instances TP = null;
                    for(int foldNum = 1; foldNum <= k; foldNum++){
                        if( foldNum != j){
                            
                            int startFoldIndex = foldIndex[ foldNum - 1 ][ 0 ]; //start index
                            foldSize = foldIndex[ foldNum - 1 ][ 1 ];
                            int endFoldIndex = startFoldIndex + foldSize - 1;
                            
                            for(int currentFoldIndex = startFoldIndex; currentFoldIndex <= endFoldIndex; currentFoldIndex++){
                                
                                if( TP == null ){
                                    TP = new Instances(data, currentSegment[ currentFoldIndex ], 1);
                                }else{
                                    TP.add( data.instance( currentSegment[ currentFoldIndex ] ) );
                                }
                            }
                        }
                    }
                    
                    TP.randomize(random);
                    
                    if( getTrainSize() > TP.numInstances() ){
                        throw new Exception("The training set size of " + getTrainSize() + ", is greater than the training pool "
                        + TP.numInstances() );
                    }
                    
                    Instances train = new Instances(TP, 0, m_TrainSize);
                    
		    Classifier current = Classifier.makeCopy(m_Classifier);
                    current.buildClassifier(train); // create a clssifier using the instances in train.
                    
                    int currentTestIndex = foldIndex[ j - 1 ][ 0 ]; //start index
                    int testFoldSize = foldIndex[ j - 1 ][ 1 ]; //size
                    int endTestIndex = currentTestIndex + testFoldSize - 1;
                    
                    while( currentTestIndex <= endTestIndex ){
                        
                        Instance testInst = data.instance( currentSegment[currentTestIndex] );
                        int pred = (int)current.classifyInstance( testInst );
                        
                        
                        if(pred != testInst.classValue()) {
                            m_Error++; // add 1 to mis-classifications.
                        }
                        instanceProbs[ currentSegment[ currentTestIndex ] ][ pred ]++;
                        currentTestIndex++;
                    }
                    
                    if( i == 1 && j == 1){
                        int[] segmentElast = (int[])segmentList.lastElement();
                        for( currentIndex = 0; currentIndex < segmentElast.length; currentIndex++){
                            Instance testInst = data.instance( segmentElast[currentIndex] );
                            int pred = (int)current.classifyInstance( testInst );
                            if(pred != testInst.classValue()) {
                                m_Error++; // add 1 to mis-classifications.
                            }
                            
                            instanceProbs[ segmentElast[ currentIndex ] ][ pred ]++;
                        }
                    }
                }
            }
        }
        
        m_Error /= (double)( m_ClassifyIterations * data.numInstances() );
        
        m_KWBias = 0.0;
        m_KWVariance = 0.0;
        m_KWSigma = 0.0;
        
        m_WBias = 0.0;
        m_WVariance = 0.0;
        
        for (int i = 0; i < data.numInstances(); i++) {
            
            Instance current = data.instance( i );
            
            double [] predProbs = instanceProbs[ i ];
            double pActual, pPred;
            double bsum = 0, vsum = 0, ssum = 0;
            double wBSum = 0, wVSum = 0;
            
            Vector centralTendencies = findCentralTendencies( predProbs );
            
            if( centralTendencies == null ){
                throw new Exception("Central tendency was null.");
            }
            
            for (int j = 0; j < numClasses; j++) {
                pActual = (current.classValue() == j) ? 1 : 0;
                pPred = predProbs[j] / m_ClassifyIterations;
                bsum += (pActual - pPred) * (pActual - pPred) - pPred * (1 - pPred) / (m_ClassifyIterations - 1);
                vsum += pPred * pPred;
                ssum += pActual * pActual;
            }
            
            m_KWBias += bsum;
            m_KWVariance += (1 - vsum);
            m_KWSigma += (1 - ssum);
            
            for( int count = 0; count < centralTendencies.size(); count++ ) {
                
                int wB = 0, wV = 0;
                int centralTendency = ((Integer)centralTendencies.get(count)).intValue();
                
                // For a single instance xi, find the bias and variance.
                for (int j = 0; j < numClasses; j++) {
                    
                    //Webb definition
                    if( j != (int)current.classValue() && j == centralTendency ) {
                        wB += predProbs[j];
                    }
                    if( j != (int)current.classValue() && j != centralTendency ) {
                        wV += predProbs[j];
                    }
                    
                }
                wBSum += (double) wB;
                wVSum += (double) wV;
            }
            
            // calculate bais by dividing bSum by the number of central tendencies and
            // total number of instances. (effectively finding the average and dividing
            // by the number of instances to get the nominalised probability).
            
            m_WBias += ( wBSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
            // calculate variance by dividing vSum by the total number of interations
            m_WVariance += ( wVSum / ((double) ( centralTendencies.size() * m_ClassifyIterations )));
            
        }
        
        m_KWBias /= (2.0 * (double) data.numInstances());
        m_KWVariance /= (2.0 * (double) data.numInstances());
        m_KWSigma /= (2.0 * (double) data.numInstances());
        
        // bias = bias / number of data instances
        m_WBias /= (double) data.numInstances();
        // variance = variance / number of data instances.
        m_WVariance /= (double) data.numInstances();
        
        if (m_Debug) {
            System.err.println("Decomposition finished");
        }
        
    }
    
    /** Finds the central tendency, given the classifications for an instance.
     *
     * Where the central tendency is defined as the class that was most commonly
     * selected for a given instance.<p>
     *
     * For example, instance 'x' may be classified out of 3 classes y = {1, 2, 3},
     * so if x is classified 10 times, and is classified as follows, '1' = 2 times, '2' = 5 times
     * and '3' = 3 times. Then the central tendency is '2'. <p>
     *
     * However, it is important to note that this method returns a list of all classes
     * that have the highest number of classifications.
     *
     * In cases where there are several classes with the largest number of classifications, then
     * all of these classes are returned. For example if 'x' is classified '1' = 4 times,
     * '2' = 4 times and '3' = 2 times. Then '1' and '2' are returned.<p>
     *
     * @param predProbs the array of classifications for a single instance.
     *
     * @return a Vector containing Integer objects which store the class(s) which
     * are the central tendency.
     */
    public Vector findCentralTendencies(double[] predProbs) {
        
        int centralTValue = 0;
        int currentValue = 0;
        //array to store the list of classes the have the greatest number of classifictions.
        int index = 0;
        Vector centralTClasses;
        
        centralTClasses = new Vector(); //create an array with size of the number of classes.
        
        // Go through array, finding the central tendency.
        for( int i = 0; i < predProbs.length; i++) {
            currentValue = (int) predProbs[i];
            // if current value is greater than the central tendency value then
            // clear vector and add new class to vector array.
            if( currentValue > centralTValue) {
                centralTClasses.clear();
                centralTClasses.addElement( new Integer(i) );
                centralTValue = currentValue;
            } else if( currentValue != 0 && currentValue == centralTValue) {
                centralTClasses.addElement( new Integer(i) );
            }
        }
        //return all classes that have the greatest number of classifications.
        if( centralTValue != 0){
            return centralTClasses;
        } else {
            return null;
        }
        
    }
    
    /**
     * Returns description of the bias-variance decomposition results.
     *
     * @return the bias-variance decomposition results as a string
     */
    public String toString() {
        
        String result = "\nBias-Variance Decomposition Segmentation, Cross Validation\n" +
        "with subsampling.\n";
        
        if (getClassifier() == null) {
            return "Invalid setup";
        }
        
        result += "\nClassifier    : " + getClassifier().getClass().getName();
        if (getClassifier() instanceof OptionHandler) {
            result += Utils.joinOptions(((OptionHandler)m_Classifier).getOptions());
        }
        result += "\nData File     : " + getDataFileName();
        result += "\nClass Index   : ";
        if (getClassIndex() == 0) {
            result += "last";
        } else {
            result += getClassIndex();
        }
        result += "\nIterations    : " + getClassifyIterations();
        result += "\np             : " + getP();
        result += "\nTraining Size : " + getTrainSize();
        result += "\nSeed          : " + getSeed();
        
        result += "\n\nDefinition   : " +"Kohavi and Wolpert";
        result += "\nError         :" + Utils.doubleToString(getError(), 4);
        result += "\nBias^2        :" + Utils.doubleToString(getKWBias(), 4);
        result += "\nVariance      :" + Utils.doubleToString(getKWVariance(), 4);
        result += "\nSigma^2       :" + Utils.doubleToString(getKWSigma(), 4);
        
        result += "\n\nDefinition   : " +"Webb";
        result += "\nError         :" + Utils.doubleToString(getError(), 4);
        result += "\nBias          :" + Utils.doubleToString(getWBias(), 4);
        result += "\nVariance      :" + Utils.doubleToString(getWVariance(), 4);
        
        return result;
    }
    
    
    
    /**
     * Test method for this class
     *
     * @param args the command line arguments
     */
    public static void main(String [] args) {
        
        try {
            BVDecomposeSegCVSub bvd = new BVDecomposeSegCVSub();
            
            try {
                bvd.setOptions(args);
                Utils.checkForRemainingOptions(args);
            } catch (Exception ex) {
                String result = ex.getMessage() + "\nBVDecompose Options:\n\n";
                Enumeration enu = bvd.listOptions();
                while (enu.hasMoreElements()) {
                    Option option = (Option) enu.nextElement();
                    result += option.synopsis() + "\n" + option.description() + "\n";
                }
                throw new Exception(result);
            }
            
            bvd.decompose();
            
            System.out.println(bvd.toString());
            
        } catch (Exception ex) {
            System.err.println(ex.getMessage());
        }
        
    }
    
    /**
     * Accepts an array of ints and randomises the values in the array, using the
     * random seed.
     *
     *@param index is the array of integers
     *@param random is the Random seed.
     */
    public final void randomize(int[] index, Random random) {
        for( int j = index.length - 1; j > 0; j-- ){
            int k = random.nextInt( j + 1 );
            int temp = index[j];
            index[j] = index[k];
            index[k] = temp;
        }
    }
    
    
    
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -