⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 balancedwinnowtrainer.java

📁 mallet是自然语言处理、机器学习领域的一个开源项目。
💻 JAVA
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept.   This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).   http://www.cs.umass.edu/~mccallum/mallet   This software is provided under the terms of the Common Public License,   version 1.0, as published by http://www.opensource.org.  For further   information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.classify;import java.io.Serializable;import java.util.Arrays;import edu.umass.cs.mallet.base.types.Alphabet;import edu.umass.cs.mallet.base.types.FeatureSelection;import edu.umass.cs.mallet.base.types.FeatureVector;import edu.umass.cs.mallet.base.types.Instance;import edu.umass.cs.mallet.base.types.InstanceList;import edu.umass.cs.mallet.base.types.Labeling;/** * An implementation of the training methods of a BalancedWinnow * on-line classifier. Given a labeled instance (x, y) the algorithm  * computes dot(x, wi), for w1, ... , wc where wi is the weight  * vector for class i.  The instance is classified as class j * if the value of dot(x, wj) is the largest among the c dot  * products. * * <p>The weight vectors are updated whenever the the classifier  * makes a mistake or just barely got the correct answer (highest * dot product is within delta percent higher than the second highest). * Suppose the classifier guessed j and answer was j'. For each  * feature i that is present, multiply w_ji by (1-epsilon) and  * multiply w_j'i by (1+epsilon) * * <p>The above procedure is done multiple times to the training * examples (default is 5), and epsilon is cut by the cooling * rate at each iteration (default is cutting epsilon by half). * * @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a> */public class BalancedWinnowTrainer extends ClassifierTrainer implements Boostable, Serializable{	private static final long serialVersionUID = 1L;		/**     * 0.5     */    public static final double DEFAULT_EPSILON = .5;    /**     * 0.1     */    public static final double DEFAULT_DELTA = .1;    /**     * 30     */    public static final int DEFAULT_MAX_ITERATIONS = 30;    /**     * 0.5     */    public static final double DEFAULT_COOLING_RATE = .5;    double m_epsilon;    double m_delta;    int m_maxIterations;    double m_coolingRate;    /**     * Array of weights, one for each class and feature, initialized to 1.     * For each class, there is an additional default "feature" weight     * that is set to 1 in every example (it remains constant; this is     * used to prevent the instance from having 0 dot product with a class).     */    double[][] m_weights;    /**     * Default constructor. Sets all features to defaults.     */    public BalancedWinnowTrainer()    {        this(DEFAULT_EPSILON, 			 DEFAULT_DELTA, 			 DEFAULT_MAX_ITERATIONS,			 DEFAULT_COOLING_RATE);    }		    /**     * @param epsilon percentage by which to increase/decrease weight vectors     * when an example is misclassified.     * @param delta percentage by which the highest (and correct) dot product      * should exceed the second highest dot product before we consider an example     * to be correctly classified (margin width) when adjusting weights.     * @param maxIterations maximum number of times to loop through training examples.     * @param coolingRate percentage of epsilon to decrease after each iteration     */    public BalancedWinnowTrainer(double epsilon, 								 double delta, 								 int maxIterations,								 double coolingRate)    {        m_epsilon = epsilon;        m_delta = delta;        m_maxIterations = maxIterations;        m_coolingRate = coolingRate;    }	    /**     * Trains the classifier on the instance list, updating      * class weight vectors as appropriate     * @param trainingList Instance list to be trained on     * @return Classifier object containing learned weights     */    public Classifier train (InstanceList trainingList,		         InstanceList validationList,		         InstanceList testSet,		         ClassifierEvaluating evaluator,		         Classifier initialClassifier)    {        FeatureSelection selectedFeatures = trainingList.getFeatureSelection();        if (selectedFeatures != null)			// xxx Attend to FeatureSelection!!!			throw new UnsupportedOperationException ("FeatureSelection not yet implemented.");        double epsilon = m_epsilon;        Alphabet dict = (Alphabet) trainingList.getDataAlphabet ();        int numLabels = trainingList.getTargetAlphabet().size();        int numFeats = dict.size();        m_weights = new double [numLabels][numFeats+1];        // init weights to 1        for(int i = 0; i < numLabels; i++)			Arrays.fill(m_weights[i], 1.0);        // Loop through training instances multiple times        double[] results = new double[numLabels];        for (int iter = 0; iter < m_maxIterations; iter++) {			// loop through all instances			for (int ii = 0; ii < trainingList.size(); ii++) {				Instance inst = trainingList.getInstance(ii);				Labeling labeling = inst.getLabeling ();				FeatureVector fv = (FeatureVector) inst.getData();				int fvisize = fv.numLocations();				int correctIndex = labeling.getBestIndex();				Arrays.fill(results, 0);				// compute dot(x, wi) for each class i				for(int lpos = 0; lpos < numLabels; lpos++) {					for(int fvi = 0; fvi < fvisize; fvi++) {						int fi = fv.indexAtLocation(fvi);						double vi = fv.valueAtLocation(fvi);						results[lpos] += vi * m_weights[lpos][fi];					}					// This extra value comes from the extra					// "feature" present in all examples					results[lpos] += m_weights[lpos][numFeats];				}				// Get indices of the classes with the 2 highest dot products				int predictedIndex = 0;				int secondHighestIndex = 0;				double max = Double.MIN_VALUE;				double secondMax = Double.MIN_VALUE;				for (int i = 0; i < numLabels; i++) {					if (results[i] > max) {						secondMax = max;						max = results[i];						secondHighestIndex = predictedIndex;						predictedIndex = i;					}					else if (results[i] > secondMax) {						secondMax = results[i];						secondHighestIndex = i;					}				}				// Adjust weights if this example is mispredicted				// or just barely correct				if (predictedIndex != correctIndex) {					for (int fvi = 0; fvi < fvisize; fvi++) {						int fi = fv.indexAtLocation(fvi);						m_weights[predictedIndex][fi] *= (1 - epsilon);						m_weights[correctIndex][fi] *= (1 + epsilon);					}					m_weights[predictedIndex][numFeats] *= (1 - epsilon);					m_weights[correctIndex][numFeats] *= (1 + epsilon);				}				else if (max/secondMax - 1 < m_delta) {					for (int fvi = 0; fvi < fvisize; fvi++) {						int fi = fv.indexAtLocation(fvi);						m_weights[secondHighestIndex][fi] *= (1 - epsilon);						m_weights[correctIndex][fi] *= (1 + epsilon);					}					m_weights[secondHighestIndex][numFeats] *= (1 - epsilon);					m_weights[correctIndex][numFeats] *= (1 + epsilon);				}			}			// Cut epsilon by the cooling rate			epsilon *= (1-m_coolingRate);        }                return new BalancedWinnow (trainingList.getPipe(), m_weights);    }	}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -