📄 balancedwinnow.java
字号:
/* Copyright (C) 2002 Univ. of Massachusetts Amherst, Computer Science Dept. This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit). http://www.cs.umass.edu/~mccallum/mallet This software is provided under the terms of the Common Public License, version 1.0, as published by http://www.opensource.org. For further information, see the file `LICENSE' included with this distribution. */package edu.umass.cs.mallet.base.classify;import edu.umass.cs.mallet.base.classify.Classifier;import edu.umass.cs.mallet.base.types.Instance;import edu.umass.cs.mallet.base.types.Alphabet;import edu.umass.cs.mallet.base.types.FeatureVector;import edu.umass.cs.mallet.base.types.Labeling;import edu.umass.cs.mallet.base.types.LabelVector;import edu.umass.cs.mallet.base.types.MatrixOps;import edu.umass.cs.mallet.base.pipe.Pipe;/** * Classification methods of BalancedWinnow2 algorithm. * @see BalancedWinnowTrainer * @author Gary Huang <a href="mailto:ghuang@cs.umass.edu">ghuang@cs.umass.edu</a> */public class BalancedWinnow extends Classifier{ double [][] m_weights; /** * Passes along data pipe and weights from * {@link #BalancedWinnowTrainer BalancedWinnowTrainer} * @param dataPipe needed for dictionary, labels, feature vectors, etc * @param weights weights calculated during training phase */ public BalancedWinnow (Pipe dataPipe, double [][] weights) { super (dataPipe); m_weights = new double[weights.length][weights[0].length]; for (int i = 0; i < weights.length; i++) for (int j = 0; j < weights[0].length; j++) m_weights[i][j] = weights[i][j]; } /** * Classifies an instance using BalancedWinnow's weights * * <p>Returns a Classification containing the normalized * dot products between class weight vectors and the instance * feature vector. * * <p>One can obtain the confidence of the classification by * calculating weight(j')/weight(j), where j' is the * highest weight prediction and j is the 2nd-highest. * Another possibility is to calculate * <br><center>e^{dot(w_j', x} / sum_j[e^{dot(w_j, x)}]</center></br> */ public Classification classify (Instance instance) { int numClasses = getLabelAlphabet().size(); double[] scores = new double[numClasses]; FeatureVector fv = (FeatureVector) instance.getData (this.instancePipe); // Make sure the feature vector's feature dictionary matches // what we are expecting from our data pipe (and thus our notion // of feature probabilities. assert (fv.getAlphabet () == this.instancePipe.getDataAlphabet()); int fvisize = fv.numLocations(); // Take dot products double sum = 0; for (int fvi = 0; fvi < fvisize; fvi++) { int fi = fv.indexAtLocation (fvi); for (int ci = 0; ci < numClasses; ci++) { scores[ci] += m_weights[ci][fi]; sum += m_weights[ci][fi]; } } MatrixOps.timesEquals(scores, 1.0 / sum); // Create and return a Classification object return new Classification (instance, this, new LabelVector (getLabelAlphabet(), scores)); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -