⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gdmetriclearner.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    GDMetricLearner.java *    Copyright (C) 2002 Mikhail Bilenko * */package weka.core.metrics;import java.util.*;import java.io.Serializable;import java.io.*;import java.text.SimpleDateFormat;import java.text.DecimalFormat;import java.text.NumberFormat;import weka.classifiers.*;import weka.classifiers.functions.*;import weka.core.*;import weka.attributeSelection.*;/**  * GDMetricLearner - sets the weights of a metric * using gradient descent * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.1 $ */public class GDMetricLearner extends MetricLearner implements Serializable, OptionHandler {  /** The metric that the classifier was used to learn, useful for external-calculation based metrics */  protected LearnableMetric m_metric = null;  /** Maximum number of iterations */  protected int m_maxIterations = 20;  /** The learning rate */  protected double m_learningRate = 0.0000001;  /** The training data */  protected Instances m_instances = null;  protected ArrayList m_pairList = null;  protected int m_numPosPairs = 200;  protected int m_numNegPairs = 200;  /** The convergence criterion for total weight updates */  protected double m_epsilon = 10e-5;  /** The pairwise selector used by the metric */  protected PairwiseSelector m_selector = new RandomPairwiseSelector();    /** Create a new gradient descent metric learner    * @param classifierName the name of the classifier class to be used   */  public GDMetricLearner() {  }       /**   * Train a given metric using given training instances   *   * @param metric the metric to train   * @param instances data to train the metric on   * @exception Exception if training has gone bad.   */  public void trainMetric(LearnableMetric metric, Instances instances) throws Exception {    // If the data doesn't have a class attribute, bail    if (instances.classIndex() < 0 || instances.numInstances() < 2) {      metric.m_trained = false;      System.out.println("Problem with training data");      return;    }    if (metric.getExternal()) {      throw new Exception("GDMetricLearner cannot be used as an external distance metric!");    }        System.out.println(getTimestamp() + "  Starting to calculate weights over " + metric.getNumAttributes() +" attributes");    m_metric = metric;    m_instances = instances;    m_pairList = m_selector.createPairList(m_instances, m_numPosPairs, m_numNegPairs, metric);    int numWeights = metric.getNumAttributes();    double[] currentWeights = new double[numWeights];    Arrays.fill(currentWeights,1.0/numWeights);    metric.setWeights(currentWeights);    int iterCount = 0;    boolean converged = false;    while (iterCount < m_maxIterations && !converged) {      // calculate the gradient vector      double [] gradients = calculateGradients(currentWeights);      double updateTotal = 0;            // update the weights      for (int i = 0; i < numWeights; i++) {	//	System.out.println("update: " +  gradients[i]);	double update = m_learningRate * gradients[i];	updateTotal += Math.abs(update);	currentWeights[i] += update;      }      currentWeights = normalizeWeights(currentWeights);      metric.setWeights(currentWeights);      // check convergence      if (updateTotal <= m_epsilon) {	converged = true;      }      iterCount++;    }    printTopAttributes(currentWeights, 10, iterCount);    System.out.println(getTimestamp() + "  Gradient descent complete after " + iterCount + " iterations");    metric.m_trained = true;  }  /** A helper function that calculates the current gradient value   * @param weights the current weights vector   * @return the values of the partial derivatives   */  protected double[] calculateGradients(double[] weights) throws Exception {    double [] gradients = new double[weights.length];    // calculate the gradients    for (int i = 0; i < m_pairList.size(); i++) {      TrainingPair pair = (TrainingPair) m_pairList.get(i);      double[] pairGradients = m_metric.getGradients(pair.instance1, pair.instance2);      //      System.out.println(pair.instance1 + "\t" + pair.instance2 + "\t" + pair.positive);      for (int j = 0; j < gradients.length; j++) {	//System.out.print(gradients[j] + "(" + pairGradients[j] + ")\t");	gradients[j] = gradients[j] + (pair.positive ? pairGradients[j] : -pairGradients[j]);      }      //      System.out.println();    }    return gradients;  }    /** Normalize weights   * @param weights an unnormalized array of weights   * @return a normalized array of weights   */  protected double[] normalizeWeights(double[] weights) {    double sum = 0;    for (int i = 0; i < weights.length; i++) {      if (weights[i] < 0) {	weights[i] = 0;      } else {	sum += weights[i];      }    }        double [] newWeights = new double[weights.length];    for (int i = 0; i < weights.length; i++) {      newWeights[i] = weights[i] / sum;    }    return newWeights;  }  /** Get the norm-2 length of an instance assuming all attributes are numeric   * and utilizing the attribute weights   * @returns norm-2 length of an instance   */  public double lengthWeighted(Instance instance, double[] weights) {    int classIndex = instance.classIndex();    double length = 0;	    if (instance instanceof SparseInstance) {      // remap classIndex to an internal index      if (classIndex >= 0) {	classIndex = ((SparseInstance)instance).locateIndex(classIndex);      }      for (int i = 0; i < instance.numValues(); i++) {	if (i != classIndex) {	  double value = instance.valueSparse(i);  	  length += weights[i] * value * value;	}      }    } else {  // non-sparse instance      double[] values = instance.toDoubleArray();       for (int i = 0; i < values.length; i++) {	if (i != classIndex) {  	  length += weights[i] * values[i] * values[i];	}      }    }    return Math.sqrt(length);  }  /**   * Use the Classifier for an estimation of similarity   * @param instance1 first instance of a pair   * @param instance2 second instance of a pair   * @returns sim an approximate similarity obtained from the classifier   */  public double getSimilarity(Instance instance1, Instance instance2) throws Exception{    throw new Exception("GDMetricLearner cannot be used as an external distance metric!");  }  /**   * Use the Classifier for an estimation of distance   * @param instance1 first instance of a pair   * @param instance2 second instance of a pair   * @returns  an approximate distance obtained from the classifier   */  public double getDistance(Instance instance1, Instance instance2) throws Exception{    throw new Exception("GDMetricLearner cannot be used as an external distance metric!");  }  /** Set the convergence criterion   * @param epsilon the maximum sum of weight updates required for GD to converge   */  public void setEpsilon(double epsilon) {    m_epsilon = epsilon;  }  /** Get the convergence criterion   * @return the maximum sum of weight updates required for GD to converge   */  public double getEpsilon() {    return m_epsilon;  }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -