⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cattestresult.java

📁 java数据挖掘算法
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
                    Error.fatalErr("CatTestResult::error(Memorized): No test instances "
                    +"are in the training set.  This causes division by 0");
                
                double weightOnTrainIncorrect = 0;
                //	 for (int i = results.low(); i <= results.high(); i++)
                for (int i = 0; i < results.length; i++)
                    if ( (results[i].inTrainIL) &&
                    ( results[i].augCat.num() != results[i].correctCat ) )
                        weightOnTrainIncorrect += results[i].instance.get_weight();
                return weightOnTrainIncorrect / total_weight_on_train();
            }
            default:
                Error.fatalErr("CatTestResult::error unexpected error type"
                +(int)errType);
                return 1.0;
        } // end switch
    }
    
    /** Weight of test instances appearing in appearing in
     * the training data. Initializes flag for each test instance if not already
     * done.
     * @return The total weight of the instances found in the training and test
     * data sets.
     */
    public double total_weight_on_train() {
        if (!inTrainingSetInitialized )
            initializeTrainTable();
        return weightOnTrain;
    }
    
    /** Weight of test instances appearing not appearing in
     * the training data. Initializes flag for each test instance if not already
     * done.
     * @return The total weight of the instances not found in the training and test
     * data sets.
     */
    public double total_weight_off_train() {
        return total_test_weight() - total_weight_on_train();
    }
    
    /** Uses TableCategorizer as an interface to hash table to do quick
     * lookup on whether a test instance occurs in the training
     * set. Only called when inTrainIL data is needed.
     * Initializes class variable numOffTrain to number of test
     * cases found in training set.
     */
    protected void initializeTrainTable() {
        int numTestInTrain = 0;
        double weightTestInTrain = 0;
        
        // @@ Due to the MString leak compiler bug, this temporary is allocated
        // here
        String tableName ="Training Set Lookup Table";
        
        TableCategorizer trainTable = new TableCategorizer(trainIL,
        Globals.UNKNOWN_CATEGORY_VAL,
        tableName);
        
        int i = 0;
        //   for (ILPix pix(*testIL); pix; ++pix, ++i) {
        //      InstanceRC instance = *pix;
        for(ListIterator pixI = testIL.instance_list().listIterator();pixI.hasNext();i++){
            Instance instance = (Instance)pixI.next();
            if (trainTable.categorize(instance).Category() != Globals.UNKNOWN_CATEGORY_VAL) {
                results[i].inTrainIL = true; // constructor sets to FALSE
                numTestInTrain++;
                weightTestInTrain += instance.get_weight();
            }
        }
        numOnTrain = numTestInTrain;
        weightOnTrain = weightTestInTrain;
        inTrainingSetInitialized = true;
    }
    
    /** Returns the total weight in the test list.
     * @return The total weight of the test data set.
     */
    public double total_test_weight() {
        return testIL.total_weight();
    }
    
    /** Returns the total weight of instances which were
     * incorrectly classified.
     * @return The total weight of incorrectly classified instances.
     */
    public double total_incorrect_weight() {
        return metrics.weightIncorrect;
    }
    
    /** Initializes this CatTestResult by categorizing the test data set with the
     * given Categorizer.
     * @param cat	The categorizer with which the test data set will be categorized.
     */
    protected void initialize(Categorizer cat) {
        //   DBG(trainIL.get_schema().compatible_with(testIL.get_schema(), true));
        int i = 0;
        
        int numProcessed = 0;
        int numTotal = testIL.num_instances();
        int tenthsDone = 0;
        // just for now JWP  logOptions.DRIBBLE("Classifying (% done): ");
        //   for (ILPix pix(*testIL); pix; ++pix, ++i) {
        for(ListIterator pixIL = testIL.instance_list().listIterator();pixIL.hasNext();++i){
            Instance instance = (Instance)pixIL.next();
            int newTenthsDone = (10 * ++numProcessed) / numTotal;
            
            // If the outer loop is not in an IFDRIBBLE, we get into
            //   an infinite loop here!
            //      IFDRIBBLE(while (newTenthsDone > tenthsDone)
            //		logOptions.DRIBBLE(10 * ++tenthsDone + "%  " + flush));
            //      Instance instance = pix;
            logOptions.LOG(2, "Instance: " + instance);
            //      logOptions.LOG(2, "Instance: ");
            //	instance.display(false,false);
            int correctCat = instance.label_info().get_nominal_val(instance.get_label());
            results[i].correctCat = correctCat;
            results[i].instance = new Instance(instance);
            
            // It is an error to encounter unknown labels at this point.
            if(instance.label_info().is_unknown(instance.get_label()))
                Error.fatalErr("Test instance (" +instance +") has an unknown "
                +"label value");
            
            // Change label to UNKNOWN, so that categorizers don't cheat.
            //      DBG(AttrValue_ av;
            //	  instance.label_info().set_unknown(av);
            //	  instance.set_label(av));
            
            // Categorize the instance.  If the categorizer is capable of
            // scoring, compute the probability distribution of the label
            // as well.  If the categorizer is NOT capable of scoring,
            // build all-or-nothing distributions.
            if(cat.supports_scoring()) {
                results[i].predDist = cat.score(instance);
                results[i].correctDist = new CatDist(cat.get_schema(), correctCat);
                results[i].augCat = new AugCategory(results[i].predDist.best_category());
            }
            else {
                AugCategory predCat = cat.categorize(instance);
                results[i].augCat = new AugCategory(predCat);
                results[i].predDist = new CatDist(cat.get_schema(), predCat);
                results[i].correctDist = new CatDist(cat.get_schema(), correctCat);
            }
            
            // Accumulate all results into a ScoringMetric structure.
            ScoringMetrics metric = new ScoringMetrics();
            compute_scoring_metrics(metric,
            results[i].augCat,
            new AugCategory(correctCat, "correct"),
            results[i].predDist,
            results[i].correctDist,
            instance.get_weight(),
            testIL.get_schema(),
            computeLogLoss);
            
            AugCategory predCat = results[i].augCat;
            catDistrib[predCat.num()].numTestSet++;
            accumulate_scoring_metrics(catDistrib[predCat.num()].metrics, metric);
            confusionMatrix[correctCat][predCat.num()] += instance.get_weight();
            accumulate_scoring_metrics(metrics, metric);
            
            // accumulate the loss into the loss statistics statData
            lossStats.insert(metric.totalLoss);
            
            logOptions.LOG(2, "Correct label: "
            +instance.get_schema().category_to_label_string(correctCat)
            +" Predicted label: " +predCat.description() +". ");
            
            if (predCat.num() == correctCat) {
                // checks that strings match too
                //	 DBG(if (predCat.description() !=
                //		 instance.get_schema().category_to_label_string(correctCat))
                //	     err + "CatTestResult::initialize: labels match but not in "
                //	     " name. Categorizer=" + predCat.description()
                //	     + ". Instance="
                //	     + instance.get_schema().
                //	     category_to_label_string(correctCat)+ fatal_error);
                logOptions.LOG(2, "Correct. ");
            } else {
                logOptions.LOG(2, "Incorrect. ");
            }
            
            logOptions.LOG(2, "No correct: " +metrics.numCorrect +'/' +(i + 1)+'\n');
            
        }
        //   IFDRIBBLE(while (tenthsDone < 10)
        //	 logOptions.DRIBBLE(10 * ++tenthsDone +"%  " + flush));
        // just for now JWP   logOptions.DRIBBLE("done." +'\n');
    }
    
    
    /** Accumulates the given ScoringMetrics by the increment of another
     * ScoringMetrics.
     * @param dest	The ScoringMetrics which is incremented.
     * @param src 	The ScoringMetrics which provides the step for incrementation.
     */
    private void accumulate_scoring_metrics(ScoringMetrics dest,
    ScoringMetrics src) {
        dest.numCorrect += src.numCorrect;
        dest.numIncorrect += src.numIncorrect;
        dest.totalLoss += src.totalLoss;
        dest.weightCorrect += src.weightCorrect;
        dest.weightIncorrect += src.weightIncorrect;
        dest.meanSquaredError += src.meanSquaredError;
        dest.meanAbsoluteError += src.meanAbsoluteError;
        dest.minimumLoss += src.minimumLoss;
        dest.maximumLoss += src.maximumLoss;
        dest.totalLogLoss += src.totalLogLoss;
    }
    
    /** Useful functions for computing scoring metrics.
     * All metrics (probabilistic or normal) are computed
     * within this function.
     * @param metrics		The ScoringMetrics
     * @param predictedCat	The category determined by a Categorizer.
     * @param correctCat		The correct category for a test Instance.
     * @param predictedDist	The distribution of categories determined by a Categorizer.
     * @param correctDist	The distribution of the correct categories for a test data set.
     * @param weight		The weight of the instance currently being added.
     * @param testSchema		The Schema for the test Instance.
     * @param computeLogLoss	Indicator of whether LogLoss should be computed. True indicates
     * LogLoss should be computed.
     */
    private void compute_scoring_metrics(ScoringMetrics metrics,
    AugCategory predictedCat,
    AugCategory correctCat,
    CatDist predictedDist,
    CatDist correctDist,
    double weight,
    Schema testSchema,
    boolean computeLogLoss) {
        double[] predScores = predictedDist.get_scores();
        double[] corrScores = correctDist.get_scores();
        if(predScores.length != corrScores.length)
            Error.fatalErr("compute_scoring_metrics: correct and predicted "
            +"distributions have different sizes" );
        
        // Compare only the categories, because the predictedCat and correctCat
        // may get different descriptions if one was created by a CatDist
        // based on a scoring inducer.
        if(predictedCat.Category() == correctCat.Category()) {
            metrics.weightCorrect += weight;
            metrics.numCorrect++;
        }
        else {
            metrics.weightIncorrect += weight;
            metrics.numIncorrect++;
        }
        
        // If a loss matrix is defined for the TEST set, look up the loss
        // and add it to our totalLoss.  Then find the minimum and maximum
        // losses given this correctCat.  Add these to the minimum and
        // maximum loss metrics respectively.
        // All losses must be scaled by the instance's weight.
        if(testSchema.has_loss_matrix()) {
            metrics.totalLoss += testSchema.get_loss_matrix()[correctCat.num()][predictedCat.num()] * weight;
            int index = 0;
            metrics.minimumLoss += Matrix.min_in_row(correctCat.num(), index, testSchema.get_loss_matrix()) * weight;
            metrics.maximumLoss += Matrix.max_in_row(correctCat.num(), index, testSchema.get_loss_matrix()) * weight;
        }
        else {
            if(predictedCat.notequal(correctCat))
                metrics.totalLoss += weight;
            metrics.maximumLoss += weight;
        }
        
        // Compute the mean squared and mean absolute error between the
        // predicted and correct distributions.
        // The correct distribution will generally be an
        // all-or-nothing distribution in favor of the correct category.
        // Normalize the mean squared/mean absolute errors to be always
        // between 0 and 1.  This involves dividing by 2.
        double mse = 0;
        double mae = 0;
        for(int i=0; i<predScores.length; i++) {
            double diff = predScores[i] - corrScores[i];
            mse += diff*diff;
            mae += Math.abs(diff);
        }
        mse /= 2;
        mae /= 2;
        
        metrics.meanSquaredError += weight*mse;
        metrics.meanAbsoluteError += weight*mae;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -