📄 informationgain.java
字号:
/* Copyright (C) 2003 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
/** Calculates information gain between two hashmaps of words
* @author Ron Bekkerman <A HREF="mailto:ronb@cs.umass.edu">ronb@cs.umass.edu</A>
*/
package edu.umass.cs.mallet.projects.dex.types;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.Alphabet;
import java.util.*;
import java.lang.Math.*;
import java.io.*;
public class InformationGain implements Serializable{
public InformationGain(People people, String outDirName) {
wordWeights = new HashMap();
personWeights = new HashMap();
fillWeightMaps(people);
calculateInformationGain(people, outDirName);
}
public HashMap featureVector2HashMap (FeatureVector fv) {
HashMap h = new HashMap (fv.numLocations());
Alphabet alph = fv.getAlphabet();
for (int i=0; i < fv.numLocations(); i++) {
double value = fv.valueAtLocation (i);
int index = fv.indexAtLocation (i);
String word = (String) alph.lookupObject (index);
h.put (word, new Double (value));
}
return h;
}
public void updateWordWeights(String word, Double wordCurrentWeight) {
Double wordWeight = (Double) wordWeights.get(word);
if (wordWeight != null) {
//The word already exists in wordWeights
wordWeights.put(word, new Double(wordWeight.doubleValue() +
wordCurrentWeight.doubleValue()));
}
else {
//The word is seen for the first time
wordWeights.put(word, wordCurrentWeight);
}
}
public void updatePersonWeights(Person person, Double wordCurrentWeight) {
Double wordWeight = (Double) personWeights.get(person);
if (wordWeight != null) {
//The word already exists in personWeights
personWeights.put(person, new Double(wordWeight.doubleValue() +
wordCurrentWeight.doubleValue()));
}
else {
//The word is seen for the first time
personWeights.put(person, wordCurrentWeight);
}
}
public void fillWeightMaps(People people) {
Iterator iter = people.iterator ();
while (iter.hasNext()) {
Person person = (Person)iter.next();
//if(person.getContactRecord() == null) continue;
HashMap words = featureVector2HashMap (person.keyWords);
Object[] wordArray = words.keySet().toArray();
for(int j = 0; j < wordArray.length; j++) {
String word = (String)wordArray[j];
Double wordCurrentWeight = (Double) words.get(word);
updateWordWeights(word, wordCurrentWeight);
updatePersonWeights(person, wordCurrentWeight);
overallWeight = overallWeight + wordCurrentWeight.doubleValue();
}
}
}
public double calculateInformationGain(PersonWordWeights data) {
double infoGain = 0;
if(data.w_in_p != 0)
infoGain = data.w_in_p * Math.log(data.w_in_p * overallWeight /
(data.w * data.p));
if(data.w_in_not_p != 0)
infoGain = infoGain + data.w_in_not_p *
Math.log(data.w_in_not_p * overallWeight / (data.w * data.not_p));
infoGain = infoGain + data.not_w_in_p *
Math.log(data.not_w_in_p * overallWeight / (data.not_w * data.p));
infoGain = infoGain + data.not_w_in_not_p *
Math.log(data.not_w_in_not_p * overallWeight / (data.not_w * data.not_p));
infoGain = infoGain / overallWeight;
return infoGain;
}
public double calculateLogRatio(PersonWordWeights data) {
double logRatio = 0;
logRatio = (data.w_in_p/data.p) *
Math.log(data.w_in_p * data.not_p / (data.p * data.w_in_not_p));
return logRatio;
}
public void updateSortedWords(ArrayList sortedWords, WeightedString weightedWord) {
if(sortedWords.size() == 0) {
sortedWords.add(weightedWord);
return;
}
for(int i = 0; i < sortedWords.size(); i++) {
WeightedString currentWeightedWord = (WeightedString)sortedWords.get(i);
if(weightedWord.weight >= currentWeightedWord.weight) {
sortedWords.add(i, weightedWord);
return;
}
}
sortedWords.add(weightedWord);
}
public void calculateInformationGain(Person person, BufferedWriter out) {
ArrayList sortedWords = new ArrayList();
PersonWordWeights data = new PersonWordWeights();
try {
data.p = ((Double)personWeights.get(person)).doubleValue();
} catch (NullPointerException e) {
System.out.print("Null pointer in ");
person.print();
System.exit(1);
}
data.not_p = overallWeight - data.p;
HashMap words = featureVector2HashMap (person.keyWords);
Object[] wordArray = words.keySet().toArray();
for(int j = 0; j < wordArray.length; j++) {
String word = (String)wordArray[j];
data.w_in_p = ((Double)words.get(word)).doubleValue();
data.w = ((Double)wordWeights.get(word)).doubleValue();
data.not_w = overallWeight - data.w;
data.w_in_not_p = data.w - data.w_in_p;
if(data.w_in_not_p == 0)
data.w_in_not_p = 1;
data.not_w_in_p = data.p - data.w_in_p;
data.not_w_in_not_p = data.not_p - data.w_in_not_p;
double infoGain = calculateInformationGain(data);
//double infoGain = calculateLogRatio(data);
WeightedString weightedWord = new WeightedString(word, infoGain);
updateSortedWords(sortedWords, weightedWord);
}
printInformationGains(person, sortedWords, out);
person.setTopKeyWordWeights(sortedWords); // Store Information Gain scores for Person
}
public static void makeDir(File dir) {
try {
if(dir.exists() == false)
dir.mkdir();
} catch (SecurityException e) {
System.out.println("No permission to make directory " + dir);
}
}
public void calculateInformationGain(People people, String outDirName) {
try {
makeDir (new File (outDirName));
Iterator iter = people.iterator();
while (iter.hasNext()) {
Person person = (Person)iter.next();
// Don't calculate IG for people who don't have Contact Record
//if(person.getContactRecord().size() == 0)
// continue;
if(person.keyWords.numLocations() == 0)
continue;
String outFileName =
outDirName + File.separator + person.getFirstName() + ".txt";
BufferedWriter out
= new BufferedWriter(new FileWriter(new File(outFileName)));
System.out.println("InfoGain for " + person.names.get(0));
calculateInformationGain(person, out);
out.close();
}
} catch(IOException e) {
System.err.println("IO problem in outputting keywords");
}
}
public void printInformationGains(Person person, ArrayList sortedWords, BufferedWriter out) {
try {
int MAX_NUM_OF_WORDS = 100;
//person.writeToFilePersonalInfo(out);
//out.write("Words:");
//out.newLine();
for(int i = 0; i < sortedWords.size() && i < MAX_NUM_OF_WORDS; i++) {
WeightedString weightedWord = (WeightedString)sortedWords.get(i);
//if(i > 0) out.write(",");
//out.write(" " + weightedWord.str + "(");
//out.write(weightedWord.weight + ")");
out.write(weightedWord.str);
out.newLine();
}
//out.write("------------------------------------------------------");
//out.newLine();
} catch(IOException e) {
System.out.println("Problem with writing to infogain file");
}
}
// Inner classes
public class PersonWordWeights {
public PersonWordWeights() {
}
double p;
double not_p;
double w;
double not_w;
double w_in_p;
double w_in_not_p;
double not_w_in_p;
double not_w_in_not_p;
// Serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeDouble (p);
out.writeDouble (not_p);
out.writeDouble (w);
out.writeDouble (not_w);
out.writeDouble (w_in_p);
out.writeDouble (w_in_not_p);
out.writeDouble (not_w_in_p);
out.writeDouble (not_w_in_not_p);
}
private void readObject (ObjectInputStream in) throws IOException {
int version = in.readInt();
p = in.readDouble ();
not_p = in.readDouble ();
w = in.readDouble ();
not_w = in.readDouble ();
w_in_p = in.readDouble ();
w_in_not_p = in.readDouble ();
not_w_in_p = in.readDouble ();
not_w_in_not_p = in.readDouble ();
}
}
// Fields
double overallWeight;
public HashMap wordWeights;
public HashMap personWeights;
// serialization
private static final long serialVersionUID = 1;
private static final int CURRENT_SERIAL_VERSION = 0;
private void writeObject (ObjectOutputStream out) throws IOException {
out.writeInt (CURRENT_SERIAL_VERSION);
out.writeDouble (overallWeight);
out.writeObject (wordWeights);
out.writeObject (personWeights);
}
private void readObject (ObjectInputStream in) throws IOException, ClassNotFoundException {
int version = in.readInt();
double overallWeight = in.readDouble();
wordWeights = (HashMap) in.readObject();
personWeights = (HashMap) in.readObject();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -