⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 globalscoresearchalgorithm.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/*
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 * CVSearchAlgorithm.java
 * Copyright (C) 2004 Remco Bouckaert
 * 
 */

package weka.classifiers.bayes.net.search.global;

import java.util.Enumeration;
import java.util.Vector;

import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

/** The CVSearchAlgorithm class supports Bayes net structure search algorithms
 * that are based on cross validation (as opposed to for example
 * score based of conditional independence based search algorithms).
 * 
 * @author Remco Bouckaert
 * @version $Revision$
 */
public class GlobalScoreSearchAlgorithm extends SearchAlgorithm {
	
	/** points to Bayes network for which a structure is searched for **/
	BayesNet m_BayesNet;
	
	/** toggle between scoring using accuracy = 0-1 loss (when false) or class probabilities (when true) **/
	boolean m_bUseProb = true;
	
	/** number of folds for k-fold cross validation **/
	int m_nNrOfFolds = 10;

	/** constants for score types **/
	final static int LOOCV = 0;
	final static int KFOLDCV = 1;
	final static int CUMCV = 2;

	public static final Tag[] TAGS_CV_TYPE =
		{
			new Tag(LOOCV, "LOO-CV"),
			new Tag(KFOLDCV, "k-Fold-CV"),
			new Tag(CUMCV, "Cumulative-CV")
		};
	/**
	 * Holds the cross validation strategy used to measure quality of network
	 */
	int m_nCVType = LOOCV;

	/**
	 * performCV returns the accuracy calculated using cross validation.  
	 * The dataset used is m_Instances associated with the Bayes Network.
	 * @param bayesNet : Bayes Network containing structure to evaluate
	 * @return accuracy (in interval 0..1) measured using cv.
	 * @throws Exception whn m_nCVType is invalided + exceptions passed on by updateClassifier
	 */
	public double calcScore(BayesNet bayesNet) throws Exception {
		switch (m_nCVType) {
			case LOOCV: 
				return leaveOneOutCV(bayesNet);
			case CUMCV: 
				return cumulativeCV(bayesNet);
			case KFOLDCV: 
				return kFoldCV(bayesNet, m_nNrOfFolds);
			default:
				throw new Exception("Unrecognized cross validation type encountered: " + m_nCVType);
		}
	} // calcScore

	/**
	 * Calc Node Score With Added Parent
	 * 
	 * @param nNode node for which the score is calculate
	 * @param nCandidateParent candidate parent to add to the existing parent set
	 * @return log score
	 */
	public double calcScoreWithExtraParent(int nNode, int nCandidateParent) throws Exception {
		ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
		Instances instances = m_BayesNet.m_Instances;

		// sanity check: nCandidateParent should not be in parent set already
		for (int iParent = 0; iParent < oParentSet.getNrOfParents(); iParent++) {
			if (oParentSet.getParent(iParent) == nCandidateParent) {
				return -1e100;
			}
		}

		// determine cardinality of parent set & reserve space for frequency counts
		int nCardinality =
			oParentSet.getCardinalityOfParents() * instances.attribute(nCandidateParent).numValues();
		int numValues = instances.attribute(nNode).numValues();
		int[][] nCounts = new int[nCardinality][numValues];

		// set up candidate parent
		oParentSet.addParent(nCandidateParent, instances);

		// calculate the score
		double fAccuracy = calcScore(m_BayesNet);

		// delete temporarily added parent
		oParentSet.deleteLastParent(instances);

		return fAccuracy;
	} // calcScoreWithExtraParent


	/**
	 * Calc Node Score With Parent Deleted
	 * 
	 * @param nNode node for which the score is calculate
	 * @param nCandidateParent candidate parent to delete from the existing parent set
	 * @return log score
	 */
	public double calcScoreWithMissingParent(int nNode, int nCandidateParent) throws Exception {
		ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
		Instances instances = m_BayesNet.m_Instances;

		// sanity check: nCandidateParent should be in parent set already
		if (!oParentSet.contains( nCandidateParent)) {
				return -1e100;
		}

		// set up candidate parent
		int iParent = oParentSet.deleteParent(nCandidateParent, instances);

		// determine cardinality of parent set & reserve space for frequency counts
		int nCardinality = oParentSet.getCardinalityOfParents();
		int numValues = instances.attribute(nNode).numValues();
		int [][] nCounts = new int[nCardinality][numValues];

		// calculate the score
		double fAccuracy = calcScore(m_BayesNet);

		// reinsert temporarily deleted parent
		oParentSet.addParent(nCandidateParent, iParent, instances);

		return fAccuracy;
	} // calcScoreWithMissingParent

	/**
	 * Calc Node Score With Arrow reversed
	 * 
	 * @param nNode node for which the score is calculate
	 * @param nCandidateParent candidate parent to delete from the existing parent set
	 * @return log score
	 */
	public double calcScoreWithReversedParent(int nNode, int nCandidateParent) throws Exception {
		ParentSet oParentSet = m_BayesNet.getParentSet(nNode);
		ParentSet oParentSet2 = m_BayesNet.getParentSet(nCandidateParent);
		Instances instances = m_BayesNet.m_Instances;

		// sanity check: nCandidateParent should be in parent set already
		if (!oParentSet.contains( nCandidateParent)) {
				return -1e100;
		}

		// set up candidate parent
		int iParent = oParentSet.deleteParent(nCandidateParent, instances);
		oParentSet2.addParent(nNode, instances);

		// determine cardinality of parent set & reserve space for frequency counts
		int nCardinality = oParentSet.getCardinalityOfParents();
		int numValues = instances.attribute(nNode).numValues();
		int [][] nCounts = new int[nCardinality][numValues];

		// calculate the score
		double fAccuracy = calcScore(m_BayesNet);

		// restate temporarily reversed arrow
		oParentSet2.deleteLastParent(instances);
		oParentSet.addParent(nCandidateParent, iParent, instances);

		return fAccuracy;
	} // calcScoreWithReversedParent

	/**
	 * LeaveOneOutCV returns the accuracy calculated using Leave One Out
	 * cross validation. The dataset used is m_Instances associated with
	 * the Bayes Network.
	 * @param bayesNet : Bayes Network containing structure to evaluate
	 * @return accuracy (in interval 0..1) measured using leave one out cv.
	 * @throws Exception passed on by updateClassifier
	 */
	public double leaveOneOutCV(BayesNet bayesNet) throws Exception {
		m_BayesNet = bayesNet;
		double fAccuracy = 0.0;
		double fWeight = 0.0;
		Instances instances = bayesNet.m_Instances;
		bayesNet.estimateCPTs();
		for (int iInstance = 0; iInstance < instances.numInstances(); iInstance++) {
			Instance instance = instances.instance(iInstance);
			instance.setWeight(-instance.weight());
			bayesNet.updateClassifier(instance);
			fAccuracy += accuracyIncrease(instance);
			fWeight += instance.weight();
			instance.setWeight(-instance.weight());
			bayesNet.updateClassifier(instance);
		}
		return fAccuracy / fWeight;
	} // LeaveOneOutCV

	/**
	 * CumulativeCV returns the accuracy calculated using cumulative
	 * cross validation. The idea is to run through the data set and
	 * try to classify each of the instances based on the previously
	 * seen data.
	 * The data set used is m_Instances associated with the Bayes Network.
	 * @param bayesNet : Bayes Network containing structure to evaluate
	 * @return accuracy (in interval 0..1) measured using leave one out cv.
	 * @throws Exception passed on by updateClassifier

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -