📄 tfidfclassifiertrainer.java
字号:
Object writeReplace() { return new Serializer<E>(this); } static double idf(double docFrequency, double numDocs) { return Math.log(numDocs/docFrequency); } static double tf(double count) { return Math.sqrt(count); } static class Externalizer<F> extends AbstractExternalizable { static final long serialVersionUID = 5578122239615646843L; final TfIdfClassifierTrainer<F> mTrainer; public Externalizer() { this(null); } public Externalizer(TfIdfClassifierTrainer trainer) { mTrainer = trainer; } public void writeExternal(ObjectOutput out) throws IOException { // Feature Extractor AbstractExternalizable .compileOrSerialize(mTrainer.mFeatureExtractor,out); // Feature Symbol Table int numFeatures = mTrainer.mFeatureSymbolTable.numSymbols(); mTrainer.mFeatureSymbolTable.compileTo(out); int numCats = mTrainer.mCategorySymbolTable.numSymbols(); double numCatsD = (double) numCats; // num cats out.writeInt(numCats); // string categories [* num cats] for (int i = 0; i < numCats; ++i) out.writeUTF(mTrainer.mCategorySymbolTable.idToSymbol(i)); // idfs [* num features] for (int i = 0; i < mTrainer.mFeatureSymbolTable.numSymbols(); ++i) { int docFrequency = mTrainer.mFeatureToCategoryCount.get(i).size(); float idf = (float) idf(docFrequency,numCatsD); out.writeFloat(idf); } // feature offset [* numFeatures] int nextFeatureOffset = 0; for (int i = 0; i < numFeatures; ++i) { out.writeInt(nextFeatureOffset); int featureSize = mTrainer.mFeatureToCategoryCount.get(i).size(); nextFeatureOffset += featureSize; } // catId/tfIdf array sizes out.writeInt(nextFeatureOffset); double[] catLengths = new double[numCats]; for (Map.Entry<Integer,ObjectToDoubleMap<Integer>> entry : mTrainer.mFeatureToCategoryCount.entrySet()) { int featureId = entry.getKey().intValue(); ObjectToDoubleMap<Integer> categoryCounts = entry.getValue(); double idf = idf(categoryCounts.size(),numCatsD); for (Map.Entry<Integer,Double> categoryCount : categoryCounts.entrySet()) { int catId = categoryCount.getKey().intValue(); double count = categoryCount.getValue().doubleValue(); double tfIdf = tf(count) * idf; catLengths[catId] += tfIdf * tfIdf; } } for (int i = 0; i < catLengths.length; ++i) catLengths[i] = Math.sqrt(catLengths[i]); // catId, normedTfIdf [* array size] int nextCategoryCountIndex = 0; for (int featureId = 0; featureId < numFeatures; ++featureId) { ObjectToDoubleMap<Integer> categoryCounts = mTrainer.mFeatureToCategoryCount.get(featureId); double idf = idf(categoryCounts.size(),numCatsD); for (Map.Entry<Integer,Double> categoryCount : categoryCounts.entrySet()) { int catId = categoryCount.getKey().intValue(); double count = categoryCount.getValue().doubleValue(); float tfIdf = (float) ((tf(count) * idf) / catLengths[catId]); out.writeInt(catId); out.writeFloat(tfIdf); } } } public Object read(ObjectInput objIn) throws ClassNotFoundException, IOException { FeatureExtractor<F> featureExtractor = (FeatureExtractor<F>) objIn.readObject(); MapSymbolTable featureSymbolTable = (MapSymbolTable) objIn.readObject(); int numFeatures = featureSymbolTable.numSymbols(); int numCategories = objIn.readInt(); String[] categories = new String[numCategories]; for (int i = 0; i < numCategories; ++i) categories[i] = objIn.readUTF(); float[] featureIdfs = new float[featureSymbolTable.numSymbols()]; for (int i = 0; i < featureIdfs.length; ++i) featureIdfs[i] = objIn.readFloat(); int[] featureOffsets = new int[numFeatures + 1]; for (int i = 0; i < numFeatures; ++i) featureOffsets[i] = objIn.readInt(); int catIdTfIdfArraySize = objIn.readInt(); featureOffsets[featureOffsets.length-1] = catIdTfIdfArraySize; int[] catIds = new int[catIdTfIdfArraySize]; float[] normedTfIdfs = new float[catIdTfIdfArraySize]; for (int i = 0; i < catIdTfIdfArraySize; ++i) { catIds[i] = objIn.readInt(); normedTfIdfs[i] = objIn.readFloat(); } return new TfIdfClassifier<F>(featureExtractor, featureSymbolTable, categories, featureIdfs, featureOffsets, catIds, normedTfIdfs); } } static class TfIdfClassifier<G> implements Classifier<G,ScoredClassification> { final FeatureExtractor<G> mFeatureExtractor; final MapSymbolTable mFeatureSymbolTable; final String[] mCategories; // parallel (mFeatureIdfs, mFeatureIndexes) final float[] mFeatureIdfs; final int[] mFeatureOffsets; // parallel (mCategoryIds, mTfIdfs) final int[] mCategoryIds; final float[] mTfIdfs; TfIdfClassifier(FeatureExtractor featureExtractor, MapSymbolTable featureSymbolTable, String[] categories, float[] featureIdfs, int[] featureOffsets, int[] categoryIds, float[] tfIdfs) { mFeatureExtractor = featureExtractor; mFeatureSymbolTable = featureSymbolTable; mCategories = categories; mFeatureIdfs = featureIdfs; mFeatureOffsets = featureOffsets; mCategoryIds = categoryIds; mTfIdfs = tfIdfs; } public String toString() { StringBuilder sb = new StringBuilder(); sb.append("TfIdfClassifierTrainer.TfIdfClassifier\n"); sb.append("Feature Symbol Table\n "); sb.append(mFeatureSymbolTable.toString()); sb.append("\n"); sb.append("Categories\n"); for (int i = 0; i < mCategories.length; ++i) sb.append(" " + i + "=" + mCategories[i] + "\n"); sb.append("Index Feature IDF offset\n"); for (int i = 0; i < mFeatureIdfs.length; ++i) { sb.append(" " + i + " " + mFeatureSymbolTable.idToSymbol(i) + " " + mFeatureIdfs[i] + " " + mFeatureOffsets[i] + "\n"); } sb.append("Index CategoryID TF-IDF\n"); for (int i = 0; i < mCategoryIds.length; ++i) { sb.append(" " + i + " " + mCategoryIds[i] + " " + mTfIdfs[i] + "\n"); } return sb.toString(); } public ScoredClassification classify(G in) { Map<String,? extends Number> featureVector = mFeatureExtractor.features(in); double[] scores = new double[mCategories.length]; double inputLengthSquared = 0.0; for (Map.Entry<String,? extends Number> featureValue : featureVector.entrySet()) { String feature = featureValue.getKey(); int featureId = mFeatureSymbolTable.symbolToID(feature); if (featureId == -1) continue; double inputTf = tf(featureValue.getValue().doubleValue()); double inputIdf = mFeatureIdfs[featureId]; double inputTfIdf = inputTf * inputIdf; inputLengthSquared += inputTfIdf * inputTfIdf; for (int offset = mFeatureOffsets[featureId]; offset < mFeatureOffsets[featureId+1]; ++offset) { int categoryId = mCategoryIds[offset]; double docNormedTfIdf = mTfIdfs[offset]; scores[categoryId] += docNormedTfIdf * inputTfIdf; } } double inputLength = Math.sqrt(inputLengthSquared); ScoredObject<String>[] categoryScores = (ScoredObject<String>[]) new ScoredObject[mCategories.length]; for (int i = 0; i < categoryScores.length; ++i) { double score = scores[i] / inputLength; // cosine norm for input categoryScores[i] = new ScoredObject(mCategories[i],score); } return ScoredClassification.create(categoryScores); } } static class Serializer<F> extends AbstractExternalizable { static final long serialVersionUID = -4757808688956812832L; final TfIdfClassifierTrainer<F> mTrainer; public Serializer() { this(null); } public Serializer(TfIdfClassifierTrainer trainer) { mTrainer = trainer; } public void writeExternal(ObjectOutput out) throws IOException { AbstractExternalizable .serializeOrCompile(mTrainer.mFeatureExtractor,out); out.writeObject(mTrainer.mFeatureToCategoryCount); out.writeObject(mTrainer.mFeatureSymbolTable); out.writeObject(mTrainer.mCategorySymbolTable); } public Object read(ObjectInput objIn) throws ClassNotFoundException, IOException { FeatureExtractor<F> featureExtractor = (FeatureExtractor<F>) objIn.readObject(); Map<Integer,ObjectToDoubleMap<Integer>> featureToCategoryCount = (Map<Integer,ObjectToDoubleMap<Integer>>) objIn.readObject(); MapSymbolTable featureSymbolTable = (MapSymbolTable) objIn.readObject(); MapSymbolTable categorySymbolTable = (MapSymbolTable) objIn.readObject(); return new TfIdfClassifierTrainer(featureExtractor, featureToCategoryCount, featureSymbolTable, categorySymbolTable); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -