⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 globalscoresearchalgorithm.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
	 */
	public double cumulativeCV(BayesNet bayesNet) throws Exception {
		m_BayesNet = bayesNet;
		double fAccuracy = 0.0;
		double fWeight = 0.0;
		Instances instances = bayesNet.m_Instances;
		bayesNet.initCPTs();
		for (int iInstance = 0; iInstance < instances.numInstances(); iInstance++) {
			Instance instance = instances.instance(iInstance);
			fAccuracy += accuracyIncrease(instance);
			bayesNet.updateClassifier(instance);
			fWeight += instance.weight();
		}
		return fAccuracy / fWeight;
	} // LeaveOneOutCV
	
	/**
	 * kFoldCV uses k-fold cross validation to measure the accuracy of a Bayes
	 * network classifier.
	 * @param bayesNet : Bayes Network containing structure to evaluate
	 * @param nNrOfFolds : the number of folds k to perform k-fold cv
	 * @return accuracy (in interval 0..1) measured using leave one out cv.
	 * @throws Exception passed on by updateClassifier
	 */
	public double kFoldCV(BayesNet bayesNet, int nNrOfFolds) throws Exception {
		m_BayesNet = bayesNet;
		double fAccuracy = 0.0;
		double fWeight = 0.0;
		Instances instances = bayesNet.m_Instances;
		// estimate CPTs based on complete data set
		bayesNet.estimateCPTs();
		int nFoldStart = 0;
		int nFoldEnd = instances.numInstances() / nNrOfFolds;
		int iFold = 1;
		while (nFoldStart < instances.numInstances()) {
			// remove influence of fold iFold from the probability distribution
			for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
				Instance instance = instances.instance(iInstance);
				instance.setWeight(-instance.weight());
				bayesNet.updateClassifier(instance);
			}
			
			// measure accuracy on fold iFold
			for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
				Instance instance = instances.instance(iInstance);
				instance.setWeight(-instance.weight());
				fAccuracy += accuracyIncrease(instance);
				fWeight += instance.weight();
			}

			// restore influence of fold iFold from the probability distribution
			for (int iInstance = nFoldStart; iInstance < nFoldEnd; iInstance++) {
				Instance instance = instances.instance(iInstance);
				instance.setWeight(-instance.weight());
				bayesNet.updateClassifier(instance);
			}

			// go to next fold
			nFoldStart = nFoldEnd;
			iFold++;
			nFoldEnd = iFold * instances.numInstances() / nNrOfFolds;
		}
		return fAccuracy / fWeight;
	} // kFoldCV
	
	/** accuracyIncrease determines how much the accuracy estimate should
	 * be increased due to the contribution of a single given instance. 
	 * 
	 * @param instance : instance for which to calculate the accuracy increase.
	 * @return increase in accuracy due to given instance. 
	 * @throws Exception passed on by distributionForInstance and classifyInstance
	 */
	double accuracyIncrease(Instance instance) throws Exception {
		if (m_bUseProb) {
			double [] fProb = m_BayesNet.distributionForInstance(instance);
			return fProb[(int) instance.classValue()] * instance.weight();
		} else {
			if (m_BayesNet.classifyInstance(instance) == instance.classValue()) {
				return instance.weight();
			}
		}
		return 0;
	} // accuracyIncrease

	/**
	 * @return use probabilities or not in accuracy estimate
	 */
	public boolean getUseProb() {
		return m_bUseProb;
	} // getUseProb

	/**
	 * @param useProb : use probabilities or not in accuracy estimate
	 */
	public void setUseProb(boolean useProb) {
		m_bUseProb = useProb;
	} // setUseProb
	
	/**
	 * set cross validation strategy to be used in searching for networks.
	 * @param newCVType : cross validation strategy 
	 */
	public void setCVType(SelectedTag newCVType) {
		if (newCVType.getTags() == TAGS_CV_TYPE) {
			m_nCVType = newCVType.getSelectedTag().getID();
		}
	} // setCVType

	/**
	 * get cross validation strategy to be used in searching for networks.
	 * @return cross validation strategy 
	 */
	public SelectedTag getCVType() {
		return new SelectedTag(m_nCVType, TAGS_CV_TYPE);
	} // getCVType

	/**
	 * Returns an enumeration describing the available options
	 * 
	 * @return an enumeration of all the available options
	 */
	public Enumeration listOptions() {
		Vector newVector = new Vector(2);

		newVector.addElement(
			new Option(
				"\tScore type (LOO-CV,k-Fold-CV,Cumulative-CV)\n",
				"S",
				1,
				"-S [LOO-CV|k-Fold-CV|Cumulative-CV]"));

		newVector.addElement(new Option("\tUse probabilistic scoring.\n\t(default true)", "Q", 0, "-Q"));

		Enumeration em = super.listOptions();
		while (em.hasMoreElements()) {
			newVector.addElement(em.nextElement());
		}
		return newVector.elements();
	} // listOptions

	/**
	 * Parses a given list of options. Valid options are:<p>
	 *
	 * For other options see search algorithm.
	 *
	 * @param options the list of options as an array of strings
	 * @exception Exception if an option is not supported
	 */
	public void setOptions(String[] options) throws Exception {

		String sScore = Utils.getOption('S', options);

		if (sScore.compareTo("LOO-CV") == 0) {
			setCVType(new SelectedTag(LOOCV, TAGS_CV_TYPE));
		}
		if (sScore.compareTo("k-Fold-CV") == 0) {
			setCVType(new SelectedTag(KFOLDCV, TAGS_CV_TYPE));
		}
		if (sScore.compareTo("Cumulative-CV") == 0) {
			setCVType(new SelectedTag(CUMCV, TAGS_CV_TYPE));
		}
		setUseProb(Utils.getFlag('Q', options));		
		super.setOptions(options);
	} // setOptions

	/**
	 * Gets the current settings of the search algorithm.
	 *
	 * @return an array of strings suitable for passing to setOptions
	 */
	public String[] getOptions() {
		String[] superOptions = super.getOptions();
		String[] options = new String[3 + superOptions.length];
		int current = 0;

		options[current++] = "-S";

		switch (m_nCVType) {
			case (LOOCV) :
				options[current++] = "LOO-CV";
				break;
			case (KFOLDCV) :
				options[current++] = "k-Fold-CV";
				break;
			case (CUMCV) :
				options[current++] = "Cumulative-CV";
				break;
		}
		
		if (getUseProb()) {
		  options[current++] = "-Q";
		}

		// Fill up rest with empty strings, not nulls!
		while (current < options.length) {
			options[current++] = "";
		}
		return options;
	} // getOptions

	/**
	 * @return a string to describe the CVType option.
	 */
	public String CVTypeTipText() {
	  return "Select cross validation strategy to be used in searching for networks." +
	  "LOO-CV = Leave one out cross validation\n" +
	  "k-Fold-CV = k fold cross validation\n" +
	  "Cumulative-CV = cumulative cross validation."
	  ;
	} // CVTypeTipText

	/**
	 * @return a string to describe the UseProb option.
	 */
	public String useProbTipText() {
	  return "If set to true, the probability of the class if returned in the estimate of the "+
	  "accuracy. If set to false, the accuracy estimate is only increased if the classifier returns " +
	  "exactly the correct class.";
	} // useProbTipText

	/**
	 * This will return a string describing the search algorithm.
	 * @return The string.
	 */
	public String globalInfo() {
	  return "This Bayes Network learning algorithm uses cross validation to estimate " +
	  "classification accuracy.";
	} // globalInfo

} // class CVSearchAlgorithm

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -