⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 cattestresult.java

📁 java数据挖掘算法
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
    public String display_confusion_matrix(String stream) {
        //   DispConfMatType defaultDispConfMat = no;
        String dispConfMatHelp =
        "Display confusion matrix when displaying results on test";
        //   static DispConfMatType dispConfusionMatrix =
        //      get_option_enum("DISP_CONFUSION_MAT", dispConfMatEnum,
        //		      defaultDispConfMat, dispConfMatHelp, TRUE);
        
        //   switch (dispConfusionMatrix) {
        //      case no:
        //	 break;
        //      case ascii:
        return display_ascii_confusion_matrix(stream);
        //	 break;
        //      case both:
        //	 display_ascii_confusion_matrix(stream);
        //      case scatterviz:
        //	 MString outputSVizfile =
        //	    get_option_string("SCATTERVIZ_FILE", EMPTY_STRING,
        //	      "File name for ScatterViz confusion matrix config file", TRUE);
        //
        //	 // if the name is provided, use it to generate a permanent file.
        //	 if(outputSVizfile != EMPTY_STRING) {
        //	    outputSVizfile += ".conf_matrix.scatterviz";
        //	    MLCOStream out(outputSVizfile);
        //	    MString dataName = out.description() + ".data";
        //	    MLCOStream data(dataName);
        //	    display_scatterviz_confusion_matrix(out, data);
        //	 }
        
        // otherwise, use a TEMPORARY scatterviz file only.
        //	 else {
        //	    Array<MString> suffixes(2);
        //	    suffixes[0] = ".scatterviz";
        //	    suffixes[1] = ".scatterviz.data";
        //	    PtrArray<TmpFileName *> *tempNames = gen_temp_file_names(suffixes);
        //	    const TmpFileName& tmpSVizConf = *tempNames->index(0);
        //	    const TmpFileName& tmpSVizData = *tempNames->index(1);
        //	    MLCOStream SVizConf(tmpSVizConf);
        //	    MLCOStream SVizData(tmpSVizData);
        //	    display_scatterviz_confusion_matrix(SVizConf, SVizData);
        //	    SVizConf.close();
        //	    SVizData.close();
        //	    if(system(*GlobalOptions::scattervizUtil + " " +
        //		      tmpSVizConf))
        //	       Mcerr << "CatTestResult::display: "
        //		  "Call to ScatterViz returns an error." << endl;
        //	    delete tempNames;
        //	 }
        //	 break;
        //   }
    }
    
    
    /** Gives all available statistics (not displays)
     * @param stream	The writer to which the statistics will be displayed.
     */
    public void display(BufferedWriter stream) {
        try{stream.write(toString());}
        catch(IOException e){e.printStackTrace();}
    }
    
    /** Converts information in this CatTestResult object to a string for display.
     * @return The String containing the display.
     */
    public String toString() {
        String rtrn = new String();
        GetEnv getenv = new GetEnv();
        boolean dispMisclassifications = getenv.get_option_bool("DISP_MISCLASS", false,
        "Display misclassified instances",true);
        
        boolean isWeighted = trainIL.is_weighted() || testIL.is_weighted();
        
        //No perf_display method yet -JL
        //   rtrn = rtrn + push_opts + perf_display;
        if(isWeighted) {
            rtrn = rtrn +  "Total training weight: " + (float)total_train_weight()
            + " (" + num_train_instances() + " instances)" + "\n";
            rtrn = rtrn +  "Total test weight: " + (float)total_test_weight()
            + " (" + num_test_instances() + " instances)" + "\n";
            rtrn = rtrn +  "  seen: " + (float)total_weight_on_train()
            + " (" + num_on_train() + " instances)" + "\n";
            rtrn = rtrn +  "  unseen: " + (float)total_weight_off_train()
            + " (" + num_off_train() + " instances)" + "\n";
            rtrn = rtrn +  "  correct: " + (float)total_correct_weight()
            + " (" + num_correct() + " instances)" + "\n";
            rtrn = rtrn +  "  incorrect: " + (float)total_incorrect_weight()
            + " (" + num_incorrect() + " instances)" + "\n";
        }
        else {
            rtrn = rtrn +  "Number of training instances: " + num_train_instances()
            + "\n";
            rtrn = rtrn +  "Number of test instances: " + num_test_instances()
            + ".  Unseen: " + num_off_train()
            + ",  seen " + num_on_train() + ".\n";
            rtrn = rtrn +  "Number correct: " + num_correct()
            + ".  Number incorrect: " + num_incorrect() + "\n";
        }
        
        //   DBG(ASSERT(num_off_train() + num_on_train() == num_test_instances()));
        DoubleRef confLow = new DoubleRef();
        DoubleRef confHigh = new DoubleRef();
        confidence(confLow, confHigh, error(), num_test_instances());
        rtrn = rtrn +  "Generalization error: ";
        
        if (num_off_train() > 0)
            rtrn = rtrn +  (error(Generalized) * 100) + '%';
        else
            rtrn = rtrn +  "unknown";
        rtrn = rtrn +  ".  Memorization error: ";
        if (num_on_train() > 0)
            rtrn = rtrn + (error(Memorized) * 100) + '%' + "\n";
        else
            rtrn = rtrn +  "unknown" + "\n";
        
        rtrn = rtrn +  "Error: " + error() * 100 + "% +- ";
        if(total_test_weight() > 1)
            rtrn = rtrn +  (theoretical_std_dev(error(), total_test_weight()) * 100)
            + "%";
        else
            rtrn = rtrn +  "Undefined";
        
        rtrn = rtrn +  " [" + confLow.value * 100 + "% - " + confHigh.value * 100 + "%]" + "\n";
        
        // display scoring metrics if the categorizer supports scoring
        rtrn = rtrn +  "Average Normalized Mean Squared Error: "
        + (100 * total_mean_squared_error() / total_test_weight())
        + "%" + "\n";
        rtrn = rtrn +  "Average Normalized Mean Absolute Error: "
        + (100 * total_mean_absolute_error() / total_test_weight())
        + "%" + "\n";
        
        if (computeLogLoss)
            rtrn = rtrn +  "Average log-loss: "
            + total_log_loss() / total_test_weight() + "\n";
        
        // display the total and average loss if the test set had a loss matrix
        if(testIL.get_schema().has_loss_matrix()) {
            rtrn = rtrn +  "Total Loss: " + metrics.totalLoss + "\n";
            rtrn = rtrn +  "Average Loss: " + (metrics.totalLoss / total_test_weight())
            + "\n";
        }
        /*   stream.write(pop_opts);*/
        
        rtrn = display_confusion_matrix(rtrn);
/*
   if (dispMisclassifications)
      display_incorrect_instances();
 
   // dispay a confusion scattergram in VRML if requested
   if(get_option_bool("DISP_CONF_SCATTERGRAM", false, "", true)) {
      MString fileName =
         get_option_string_no_default("CONF_SCATTERGRAM_NAME", "", false)
         + ".wrl";
      display_vrml_scattergram(fileName);
   }
 */
        return rtrn;
    }
    
    /** Returns the InstanceList used for training.
     * @return The InstanceList used for training.
     */
    public InstanceList get_training_instance_list(){return trainIL;}
    
    /** Returns the InstanceList used for testing.
     * @return The InstanceList used for testing.
     */
    public InstanceList get_testing_instance_list(){return testIL;}
    
    /** Returns the individual results from testing.
     * @return The individual results from testing.
     */
    public CatOneTestResult[] get_results(){return results;}
    
    /** Returns the scoring metrics collected from the test results.
     * @return The scoring metrics collected from the test results.
     */
    public ScoringMetrics get_metrics(){ return metrics; }
    
    /** Returns the confusion matrix of the results of testing.
     * @return The confusion matrix of the results of testing.
     */
    public double[][] get_confusion_matrix(){ return confusionMatrix; }
    
    /** Returns the total loss value from the scoring metrics.
     * @return The total loss value from the scoring metrics.
     */
    public double total_loss() {
        return metrics.totalLoss;
    }
    
    /** Calculates a normalized loss value.
     * @return The loss value normalized by the loss value range.
     */
    public double normalized_loss() {
        return (metrics.totalLoss - metrics.minimumLoss) /
        (metrics.maximumLoss - metrics.minimumLoss);
    }
    
    /** Sets the computation of log loss option.
     * @param b The new setting of the log loss option.
     */    
    public static void set_compute_log_loss(boolean b) {
        computeLogLoss = b;
    }

    /** Returns TRUE if the log loss option is set, or FALSE otherwise.
     * @return TRUE if the log loss option is set, FALSE otherwise.
     */    
    public static boolean get_compute_log_loss() {
        return computeLogLoss;
    }
    
/*
protected:
   // Protected methods
   void initialize();
 
public:
   void OK(int level = 1) const;
   CatTestResult(const Categorizer& cat,
                 const InstanceList& trainILSource,
                 const MString& testFile,
                 const MString& namesExtension = DEFAULT_NAMES_EXT,
                 const MString& testExtension = DEFAULT_TEST_EXT);
   CatTestResult(const Categorizer& cat,
                 const InstanceList& trainILSource,
                 const InstanceList& testILSource);
   CatTestResult(const Array<CatOneTestResult>& resultsArray,
                 const InstanceList& trainILSource,
                 const InstanceList& testILSource);
   virtual ~CatTestResult();
 
   void assign_instance_lists(InstanceList *& train,
                              InstanceList *& test);
 
 
   // metrics
   Real total_label_weight(Category label) const;
   Real true_label_weight(Category label) const;
   Real false_label_weight(Category label) const;
 
   Real accuracy(ErrorType errType = Normal) const;
   Real minimum_loss() const;
   Real maximum_loss() const;
 
   const StatData& loss_stats() const { return lossStats; }
   Real mean_loss() const { return total_loss() / total_test_weight(); }
   Real std_dev_loss() const { return loss_stats().std_dev(); }
   void confidence_loss(Real& confLow, Real& confHigh,
                               Real z = CONFIDENCE_INTERVAL_Z) const
   { loss_stats().percentile(z, confLow, confHigh); }
 
 
   virtual InstanceRC get_instance(int num) const;
   virtual const AugCategory& label(int num) const;
   virtual const AugCategory& predicted_label(int num) const;
   // display_* show the instance and both labels (except for
   //   display_correct_instances() which shows only one label).
   // Instances which were in the training set say "(In TS)" on
   //   the display line.
   virtual void display_all_instances(MLCOStream& stream = Mcout) const;
   virtual void display_incorrect_instances(MLCOStream& stream = Mcout) const;
   virtual void display_correct_instances(MLCOStream& stream = Mcout) const;
   // category distribution displays an array where each cell i
   //   consists of: (1) number of test instances in class i,
   //   (2) number of test instances correctly predicted as class i,
   //   (3) number of test instances incorrectly predicted as class i.
   virtual void display_category_distrib(MLCOStream& stream = Mcout) const;
 
   virtual void display_scatterviz_confusion_matrix(MLCOStream& stream,
                                                    MLCOStream& data) const;
   // display_all dumps everything (display + display_all_instances).
   virtual void display_all(MLCOStream& stream = Mcout) const;
   virtual void display_scatterviz_lift_curve(const Category& labelValue,
                                              MString configFileName) const;
   virtual void display_vrml_scattergram(const MString& fileName) const;
*/
    
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -