⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 gdmetriclearner.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    MetricLearner.java *    Copyright (C) 2004 Mikhail Bilenko and Sugato Basu * */package weka.clusterers.metriclearners; import java.util.*;import weka.core.*;import weka.core.metrics.LearnableMetric;import weka.clusterers.MPCKMeans;/**  * A parent class for MPCKMeans metric learners * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) and Sugato Basu * (sugato@cs.utexas.edu * @version $Revision: 1.2 $ */public abstract class GDMetricLearner extends  MPCKMeansMetricLearner {  /** Initial value of gradient descent step parameter */  protected double m_eta = 0.001;  public void setEta(double eta) {    m_eta = eta;  }  public void resetLearner() {    m_currEta = m_eta;  }   public double getEta() {    return m_eta;  }  /** Current value of the step parameter */  protected double m_currEta = 0;   /** Decay rate of gradient descent eta */  protected double m_etaDecayRate = 0.9;    public void setEtaDecayRate(double etaDecayRate) {    m_etaDecayRate = etaDecayRate;  }  public double getEtaDecayRate() {    return m_etaDecayRate;  }  /** The maximum number of GD iterations */  protected int m_maxGDIterations = 20;  public void setMaxGDIterations(int maxGDIterations) {    m_maxGDIterations = maxGDIterations;  }  public int getMaxGDIterations() {    return m_maxGDIterations;  }  protected double[] InitRegularizerComponents(double []currentWeights) {    double [] regularizerComponents = new double[m_numAttributes];    for (int attr = 0; attr < m_numAttributes; attr++) {      if (currentWeights[attr] > 0) {	regularizerComponents[attr] = m_regularizerTermWeight	  * m_metric.getRegularizer().gradient(currentWeights[attr]);      } else {	regularizerComponents[attr] = 0;      }    }    return regularizerComponents;  }  /**   * Perform gradient step update using the current weights,   * the gradients, the regularizers and the current learning rate   * Returns the updated weights.   **/   protected double[] GDUpdate(double [] currentWeights,			      double [] gradients,			      double [] regularizerComponents) {    double [] newWeights = new double[m_numAttributes];     for (int attr = 0; attr < m_numAttributes; attr++) {      newWeights[attr] = currentWeights[attr] - m_currEta*(gradients[attr] - regularizerComponents[attr]);       if (newWeights[attr] <= 0) {	System.out.println("Prevented 0/- weight " + ((float)newWeights[attr]) 			   + " for attribute " + m_instances.attribute(attr).name()			   + ";\tprev=" + ((float)currentWeights[attr])			   + ";\tgrad=" + ((float)gradients[attr])			   + ";\treg=" + ((float)regularizerComponents[attr])); 	newWeights[attr] = m_minWeightValue;      }    }    System.out.print("eta=" + (float)m_currEta);     m_currEta = m_currEta * m_etaDecayRate;    System.out.print(" -> " + (float)m_currEta);     // PRINT top weights    TreeMap map = new TreeMap(Collections.reverseOrder());    for (int j = 0; j < newWeights.length; j++) {      map.put(new Double(newWeights[j]), new Integer(j));    }    Iterator it = map.entrySet().iterator();    for (int j=0; j < 5 && it.hasNext(); j++) {      Map.Entry entry = (Map.Entry) it.next();      int idx = ((Integer)entry.getValue()).intValue();      System.out.println("\t" + m_instances.attribute(idx).name() 			 + "\t" + (float)currentWeights[idx] + "->" + (float)newWeights[idx]			 + "\tgradient=" + (float)gradients[idx]			 + "\tregularizer=" + (float)regularizerComponents[idx]);    }    // end PRINT top weights        return newWeights;   }       }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -