📄 gdmetriclearner.java
字号:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* * GDMetricLearner.java * Copyright (C) 2002 Mikhail Bilenko * */package weka.core.metrics;import java.util.*;import java.io.Serializable;import java.io.*;import java.text.SimpleDateFormat;import java.text.DecimalFormat;import java.text.NumberFormat;import weka.classifiers.*;import weka.classifiers.functions.*;import weka.core.*;import weka.attributeSelection.*;/** * GDMetricLearner - sets the weights of a metric * using gradient descent * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.1 $ */public class GDMetricLearner extends MetricLearner implements Serializable, OptionHandler { /** The metric that the classifier was used to learn, useful for external-calculation based metrics */ protected LearnableMetric m_metric = null; /** Maximum number of iterations */ protected int m_maxIterations = 20; /** The learning rate */ protected double m_learningRate = 0.0000001; /** The training data */ protected Instances m_instances = null; protected ArrayList m_pairList = null; protected int m_numPosPairs = 200; protected int m_numNegPairs = 200; /** The convergence criterion for total weight updates */ protected double m_epsilon = 10e-5; /** The pairwise selector used by the metric */ protected PairwiseSelector m_selector = new RandomPairwiseSelector(); /** Create a new gradient descent metric learner * @param classifierName the name of the classifier class to be used */ public GDMetricLearner() { } /** * Train a given metric using given training instances * * @param metric the metric to train * @param instances data to train the metric on * @exception Exception if training has gone bad. */ public void trainMetric(LearnableMetric metric, Instances instances) throws Exception { // If the data doesn't have a class attribute, bail if (instances.classIndex() < 0 || instances.numInstances() < 2) { metric.m_trained = false; System.out.println("Problem with training data"); return; } if (metric.getExternal()) { throw new Exception("GDMetricLearner cannot be used as an external distance metric!"); } System.out.println(getTimestamp() + " Starting to calculate weights over " + metric.getNumAttributes() +" attributes"); m_metric = metric; m_instances = instances; m_pairList = m_selector.createPairList(m_instances, m_numPosPairs, m_numNegPairs, metric); int numWeights = metric.getNumAttributes(); double[] currentWeights = new double[numWeights]; Arrays.fill(currentWeights,1.0/numWeights); metric.setWeights(currentWeights); int iterCount = 0; boolean converged = false; while (iterCount < m_maxIterations && !converged) { // calculate the gradient vector double [] gradients = calculateGradients(currentWeights); double updateTotal = 0; // update the weights for (int i = 0; i < numWeights; i++) { // System.out.println("update: " + gradients[i]); double update = m_learningRate * gradients[i]; updateTotal += Math.abs(update); currentWeights[i] += update; } currentWeights = normalizeWeights(currentWeights); metric.setWeights(currentWeights); // check convergence if (updateTotal <= m_epsilon) { converged = true; } iterCount++; } printTopAttributes(currentWeights, 10, iterCount); System.out.println(getTimestamp() + " Gradient descent complete after " + iterCount + " iterations"); metric.m_trained = true; } /** A helper function that calculates the current gradient value * @param weights the current weights vector * @return the values of the partial derivatives */ protected double[] calculateGradients(double[] weights) throws Exception { double [] gradients = new double[weights.length]; // calculate the gradients for (int i = 0; i < m_pairList.size(); i++) { TrainingPair pair = (TrainingPair) m_pairList.get(i); double[] pairGradients = m_metric.getGradients(pair.instance1, pair.instance2); // System.out.println(pair.instance1 + "\t" + pair.instance2 + "\t" + pair.positive); for (int j = 0; j < gradients.length; j++) { //System.out.print(gradients[j] + "(" + pairGradients[j] + ")\t"); gradients[j] = gradients[j] + (pair.positive ? pairGradients[j] : -pairGradients[j]); } // System.out.println(); } return gradients; } /** Normalize weights * @param weights an unnormalized array of weights * @return a normalized array of weights */ protected double[] normalizeWeights(double[] weights) { double sum = 0; for (int i = 0; i < weights.length; i++) { if (weights[i] < 0) { weights[i] = 0; } else { sum += weights[i]; } } double [] newWeights = new double[weights.length]; for (int i = 0; i < weights.length; i++) { newWeights[i] = weights[i] / sum; } return newWeights; } /** Get the norm-2 length of an instance assuming all attributes are numeric * and utilizing the attribute weights * @returns norm-2 length of an instance */ public double lengthWeighted(Instance instance, double[] weights) { int classIndex = instance.classIndex(); double length = 0; if (instance instanceof SparseInstance) { // remap classIndex to an internal index if (classIndex >= 0) { classIndex = ((SparseInstance)instance).locateIndex(classIndex); } for (int i = 0; i < instance.numValues(); i++) { if (i != classIndex) { double value = instance.valueSparse(i); length += weights[i] * value * value; } } } else { // non-sparse instance double[] values = instance.toDoubleArray(); for (int i = 0; i < values.length; i++) { if (i != classIndex) { length += weights[i] * values[i] * values[i]; } } } return Math.sqrt(length); } /** * Use the Classifier for an estimation of similarity * @param instance1 first instance of a pair * @param instance2 second instance of a pair * @returns sim an approximate similarity obtained from the classifier */ public double getSimilarity(Instance instance1, Instance instance2) throws Exception{ throw new Exception("GDMetricLearner cannot be used as an external distance metric!"); } /** * Use the Classifier for an estimation of distance * @param instance1 first instance of a pair * @param instance2 second instance of a pair * @returns an approximate distance obtained from the classifier */ public double getDistance(Instance instance1, Instance instance2) throws Exception{ throw new Exception("GDMetricLearner cannot be used as an external distance metric!"); } /** Set the convergence criterion * @param epsilon the maximum sum of weight updates required for GD to converge */ public void setEpsilon(double epsilon) { m_epsilon = epsilon; } /** Get the convergence criterion * @return the maximum sum of weight updates required for GD to converge */ public double getEpsilon() { return m_epsilon; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -