⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 onlinelogisticregressionclassifier.java

📁 一个自然语言处理的Java开源工具包。LingPipe目前已有很丰富的功能
💻 JAVA
字号:
/* * LingPipe v. 3.5 * Copyright (C) 2003-2008 Alias-i * * This program is licensed under the Alias-i Royalty Free License * Version 1 WITHOUT ANY WARRANTY, without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the Alias-i * Royalty Free License Version 1 for more details. * * You should have received a copy of the Alias-i Royalty Free License * Version 1 along with this program; if not, visit * http://alias-i.com/lingpipe/licenses/lingpipe-license-1.txt or contact * Alias-i, Inc. at 181 North 11th Street, Suite 401, Brooklyn, NY 11211, * +1 (718) 290-9170. */package com.aliasi.classify;import com.aliasi.classify.Classification;import com.aliasi.classify.Classifier;import com.aliasi.classify.ConditionalClassification;import com.aliasi.corpus.ClassificationHandler;import com.aliasi.corpus.Corpus;import com.aliasi.corpus.ObjectHandler;import com.aliasi.matrix.SparseFloatVector;import com.aliasi.matrix.Vector;import com.aliasi.stats.AnnealingSchedule;import com.aliasi.stats.LogisticRegression;import com.aliasi.stats.RegressionPrior;import com.aliasi.symbol.MapSymbolTable;import com.aliasi.symbol.SymbolTable;import com.aliasi.util.AbstractExternalizable;import com.aliasi.util.Compilable;import com.aliasi.util.FeatureExtractor;import com.aliasi.util.ObjectToCounterMap;import com.aliasi.util.ObjectToDoubleMap;import com.aliasi.util.Scored;import com.aliasi.util.ScoredObject;import java.io.CharArrayWriter;import java.io.IOException;import java.io.ObjectInput;import java.io.ObjectOutput;import java.io.PrintWriter;import java.io.Serializable;import java.util.ArrayList;import java.util.Arrays;import java.util.HashSet;import java.util.List;import java.util.Map;import java.util.Set;public class OnlineLogisticRegressionClassifier<E> { /*    implements Classifier<E,ConditionalClassification>,               ClassificationHandler<E,Classification>               Compilable,                Serializable {    private final OnlineLogisticRegression mRegression;    private final FeatureExtractor<? super E> mFeatureExtractor;    private final boolean mAddInterceptFeature;    private final SymbolTable mFeatureSymbolTable;    public OnlineLogisticRegressionClassifier(String[] categories,                                              RegressionPrior prior,                                              FeatureExtractor<? super E> featureExtractor,                                              boolean addInterceptFeature,                                              SymbolTable featureSymbolTable,                                              int initialDataDimensions,                                              double initialLearningRate,                                              double initialRegularizationRate) {        mRegression = new OnlineLogisticRegression(categorySymbols.length,                                                   prior,                                                   initialDataDimensions,                                                   initialLearningRate,                                                   initialRegularizationDiscount);        mFeatureExtractor = featureExtractor;        mAddInterceptFeature = addInterceptFeature;        mFeatureSymbolTable = featureSymbolTable;        Arrays.sort(categories);        mCategories = categories;        if (mAddInterceptFeature)            mFeatureSymbolTable.getOrAddSymbol(LogisticRegressionClassifier                                               .INTERCEPT_FEATURE_NAME);    }    public String[] categories() {        return copy(mCategories);    }    public void setLearningRate(double rate) {        mRegression.setLearningRate(rate);    }    public void setRegularizationDiscount(double regularizationDiscount) {        mRegression.setRegularizationDiscount(regularizationDiscount);    }    public void handle(E input, Classification classification) {        String category = classification.bestCategory();        int categoryIndex = Arrays.<String>binarySearch(categories);        if (categoryIndex < 0) {            String msg = "Unknown classification category."                + " Found category=" + category                + " Known categories=" + Arrays.asList(mCategories);            throw new IllegalArgumentException(msg);        }        Map<String,? extends Number> featureMap = mFeatureExtractor.features(input);        for (String featureName : featureMap.keySet())            mSymbolTable.getOrAddSymbol(featureName);        SparseFloatVector featureVector            = PerceptronClassifier            .toVector(featureMap,                      mFeatureSymbolTable,                      mFeatureSymbolTable.numSymbols(),                      mAddInterceptFeature);        mModel.train(featureVector,categoryIndex);    }    public ConditionalClassification classify(E input) {        Map<String,? extends Number> featureMap = mFeatureExtractor.features(input);        SparseFloatVector featureVector            = PerceptronClassifier            .toVector(featureMap,                      mFeatureSymbolTable,                      mRegression.weightVectorDimensionality(),                      mAddInterceptFeature);        double[] conditionalProbs = mRegression.classify(vector);        ScoredObject[] sos = new ScoredObject[conditionalProbs.length];        for (int i = 0; i < conditionalProbs.length; ++i)            sos[i] = new ScoredObject<String>(mCategorySymbols[i],conditionalProbs[i]);        Arrays.sort(sos,Scored.REVERSE_SCORE_COMPARATOR);        String[] categories = new String[conditionalProbs.length];        for (int i = 0; i < conditionalProbs.length; ++i) {            categories[i] = sos[i].getObject().toString();            conditionalProbs[i] = sos[i].score();        }        return new ConditionalClassification(categories,conditionalProbs);    }    public void compileTo(ObjectOutput out) throws IOException {        out.writeObject(new Compiler<E>(this));    }    private Object writeReplace() {        return new Serializer<E>(this);    }    private static String[] copy(String[] categories) {        String[] copy = new String[categories.length];        for (int i = 0; i < copy.length; ++i)            copy[i] = categories[i];        return copy;    }     static class Serializer<F> extends AbstractExternalizable {        OnlineLogisticRegressionClassifier<F> mClassifier;        Serializer(OnlineLogisticRegressionClassifier<F> classifier) {            mClassifier = classifier;        }        Serializer() {            this(null);        }        public void writeExternal(ObjectOutput out) throws IOException {            out.writeObject(mClassifier.mRegression);            out.writeObject(mClassifier.mPrior);            out.writeObject(mClassifier.mFeatureSymbolTable);            out.writeObject(mClassifier.mSymbolTable);            String[] categories = mClassifier.mCategorySymbols;            out.writeInt(categories.length);            for (String category : categories)                out.writeUTF(category);                                }        public Object read(ObjectInput in) throws IOException, ClassNotFoundException {            OnlineLogisticRegression regression                = (OnlineLogisticRegression) in.readObject();            RegressionPrior prior                =             FeatureExtractor<F> featureExtractor                = (FeatureExtractor<F>) in.readObject();            SymbolTable featureSymbolTable                = (SymbolTable) in.readObject();            int numCategories = in.readInt();            String[] categories = new String[numCategories];            for (int i = 0; i < categories.length; ++i)                categories[i] = in.readUTF();            return new OnlineLogisticRegressionClassifier(categories,                                                          prior,        }    }    static class Compiler<F> extends AbstractExternalizable {        OnlineLogisticRegressionClassifier<F> mClassifier;        Compiler(OnlineLogisticRegressionClassifier<F> classifier) {            mClassifier = classifier;        }        Compiler() {            this(null);        }        public void writeExternal(ObjectOutput out) throws IOException {            mClassifier.mRegression.compileTo(out);            AbstractExternalizable.compileOrSerialize(mClassifier.mFeatureExtractor,out);            AbstractExternalizable.compileOrSerialize(mSymbolTable);            out.writeBoolean(mClassifier.mAddInterceptFeature);            String[] categories = mClassifier.mCategorySymbols;            out.writeInt(categories.length);            for (String category : categories)                 out.writeUTF(category);                        }        public Object read(ObjectInput in) throws IOException, ClassNotFoundException {            LogisticRegression regression = (LogisticRegression) in.readObject();            FeatureExtractor<F> featureExtractor = (FeatureExtractor<F>) in.readObject();            SymbolTable featureSymbolTable = (SymbolTable) in.readObject();            boolean addInterceptFeature = in.readBoolean();            int numCategories = in.readInt();            String[] categories = new String[numCategories];            for (int i = 0; i < categories.length; ++i)                cateogries[i] = in.readUTF();            return new LogisticRegressionClassifier(regression,                                                    featureExtractor,                                                    addInterceptFeature,                                                    featureSymbolTable,                                                    categories);        }    }    */}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -