📄 logisticregressionclassifier.java
字号:
* @throws IllegalArgumentException If the category is unknown. */ public ObjectToDoubleMap<String> featureValues(String category) { int categoryId = categoryToId(category); if (categoryId < 0) { String msg = "Unknown category=" + category; throw new IllegalArgumentException(msg); } ObjectToDoubleMap<String> result = new ObjectToDoubleMap<String>(); if (categoryId == mCategorySymbols.length-1) return result; int numSymbols = mFeatureSymbolTable.numSymbols(); Vector[] weightVectors = mModel.weightVectors(); Vector weightVector = weightVectors[categoryId]; for (int i = 0; i < numSymbols; ++i) { String symbol = mFeatureSymbolTable.idToSymbol(i); result.set(symbol,weightVector.value(i)); } return result; } /** * Returns a string-based representation of this classifier, * listing the parameter vectors for each category. * * @return A string-based representation of this classifier. */ public String toString() { CharArrayWriter writer = new CharArrayWriter(); PrintWriter printWriter = new PrintWriter(writer); List<String> categorySymbols = categorySymbols(); printWriter.println("NUMBER OF CATEGORIES=" + categorySymbols.size()); printWriter.println("NUMBER OF FEATURES=" + mFeatureSymbolTable.numSymbols()); for (int i = 0; i < categorySymbols.size()-1; ++i) { String category = categorySymbols.get(i); printWriter.println("\n CATEGORY=" + category); ObjectToDoubleMap<String> parameterVector = featureValues(category); for (String feature : parameterVector.keysOrderedByValueList()) printWriter.printf("%20s %15.6f\n",feature,parameterVector.get(feature)); } printWriter.write('\n'); return writer.toString(); } private Object writeReplace() { return new Externalizer<E>(this); } static final String INTERCEPT_FEATURE_NAME = "*&^INTERCEPT%$^&**"; /** * Returns a trained logistic regression classifier given the specified * feature extractor, corpus, model priors and search parameters. * * <p>Only the training section of the specified corpus is used for * training. * * <p>See the class documentation above and the class * documentation for {@link LogisticRegression} for more * information on the parameters. * * @param featureExtractor Converter from objects to feature maps. * @param corpus Corpus of training data. * @param minFeatureCount Minimum count for features in corpus to * keep feature as part of model. * @param addInterceptFeature A flag set to <code>true</code> if * an intercept feature should be added to each input vector. * @param prior The prior for regularization of the regression. * @param annealingSchedule Class to compute learning rate for each epoch. * @param minImprovement Minimum relative improvement in error during * an epoch to stop search. * @param minEpochs Minimum number of search epochs. * @param maxEpochs Maximum number of epochs. * @param progressWriter Writer to which progress reports are written. * and checks for termination. * @throws IOException If there is an underlying I/O exception * reading the data from the corpus. */ public static <F> LogisticRegressionClassifier<F> train(FeatureExtractor<? super F> featureExtractor, Corpus<ClassificationHandler<F,Classification>> corpus, int minFeatureCount, boolean addInterceptFeature, RegressionPrior prior, AnnealingSchedule annealingSchedule, double minImprovement, int minEpochs, int maxEpochs, PrintWriter progressWriter) throws IOException { MapSymbolTable featureSymbolTable = new MapSymbolTable(); MapSymbolTable categorySymbolTable = new MapSymbolTable(); if (addInterceptFeature) featureSymbolTable.getOrAddSymbol(INTERCEPT_FEATURE_NAME); ObjectToCounterMap<String> featureCounter = new ObjectToCounterMap<String>(); corpus.visitTrain(new FeatureCounter<F>(featureExtractor,featureCounter)); featureCounter.prune(minFeatureCount); for (String feature : featureCounter.keySet()) featureSymbolTable.getOrAddSymbol(feature); DataExtractor<F> dataExtractor = new DataExtractor<F>(featureExtractor, featureSymbolTable, categorySymbolTable, addInterceptFeature, featureSymbolTable.numSymbols()); corpus.visitTrain(dataExtractor); Vector[] inputs = dataExtractor.inputs(); int[] categories = dataExtractor.categories(); LogisticRegression model = LogisticRegression.estimate(inputs,categories, prior, annealingSchedule, minImprovement, minEpochs,maxEpochs, progressWriter); String[] categorySymbols = new String[categorySymbolTable.numSymbols()]; for (int i = 0; i < categorySymbols.length; ++i) categorySymbols[i] = categorySymbolTable.idToSymbol(i); return new LogisticRegressionClassifier<F>(model, featureExtractor, addInterceptFeature, featureSymbolTable, categorySymbols); } static class FeatureCounter<H> implements ClassificationHandler<H,Classification> { private final FeatureExtractor<? super H> mFeatureExtractor; private final ObjectToCounterMap<String> mFeatureCounter; FeatureCounter(FeatureExtractor<? super H> featureExtractor, ObjectToCounterMap<String> featureCounter) { mFeatureExtractor = featureExtractor; mFeatureCounter = featureCounter; } public void handle(H h, Classification c) { Map<String,? extends Number> featureMap = mFeatureExtractor.features(h); for (String feature : featureMap.keySet()) { mFeatureCounter.increment(feature); } } } static class Externalizer<G> extends AbstractExternalizable { static final long serialVersionUID = -2003123148721825458L; final LogisticRegressionClassifier mClassifier; public Externalizer() { this(null); } public Externalizer(LogisticRegressionClassifier classifier) { mClassifier = classifier; } public void writeExternal(ObjectOutput objOut) throws IOException { objOut.writeObject(mClassifier.mModel); objOut.writeObject(mClassifier.mFeatureExtractor); objOut.writeBoolean(mClassifier.mAddInterceptFeature); objOut.writeObject(mClassifier.mFeatureSymbolTable); objOut.writeInt(mClassifier.mCategorySymbols.length); for (int i = 0; i < mClassifier.mCategorySymbols.length; ++i) objOut.writeUTF(mClassifier.mCategorySymbols[i]); } public Object read(ObjectInput objIn) throws IOException, ClassNotFoundException { LogisticRegression model = (LogisticRegression) objIn.readObject(); FeatureExtractor<? super G> featureExtractor = (FeatureExtractor<? super G>) objIn.readObject(); boolean addInterceptFeature = objIn.readBoolean(); SymbolTable featureSymbolTable = (SymbolTable) objIn.readObject(); int numSymbols = objIn.readInt(); String[] categorySymbols = new String[numSymbols]; for (int i = 0; i < categorySymbols.length; ++i) categorySymbols[i] = objIn.readUTF(); return new LogisticRegressionClassifier(model, featureExtractor, addInterceptFeature, featureSymbolTable, categorySymbols); } } static class DataExtractor<F> implements ClassificationHandler<F,Classification> { final FeatureExtractor<? super F> mFeatureExtractor; final SymbolTable mFeatureSymbolTable; final SymbolTable mCategorySymbolTable; final boolean mAddInterceptFeature; final int mNumSymbols; final List<Vector> mInputVectorList = new ArrayList<Vector>(); final List<Integer> mOutputCategoryList = new ArrayList<Integer>(); // if has intercept, already added DataExtractor(FeatureExtractor<? super F> featureExtractor, SymbolTable featureSymbolTable, SymbolTable categorySymbolTable, boolean addInterceptFeature, int numSymbols) { mFeatureExtractor = featureExtractor; mFeatureSymbolTable = featureSymbolTable; mCategorySymbolTable = categorySymbolTable; mAddInterceptFeature = addInterceptFeature; mNumSymbols = numSymbols; } public void handle(F input, Classification output) { String outputCategoryName = output.bestCategory(); Integer outputCategoryId = mCategorySymbolTable.getOrAddSymbol(outputCategoryName); Map<String,? extends Number> featureMap = mFeatureExtractor.features(input); SparseFloatVector vector = PerceptronClassifier .toVector(featureMap, mFeatureSymbolTable, mNumSymbols, mAddInterceptFeature); mInputVectorList.add(vector); mOutputCategoryList.add(outputCategoryId); } int[] categories() { int[] inputs = new int[mOutputCategoryList.size()]; for (int i = 0; i < inputs.length; ++i) inputs[i] = mOutputCategoryList.get(i).intValue(); return inputs; } Vector[] inputs() { return (Vector[]) mInputVectorList.<Vector>toArray(new Vector[mInputVectorList.size()]); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -