⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 c45.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org.  For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.classify;import java.io.IOException;import java.io.ObjectInputStream;import java.io.ObjectOutputStream;import java.io.Serializable;import java.util.logging.Logger;import edu.umass.cs.mallet.base.classify.Boostable;import edu.umass.cs.mallet.base.classify.Classification;import edu.umass.cs.mallet.base.classify.Classifier;import edu.umass.cs.mallet.base.pipe.Pipe;import edu.umass.cs.mallet.base.types.Alphabet;import edu.umass.cs.mallet.base.types.FeatureVector;import edu.umass.cs.mallet.base.types.Instance;import edu.umass.cs.mallet.base.types.InstanceList;import edu.umass.cs.mallet.base.types.GainRatio;import edu.umass.cs.mallet.base.util.MalletLogger;import edu.umass.cs.mallet.base.util.Maths;/** * A C4.5 Decision Tree classifier. * * @see C45Trainer * @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a> */public class C45 extends Classifier implements Boostable, Serializable{	private static Logger logger = MalletLogger.getLogger(C45.class.getName());	Node m_root;		public C45 (Pipe instancePipe, C45.Node root)	{		super (instancePipe);		m_root = root;	}		public Node getRoot ()	{		return m_root;	}		private Node getLeaf (Node node, FeatureVector fv)	{		if (node.getLeftChild() == null && node.getRightChild() == null)			return node;		else if (fv.value(node.getGainRatio().getMaxValuedIndex()) <= node.getGainRatio().getMaxValuedThreshold())			return getLeaf(node.getLeftChild(), fv);		else			return getLeaf(node.getRightChild(), fv);	}		public Classification classify (Instance instance)	{		FeatureVector fv = (FeatureVector) instance.getData (instancePipe);		assert (instancePipe == null || fv.getAlphabet () == this.instancePipe.getDataAlphabet ());				Node leaf = getLeaf(m_root, fv);		return new Classification (instance, this, leaf.getGainRatio().getBaseLabelDistribution());	}		/**	 * Prune the tree using minimum description length	 */	public void prune()	{		getRoot().computeCostAndPrune();	}		/**	 * @return the total number of nodes in this tree	 */	public int getSize()	{		Node root = getRoot();        		if (root == null)			return 0;		return 1+root.getNumDescendants();	}		/**	 * Prints the tree	 */	public void print()	{		if (getRoot() != null)			getRoot().print();	}		public static class Node implements Serializable	{		private static final long serialVersionUID = 1L;				GainRatio m_gainRatio;		// the entire set of instances given to the root node		InstanceList m_ilist;		// indices of instances at this node		int[] m_instIndices;		// data vocabulary		Alphabet m_dataDict;		// mininum number of instances allowed in this node		int m_minNumInsts;		Node m_parent, m_leftChild, m_rightChild;				public Node(InstanceList ilist, Node parent, int minNumInsts)		{			this(ilist, parent, minNumInsts, null);		}				public Node(InstanceList ilist, Node parent, int minNumInsts, int[] instIndices)		{			if (instIndices == null) {				instIndices = new int[ilist.size()];				for (int ii = 0; ii < instIndices.length; ii++)					instIndices[ii] = ii;			}			m_gainRatio = GainRatio.createGainRatio(ilist, instIndices, minNumInsts);			m_ilist = ilist;			m_instIndices = instIndices;			m_dataDict = m_ilist.getDataAlphabet();			m_minNumInsts = minNumInsts;			m_parent = parent;			m_leftChild = m_rightChild = null;		}				/** The root has depth zero. */		public int depth ()		{			int depth = 0;			Node p = m_parent;			while (p != null) {				p = p.m_parent;				depth++;			}			return depth;		}				public int getSize() { return m_instIndices.length; }		public boolean isLeaf() { return (m_leftChild == null && m_rightChild == null); }		public boolean isRoot() { return m_parent == null; }		public Node getParent() { return m_parent; }		public Node getLeftChild() { return m_leftChild; }		public Node getRightChild() { return m_rightChild; }		public GainRatio getGainRatio() { return m_gainRatio; }		public Object getSplitFeature() { return m_dataDict.lookupObject(m_gainRatio.getMaxValuedIndex()); }				public InstanceList getInstances() 		{ 			InstanceList ret = new InstanceList(m_ilist.getPipe());			for (int ii = 0; ii < m_instIndices.length; ii++)				ret.add(m_ilist.getInstance(m_instIndices[ii]));			return ret; 		}				/** 		 * Count the number of non-leaf descendant nodes		 */		public int getNumDescendants()		{			if (isLeaf())				return 0;			int count = 0;			if (! getLeftChild().isLeaf())				count += 1 + getLeftChild().getNumDescendants();			if (! getRightChild().isLeaf())				count += 1 + getRightChild().getNumDescendants();			return count;		}				public void split()		{			if (m_ilist == null)				throw new IllegalStateException ("Frozen.  Cannot split.");			int numLeftChildren = 0;			boolean[] toLeftChild = new boolean[m_instIndices.length];			for (int i = 0; i < m_instIndices.length; i++) {				Instance instance = m_ilist.getInstance(m_instIndices[i]);				FeatureVector fv = (FeatureVector) instance.getData();				if (fv.value (m_gainRatio.getMaxValuedIndex()) <= m_gainRatio.getMaxValuedThreshold()) {					toLeftChild[i] = true;					numLeftChildren++;				}				else					toLeftChild[i] = false;			}			logger.info("leftChild.size=" + numLeftChildren 					+ " rightChild.size=" + (m_instIndices.length-numLeftChildren));			int[] leftIndices = new int[numLeftChildren];			int[] rightIndices = new int[m_instIndices.length - numLeftChildren];			int li = 0, ri = 0;			for (int i = 0; i < m_instIndices.length; i++) {				if (toLeftChild[i])					leftIndices[li++] = m_instIndices[i];				else					rightIndices[ri++] = m_instIndices[i];			}			m_leftChild = new Node(m_ilist, this, m_minNumInsts, leftIndices);			m_rightChild = new Node(m_ilist, this, m_minNumInsts, rightIndices);		}				public double computeCostAndPrune()		{			double costS = getMDL();			if (isLeaf())				return costS + 1;			double minCost1 = getLeftChild().computeCostAndPrune();			double minCost2 = getRightChild().computeCostAndPrune();			double costSplit = Math.log(m_gainRatio.getNumSplitPointsForBestFeature()) / GainRatio.log2;			double minCostN = Math.min(costS+1, costSplit+1+minCost1+minCost2);			if (Maths.almostEquals(minCostN, costS+1))				m_leftChild = m_rightChild = null;			return minCostN;		}				/**		 * Calculates the minimum description length of this node, i.e., 		 * the length of the binary encoding that describes the feature 		 * and the split value used at this node		 */		public double getMDL()		{			int numClasses = m_ilist.getTargetAlphabet().size();			double mdl = getSize() * getGainRatio().getBaseEntropy();			mdl += ((numClasses-1) * Math.log(getSize() / 2.0)) / (2 * GainRatio.log2);			double piPow = Math.pow(Math.PI, numClasses/2.0);			double gammaVal = Maths.gamma(numClasses/2.0);			mdl += Math.log(piPow/gammaVal) / GainRatio.log2;			return mdl;		}				/**		 * Saves memory by allowing ilist to be garbage collected		 * (deletes this node's associated instance list)		 */		public void stopGrowth ()		{			if (m_leftChild != null)				m_leftChild.stopGrowth();			if (m_rightChild != null)				m_rightChild.stopGrowth();	  			m_ilist = null;		}				public String getName()		{			return getStringBufferName().toString();		}				public StringBuffer getStringBufferName()		{			StringBuffer sb = new StringBuffer();			if (m_parent == null)				return sb.append("root");			else if (m_parent.getParent() == null) {				sb.append("(\"");				sb.append(m_dataDict.lookupObject(m_parent.getGainRatio().getMaxValuedIndex()).toString());				sb.append("\"");				if (m_parent.getLeftChild() == this)					sb.append(" <= ");				else					sb.append(" > ");				sb.append(m_parent.getGainRatio().getMaxValuedThreshold());				return sb.append(")");			} 			else {				sb.append(m_parent.getStringBufferName());				sb.append(" && (\"");				sb.append(m_dataDict.lookupObject(m_parent.getGainRatio().getMaxValuedIndex()).toString());				sb.append("\"");				if (m_parent.getLeftChild() == this)					sb.append(" <= ");				else 					sb.append(" > ");				sb.append(m_parent.getGainRatio().getMaxValuedThreshold());				return sb.append(")");			}		}				/**		 * Prints the tree rooted at this node		 */		public void print()		{			print("");		}				public void print(String prefix)		{	  			if (isLeaf()) {				int bestLabelIndex = getGainRatio().getBaseLabelDistribution().getBestIndex();				int numMajorityLabel = (int) (getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * getSize());				System.out.println("root:" + getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + getSize());			}			else {				String featName = m_dataDict.lookupObject(getGainRatio().getMaxValuedIndex()).toString();				double threshold = getGainRatio().getMaxValuedThreshold();				System.out.print(prefix + "\"" + featName + "\" <= " + threshold + ":");				if (m_leftChild.isLeaf()) {					int bestLabelIndex = m_leftChild.getGainRatio().getBaseLabelDistribution().getBestIndex();					int numMajorityLabel = (int) (m_leftChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * m_leftChild.getSize());					System.out.println(m_leftChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + m_leftChild.getSize());				}				else {					System.out.println();					m_leftChild.print(prefix + "|    ");				}	      				System.out.print(prefix + "\"" + featName + "\" > " + threshold + ":");				if (m_rightChild.isLeaf()) {					int bestLabelIndex = m_rightChild.getGainRatio().getBaseLabelDistribution().getBestIndex();					int numMajorityLabel = (int) (m_rightChild.getGainRatio().getBaseLabelDistribution().value(bestLabelIndex) * m_rightChild.getSize());					System.out.println(m_rightChild.getGainRatio().getBaseLabelDistribution().getBestLabel() + " " + numMajorityLabel + "/" + m_rightChild.getSize());				}				else {					System.out.println();					m_rightChild.print(prefix + "|    ");				}			}		}			}		// Serialization	// serialVersionUID is overriden to prevent innocuous changes in this	// class from making the serialization mechanism think the external	// format has changed.		private static final long serialVersionUID = 1;	private static final int CURRENT_SERIAL_VERSION = 1;		private void writeObject(ObjectOutputStream out) throws IOException	{		out.writeInt(CURRENT_SERIAL_VERSION);		out.writeObject(getInstancePipe());		out.writeObject(m_root);	}		private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {		int version = in.readInt();		if (version != CURRENT_SERIAL_VERSION)			throw new ClassNotFoundException("Mismatched C45 versions: wanted " +					CURRENT_SERIAL_VERSION + ", got " +					version);		instancePipe = (Pipe) in.readObject();		m_root = (Node) in.readObject();			}	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -