⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 informationgain.java

📁 常用机器学习算法,java编写源代码,内含常用分类算法,包括说明文档
💻 JAVA
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
   http://www.cs.umass.edu/~mccallum/mallet
   This software is provided under the terms of the Common Public License,
   version 1.0, as published by http://www.opensource.org.  For further
   information, see the file `LICENSE' included with this distribution. */

/**  Calculates information gain between two hashmaps of words
 *   @author Ron Bekkerman <A HREF="mailto:ronb@cs.umass.edu">ronb@cs.umass.edu</A>
*/


package edu.umass.cs.mallet.projects.dex.types;

import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.Alphabet;

import java.util.*;
import java.lang.Math.*;
import java.io.*;


public class InformationGain implements Serializable{
	public InformationGain(People people, String outDirName) {
		wordWeights = new HashMap();
		personWeights = new HashMap();
		
		fillWeightMaps(people);
		calculateInformationGain(people, outDirName);
	}

	public HashMap featureVector2HashMap (FeatureVector fv) {
		HashMap h = new HashMap (fv.numLocations());
		Alphabet alph = fv.getAlphabet();
		for (int i=0; i < fv.numLocations(); i++) {
			double value = fv.valueAtLocation (i);
			int index = fv.indexAtLocation (i);
			String word = (String) alph.lookupObject (index);
			h.put (word, new Double (value));			
		}
		return h;
	}
	
	public void updateWordWeights(String word, Double wordCurrentWeight) {
		Double wordWeight = (Double) wordWeights.get(word);
		if (wordWeight != null) {
	    //The word already exists in wordWeights
	    wordWeights.put(word, new Double(wordWeight.doubleValue() +
																			 wordCurrentWeight.doubleValue()));
		}
		else {
	    //The word is seen for the first time
	    wordWeights.put(word, wordCurrentWeight);
		}
	}
	
	public void updatePersonWeights(Person person, Double wordCurrentWeight) {
		Double wordWeight = (Double) personWeights.get(person);
		if (wordWeight != null) {
	    //The word already exists in personWeights
	    personWeights.put(person, new Double(wordWeight.doubleValue() +
																					 wordCurrentWeight.doubleValue()));
		}
		else {
	    //The word is seen for the first time
	    personWeights.put(person, wordCurrentWeight);
		}
	}
	
	
	public void fillWeightMaps(People people) {
		Iterator iter = people.iterator ();
		while (iter.hasNext()) {
	    Person person = (Person)iter.next();
	    //if(person.getContactRecord() == null) continue;
	    HashMap words = featureVector2HashMap (person.keyWords);
	    Object[] wordArray = words.keySet().toArray();
	    for(int j = 0; j < wordArray.length; j++) {
				String word = (String)wordArray[j];
				Double wordCurrentWeight = (Double) words.get(word);
				updateWordWeights(word, wordCurrentWeight);
				updatePersonWeights(person, wordCurrentWeight);
				overallWeight = overallWeight + wordCurrentWeight.doubleValue();
	    }
		}
	}
	
	public double calculateInformationGain(PersonWordWeights data) {
		double infoGain = 0;
		if(data.w_in_p != 0)
	    infoGain = data.w_in_p * Math.log(data.w_in_p * overallWeight /
																				(data.w * data.p));
		if(data.w_in_not_p != 0)
	    infoGain = infoGain + data.w_in_not_p *
								 Math.log(data.w_in_not_p * overallWeight / (data.w * data.not_p));
		infoGain = infoGain + data.not_w_in_p *
							 Math.log(data.not_w_in_p * overallWeight / (data.not_w * data.p));
		infoGain = infoGain + data.not_w_in_not_p *
							 Math.log(data.not_w_in_not_p * overallWeight / (data.not_w * data.not_p));
		infoGain = infoGain / overallWeight;
		return infoGain;
	}
	
	public double calculateLogRatio(PersonWordWeights data) {
		double logRatio = 0;
		logRatio = (data.w_in_p/data.p) * 
							 Math.log(data.w_in_p * data.not_p / (data.p * data.w_in_not_p));
		return logRatio;
	}
	
	public void updateSortedWords(ArrayList sortedWords, WeightedString weightedWord) {
		if(sortedWords.size() == 0) {
			sortedWords.add(weightedWord);
			return;
		}
		for(int i = 0; i < sortedWords.size(); i++) {
			WeightedString currentWeightedWord = (WeightedString)sortedWords.get(i);
			if(weightedWord.weight >= currentWeightedWord.weight) {
				sortedWords.add(i, weightedWord);
				return;
			}
		}
		sortedWords.add(weightedWord);
	}
	
	public void calculateInformationGain(Person person, BufferedWriter out) {
		ArrayList sortedWords = new ArrayList();
		PersonWordWeights data = new PersonWordWeights();
		try {
	    data.p = ((Double)personWeights.get(person)).doubleValue();
		} catch (NullPointerException e) {
	    System.out.print("Null pointer in ");
	    person.print();
	    System.exit(1);
		}
		data.not_p = overallWeight - data.p;
		HashMap words = featureVector2HashMap (person.keyWords);
		Object[] wordArray = words.keySet().toArray();
		for(int j = 0; j < wordArray.length; j++) {
	    String word = (String)wordArray[j];
	    data.w_in_p = ((Double)words.get(word)).doubleValue();
	    data.w = ((Double)wordWeights.get(word)).doubleValue();
	    data.not_w = overallWeight - data.w;
	    data.w_in_not_p = data.w - data.w_in_p;
	    if(data.w_in_not_p == 0)
				data.w_in_not_p = 1;
	    data.not_w_in_p = data.p - data.w_in_p;
	    data.not_w_in_not_p = data.not_p - data.w_in_not_p;
	    double infoGain = calculateInformationGain(data);
	    //double infoGain = calculateLogRatio(data);
	    WeightedString weightedWord = new WeightedString(word, infoGain);
	    updateSortedWords(sortedWords, weightedWord);
		}
		printInformationGains(person, sortedWords, out);
		person.setTopKeyWordWeights(sortedWords); // Store Information Gain scores for Person
	}
	public static void makeDir(File dir) {
		try {
	    if(dir.exists() == false)
				dir.mkdir();
		} catch (SecurityException e) {
	    System.out.println("No permission to make directory " + dir);
		}
	}
	
	public void calculateInformationGain(People people, String outDirName) {
		try {
			makeDir (new File (outDirName));
			Iterator iter = people.iterator();
			while (iter.hasNext()) {
				Person person = (Person)iter.next();
				// Don't calculate IG for people who don't have Contact Record
				//if(person.getContactRecord().size() == 0)
				//    continue;
				if(person.keyWords.numLocations() == 0)
					continue;
				String outFileName = 
					outDirName + File.separator + person.getFirstName() + ".txt";
				BufferedWriter out
					= new BufferedWriter(new FileWriter(new File(outFileName)));
				System.out.println("InfoGain for " + person.names.get(0));
				calculateInformationGain(person, out);
				out.close();
	    }
		} catch(IOException e) {
	    System.err.println("IO problem in outputting keywords");
		}
	}
	
	public void printInformationGains(Person person, ArrayList sortedWords, BufferedWriter out) {
		try {
	    int MAX_NUM_OF_WORDS = 100;
	    //person.writeToFilePersonalInfo(out);
	    //out.write("Words:");
	    //out.newLine();
	    for(int i = 0; i < sortedWords.size() && i < MAX_NUM_OF_WORDS; i++) {
				WeightedString weightedWord = (WeightedString)sortedWords.get(i);
				//if(i > 0) out.write(",");
				//out.write(" " + weightedWord.str + "(");
				//out.write(weightedWord.weight + ")");
				out.write(weightedWord.str);
				out.newLine();
	    }
	    //out.write("------------------------------------------------------");
	    //out.newLine();
		} catch(IOException e) {
	    System.out.println("Problem with writing to infogain file");
		}
		
	}
	
    // Inner classes
	public class PersonWordWeights {
		public PersonWordWeights() {
		}
		
		double p;
		double not_p;
		double w;
		double not_w;
		double w_in_p;
		double w_in_not_p;
		double not_w_in_p;
		double not_w_in_not_p;

		// Serialization
		private static final long serialVersionUID = 1;
		private static final int CURRENT_SERIAL_VERSION = 0;
		
		private void writeObject (ObjectOutputStream out) throws IOException {
			out.writeInt (CURRENT_SERIAL_VERSION);
			out.writeDouble (p);
			out.writeDouble (not_p);
			out.writeDouble (w);
			out.writeDouble (not_w);
			out.writeDouble (w_in_p);
			out.writeDouble (w_in_not_p);
			out.writeDouble (not_w_in_p);
			out.writeDouble (not_w_in_not_p);
		}
		private void readObject (ObjectInputStream in) throws IOException {
			int version = in.readInt();
			p = in.readDouble ();
			not_p = in.readDouble ();
			w = in.readDouble ();
			not_w = in.readDouble ();
			w_in_p = in.readDouble ();
			w_in_not_p = in.readDouble ();
			not_w_in_p = in.readDouble ();
			not_w_in_not_p = in.readDouble ();
		}
	}
	
	
    // Fields
	double overallWeight;
	public HashMap wordWeights;
	public HashMap personWeights;

	// serialization
	private static final long serialVersionUID = 1;
	private static final int CURRENT_SERIAL_VERSION = 0;
	
	private void writeObject (ObjectOutputStream out) throws IOException {
		out.writeInt (CURRENT_SERIAL_VERSION);
		out.writeDouble (overallWeight);
		out.writeObject (wordWeights);
		out.writeObject (personWeights);
	}

	private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
		int version = in.readInt();
		double overallWeight = in.readDouble();
		wordWeights = (HashMap) in.readObject();
		personWeights = (HashMap) in.readObject();		
	}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -