⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 kl.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 4 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    KL.java *    Copyright (C) 2003 Mikhail Bilenko * */package weka.core.metrics;import weka.core.*;import weka.deduping.metrics.HashMapVector;import java.util.*;import java.io.*;/**  * KL class * * Implements weighted Kullback-Leibler divergence  * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.23 $ */public class KL extends SmoothingMetric  implements InstanceConverter, OptionHandler {  public final double LOG2 = Math.log(2);  /** We can switch between regular KL divergence and I-divergence */  protected boolean m_useIDivergence = true;  /** Frequencies over the entire dataset used for smoothing */  protected HashMapVector m_datasetFrequencies = null;  /** We hash sum(p log(p)) terms for the input instances to speed up computation */  protected HashMap m_instanceNormHash = null;   /** Total number of tokens in the dataset */  protected int m_numTotalTokens = 0;  /** Different smoothing methods for obtaining probability distributions from frequencies  */  public static final int SMOOTHING_UNSMOOTHED = 1;  public static final int SMOOTHING_DIRICHLET = 2;  public static final int SMOOTHING_JELINEK_MERCER = 4;  public static final Tag[] TAGS_SMOOTHING = {    new Tag(SMOOTHING_UNSMOOTHED, "unsmoothed"),    new Tag(SMOOTHING_DIRICHLET, "Dirichlet"),    new Tag(SMOOTHING_JELINEK_MERCER, "Jelinek-Mercer")      };  /** The smoothing method */  protected int m_smoothingType = SMOOTHING_UNSMOOTHED;   /** The pseudocount value for the Dirichlet smoothing */  protected double m_pseudoCountDirichlet = 1.0;  /** The lambda value for the Jelinek-Mercer smoothing */  protected double m_lambdaJM = 0.5;       /** We can have different ways of converting from distance to similarity  */  public static final int CONVERSION_LAPLACIAN = 1;  public static final int CONVERSION_UNIT = 2;  public static final int CONVERSION_EXPONENTIAL = 4;  public static final Tag[] TAGS_CONVERSION = {    new Tag(CONVERSION_UNIT, "similarity = 1-distance"),    new Tag(CONVERSION_LAPLACIAN, "similarity=1/(1+distance)"),    new Tag(CONVERSION_EXPONENTIAL, "similarity=exp(-distance)")      };  /** The method of converting, by default laplacian */  protected int m_conversionType = CONVERSION_LAPLACIAN;  /** A metric learner responsible for training the parameters of the metric */  protected MetricLearner m_metricLearner = new ClassifierMetricLearner();//    protected MetricLearner m_metricLearner = new GDMetricLearner();  /** A hashmap that maps every instance to a set of instances with which JS has been computed */  protected HashMap m_instanceConstraintMap = new HashMap();    /**   * Create a new metric.   * @param numAttributes the number of attributes that the metric will work on   */   public KL(int numAttributes) throws Exception {    super();     buildMetric(numAttributes);  }  /** Create a default new metric */  public KL() {    m_fixedMaxDistance = true;     m_maxDistance = 1;   }      /**   * Creates a new metric which takes specified attributes.   *   * @param _attrIdxs An array containing attribute indeces that will   * be used in the metric   */  public KL(int[] _attrIdxs) throws Exception {    super();    setAttrIdxs(_attrIdxs);    buildMetric(_attrIdxs.length);	  }  /**   * Reset all values that have been learned   */  public void resetMetric() throws Exception {    super.resetMetric();    m_currAlpha = m_alpha;    m_instanceConstraintMap = new HashMap();  }  /**   * Generates a new Metric. Has to initialize all fields of the metric   * with default values.   *   * @param numAttributes the number of attributes that the metric will work on   * @exception Exception if the distance metric has not been   * generated successfully.   */  public void buildMetric(int numAttributes) throws Exception {    m_numAttributes = numAttributes;    m_attrWeights = new double[numAttributes];    m_attrIdxs = new int[numAttributes];    for (int i = 0; i < numAttributes; i++) {      m_attrWeights[i] = 1;      m_attrIdxs[i] = i;    }    m_instanceConstraintMap = new HashMap();    m_currAlpha = m_alpha;  }      /**   * Generates a new Metric. Has to initialize all fields of the metric   * with default values   *   * @param options an array of options suitable for passing to setOptions.   * May be null.    * @exception Exception if the distance metric has not been   * generated successfully.   */  public void buildMetric(int numAttributes, String[] options) throws Exception {    buildMetric(numAttributes);  }  /**   * Create a new metric for operating on specified instances   * @param data instances that the metric will be used on   */  public  void buildMetric(Instances data) throws Exception {    m_classIndex = data.classIndex();    m_numAttributes = data.numAttributes();    if (m_classIndex != m_numAttributes-1 && m_classIndex != -1) {      throw new Exception("Class attribute (" + m_classIndex + ") should be the last attribute!!!");    }    if (m_classIndex != -1) {      m_numAttributes--;    }    buildMetric(m_numAttributes);    // hash the dataset-wide frequencies    m_datasetFrequencies = new HashMapVector(); //  # of occurrences of each unique token in dataset    m_numTotalTokens = 0;  // total # of (non-unique) tokens in the dataset    double []instanceLengths = new double[data.numInstances()]; // num tokens per instance        for (int i = 0; i < data.numInstances(); i++) {      Instance instance = data.instance(i);      for (int j = 0; j < instance.numValues(); j++) {	Attribute attr = instance.attributeSparse(j);	int attrIdx = instance.index(j);	if (attrIdx != m_classIndex) {	  m_datasetFrequencies.increment(attr.name(), instance.value(attr));	  m_numTotalTokens += instance.value(attr);	  instanceLengths[i] += instance.value(attr);	}      }    }    // convert all instances in the dataset    System.out.println("\n\nConverting all instances for KL distance\n");    Instances convertedData = new Instances(data, data.numInstances());    for (int i = 0; i < data.numInstances(); i++) {      convertedData.add(convertInstance(data.instance(i)));      if (i % 10 == 9 ) System.out.print(".");      if (i % 100 == 99) System.out.println(" " + (i+1));          }    System.out.println();        // copy all converted instances to original data    data.delete();    for (int i = 0; i < convertedData.numInstances(); i++) {      data.add(convertedData.instance(i));    }    // Hash instance norms    m_instanceNormHash = new HashMap();    for (int i = 0; i < data.numInstances(); i++) {      Instance instance = data.instance(i);      double norm = 0;      for (int j = 0; j < instance.numValues(); j++) { 	int attrIdx = instance.index(j);	if (attrIdx != m_classIndex) {	  double value = instance.value(attrIdx);	  norm += value * Math.log(value);	}      }      m_instanceNormHash.put(instance, new Double(norm));    }    if (m_trainable) {      learnMetric(data);    }  }  public Instance convertInstance(Instance oldInstance) {    Instance newInstance;    if (oldInstance instanceof SparseInstance	&& m_smoothingType == SMOOTHING_UNSMOOTHED) {      newInstance = new SparseInstance(oldInstance);    } else { // either original data is dense, or smoothing is on - returning dense.      newInstance = new Instance(oldInstance.numAttributes());    }    newInstance.setDataset(oldInstance.dataset());    int classIdx = oldInstance.classIndex();    // get the total count of tokens for this instance    double numTotalTokens = 0;    for (int i = 0; i < oldInstance.numValues(); i++) {      int idx = oldInstance.index(i);      if (idx != classIdx) { 	numTotalTokens += oldInstance.valueSparse(i);      }    }    // we're iterating over newInstance in case    // there was a transition from sparse to non-sparse.     for (int i = 0; i < newInstance.numValues(); i++) {      int idx = newInstance.index(i);      if (idx != classIdx) {	Attribute attr = newInstance.attribute(idx);	newInstance.setValue(idx, convertFrequency(oldInstance.value(idx),						   numTotalTokens, attr.name()));      }    }    return newInstance;  }   /** Given a frequency of a given token in a document, convert    *  it to a probability value for that document's distribution   * @param freq frequency of a token   * @param token the token   * @returns a probability value   */  protected double convertFrequency(double freq, double numTotalTokens, String token) {    double datasetProb = 0;     switch (m_smoothingType) {    case SMOOTHING_UNSMOOTHED:      return freq/numTotalTokens;    case SMOOTHING_DIRICHLET:      datasetProb = m_datasetFrequencies.getWeight(token) / m_numTotalTokens;      return (freq + m_pseudoCountDirichlet * datasetProb) /	(numTotalTokens + m_pseudoCountDirichlet);    case SMOOTHING_JELINEK_MERCER:      datasetProb = m_datasetFrequencies.getWeight(token) / m_numTotalTokens;      return (1 - m_lambdaJM) * (freq / numTotalTokens) + m_lambdaJM * datasetProb;    default:      System.err.println("Unknown smoothing method: " + m_smoothingType);       return -1;     }  }  /** Smooth an instance */  public Instance smoothInstance(Instance instance) {    int numAttributes = instance.numAttributes();     double[] values = new double[numAttributes];    double prior = 1.0/ numAttributes;    for (int j = 0; j < numAttributes; j++) {      values[j] = 1.0 / (1 + m_alpha) * (instance.value(j) +  m_alpha * prior);    }    return  new Instance(1.0, values);   }         /**   * Returns a distance value between two instances.    * @param instance1 First instance.   * @param instance2 Second instance.   * @exception Exception if distance could not be estimated.   */  public double distance(Instance instance1, Instance instance2) throws Exception {    // either pass the computation to the external classifier, or do the work yourself    if (m_trainable && m_external && m_trained) {      return m_metricLearner.getDistance(instance1, instance2);    } else {      return distanceInternal(instance1, instance2);    }  }  /** Return the penalty contribution - KL */  public double penalty(Instance instance1,			Instance instance2) throws Exception {    double distance = distance(instance1, instance2);    return distance;   }  /** Return the penalty contribution - JS */  public double penaltySymmetric(Instance instance1,			Instance instance2) throws Exception {    double distance = distanceJS(instance1, instance2);    return distance;   }  /**   * Returns a distance value between two instances.    * @param instance1 First instance.   * @param instance2 Second instance.   * @exception Exception if distance could not be estimated.   */  public double distanceInternal(Instance instance1, Instance instance2) throws Exception {    if (instance1 instanceof SparseInstance) {      return distanceSparse((SparseInstance)instance1, instance2);    } else {      return distanceNonSparse(instance1, instance2);    }  }      /** Returns a distance value between two sparse instances.    * @param instance1 First sparse instance.   * @param instance2 Second sparse instance.   * @exception Exception if distance could not be estimated.   */  public double distanceSparse(SparseInstance instance1, Instance instance2) throws Exception {    double distance = 0, value1, value2, idivTerm = 0;    int numValues1 = instance1.numValues();    boolean dbg = true;	        // iterate through the attributes that are present in the first instance    for (int i = 0; i < numValues1; i++) {      int attrIdx = instance1.index(i);      if (attrIdx != m_classIndex) {	value1 = instance1.valueSparse(i);	value2 = instance2.value(attrIdx);	if (value2 > 0) { 	  distance += m_attrWeights[attrIdx] * value1 * Math.log(value1/value2);	  if (m_useIDivergence) {	    idivTerm -= m_attrWeights[attrIdx] * value1;	  }	  //	    if (dbg) { System.out.println("\t" + attrIdx + "\t" + value1 + "  " + value2 + "\t" + distance); dbg= false;}	} else {	  System.err.println("KL.distanceNonSparse:  0 value in instance2, attribute=" + attrIdx + "\n" + instance2.value(attrIdx) + "\n" + instance2); 	  	  return Double.MAX_VALUE;	}       }    }    // if i-divergence is used, need to pick up values of instance2    if (m_useIDivergence) {       for (int i = 0; i < instance2.numValues(); i++) {	int attrIdx = instance2.index(i);	if (attrIdx != m_classIndex) {	  value2 = instance2.valueSparse(i);	  idivTerm += m_attrWeights[attrIdx] * value2;	}      }    }    distance = distance + idivTerm;    return distance;  }  /** Returns a distance value between non-sparse instances without using the weights

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -