📄 balancedwinnowtrainer.java
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.classify;import edu.umass.cs.mallet.base.classify.Classifier;import edu.umass.cs.mallet.base.types.Instance;import edu.umass.cs.mallet.base.types.InstanceList;import edu.umass.cs.mallet.base.types.Alphabet;import edu.umass.cs.mallet.base.types.FeatureVector;import edu.umass.cs.mallet.base.types.Labeling;import edu.umass.cs.mallet.base.types.LabelVector;import edu.umass.cs.mallet.base.types.FeatureSelection;import edu.umass.cs.mallet.base.types.MatrixOps;import edu.umass.cs.mallet.base.classify.BalancedWinnow;import edu.umass.cs.mallet.base.pipe.Pipe;import java.util.Arrays;/** * An implementation of the training methods of a BalancedWinnow2 * on-line classifier. Given a labeled instance (x, y) the algorithm * computes dot(x, wi), for w1, ... , wc where wi is the weight * vector for class i. The instance is classified as class j * if the value of dot(x, wj) is the largest among the c dot * products. * * <p>The weight vectors are updated whenever the the classifier * makes a mistake or just barely got the correct answer (highest * dot product is within delta percent higher than the second highest). * Suppose the classifier guessed j and answer was j'. For each * feature i that is present, multiply w_ji by (1-epsilon) and * multiply w_j'i by (1+epsilon) * * <p>The above procedure is done multiple times to the training * examples (default is 5), and epsilon is cut by half in each * iteration. * * <p>Limitations: BalancedWinnow2 considers only binary feature * vectors (i.e. whether or not the feature is present, * not its value). * * @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a> */public class BalancedWinnowTrainer extends ClassifierTrainer implements Boostable{ /** * 0.5 */ public static final double DEFAULT_EPSILON = .5; /** * 0.1 */ public static final double DEFAULT_DELTA = .1; /** * 5 */ public static final int DEFAULT_NUM_ITERATIONS = 5; double m_epsilon; double m_delta; int m_numIterations; /** * Array of weights, one for each class and feature, initialized to 1. * For each class, there is an additional default "feature" weight * that is set to 1 in every example (it remains constant; this is * used to prevent dot products from becoming too small). */ double [][] m_weights; /** * Default constructor. Sets all features to defaults. */ public BalancedWinnowTrainer() { this(DEFAULT_EPSILON, DEFAULT_DELTA, DEFAULT_NUM_ITERATIONS); } /** * @param epsilon percentage by which to increase/decrease weight vectors * when an example is misclassified. * @param delta percentage by which the highest (and correct) dot product * should exceed the second highest dot product before we consider an example * to be correctly classified (margin width) when adjusting weights. * @param numIterations number of times to loop through training examples. */ public BalancedWinnowTrainer(double epsilon, double delta, int numIterations) { m_epsilon = epsilon; m_delta = delta; m_numIterations = numIterations; } /** * Trains the classifier on the instance list, updating * class weight vectors as appropriate * @param ilist Instance list to be trained on * @return Classifier object containing learned weights */ public Classifier train (InstanceList trainingList, InstanceList validationList, InstanceList testSet, ClassifierEvaluating evaluator, Classifier initialClassifier) { FeatureSelection selectedFeatures = trainingList.getFeatureSelection(); if (selectedFeatures != null) // xxx Attend to FeatureSelection!!! throw new UnsupportedOperationException ("FeatureSelection not yet implemented."); double epsilon = m_epsilon; Alphabet dict = (Alphabet) trainingList.getDataAlphabet (); int numLabels = trainingList.getTargetAlphabet().size(); int numFeats = dict.size(); m_weights = new double [numLabels][numFeats]; // init weights to 1 for(int i = 0; i < numLabels; i++) Arrays.fill(m_weights[i], 1.0); // Loop through training instances multiple times double[] results = new double[numLabels]; for (int iter = 0; iter < m_numIterations; iter++) { // loop through all instances for (int ii = 0; ii < trainingList.size(); ii++) { Instance inst = trainingList.getInstance(ii); Labeling labeling = inst.getLabeling (); FeatureVector fv = (FeatureVector) inst.getData(); int fvisize = fv.numLocations(); int correctIndex = labeling.getBestIndex(); Arrays.fill(results, 0.0); // compute dot(x, wi) for each class i for(int fvi = 0; fvi < fvisize; fvi++) { int fi = fv.indexAtLocation(fvi); for(int lpos = 0; lpos < numLabels; lpos++) { // The 1 is comes from the constant "default" // feature shared by all examples results[lpos] += m_weights[lpos][fi] + 1; } } // Get indices of the classes with the 2 highest dot products int predictedIndex = 0; int secondHighestIndex = 0; double max = Double.MIN_VALUE; double secondMax = Double.MIN_VALUE; for (int i = 0; i < numLabels; i++) { if (results[i] > max) { secondMax = max; max = results[i]; secondHighestIndex = predictedIndex; predictedIndex = i; } } // Adjust weights if this example is mispredicted // or just barely correct if (predictedIndex != correctIndex) { for (int fvi = 0; fvi < fvisize; fvi++) { int fi = fv.indexAtLocation(fvi); m_weights[predictedIndex][fi] *= (1 - epsilon); m_weights[correctIndex][fi] *= (1 + epsilon); } } else if (max/secondMax - 1 < m_delta) { for (int fvi = 0; fvi < fvisize; fvi++) { int fi = fv.indexAtLocation(fvi); m_weights[secondHighestIndex][fi] *= (1 - epsilon); m_weights[correctIndex][fi] *= (1 + epsilon); } } } // Cut epsilon by half epsilon *= 0.5; } return new BalancedWinnow (trainingList.getPipe(), m_weights); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -