📄 kl.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * KL.java * Copyright (C) 2003 Mikhail Bilenko * */package weka.core.metrics;import weka.core.*;import weka.deduping.metrics.HashMapVector;import java.util.*;import java.io.*;/** * KL class * * Implements weighted Kullback-Leibler divergence * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.23 $ */public class KL extends SmoothingMetric implements InstanceConverter, OptionHandler { public final double LOG2 = Math.log(2); /** We can switch between regular KL divergence and I-divergence */ protected boolean m_useIDivergence = true; /** Frequencies over the entire dataset used for smoothing */ protected HashMapVector m_datasetFrequencies = null; /** We hash sum(p log(p)) terms for the input instances to speed up computation */ protected HashMap m_instanceNormHash = null; /** Total number of tokens in the dataset */ protected int m_numTotalTokens = 0; /** Different smoothing methods for obtaining probability distributions from frequencies */ public static final int SMOOTHING_UNSMOOTHED = 1; public static final int SMOOTHING_DIRICHLET = 2; public static final int SMOOTHING_JELINEK_MERCER = 4; public static final Tag[] TAGS_SMOOTHING = { new Tag(SMOOTHING_UNSMOOTHED, "unsmoothed"), new Tag(SMOOTHING_DIRICHLET, "Dirichlet"), new Tag(SMOOTHING_JELINEK_MERCER, "Jelinek-Mercer") }; /** The smoothing method */ protected int m_smoothingType = SMOOTHING_UNSMOOTHED; /** The pseudocount value for the Dirichlet smoothing */ protected double m_pseudoCountDirichlet = 1.0; /** The lambda value for the Jelinek-Mercer smoothing */ protected double m_lambdaJM = 0.5; /** We can have different ways of converting from distance to similarity */ public static final int CONVERSION_LAPLACIAN = 1; public static final int CONVERSION_UNIT = 2; public static final int CONVERSION_EXPONENTIAL = 4; public static final Tag[] TAGS_CONVERSION = { new Tag(CONVERSION_UNIT, "similarity = 1-distance"), new Tag(CONVERSION_LAPLACIAN, "similarity=1/(1+distance)"), new Tag(CONVERSION_EXPONENTIAL, "similarity=exp(-distance)") }; /** The method of converting, by default laplacian */ protected int m_conversionType = CONVERSION_LAPLACIAN; /** A metric learner responsible for training the parameters of the metric */ protected MetricLearner m_metricLearner = new ClassifierMetricLearner();// protected MetricLearner m_metricLearner = new GDMetricLearner(); /** A hashmap that maps every instance to a set of instances with which JS has been computed */ protected HashMap m_instanceConstraintMap = new HashMap(); /** * Create a new metric. * @param numAttributes the number of attributes that the metric will work on */ public KL(int numAttributes) throws Exception { super(); buildMetric(numAttributes); } /** Create a default new metric */ public KL() { m_fixedMaxDistance = true; m_maxDistance = 1; } /** * Creates a new metric which takes specified attributes. * * @param _attrIdxs An array containing attribute indeces that will * be used in the metric */ public KL(int[] _attrIdxs) throws Exception { super(); setAttrIdxs(_attrIdxs); buildMetric(_attrIdxs.length); } /** * Reset all values that have been learned */ public void resetMetric() throws Exception { super.resetMetric(); m_currAlpha = m_alpha; m_instanceConstraintMap = new HashMap(); } /** * Generates a new Metric. Has to initialize all fields of the metric * with default values. * * @param numAttributes the number of attributes that the metric will work on * @exception Exception if the distance metric has not been * generated successfully. */ public void buildMetric(int numAttributes) throws Exception { m_numAttributes = numAttributes; m_attrWeights = new double[numAttributes]; m_attrIdxs = new int[numAttributes]; for (int i = 0; i < numAttributes; i++) { m_attrWeights[i] = 1; m_attrIdxs[i] = i; } m_instanceConstraintMap = new HashMap(); m_currAlpha = m_alpha; } /** * Generates a new Metric. Has to initialize all fields of the metric * with default values * * @param options an array of options suitable for passing to setOptions. * May be null. * @exception Exception if the distance metric has not been * generated successfully. */ public void buildMetric(int numAttributes, String[] options) throws Exception { buildMetric(numAttributes); } /** * Create a new metric for operating on specified instances * @param data instances that the metric will be used on */ public void buildMetric(Instances data) throws Exception { m_classIndex = data.classIndex(); m_numAttributes = data.numAttributes(); if (m_classIndex != m_numAttributes-1 && m_classIndex != -1) { throw new Exception("Class attribute (" + m_classIndex + ") should be the last attribute!!!"); } if (m_classIndex != -1) { m_numAttributes--; } buildMetric(m_numAttributes); // hash the dataset-wide frequencies m_datasetFrequencies = new HashMapVector(); // # of occurrences of each unique token in dataset m_numTotalTokens = 0; // total # of (non-unique) tokens in the dataset double []instanceLengths = new double[data.numInstances()]; // num tokens per instance for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); for (int j = 0; j < instance.numValues(); j++) { Attribute attr = instance.attributeSparse(j); int attrIdx = instance.index(j); if (attrIdx != m_classIndex) { m_datasetFrequencies.increment(attr.name(), instance.value(attr)); m_numTotalTokens += instance.value(attr); instanceLengths[i] += instance.value(attr); } } } // convert all instances in the dataset System.out.println("\n\nConverting all instances for KL distance\n"); Instances convertedData = new Instances(data, data.numInstances()); for (int i = 0; i < data.numInstances(); i++) { convertedData.add(convertInstance(data.instance(i))); if (i % 10 == 9 ) System.out.print("."); if (i % 100 == 99) System.out.println(" " + (i+1)); } System.out.println(); // copy all converted instances to original data data.delete(); for (int i = 0; i < convertedData.numInstances(); i++) { data.add(convertedData.instance(i)); } // Hash instance norms m_instanceNormHash = new HashMap(); for (int i = 0; i < data.numInstances(); i++) { Instance instance = data.instance(i); double norm = 0; for (int j = 0; j < instance.numValues(); j++) { int attrIdx = instance.index(j); if (attrIdx != m_classIndex) { double value = instance.value(attrIdx); norm += value * Math.log(value); } } m_instanceNormHash.put(instance, new Double(norm)); } if (m_trainable) { learnMetric(data); } } public Instance convertInstance(Instance oldInstance) { Instance newInstance; if (oldInstance instanceof SparseInstance && m_smoothingType == SMOOTHING_UNSMOOTHED) { newInstance = new SparseInstance(oldInstance); } else { // either original data is dense, or smoothing is on - returning dense. newInstance = new Instance(oldInstance.numAttributes()); } newInstance.setDataset(oldInstance.dataset()); int classIdx = oldInstance.classIndex(); // get the total count of tokens for this instance double numTotalTokens = 0; for (int i = 0; i < oldInstance.numValues(); i++) { int idx = oldInstance.index(i); if (idx != classIdx) { numTotalTokens += oldInstance.valueSparse(i); } } // we're iterating over newInstance in case // there was a transition from sparse to non-sparse. for (int i = 0; i < newInstance.numValues(); i++) { int idx = newInstance.index(i); if (idx != classIdx) { Attribute attr = newInstance.attribute(idx); newInstance.setValue(idx, convertFrequency(oldInstance.value(idx), numTotalTokens, attr.name())); } } return newInstance; } /** Given a frequency of a given token in a document, convert * it to a probability value for that document's distribution * @param freq frequency of a token * @param token the token * @returns a probability value */ protected double convertFrequency(double freq, double numTotalTokens, String token) { double datasetProb = 0; switch (m_smoothingType) { case SMOOTHING_UNSMOOTHED: return freq/numTotalTokens; case SMOOTHING_DIRICHLET: datasetProb = m_datasetFrequencies.getWeight(token) / m_numTotalTokens; return (freq + m_pseudoCountDirichlet * datasetProb) / (numTotalTokens + m_pseudoCountDirichlet); case SMOOTHING_JELINEK_MERCER: datasetProb = m_datasetFrequencies.getWeight(token) / m_numTotalTokens; return (1 - m_lambdaJM) * (freq / numTotalTokens) + m_lambdaJM * datasetProb; default: System.err.println("Unknown smoothing method: " + m_smoothingType); return -1; } } /** Smooth an instance */ public Instance smoothInstance(Instance instance) { int numAttributes = instance.numAttributes(); double[] values = new double[numAttributes]; double prior = 1.0/ numAttributes; for (int j = 0; j < numAttributes; j++) { values[j] = 1.0 / (1 + m_alpha) * (instance.value(j) + m_alpha * prior); } return new Instance(1.0, values); } /** * Returns a distance value between two instances. * @param instance1 First instance. * @param instance2 Second instance. * @exception Exception if distance could not be estimated. */ public double distance(Instance instance1, Instance instance2) throws Exception { // either pass the computation to the external classifier, or do the work yourself if (m_trainable && m_external && m_trained) { return m_metricLearner.getDistance(instance1, instance2); } else { return distanceInternal(instance1, instance2); } } /** Return the penalty contribution - KL */ public double penalty(Instance instance1, Instance instance2) throws Exception { double distance = distance(instance1, instance2); return distance; } /** Return the penalty contribution - JS */ public double penaltySymmetric(Instance instance1, Instance instance2) throws Exception { double distance = distanceJS(instance1, instance2); return distance; } /** * Returns a distance value between two instances. * @param instance1 First instance. * @param instance2 Second instance. * @exception Exception if distance could not be estimated. */ public double distanceInternal(Instance instance1, Instance instance2) throws Exception { if (instance1 instanceof SparseInstance) { return distanceSparse((SparseInstance)instance1, instance2); } else { return distanceNonSparse(instance1, instance2); } } /** Returns a distance value between two sparse instances. * @param instance1 First sparse instance. * @param instance2 Second sparse instance. * @exception Exception if distance could not be estimated. */ public double distanceSparse(SparseInstance instance1, Instance instance2) throws Exception { double distance = 0, value1, value2, idivTerm = 0; int numValues1 = instance1.numValues(); boolean dbg = true; // iterate through the attributes that are present in the first instance for (int i = 0; i < numValues1; i++) { int attrIdx = instance1.index(i); if (attrIdx != m_classIndex) { value1 = instance1.valueSparse(i); value2 = instance2.value(attrIdx); if (value2 > 0) { distance += m_attrWeights[attrIdx] * value1 * Math.log(value1/value2); if (m_useIDivergence) { idivTerm -= m_attrWeights[attrIdx] * value1; } // if (dbg) { System.out.println("\t" + attrIdx + "\t" + value1 + " " + value2 + "\t" + distance); dbg= false;} } else { System.err.println("KL.distanceNonSparse: 0 value in instance2, attribute=" + attrIdx + "\n" + instance2.value(attrIdx) + "\n" + instance2); return Double.MAX_VALUE; } } } // if i-divergence is used, need to pick up values of instance2 if (m_useIDivergence) { for (int i = 0; i < instance2.numValues(); i++) { int attrIdx = instance2.index(i); if (attrIdx != m_classIndex) { value2 = instance2.valueSparse(i); idivTerm += m_attrWeights[attrIdx] * value2; } } } distance = distance + idivTerm; return distance; } /** Returns a distance value between non-sparse instances without using the weights
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -