⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 balancedwinnowtrainer.java

📁 常用机器学习算法,java编写源代码,内含常用分类算法,包括说明文档
💻 JAVA
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.classify;import edu.umass.cs.mallet.base.classify.Classifier;import edu.umass.cs.mallet.base.types.Instance;import edu.umass.cs.mallet.base.types.InstanceList;import edu.umass.cs.mallet.base.types.Alphabet;import edu.umass.cs.mallet.base.types.FeatureVector;import edu.umass.cs.mallet.base.types.Labeling;import edu.umass.cs.mallet.base.types.LabelVector;import edu.umass.cs.mallet.base.types.FeatureSelection;import edu.umass.cs.mallet.base.types.MatrixOps;import edu.umass.cs.mallet.base.classify.BalancedWinnow;import edu.umass.cs.mallet.base.pipe.Pipe;import java.util.Arrays;/** * An implementation of the training methods of a BalancedWinnow2  * on-line classifier. Given a labeled instance (x, y) the algorithm  * computes dot(x, wi), for w1, ... , wc where wi is the weight  * vector for class i.  The instance is classified as class j * if the value of dot(x, wj) is the largest among the c dot  * products. * * <p>The weight vectors are updated whenever the the classifier  * makes a mistake or just barely got the correct answer (highest * dot product is within delta percent higher than the second highest). * Suppose the classifier guessed j and answer was j'. For each  * feature i that is present, multiply w_ji by (1-epsilon) and  * multiply w_j'i by (1+epsilon) * * <p>The above procedure is done multiple times to the training * examples (default is 5), and epsilon is cut by half in each  * iteration. * * <p>Limitations: BalancedWinnow2 considers only binary feature * vectors (i.e. whether or not the feature is present, * not its value). * * @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a> */public class BalancedWinnowTrainer extends ClassifierTrainer implements Boostable{    /**     * 0.5     */    public static final double DEFAULT_EPSILON = .5;    /**     * 0.1     */    public static final double DEFAULT_DELTA = .1;    /**     * 5     */    public static final int DEFAULT_NUM_ITERATIONS = 5;    double m_epsilon;    double m_delta;    int m_numIterations;    /**     * Array of weights, one for each class and feature, initialized to 1.     * For each class, there is an additional default "feature" weight     * that is set to 1 in every example (it remains constant; this is     * used to prevent dot products from becoming too small).     */    double [][] m_weights;    /**     * Default constructor. Sets all features to defaults.     */    public BalancedWinnowTrainer()    {        this(DEFAULT_EPSILON, DEFAULT_DELTA, DEFAULT_NUM_ITERATIONS);    }		    /**     * @param epsilon percentage by which to increase/decrease weight vectors     * when an example is misclassified.     * @param delta percentage by which the highest (and correct) dot product      * should exceed the second highest dot product before we consider an example     * to be correctly classified (margin width) when adjusting weights.     * @param numIterations number of times to loop through training examples.     */    public BalancedWinnowTrainer(double epsilon, double delta, int numIterations)    {        m_epsilon = epsilon;        m_delta = delta;        m_numIterations = numIterations;    }	    /**     * Trains the classifier on the instance list, updating      * class weight vectors as appropriate     * @param ilist Instance list to be trained on     * @return Classifier object containing learned weights     */    public Classifier train (InstanceList trainingList,		         InstanceList validationList,		         InstanceList testSet,		         ClassifierEvaluating evaluator,		         Classifier initialClassifier)    {        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();        if (selectedFeatures != null)	  // xxx Attend to FeatureSelection!!!	  throw new UnsupportedOperationException ("FeatureSelection not yet implemented.");        double epsilon = m_epsilon;        Alphabet dict = (Alphabet) trainingList.getDataAlphabet ();        int numLabels = trainingList.getTargetAlphabet().size();        int numFeats = dict.size();        m_weights = new double [numLabels][numFeats];        // init weights to 1        for(int i = 0; i < numLabels; i++)	  Arrays.fill(m_weights[i], 1.0);        // Loop through training instances multiple times        double[] results = new double[numLabels];        for (int iter = 0; iter < m_numIterations; iter++) {	  // loop through all instances	  for (int ii = 0; ii < trainingList.size(); ii++) {	      Instance inst = trainingList.getInstance(ii);	      Labeling labeling = inst.getLabeling ();	      FeatureVector fv = (FeatureVector) inst.getData();	      int fvisize = fv.numLocations();	      int correctIndex = labeling.getBestIndex();	      Arrays.fill(results, 0.0);	      // compute dot(x, wi) for each class i	      for(int fvi = 0; fvi < fvisize; fvi++) {		int fi = fv.indexAtLocation(fvi);		for(int lpos = 0; lpos < numLabels; lpos++) {		    // The 1 is comes from the constant "default" 		    // feature shared by all examples		    results[lpos] += m_weights[lpos][fi] + 1;		}	      }	      // Get indices of the classes with the 2 highest dot products	      int predictedIndex = 0;	      int secondHighestIndex = 0;	      double max = Double.MIN_VALUE;	      double secondMax = Double.MIN_VALUE;	      for (int i = 0; i < numLabels; i++) {		if (results[i] > max) {		    secondMax = max;		    max = results[i];		    secondHighestIndex = predictedIndex;		    predictedIndex = i;		}	      }	      // Adjust weights if this example is mispredicted	      // or just barely correct	      if (predictedIndex != correctIndex) {		for (int fvi = 0; fvi < fvisize; fvi++) {		    int fi = fv.indexAtLocation(fvi);		    m_weights[predictedIndex][fi] *= (1 - epsilon);		    m_weights[correctIndex][fi] *= (1 + epsilon);		}	      }	      else if (max/secondMax - 1 < m_delta) {		for (int fvi = 0; fvi < fvisize; fvi++) {		    int fi = fv.indexAtLocation(fvi);		    m_weights[secondHighestIndex][fi] *= (1 - epsilon);		    m_weights[correctIndex][fi] *= (1 + epsilon);		}	      }	  }	  // Cut epsilon by half	  epsilon *= 0.5;        }        return new BalancedWinnow (trainingList.getPipe(), m_weights);    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -