⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 hillclimber.java

📁 一个数据挖掘软件ALPHAMINERR的整个过程的JAVA版源代码
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/*
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 * TabuSearch.java
 * Copyright (C) 2004 Remco Bouckaert
 * 
 */
 
package weka.classifiers.bayes.net.search.local;

import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.core.*;
import java.util.*;
import java.io.Serializable;

/** HillClimber implements hill climbing using local search 
 * for learning Bayesian network.
 * 
 * @author Remco Bouckaert (rrb@xm.co.nz)
 * Version: $Revision$
 */
public class HillClimber extends LocalScoreSearchAlgorithm {

	/** the Operation class contains info on operations performed
	 * on the current Bayesian network.
	 */
    class Operation implements Serializable {
    	// constants indicating the type of an operation
    	final static int OPERATION_ADD = 0;
    	final static int OPERATION_DEL = 1;
    	final static int OPERATION_REVERSE = 2;
    	/** c'tor **/
        public Operation() {
        }
		/** c'tor + initializers
		 * 
		 * @param nTail
		 * @param nHead
		 * @param nOperation
		 */ 
	    public Operation(int nTail, int nHead, int nOperation) {
			m_nHead = nHead;
			m_nTail = nTail;
			m_nOperation = nOperation;
		}
		/** compare this operation with another
		 * @param other: operation to compare with
		 * @return true if operation is the same
		 */
		public boolean equals(Operation other) {
			if (other == null) {
				return false;
			}
			return ((	m_nOperation == other.m_nOperation) &&
			(m_nHead == other.m_nHead) &&
			(m_nTail == other.m_nTail));
		} // equals
		/** number of the tail node **/
        public int m_nTail;
		/** number of the head node **/
        public int m_nHead;
		/** type of operation (ADD, DEL, REVERSE) **/
        public int m_nOperation;
        /** change of score due to this operation **/
        public double m_fDeltaScore = -1E100;
    } // class Operation

	/** cache for remembering the change in score for steps in the search space
	 */
	class Cache {
		/** change in score due to adding an arc **/
		double [] [] m_fDeltaScoreAdd;
		/** change in score due to deleting an arc **/
		double [] [] m_fDeltaScoreDel;
		/** c'tor
		 * @param nNrOfNodes: number of nodes in network, used to determine memory size to reserve
		 */
		Cache(int nNrOfNodes) {
			m_fDeltaScoreAdd = new double [nNrOfNodes][nNrOfNodes];
			m_fDeltaScoreDel = new double [nNrOfNodes][nNrOfNodes];
		}

		/** set cache entry
		 * @param oOperation: operation to perform
		 * @param fValue: value to put in cache
		 */
		public void put(Operation oOperation, double fValue) {
			if (oOperation.m_nOperation == Operation.OPERATION_ADD) {
				m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead] = fValue;
			} else {
				m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead] = fValue;
			}
		} // put

		/** get cache entry
		 * @param oOperation: operation to perform
		 * @return cache value
		 */
		public double get(Operation oOperation) {
			switch(oOperation.m_nOperation) {
				case Operation.OPERATION_ADD:
					return m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead];
				case Operation.OPERATION_DEL:
					return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead];
				case Operation.OPERATION_REVERSE:
				return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead] + 
						m_fDeltaScoreAdd[oOperation.m_nHead][oOperation.m_nTail];
			}
			// should never get here
			return 0;
		} // get
	} // class Cache

	/** cache for storing score differences **/
	Cache m_Cache = null;
	
    /** use the arc reversal operator **/
    boolean m_bUseArcReversal = false;
    	

	/**
	* search determines the network structure/graph of the network
	* with the Taby algorithm.
	**/
    protected void search(BayesNet bayesNet, Instances instances) throws Exception {
        initCache(bayesNet, instances);

        // go do the search        
		Operation oOperation = getOptimalOperation(bayesNet, instances);
		while ((oOperation != null) && (oOperation.m_fDeltaScore > 0)) {
			performOperation(bayesNet, instances, oOperation);
			oOperation = getOptimalOperation(bayesNet, instances);
        }
        
		// free up memory
		m_Cache = null;
    } // search


	/** initCache initializes the cache
	 * @param bayesNet: Bayes network to be learned
	 * @param instances: data set to learn from
	 * @throws Exception
	 */
    void initCache(BayesNet bayesNet, Instances instances)  throws Exception {
    	
        // determine base scores
		double[] fBaseScores = new double[instances.numAttributes()];
        int nNrOfAtts = instances.numAttributes();

		m_Cache = new Cache (nNrOfAtts);
		
		for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
			updateCache(iAttribute, nNrOfAtts, bayesNet.getParentSet(iAttribute));
		}


        for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
            fBaseScores[iAttribute] = calcNodeScore(iAttribute);
        }

        for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
                for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
                	if (iAttributeHead != iAttributeTail) {
	                    Operation oOperation = new Operation(iAttributeTail, iAttributeHead, Operation.OPERATION_ADD);
	                    m_Cache.put(oOperation, calcScoreWithExtraParent(iAttributeHead, iAttributeTail) - fBaseScores[iAttributeHead]);
					}
            }
        }

    } // initCache

	/** check whether the operation is not in the forbidden.
	 * For base hill climber, there are no restrictions on operations,
	 * so we always return true.
	 * @param oOperation: operation to be checked
	 * @return true if operation is not in the tabu list
	 */
	boolean isNotTabu(Operation oOperation) {
		return true;
	} // isNotTabu

	/** getOptimalOperation finds the optimal operation that can be performed
	 * on the Bayes network that is not in the tabu list.
	 * @param bayesNet: Bayes network to apply operation on
	 * @param instances: data set to learn from
	 * @return optimal operation found
	 * @throws Exception
	 */
    Operation getOptimalOperation(BayesNet bayesNet, Instances instances) throws Exception {
        Operation oBestOperation = new Operation();

		// Add???
		oBestOperation = findBestArcToAdd(bayesNet, instances, oBestOperation);
		// Delete???
		oBestOperation = findBestArcToDelete(bayesNet, instances, oBestOperation);
		// Reverse???
		if (getUseArcReversal()) {
			oBestOperation = findBestArcToReverse(bayesNet, instances, oBestOperation);
		}

		// did we find something?
		if (oBestOperation.m_fDeltaScore == -1E100) {
			return null;
		}

        return oBestOperation;
    } // getOptimalOperation

	/** performOperation applies an operation 
	 * on the Bayes network and update the cache.
	 * @param bayesNet: Bayes network to apply operation on
	 * @param instances: data set to learn from
	 * @param oOperation: operation to perform
	 * @throws Exception
	 */
	void performOperation(BayesNet bayesNet, Instances instances, Operation oOperation) throws Exception {
		// perform operation
		switch (oOperation.m_nOperation) {
			case Operation.OPERATION_ADD:
				applyArcAddition(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
				if (bayesNet.getDebug()) {
					System.out.print("Add " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
				}
				break;
			case Operation.OPERATION_DEL:
				applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
				if (bayesNet.getDebug()) {
					System.out.print("Del " + oOperation.m_nHead + " -> " + oOperation.m_nTail);
				}
				break;
			case Operation.OPERATION_REVERSE:
				applyArcDeletion(bayesNet, oOperation.m_nHead, oOperation.m_nTail, instances);
				applyArcAddition(bayesNet, oOperation.m_nTail, oOperation.m_nHead, instances);
				if (bayesNet.getDebug()) {
					System.out.print("Rev " + oOperation.m_nHead+ " -> " + oOperation.m_nTail);
				}
				break;
		}
	} // performOperation


	void applyArcAddition(BayesNet bayesNet, int iHead, int iTail, Instances instances) {
		ParentSet bestParentSet = bayesNet.getParentSet(iHead);
		bestParentSet.addParent(iTail, instances);
		updateCache(iHead, instances.numAttributes(), bestParentSet);
	} // applyArcAddition

	void applyArcDeletion(BayesNet bayesNet, int iHead, int iTail, Instances instances) {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -