📄 logisticregressionclassifier.java
字号:
/* * LingPipe v. 3.5 * Copyright (C) 2003-2008 Alias-i * * This program is licensed under the Alias-i Royalty Free License * Version 1 WITHOUT ANY WARRANTY, without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Alias-i * Royalty Free License Version 1 for more details. * * You should have received a copy of the Alias-i Royalty Free License * Version 1 along with this program; if not, visit * http://alias-i.com/lingpipe/licenses/lingpipe-license-1.txt or contact * Alias-i, Inc. at 181 North 11th Street, Suite 401, Brooklyn, NY 11211, * +1 (718) 290-9170. */package com.aliasi.classify;import com.aliasi.corpus.ClassificationHandler;import com.aliasi.corpus.Corpus;import com.aliasi.corpus.ObjectHandler;import com.aliasi.matrix.SparseFloatVector;import com.aliasi.matrix.Vector;import com.aliasi.stats.AnnealingSchedule;import com.aliasi.stats.LogisticRegression;import com.aliasi.stats.RegressionPrior;import com.aliasi.symbol.MapSymbolTable;import com.aliasi.symbol.SymbolTable;import com.aliasi.util.AbstractExternalizable;import com.aliasi.util.Compilable;import com.aliasi.util.FeatureExtractor;import com.aliasi.util.ObjectToCounterMap;import com.aliasi.util.ObjectToDoubleMap;import com.aliasi.util.Scored;import com.aliasi.util.ScoredObject;import java.io.CharArrayWriter;import java.io.IOException;import java.io.ObjectInput;import java.io.ObjectOutput;import java.io.PrintWriter;import java.io.Serializable;import java.util.ArrayList;import java.util.Arrays;import java.util.HashSet;import java.util.List;import java.util.Map;import java.util.Set;/** * A <code>LogisticRegressionClassifier</code> provides conditional * probability classifications of input objects using an underlying * logistic regression model and feature extractor. Logistic regression * is a discrimitive classifier which operates over arbitrary * floating-point-valued features of objects * * <h4>Training</h4> * * <p>Logistic regression classifiers may be trained from a data * corpus using the method {@link * #train(FeatureExtractor,Corpus,int,boolean,RegressionPrior,AnnealingSchedule,double,int,int,PrintWriter)}, * the last six arguments of which are shared with the logistic * regression training method {@link * LogisticRegression#estimate(Vector[],int[],RegressionPrior,AnnealingSchedule,double,int,int,PrintWriter)}. * The first three arguments are required to adapt logistic regression * to general classification, and consist of a feature extractor, a * corpus to train over, and a boolean flag indicating whether or not * to add an intercept feature to every input vector. * * <p>This class merely acts as an adapter to implement the {@link * Classifier} interface based on the {@link LogisticRegression} class * in the statistics package. The basis of the adaptation is a * general feature extractor, which is an instance of {@link * FeatureExtractor}. A feature extractor converts an arbitrary input * object (whose type is specified generically in this class) to a * mapping from features (represented as strings) to values * (represented as instances of {@link Number}). The class then uses * a symbol table for features to convert the maps from feature names * to numbers into sparse vectors, where the dimensions are the * identifiers for the features in the symbol table. By convention, * if the intercept feature flag is set, it will set dimension 0 of * all inputs to 1.0. * * <p>For more information on the logistic regression model itself and * the training procedure used, see the class documentation for {@link * LogisticRegression}. * * <h4>Serialization and Compilation</h4> * * <p>This class implements both {@link Serializable} and {@link * Compilable}, but both do the same thing and simply write the * content of the model to the object output. The model read back in * will be an instance of {@link LogisticRegressionClassifier} with * the same components as the model that was serialized or compiled. * * @author Bob Carpenter * @version 3.5 * @since LingPipe3.5 */public class LogisticRegressionClassifier<E> implements Classifier<E,ConditionalClassification>, Compilable, Serializable { private final LogisticRegression mModel; private final FeatureExtractor<? super E> mFeatureExtractor; private final boolean mAddInterceptFeature; private final SymbolTable mFeatureSymbolTable; private final String[] mCategorySymbols; /** * Construct a logistic regression classifier using the specified * model, feature extractor, intercept flag, symbol table for features * and categories. * * @param model Logistic regression model. * @param featureExtractor Feature extractor to convert input * objects to feature maps. * @param addInterceptFeature Flag set to <code>true</code> if the intercept * feature at dimension 0 should always be set to 1.0. * @param featureSymbolTable Symbol table for converting features to vector * dimensions. * @param categorySymbols List of outputs aligned with the model's categories. * @throws IllegalArgumentException If the number of outcomes in the model is * not the same as the length of the category symbols array, or if the * category symbols are not all unique. */ LogisticRegressionClassifier(LogisticRegression model, FeatureExtractor<? super E> featureExtractor, boolean addInterceptFeature, SymbolTable featureSymbolTable, String[] categorySymbols) { if (model.numOutcomes() != categorySymbols.length) { String msg = "Number of model outcomes must match category symbols length." + " Found model.numOutcomes()=" + model.numOutcomes() + " categorySymbols.length=" + categorySymbols.length; throw new IllegalArgumentException(msg); } Set<String> categorySymbolSet = new HashSet<String>(); for (int i = 0; i < categorySymbols.length; ++i) { if (!categorySymbolSet.add(categorySymbols[i])) { String msg = "Categories must be unique." + " Found duplicate category categorySymbols[" + i + "]=" + categorySymbols[i]; throw new IllegalArgumentException(msg); } } mModel = model; mFeatureExtractor = featureExtractor; mAddInterceptFeature = addInterceptFeature; mFeatureSymbolTable = featureSymbolTable; mCategorySymbols = categorySymbols; } /** * Returns an unmodifiable view of the symbol table used for * features in this classifier. * * @return The feature symbol table for this classifier. */ public SymbolTable featureSymbolTable() { return MapSymbolTable.unmodifiableView(mFeatureSymbolTable); } /** * Returns the category symbols used by this classifier. Classifications * that this class returns will use only these symbols. * * @return The category symbols for this classifier. */ public List<String> categorySymbols() { return Arrays.<String>asList(mCategorySymbols); } /** * Return the conditional classification of the specified object * using logistic regression classification. All categories will * have conditional probabilities in results. * * @param in Input object to classify. * @return The conditional classification of the object. */ public ConditionalClassification classify(E in) { Map<String,? extends Number> featureMap = mFeatureExtractor.features(in); SparseFloatVector vector = PerceptronClassifier .toVector(featureMap, mFeatureSymbolTable, mFeatureSymbolTable.numSymbols(), mAddInterceptFeature); double[] conditionalProbs = mModel.classify(vector); ScoredObject[] sos = new ScoredObject[conditionalProbs.length]; for (int i = 0; i < conditionalProbs.length; ++i) sos[i] = new ScoredObject<String>(mCategorySymbols[i],conditionalProbs[i]); Arrays.sort(sos,Scored.REVERSE_SCORE_COMPARATOR); String[] categories = new String[conditionalProbs.length]; for (int i = 0; i < conditionalProbs.length; ++i) { categories[i] = sos[i].getObject().toString(); conditionalProbs[i] = sos[i].score(); } return new ConditionalClassification(categories,conditionalProbs); } /** * Compile this classifier to the specified object output. This * method is only for storage convenience; the classifier read * back in from the serialized object will be equivalent to this * one (but not in the <code>Object.equals()</code> sense). * * <p>Serializing this class produces exactly the same output. * * @param objOut Object output to which this classifier is * written. * @throws IOException If there is an underlying I/O error * writing the model to the stream. */ public void compileTo(ObjectOutput objOut) throws IOException { objOut.writeObject(new Externalizer<E>(this)); } private int categoryToId(String category) { for (int i = 0; i < mCategorySymbols.length; ++i) if (mCategorySymbols[i].equals(category)) return i; return -1; } /** * Returns a mapping from features to their parameter values for * the specified category. If the category is the last category, * which implicitly has zero values for all parameters, the map returned * by this method will also have zero values for all features. * * @param category Classification category. * @return The map from features to their parameter values for the * specified category.
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -