⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 logisticregressionclassifier.java

📁 一个自然语言处理的Java开源工具包。LingPipe目前已有很丰富的功能
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
     * @throws IllegalArgumentException If the category is unknown.     */    public ObjectToDoubleMap<String> featureValues(String category) {        int categoryId = categoryToId(category);        if (categoryId < 0) {            String msg = "Unknown category=" + category;            throw new IllegalArgumentException(msg);        }         ObjectToDoubleMap<String> result = new ObjectToDoubleMap<String>();        if (categoryId == mCategorySymbols.length-1)            return result;        int numSymbols = mFeatureSymbolTable.numSymbols();        Vector[] weightVectors = mModel.weightVectors();        Vector weightVector = weightVectors[categoryId];        for (int i = 0; i < numSymbols; ++i) {            String symbol = mFeatureSymbolTable.idToSymbol(i);            result.set(symbol,weightVector.value(i));        }        return result;    }    /**     * Returns a string-based representation of this classifier,     * listing the parameter vectors for each category.     *     * @return A string-based representation of this classifier.     */    public String toString() {        CharArrayWriter writer = new CharArrayWriter();        PrintWriter printWriter = new PrintWriter(writer);        List<String> categorySymbols = categorySymbols();        printWriter.println("NUMBER OF CATEGORIES=" + categorySymbols.size());        printWriter.println("NUMBER OF FEATURES=" + mFeatureSymbolTable.numSymbols());        for (int i = 0; i < categorySymbols.size()-1; ++i) {            String category = categorySymbols.get(i);            printWriter.println("\n  CATEGORY=" + category);            ObjectToDoubleMap<String> parameterVector = featureValues(category);            for (String feature : parameterVector.keysOrderedByValueList())                printWriter.printf("%20s %15.6f\n",feature,parameterVector.get(feature));        }        printWriter.write('\n');        return writer.toString();    }    private Object writeReplace() {        return new Externalizer<E>(this);    }    static final String INTERCEPT_FEATURE_NAME = "*&^INTERCEPT%$^&**";    /**     * Returns a trained logistic regression classifier given the specified     * feature extractor, corpus, model priors and search parameters.     *     * <p>Only the training section of the specified corpus is used for     * training.     *     * <p>See the class documentation above and the class     * documentation for {@link LogisticRegression} for more     * information on the parameters.     *     * @param featureExtractor Converter from objects to feature maps.     * @param corpus Corpus of training data.     * @param minFeatureCount Minimum count for features in corpus to     * keep feature as part of model.     * @param addInterceptFeature A flag set to <code>true</code> if     * an intercept feature should be added to each input vector.     * @param prior The prior for regularization of the regression.     * @param annealingSchedule Class to compute learning rate for each epoch.     * @param minImprovement Minimum relative improvement in error during     * an epoch to stop search.     * @param minEpochs Minimum number of search epochs.     * @param maxEpochs Maximum number of epochs.     * @param progressWriter Writer to which progress reports are written.     * and checks for termination.     * @throws IOException If there is an underlying I/O exception     * reading the data from the corpus.     */    public static <F> LogisticRegressionClassifier<F>        train(FeatureExtractor<? super F> featureExtractor,              Corpus<ClassificationHandler<F,Classification>> corpus,              int minFeatureCount,              boolean addInterceptFeature,              RegressionPrior prior,              AnnealingSchedule annealingSchedule,              double minImprovement,              int minEpochs,              int maxEpochs,              PrintWriter progressWriter) throws IOException {        MapSymbolTable featureSymbolTable = new MapSymbolTable();        MapSymbolTable categorySymbolTable = new MapSymbolTable();        if (addInterceptFeature)            featureSymbolTable.getOrAddSymbol(INTERCEPT_FEATURE_NAME);        ObjectToCounterMap<String> featureCounter = new ObjectToCounterMap<String>();        corpus.visitTrain(new FeatureCounter<F>(featureExtractor,featureCounter));                featureCounter.prune(minFeatureCount);        for (String feature : featureCounter.keySet())            featureSymbolTable.getOrAddSymbol(feature);        DataExtractor<F> dataExtractor            = new DataExtractor<F>(featureExtractor,                                   featureSymbolTable,                                   categorySymbolTable,                                   addInterceptFeature,                                   featureSymbolTable.numSymbols());        corpus.visitTrain(dataExtractor);        Vector[] inputs = dataExtractor.inputs();        int[] categories = dataExtractor.categories();        LogisticRegression model            = LogisticRegression.estimate(inputs,categories,                                          prior,                                          annealingSchedule,                                          minImprovement,                                          minEpochs,maxEpochs,                                          progressWriter);        String[] categorySymbols = new String[categorySymbolTable.numSymbols()];        for (int i = 0; i < categorySymbols.length; ++i)            categorySymbols[i] = categorySymbolTable.idToSymbol(i);        return new LogisticRegressionClassifier<F>(model,                                                   featureExtractor,                                                   addInterceptFeature,                                                   featureSymbolTable,                                                   categorySymbols);            }    static class FeatureCounter<H> implements ClassificationHandler<H,Classification> {        private final FeatureExtractor<? super H> mFeatureExtractor;        private final ObjectToCounterMap<String> mFeatureCounter;        FeatureCounter(FeatureExtractor<? super H> featureExtractor,                       ObjectToCounterMap<String> featureCounter) {            mFeatureExtractor = featureExtractor;            mFeatureCounter = featureCounter;        }        public void handle(H h, Classification c) {            Map<String,? extends Number> featureMap = mFeatureExtractor.features(h);            for (String feature : featureMap.keySet()) {                mFeatureCounter.increment(feature);            }        }    }    static class Externalizer<G> extends AbstractExternalizable {        static final long serialVersionUID = -2003123148721825458L;        final LogisticRegressionClassifier mClassifier;        public Externalizer() {            this(null);        }        public Externalizer(LogisticRegressionClassifier classifier) {            mClassifier = classifier;        }        public void writeExternal(ObjectOutput objOut) throws IOException {            objOut.writeObject(mClassifier.mModel);            objOut.writeObject(mClassifier.mFeatureExtractor);            objOut.writeBoolean(mClassifier.mAddInterceptFeature);            objOut.writeObject(mClassifier.mFeatureSymbolTable);            objOut.writeInt(mClassifier.mCategorySymbols.length);            for (int i = 0; i < mClassifier.mCategorySymbols.length; ++i)                objOut.writeUTF(mClassifier.mCategorySymbols[i]);        }        public Object read(ObjectInput objIn) throws IOException, ClassNotFoundException {            LogisticRegression model = (LogisticRegression) objIn.readObject();            FeatureExtractor<? super G> featureExtractor                = (FeatureExtractor<? super G>) objIn.readObject();            boolean addInterceptFeature = objIn.readBoolean();            SymbolTable featureSymbolTable = (SymbolTable) objIn.readObject();            int numSymbols = objIn.readInt();            String[] categorySymbols = new String[numSymbols];            for (int i = 0; i < categorySymbols.length; ++i)                categorySymbols[i] = objIn.readUTF();            return new LogisticRegressionClassifier(model,                                                    featureExtractor,                                                    addInterceptFeature,                                                    featureSymbolTable,                                                    categorySymbols);        }    }    static class DataExtractor<F> implements ClassificationHandler<F,Classification> {        final FeatureExtractor<? super F> mFeatureExtractor;        final SymbolTable mFeatureSymbolTable;        final SymbolTable mCategorySymbolTable;        final boolean mAddInterceptFeature;        final int mNumSymbols;        final List<Vector> mInputVectorList = new ArrayList<Vector>();        final List<Integer> mOutputCategoryList = new ArrayList<Integer>();        // if has intercept, already added        DataExtractor(FeatureExtractor<? super F> featureExtractor,                      SymbolTable featureSymbolTable,                      SymbolTable categorySymbolTable,                      boolean addInterceptFeature,                      int numSymbols) {            mFeatureExtractor = featureExtractor;            mFeatureSymbolTable = featureSymbolTable;            mCategorySymbolTable = categorySymbolTable;            mAddInterceptFeature = addInterceptFeature;            mNumSymbols = numSymbols;        }        public void handle(F input, Classification output) {            String outputCategoryName = output.bestCategory();            Integer outputCategoryId = mCategorySymbolTable.getOrAddSymbol(outputCategoryName);            Map<String,? extends Number> featureMap = mFeatureExtractor.features(input);            SparseFloatVector vector                = PerceptronClassifier                .toVector(featureMap,                          mFeatureSymbolTable,                          mNumSymbols,                          mAddInterceptFeature);            mInputVectorList.add(vector);            mOutputCategoryList.add(outputCategoryId);        }        int[] categories() {            int[] inputs = new int[mOutputCategoryList.size()];            for (int i = 0; i < inputs.length; ++i)                inputs[i] = mOutputCategoryList.get(i).intValue();            return inputs;        }        Vector[] inputs() {            return (Vector[]) mInputVectorList.<Vector>toArray(new Vector[mInputVectorList.size()]);        }    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -