kernelvsmetric.java

来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 702 行 · 第 1/2 页

JAVA
702
字号
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    KernelVSMetric.java *    Copyright (C) 2001 Mikhail Bilenko, Raymond J. Mooney * */package weka.deduping.metrics;import java.util.*;import java.text.SimpleDateFormat;import java.io.*;import weka.core.*;import weka.deduping.*;import weka.classifiers.DistributionClassifier;import weka.classifiers.Classifier;import weka.classifiers.sparse.*;import weka.classifiers.functions.SMO;import weka.classifiers.Evaluation;/** * This class defines a basic string kernel based on vector space * Some code borrowed from ir.vsr package by Raymond J. Mooney * * @author Mikhail Bilenko */public class KernelVSMetric extends StringMetric implements DataDependentStringMetric, LearnableStringMetric,							    OptionHandler, Serializable {  /** Strings are mapped to StringReferences in this hash */  protected HashMap m_stringRefHash = null;  /** A HashMap where tokens are indexed. Each indexed token maps   * to a TokenInfo. */  protected HashMap m_tokenHash = null;  /** A HashMap where each token is mapped to the corresponding Attribute */  protected HashMap m_tokenAttrMap = null;  /** A list of all indexed strings.  Elements are StringReference's. */  public ArrayList m_stringRefs = null;  /** An underlying tokenizer that is used for converting strings   * into HashMapVectors   */  protected Tokenizer m_tokenizer = new WordTokenizer();  /** Should IDF weighting be used? */  protected boolean m_useIDF = true;  /** We can have different ways of converting from similarity to distance */  public static final int CONVERSION_LAPLACIAN = 1;  public static final int CONVERSION_UNIT = 2;  public static final int CONVERSION_EXPONENTIAL = 4;  public static final Tag[] TAGS_CONVERSION = {    new Tag(CONVERSION_UNIT, "distance = 1-similarity"),    new Tag(CONVERSION_LAPLACIAN, "distance=1/(1+similarity)"),    new Tag(CONVERSION_EXPONENTIAL, "distance=exp(-similarity)")      };  /** The method of converting, by default laplacian */  protected int m_conversionType = CONVERSION_EXPONENTIAL;  /** The classifier */  protected DistributionClassifier m_classifier = new SVMlight();  /** Individual components of the two vectors can be added to the vector-space   * representation */  protected boolean m_useIndividualWeights = false;  /** A special example can be created that contains *all* features so that rare tokens   * are never ignored (assuming the example will be used as a support vector */  protected boolean m_useAllFeaturesExample = false;   /** has the classifier been trained? */  protected boolean m_trained = false;  /** The dataset for the vector space attributes */  protected Instances m_instances = null;    /** Construct a vector space from a given set of examples   * @param strings a list of strings from which the inverted index is   * to be constructed   */  public KernelVSMetric() {    m_stringRefHash = new HashMap();    m_tokenHash = new HashMap();    m_stringRefs = new ArrayList();  }    /** Given a list of strings, build the vector space   */  public void buildMetric(List strings) throws Exception {    m_stringRefHash = new HashMap();    m_tokenHash = new HashMap();    m_stringRefs = new ArrayList();    m_trained = false;        // Loop, processing each of the examples    Iterator stringIterator = strings.iterator();    while (stringIterator.hasNext()) {      String string = (String)stringIterator.next();      // Create a document vector for this document      HashMapVector vector = m_tokenizer.tokenize(string);      vector.initLength();      indexString(string, vector);    }    // Now that all strings have been processed, we can calculate the IDF weights for    // all tokens and the resulting lengths of all weighted document vectors.    computeIDFandStringLengths();    initKernel();    System.out.println("Indexed " +  m_stringRefs.size() + " strings with " + size() + " unique terms.");  }  /** Index a given string using its corresponding vector */  protected void indexString(String string, HashMapVector vector) {    // Create a new reference    StringReference strRef = new StringReference(string, vector);    m_stringRefs.add(strRef);        m_stringRefHash.put(string, strRef);    // Iterate through each of the tokens in the document    Iterator mapEntries = vector.iterator();    while (mapEntries.hasNext()) {      Map.Entry entry = (Map.Entry)mapEntries.next();      // An entry in the HashMap maps a token to a Weight      String token = (String)entry.getKey();      // The count for the token is in the value of the Weight      int count = (int)((Weight)entry.getValue()).getValue();      // Add an occurence of this token to the inverted index pointing to this document      indexToken(token, count, strRef);    }  }  /** Add a token occurrence to the index.   * @param token The token to index.   * @param count The number of times it occurs in the document.   * @param strRef A reference to the String it occurs in.   */  protected void indexToken(String token, int count, StringReference strRef) {    // Find this token in the index    TokenInfo tokenInfo = (TokenInfo)m_tokenHash.get(token);    if (tokenInfo == null) {      // If this is a new token, create info for it to put in the hashtable      tokenInfo = new TokenInfo();      m_tokenHash.put(token, tokenInfo);    }    // Add a new occurrence for this token to its info    tokenInfo.occList.add(new TokenOccurrence(strRef, count));  }  /** Compute the IDF factor for every token in the index and the length   * of the string vector for every string referenced in the index. */  protected void computeIDFandStringLengths() {    // Let N be the total number of documents indexed    double N = m_stringRefs.size();    // Iterate through each of the tokens in the index     Iterator mapEntries = m_tokenHash.entrySet().iterator();    while (mapEntries.hasNext()) {      // Get the token and the tokenInfo for each entry in the HashMap      Map.Entry entry = (Map.Entry)mapEntries.next();      String token = (String)entry.getKey();      TokenInfo tokenInfo = (TokenInfo)entry.getValue();      // Get the total number of strings in which this token occurs      double numStringRefs = tokenInfo.occList.size();       // Calculate the IDF factor for this token      double idf = Math.log(N/numStringRefs);      if (idf == 0.0) 	// If IDF is 0, then just remove this "omnipresent" token from the index	mapEntries.remove();      else {	tokenInfo.idf = idf;	// In order to compute document vector lengths,  sum the	// square of the weights (IDF * occurrence count) across	// every token occurrence for each document.	for(int i = 0; i < tokenInfo.occList.size(); i++) {	  TokenOccurrence occ = (TokenOccurrence)tokenInfo.occList.get(i);	  if (m_useIDF) { 	    occ.m_stringRef.m_length = occ.m_stringRef.m_length + Math.pow(idf*occ.m_count, 2);	  } else {	    occ.m_stringRef.m_length = occ.m_stringRef.m_length + occ.m_count * occ.m_count;	  }	}      }    }    // At this point, every document length should be the sum of the squares of    // its token weights.  In order to calculate final lengths, just need to    // set the length of every document reference to the square-root of this sum.    for(int i = 0; i < m_stringRefs.size(); i++) {      StringReference stringRef = (StringReference)m_stringRefs.get(i);      stringRef.m_length = Math.sqrt(stringRef.m_length);    }  }  /** Provided that all features are known, initialize the feature space for the kernel   */  protected void initKernel() {    m_tokenAttrMap = new HashMap();        // create the features    FastVector attrVector = new FastVector(m_tokenHash.size());    Iterator iterator = m_tokenHash.keySet().iterator();    while (iterator.hasNext()) {      String token = (String) iterator.next();      Attribute attr = new Attribute(token);      attrVector.addElement(attr);      m_tokenAttrMap.put(token, attr);    }    // If we are interested in a "concatenated" representation, add the extra features    if (m_useIndividualWeights) {      Iterator iterator1 = m_tokenHash.keySet().iterator();      while (iterator1.hasNext()) {	String token = (String) iterator1.next();	Attribute attr_s1 = new Attribute("s1_" + token);	Attribute attr_s2 = new Attribute("s2_" + token);	attrVector.addElement(attr_s1);	attrVector.addElement(attr_s2);	m_tokenAttrMap.put("s1_" + token, attr_s1);	m_tokenAttrMap.put("s2_" + token, attr_s2);      }     }     // create the class attribute    FastVector classValues = new FastVector();    classValues.addElement("pos");    classValues.addElement("neg");    Attribute classAttr = new Attribute("__class__", classValues);    attrVector.addElement(classAttr);    // create the dataset for the vector space    m_instances = new Instances("diffInstances", attrVector, 3000);    m_instances.setClass(classAttr);  }     /** Train the metric given a set of aligned strings   * @param pairList the training data as a list of StringPair's   * @returns distance between two strings   */  public void trainMetric(ArrayList pairList) throws Exception {    m_instances.delete();    // some training pairs will be deemed unworthy    int numDiscardedPositives = 0;    int numDiscardedNegatives = 0;        // populate the training instances    HashSet seenInstances = new HashSet();    for (int i = 0; i < pairList.size(); i++) {      StringPair pair = (StringPair) pairList.get(i);      SparseInstance pairInstance = createPairInstance(pair.str1, pair.str2);      double[] values = pairInstance.toDoubleArray();      if (seenInstances.contains(values)) {	System.out.println("Seen instance vector, skipping: " + pairInstance + "   <= " + pair.str1 + "\t" + pair.str2);      } else { 	// this pair vector has not been seen before	boolean good = true;	// set the dataset and the class value	pairInstance.setDataset(m_instances);	if (pair.positive) {	  pairInstance.setClassValue(0);	  if (pairInstance.numValues() < 1) {	    System.out.println("Too few values, skipping: " + pairInstance + "   <= " + pair.str1 + "\t" + pair.str2);	    good = false;	    numDiscardedPositives++;	  }	} else {	  // negative example	  pairInstance.setClassValue(1);	}		if (good) {	  m_instances.add(pairInstance);	}      }    }    System.out.println("Discarded " + numDiscardedPositives + " positives; " + 		       "went from " + pairList.size() + " down to " + m_instances.numInstances() + " training instances");    // Add an artificial example containing all features to prevent rare features from being excluded    if (m_useAllFeaturesExample) {       Instance allFeaturesInstance = new Instance(m_instances.numAttributes());       allFeaturesInstance.setDataset(m_instances);      allFeaturesInstance.setClassValue(0);      Iterator mapEntries = m_tokenHash.entrySet().iterator();      while (mapEntries.hasNext()) {	Map.Entry entry = (Map.Entry)mapEntries.next();	String token = (String)entry.getKey();	TokenInfo tokenInfo = (TokenInfo)entry.getValue();	Attribute attr = (Attribute) m_tokenAttrMap.get(token);	allFeaturesInstance.setValue(attr, tokenInfo.idf);	// if we are using concatenated representation, add those features as well	if (m_useIndividualWeights) {	  Attribute attr1 = (Attribute) m_tokenAttrMap.get("s1_" + token);	  allFeaturesInstance.setValue(attr1, tokenInfo.idf);	  Attribute attr2 = (Attribute) m_tokenAttrMap.get("s2_" + token);	  allFeaturesInstance.setValue(attr2, tokenInfo.idf);	}       }      normalizeInstance(allFeaturesInstance);      m_instances.add(allFeaturesInstance);      if (m_classifier instanceof SVMcplex) {	((SVMcplex)m_classifier).setUseAllFeaturesExample(true);       }     }            // BEGIN SANITY CHECK    // dump diff-instances into a temporary file    if (false) {       try {	Instances instances = new Instances(m_instances);	// dump instances	File diffDir = new File("/tmp/KVS");	diffDir.mkdir();	String diffName = Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers.");	PrintWriter writer = new PrintWriter(new BufferedOutputStream (new FileOutputStream(diffDir.getPath() + "/" +											    diffName + ".arff")));	writer.println(instances.toString());	writer.close();	// Do a sanity check - dump out the diffInstances, and	// evaluation classification with an SVM. 	long trainTimeStart = System.currentTimeMillis();	//	SVMlight classifier = new SVMlight();

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?