gainratio.java

来自「mallet是自然语言处理、机器学习领域的一个开源项目。」· Java 代码 · 共 367 行

JAVA
367
字号
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org.  For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.types;import java.awt.geom.Point2D;import java.util.ArrayList;import java.util.Arrays;import java.util.Collections;import java.util.Comparator;import java.util.Hashtable;import java.util.Iterator;import java.util.logging.Logger;import edu.umass.cs.mallet.base.util.MalletLogger;import edu.umass.cs.mallet.base.util.Maths;/** * List of features along with their thresholds sorted in descending order of  * the ratio of (1) information gained by splitting instances on the  * feature at its associated threshold value, to (2) the split information.<p> * * The calculations performed do not take into consideration the instance weights.<p> * * To create an instance of GainRatio from an InstanceList, one must do the following:<p><tt> * * InstanceList ilist = ...  *         ... * GainRatio gr = GainRatio.createGainRatio(ilist); * </tt><p> * * J. R. Quinlan * "Improved Use of Continuous Attributes in C4.5"  * ftp://ftp.cs.cmu.edu/project/jair/volume4/quinlan96a.ps * * @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a> */public class GainRatio extends RankedFeatureVector{	private static final Logger logger = MalletLogger.getLogger (GainRatio.class.getName ());		private static final long serialVersionUID = 1L;		public static final double log2 = Math.log(2);		double[] m_splitPoints;	double m_baseEntropy;	LabelVector m_baseLabelDistribution;	int m_numSplitPointsForBestFeature;	int m_minNumInsts;		/**	 * Calculates gain ratios for all (feature, split point) pairs 	 * snd returns array of:<pre>	 *   1.  gain ratios (each element is the max gain ratio of a feature 	 * for those split points with at least average gain)	 *   2.  the optimal split point for each feature	 *   3.  the overall entropy 	 *   4.  the overall label distribution of the given instances	 *   5.  the number of split points of the split feature.	 *   </pre>	 */	protected static Object[] calcGainRatios(InstanceList ilist, int[] instIndices, int minNumInsts)	{		int numInsts = instIndices.length;		Alphabet dataDict = ilist.getDataAlphabet();		LabelAlphabet targetDict = (LabelAlphabet) ilist.getTargetAlphabet();		double[] targetCounts = new double[targetDict.size()];		// Accumulate target label counts and make sure		// the sum of each instance's target label is 1		for (int ii = 0; ii < numInsts; ii++) {			Instance inst = ilist.getInstance(instIndices[ii]);			Labeling labeling = inst.getLabeling();			double labelWeightSum = 0;			for (int ll = 0; ll < labeling.numLocations(); ll++) {				int li = labeling.indexAtLocation(ll);				double labelWeight = labeling.valueAtLocation(ll);				labelWeightSum += labelWeight;				targetCounts[li] += labelWeight;			}			assert(Maths.almostEquals(labelWeightSum, 1));		}		// Calculate the base entropy Info(D) and the the 		// label distribution of the given instances		double[] targetDistribution = new double[targetDict.size()];		double baseEntropy = 0;		for (int ci = 0; ci < targetDict.size(); ci++) {			double p = targetCounts[ci] / numInsts;			targetDistribution[ci] = p;			if (p > 0)				baseEntropy -= p * Math.log(p) / log2;		}		LabelVector baseLabelDistribution = new LabelVector(targetDict, targetDistribution);		double infoGainSum = 0;		int totalNumSplitPoints = 0;		double[] passTestTargetCounts = new double[targetDict.size()];		// Maps feature index -> Hashtable, and each table 		// maps (split point) -> (info gain, split ratio)		Hashtable[] featureToInfo = new Hashtable[dataDict.size()]; 				// Go through each feature's split points in ascending order		for (int fi = 0; fi < dataDict.size(); fi++) {						if ((fi+1) % 1000 == 0)				logger.info("at feature " + (fi+1) + " / " + dataDict.size());						featureToInfo[fi] = new Hashtable();			Arrays.fill(passTestTargetCounts, 0);			// Sort instances on this feature's values			instIndices = sortInstances(ilist, instIndices, fi);			// Iterate through the sorted instances			for (int ii = 0; ii < numInsts-1; ii++) {				Instance inst = ilist.getInstance(instIndices[ii]);				Instance instPlusOne = ilist.getInstance(instIndices[ii+1]);				FeatureVector fv1 = (FeatureVector) inst.getData();				FeatureVector fv2 = (FeatureVector) instPlusOne.getData();				double lower = fv1.value(fi);				double higher = fv2.value(fi);				// Accumulate the label weights for instances passing the test				Labeling labeling = inst.getLabeling();				for (int ll = 0; ll < labeling.numLocations(); ll++) {					int li = labeling.indexAtLocation(ll);					double labelWeight = labeling.valueAtLocation(ll);					passTestTargetCounts[li] += labelWeight;				}				if (Maths.almostEquals(lower, higher) 						|| 						inst.getLabeling().toString().equals(instPlusOne.getLabeling().toString()))					continue;				// For this (feature, spilt point) pair, calculate the 				// info gain of using this pair to split insts into those 				// with value of feature <= p versus > p				totalNumSplitPoints++;				double splitPoint = (lower + higher) / 2;				double numPassInsts = ii+1;								// If this split point creates a partition 				// with too few instances, ignore it				double numFailInsts = numInsts - numPassInsts;				if (numPassInsts < minNumInsts || numFailInsts < minNumInsts)					continue;								// If all instances pass or fail this test, it is useless				double passProportion = numPassInsts / numInsts;				if (Maths.almostEquals(passProportion, 0) || Maths.almostEquals(passProportion, 1))					continue;	  								// Calculate the entropy of instances passing and failing the test				double passEntropy = 0;				double failEntropy = 0;				double p;								for (int ci = 0; ci < targetDict.size(); ci++) {					if (numPassInsts > 0) {						p = passTestTargetCounts[ci] / numPassInsts;						if (p > 0)							passEntropy -= p * Math.log(p) / log2;					}					if (numFailInsts > 0) {						double failTestTargetCount = targetCounts[ci] - passTestTargetCounts[ci];						p = failTestTargetCount / numFailInsts;						if (p  > 0)							failEntropy -= p * Math.log(p) / log2;					}				}								// Calculate Gain(D, T), the information gained 				// by testing on this (feature, split-point) pair				double gainDT = baseEntropy 				- passProportion * passEntropy				- (1-passProportion) * failEntropy;	      				infoGainSum += gainDT;				// Calculate Split(D, T), the split information				double splitDT = 					- passProportion * Math.log(passProportion) / log2					- (1-passProportion) * Math.log(1-passProportion) / log2;				// Calculate the gain ratio				double gainRatio = gainDT / splitDT;				featureToInfo[fi].put(new Double(splitPoint),						new Point2D.Double(gainDT, gainRatio));			}  // End loop through sorted instances		}  // End loop through features		// For each feature's split point with at least average gain, 		// get the maximum gain ratio and the associated split point		// (using the info gain as tie breaker)		double[] gainRatios = new double[dataDict.size()];		double[] splitPoints = new double[dataDict.size()];		int numSplitsForBestFeature = 0;		// If all feature vectors are identical or no splits are worthy, return all 0s		if (totalNumSplitPoints == 0 || Maths.almostEquals(infoGainSum, 0))			return new Object[] {gainRatios, splitPoints, new Double(baseEntropy), 				baseLabelDistribution, new Integer(numSplitsForBestFeature)};		double avgInfoGain = infoGainSum / totalNumSplitPoints;		double maxGainRatio = 0;		double gainForMaxGainRatio = 0; // tie breaker				int xxx = 0;						for (int fi = 0; fi < dataDict.size(); fi++) {			double featureMaxGainRatio = 0;			double featureGainForMaxGainRatio = 0;			double bestSplitPoint = Double.NaN;			for (Iterator iter = featureToInfo[fi].keySet().iterator(); iter.hasNext(); ) {				Object key = iter.next();				Point2D.Double pt = (Point2D.Double) featureToInfo[fi].get(key);				double splitPoint = ((Double) key).doubleValue();				double infoGain = pt.getX();				double gainRatio = pt.getY();				if (infoGain >= avgInfoGain) {					if (gainRatio > featureMaxGainRatio							||							(gainRatio == featureMaxGainRatio && infoGain > featureGainForMaxGainRatio)) {						featureMaxGainRatio = gainRatio;						featureGainForMaxGainRatio = infoGain;						bestSplitPoint = splitPoint;					}				}												else 					xxx++;											}			assert(bestSplitPoint != Double.NaN);			gainRatios[fi] = featureMaxGainRatio;			splitPoints[fi] = bestSplitPoint;			if (featureMaxGainRatio > maxGainRatio					||					(featureMaxGainRatio == maxGainRatio && featureGainForMaxGainRatio > gainForMaxGainRatio)) {				maxGainRatio = featureMaxGainRatio;				gainForMaxGainRatio = featureGainForMaxGainRatio;				numSplitsForBestFeature = featureToInfo[fi].size();			}		}				logger.info("label distrib:\n" + baseLabelDistribution);		logger.info("base entropy=" + baseEntropy + ", info gain sum=" + infoGainSum + ", total num split points=" + totalNumSplitPoints + ", avg info gain=" + avgInfoGain + ", num splits with < avg gain=" + xxx);				return new Object[] {gainRatios, splitPoints, new Double(baseEntropy), 				baseLabelDistribution, new Integer(numSplitsForBestFeature)};	}		public static int[] sortInstances(InstanceList ilist, int[] instIndices, int featureIndex)	{		ArrayList list = new ArrayList();		for (int ii = 0; ii < instIndices.length; ii++) {			Instance inst = ilist.getInstance(instIndices[ii]);			FeatureVector fv = (FeatureVector) inst.getData();			list.add(new Point2D.Double(instIndices[ii], fv.value(featureIndex)));		}		Collections.sort(list, new Comparator() 				{			public int compare(Object o1, Object o2)			{				Point2D.Double p1 = (Point2D.Double) o1;				Point2D.Double p2 = (Point2D.Double) o2;				if (p1.y == p2.y) {					assert(p1.x != p2.x);					return p1.x > p2.x ? 1 : -1;				}				else 					return p1.y > p2.y ? 1 : -1;		    			}				});		int[] sorted = new int[instIndices.length];		for (int i = 0; i < list.size(); i++)			sorted[i] = (int) ((Point2D.Double) list.get(i)).getX();		return sorted;	}		/**	 * Constructs a GainRatio object.	 */	public static GainRatio createGainRatio(InstanceList ilist)	{		int[] instIndices = new int[ilist.size()];		for (int ii = 0; ii < instIndices.length; ii++)			instIndices[ii] = ii;		return createGainRatio(ilist, instIndices, 2);	}		/**	 * Constructs a GainRatio object	 */	public static GainRatio createGainRatio(InstanceList ilist, int[] instIndices, 			int minNumInsts)	{		Object[] objs = calcGainRatios(ilist, instIndices, minNumInsts);		double[] gainRatios = (double[]) objs[0];		double[] splitPoints = (double[]) objs[1];		double baseEntropy = ((Double) objs[2]).doubleValue();		LabelVector baseLabelDistribution = (LabelVector) objs[3];		int numSplitPointsForBestFeature = ((Integer) objs[4]).intValue();		return new GainRatio(ilist.getDataAlphabet(), gainRatios, splitPoints, 				baseEntropy, baseLabelDistribution, 				numSplitPointsForBestFeature, minNumInsts);	}		protected GainRatio(Alphabet dataAlphabet, double[] gainRatios, double[] splitPoints, 			double baseEntropy, LabelVector baseLabelDistribution,			int numSplitPointsForBestFeature, int minNumInsts)	{		super (dataAlphabet, gainRatios);		m_splitPoints = splitPoints;		m_baseEntropy = baseEntropy;		m_baseLabelDistribution = baseLabelDistribution;		m_numSplitPointsForBestFeature = numSplitPointsForBestFeature;		m_minNumInsts = minNumInsts;	}		/**	 * @return the threshold of the (feature, threshold) 	 * pair with with maximum gain ratio	 */	public double getMaxValuedThreshold()	{		return getThresholdAtRank(0);	}		/**	 * @return the threshold of the (feature, threshold)	 * pair with the given rank	 */	public double getThresholdAtRank(int rank)	{		int index = getIndexAtRank(rank);		return m_splitPoints[index];	}		public double getBaseEntropy ()	{		return m_baseEntropy;	}		public LabelVector getBaseLabelDistribution ()	{		return m_baseLabelDistribution;	}		public int getNumSplitPointsForBestFeature()	{		return m_numSplitPointsForBestFeature;	}	}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?