📄 onlinelogisticregressionclassifier.java
字号:
/* * LingPipe v. 3.5 * Copyright (C) 2003-2008 Alias-i * * This program is licensed under the Alias-i Royalty Free License * Version 1 WITHOUT ANY WARRANTY, without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Alias-i * Royalty Free License Version 1 for more details. * * You should have received a copy of the Alias-i Royalty Free License * Version 1 along with this program; if not, visit * http://alias-i.com/lingpipe/licenses/lingpipe-license-1.txt or contact * Alias-i, Inc. at 181 North 11th Street, Suite 401, Brooklyn, NY 11211, * +1 (718) 290-9170. */package com.aliasi.classify;import com.aliasi.classify.Classification;import com.aliasi.classify.Classifier;import com.aliasi.classify.ConditionalClassification;import com.aliasi.corpus.ClassificationHandler;import com.aliasi.corpus.Corpus;import com.aliasi.corpus.ObjectHandler;import com.aliasi.matrix.SparseFloatVector;import com.aliasi.matrix.Vector;import com.aliasi.stats.AnnealingSchedule;import com.aliasi.stats.LogisticRegression;import com.aliasi.stats.RegressionPrior;import com.aliasi.symbol.MapSymbolTable;import com.aliasi.symbol.SymbolTable;import com.aliasi.util.AbstractExternalizable;import com.aliasi.util.Compilable;import com.aliasi.util.FeatureExtractor;import com.aliasi.util.ObjectToCounterMap;import com.aliasi.util.ObjectToDoubleMap;import com.aliasi.util.Scored;import com.aliasi.util.ScoredObject;import java.io.CharArrayWriter;import java.io.IOException;import java.io.ObjectInput;import java.io.ObjectOutput;import java.io.PrintWriter;import java.io.Serializable;import java.util.ArrayList;import java.util.Arrays;import java.util.HashSet;import java.util.List;import java.util.Map;import java.util.Set;public class OnlineLogisticRegressionClassifier<E> { /* implements Classifier<E,ConditionalClassification>, ClassificationHandler<E,Classification> Compilable, Serializable { private final OnlineLogisticRegression mRegression; private final FeatureExtractor<? super E> mFeatureExtractor; private final boolean mAddInterceptFeature; private final SymbolTable mFeatureSymbolTable; public OnlineLogisticRegressionClassifier(String[] categories, RegressionPrior prior, FeatureExtractor<? super E> featureExtractor, boolean addInterceptFeature, SymbolTable featureSymbolTable, int initialDataDimensions, double initialLearningRate, double initialRegularizationRate) { mRegression = new OnlineLogisticRegression(categorySymbols.length, prior, initialDataDimensions, initialLearningRate, initialRegularizationDiscount); mFeatureExtractor = featureExtractor; mAddInterceptFeature = addInterceptFeature; mFeatureSymbolTable = featureSymbolTable; Arrays.sort(categories); mCategories = categories; if (mAddInterceptFeature) mFeatureSymbolTable.getOrAddSymbol(LogisticRegressionClassifier .INTERCEPT_FEATURE_NAME); } public String[] categories() { return copy(mCategories); } public void setLearningRate(double rate) { mRegression.setLearningRate(rate); } public void setRegularizationDiscount(double regularizationDiscount) { mRegression.setRegularizationDiscount(regularizationDiscount); } public void handle(E input, Classification classification) { String category = classification.bestCategory(); int categoryIndex = Arrays.<String>binarySearch(categories); if (categoryIndex < 0) { String msg = "Unknown classification category." + " Found category=" + category + " Known categories=" + Arrays.asList(mCategories); throw new IllegalArgumentException(msg); } Map<String,? extends Number> featureMap = mFeatureExtractor.features(input); for (String featureName : featureMap.keySet()) mSymbolTable.getOrAddSymbol(featureName); SparseFloatVector featureVector = PerceptronClassifier .toVector(featureMap, mFeatureSymbolTable, mFeatureSymbolTable.numSymbols(), mAddInterceptFeature); mModel.train(featureVector,categoryIndex); } public ConditionalClassification classify(E input) { Map<String,? extends Number> featureMap = mFeatureExtractor.features(input); SparseFloatVector featureVector = PerceptronClassifier .toVector(featureMap, mFeatureSymbolTable, mRegression.weightVectorDimensionality(), mAddInterceptFeature); double[] conditionalProbs = mRegression.classify(vector); ScoredObject[] sos = new ScoredObject[conditionalProbs.length]; for (int i = 0; i < conditionalProbs.length; ++i) sos[i] = new ScoredObject<String>(mCategorySymbols[i],conditionalProbs[i]); Arrays.sort(sos,Scored.REVERSE_SCORE_COMPARATOR); String[] categories = new String[conditionalProbs.length]; for (int i = 0; i < conditionalProbs.length; ++i) { categories[i] = sos[i].getObject().toString(); conditionalProbs[i] = sos[i].score(); } return new ConditionalClassification(categories,conditionalProbs); } public void compileTo(ObjectOutput out) throws IOException { out.writeObject(new Compiler<E>(this)); } private Object writeReplace() { return new Serializer<E>(this); } private static String[] copy(String[] categories) { String[] copy = new String[categories.length]; for (int i = 0; i < copy.length; ++i) copy[i] = categories[i]; return copy; } static class Serializer<F> extends AbstractExternalizable { OnlineLogisticRegressionClassifier<F> mClassifier; Serializer(OnlineLogisticRegressionClassifier<F> classifier) { mClassifier = classifier; } Serializer() { this(null); } public void writeExternal(ObjectOutput out) throws IOException { out.writeObject(mClassifier.mRegression); out.writeObject(mClassifier.mPrior); out.writeObject(mClassifier.mFeatureSymbolTable); out.writeObject(mClassifier.mSymbolTable); String[] categories = mClassifier.mCategorySymbols; out.writeInt(categories.length); for (String category : categories) out.writeUTF(category); } public Object read(ObjectInput in) throws IOException, ClassNotFoundException { OnlineLogisticRegression regression = (OnlineLogisticRegression) in.readObject(); RegressionPrior prior = FeatureExtractor<F> featureExtractor = (FeatureExtractor<F>) in.readObject(); SymbolTable featureSymbolTable = (SymbolTable) in.readObject(); int numCategories = in.readInt(); String[] categories = new String[numCategories]; for (int i = 0; i < categories.length; ++i) categories[i] = in.readUTF(); return new OnlineLogisticRegressionClassifier(categories, prior, } } static class Compiler<F> extends AbstractExternalizable { OnlineLogisticRegressionClassifier<F> mClassifier; Compiler(OnlineLogisticRegressionClassifier<F> classifier) { mClassifier = classifier; } Compiler() { this(null); } public void writeExternal(ObjectOutput out) throws IOException { mClassifier.mRegression.compileTo(out); AbstractExternalizable.compileOrSerialize(mClassifier.mFeatureExtractor,out); AbstractExternalizable.compileOrSerialize(mSymbolTable); out.writeBoolean(mClassifier.mAddInterceptFeature); String[] categories = mClassifier.mCategorySymbols; out.writeInt(categories.length); for (String category : categories) out.writeUTF(category); } public Object read(ObjectInput in) throws IOException, ClassNotFoundException { LogisticRegression regression = (LogisticRegression) in.readObject(); FeatureExtractor<F> featureExtractor = (FeatureExtractor<F>) in.readObject(); SymbolTable featureSymbolTable = (SymbolTable) in.readObject(); boolean addInterceptFeature = in.readBoolean(); int numCategories = in.readInt(); String[] categories = new String[numCategories]; for (int i = 0; i < categories.length; ++i) cateogries[i] = in.readUTF(); return new LogisticRegressionClassifier(regression, featureExtractor, addInterceptFeature, featureSymbolTable, categories); } } */}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -