📄 qualitymeasurelcv.java
字号:
/** * JBNC - Bayesian Network Classifiers Toolbox <p> * * Latest release available at http://sourceforge.net/projects/jbnc/ <p> * * Copyright (C) 1999-2003 Jarek Sacha <p> * * This program is free software; you can redistribute it and/or modify it * under the terms of the GNU General Public License as published by the Free * Software Foundation; either version 2 of the License, or (at your option) * any later version. <p> * * This program is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for * more details. <p> * * You should have received a copy of the GNU General Public License along with * this program; if not, write to the Free Software Foundation, Inc., 59 Temple * Place - Suite 330, Boston, MA 02111-1307, USA. <br> * http://www.fsf.org/licenses/gpl.txt */package jbnc.measures;import BayesianNetworks.BayesNet;import jbnc.dataset.DatasetInt;import jbnc.util.BNTools;import jbnc.util.FrequencyCalc;import java.util.Vector;/** * LCV - Local Cross Validation. Measure the quality of the bayesian network on * the dataset using local cross validation (on class variable). By default * 10-fold 1-time cross validation will be performed. * * @author Jarek Sacha * @since June 1, 1999 */public class QualityMeasureLCV extends QualityMeasure { protected QualityMeasureLogC logC = null; protected Vector[ /* * cvTimes */ ][ /* * cvFolds */ ] testSets = null; protected DatasetInt trainDataset = null; protected DatasetInt testDataset = null; protected FrequencyCalc[ /* * cvTimes */ ][ /* * cvFolds */ ] fc = null; /** Number of cross-validation folds. */ protected int cvFolds = 10; /* * public void setCVFolds(int cvFolds) * { * this.cvFolds = cvFolds; * } */ /** Number of times the cross validation is repeated (averaged). */ protected int cvTimes = 1; /** Constructor for the QualityMeasureLCV object */ public QualityMeasureLCV() { super(); } /** * Create quality measure for given a number of cross-validation folds. * Number of cross-validation repetitions is set to 1. * * @param cvFolds Description of Parameter */ public QualityMeasureLCV(int cvFolds) { super(); this.cvFolds = cvFolds; this.cvTimes = 1; } /** * Create quality measure for given a dataset and a number of * cross-validation folds. Number of cross-validation repetitions is set to * 1. * * @param dataset Description of Parameter * @param cvFolds Description of Parameter */ public QualityMeasureLCV(DatasetInt dataset, int cvFolds) { super(dataset); this.cvFolds = cvFolds; this.cvTimes = 1; } /** * Create quality measure for given a number of cross-validation folds and a * number of cross-validation repetitions. * * @param cvFolds Description of Parameter * @param cvTimes Description of Parameter */ public QualityMeasureLCV(int cvFolds, int cvTimes) { super(); this.cvFolds = cvFolds; this.cvTimes = cvTimes; } /** * Create quality measure for given a dataset and a number of * cross-validation folds and a number of cross-validation repetitions. * * @param dataset Description of Parameter * @param cvFolds Description of Parameter * @param cvTimes Description of Parameter */ public QualityMeasureLCV(DatasetInt dataset, int cvFolds, int cvTimes) { super(dataset); this.cvFolds = cvFolds; this.cvTimes = cvTimes; } /** * @param dataset The new Dataset value */ public void setDataset(DatasetInt dataset) { this.dataset = dataset; if (dataset == null) { logC = null; testSets = null; trainDataset = null; testDataset = null; fc = null; return; } jbnc.util.CVGenerator cvGenerator = new jbnc.util.CVGenerator(); cvGenerator.setCases(dataset.cases); try { trainDataset = (DatasetInt) dataset.clone(); testDataset = (DatasetInt) dataset.clone(); } catch (Exception e) { } fc = new FrequencyCalc[cvTimes][cvFolds]; Vector[] trainSets = new Vector[cvFolds]; testSets = new Vector[cvTimes][cvFolds]; for (int t = 0; t < cvTimes; ++t) { cvGenerator.generateSets(cvFolds, trainSets, testSets[t]); for (int f = 0; f < cvFolds; ++f) { trainDataset.cases = trainSets[f]; fc[t][f] = new FrequencyCalc(trainDataset); } } try { trainDataset = (DatasetInt) dataset.clone(); testDataset = (DatasetInt) dataset.clone(); } catch (Exception e) { } // Prepare logC logC = new QualityMeasureLogC(); logC.setUsePriors(usePriors); logC.setAlphaK(alphaK); } /** * Gets the Name attribute of the QualityMeasureLCV object * * @return The Name value */ public String getName() { return "Cross-validation " + cvFolds + "-fold " + cvTimes + "-time"; } /** * Gets the number of cross-validation folds. * * @return The CVFolds value */ public int getCVFolds() { return this.cvFolds; } /** * Gets the number of times the cross validation is repeated (averaged). * * @return The CVTimes value */ public int getCVTimes() { return this.cvTimes; } /** * Description of the Method * * @param net Description of Parameter * @return Description of the Returned Value * @exception Exception Description of Exception */ public final double evaluate(BayesNet net) throws Exception { // Calculate LC measure double q = 0; for (int t = 0; t < cvTimes; ++t) { for (int f = 0; f < cvFolds; ++f) { BNTools.learnParameters(net, fc[t][f], usePriors, alphaK); // test testDataset.cases = testSets[t][f]; logC.setDataset(testDataset); q += logC.evaluate(net); } } return q / cvTimes; } /* * public void setCVTimes(int cvTimes) * { * this.cvTimes = cvTimes; * } */}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -