📄 knnclassifier.java
字号:
* * @param featureExtractor Feature extractor for training and * classification instances. * @param k Maximum number of neighbors to use during * classification. * @param distance Distance function to use to compare examples. */ public KnnClassifier(FeatureExtractor<? super E> featureExtractor, int k, Distance<Vector> distance) { this(featureExtractor,k,new ProximityWrapper(distance),false); } /** * Construct a k-nearest-neighbor classifier based on the * specified feature extractor, specified maximum number of * neighbors, specified proximity function, and boolean flag * indicating whether or not to weight nearest neighbors by proximity * during classification. * * @param featureExtractor Feature extractor for training and * classification instances. * @param k Maximum number of neighbors to use during * classification. * @param proximity Proximity function to compare examples. * @param weightByProximity Flag indicating whether to weight * neighbors by proximity during classification. */ public KnnClassifier(FeatureExtractor<? super E> featureExtractor, int k, Proximity<Vector> proximity, boolean weightByProximity) { mFeatureExtractor = featureExtractor; mK = k; mProximity = proximity; mWeightByProximity = weightByProximity; mFeatureSymbolTable = new MapSymbolTable(); mCategorySymbolTable = new MapSymbolTable(); mTrainingCategories = new ArrayList<Integer>(); mTrainingVectors = new ArrayList<SparseFloatVector>(); } /** * Handle the specified classified training instance. The * training instance is converted to a feature vector using the * feature extractor, and then stored as a sparse vector relative * to a feature symbol table. * * @param trainingInstance Object being classified during training. * @param classification Classification for specified object. */ public void handle(E trainingInstance, Classification classification) { String category = classification.bestCategory(); Map<String,? extends Number> featureMap = mFeatureExtractor.features(trainingInstance); SparseFloatVector vector = PerceptronClassifier .toVectorAddSymbols(featureMap, mFeatureSymbolTable, Integer.MAX_VALUE-1); mTrainingCategories.add(mCategorySymbolTable.getOrAddSymbolInteger(category)); mTrainingVectors.add(vector); } /** * Return the k-nearest-neighbor classification result for the * specified input object. The resulting classification will have * all of the categories defined, though those with no support in * the nearest neighbors will have scores of zero. * * <p>If this classifier does not weight by proximity, the * resulting score for a category will be the number of nearest * neighbors of the specified category. That is, it will be a * straight vote. * * <p>If the classifier does weight by proximity, the resulting * score for a category will be the sum of the proximity scores * for the nearest neighbors of a given category. Instances with * no near neighbors will be scored zero (<code>0</code>). Thus * proximities should be configured to return positive values. * * @param in Object to classify. * @return Scored classification for the specified object. */ public ScoredClassification classify(E in) { Map<String,? extends Number> featureMap = mFeatureExtractor.features(in); SparseFloatVector inputVector = PerceptronClassifier .toVector(featureMap, mFeatureSymbolTable, Integer.MAX_VALUE-1); BoundedPriorityQueue<ScoredObject<Integer>> queue = new BoundedPriorityQueue<ScoredObject<Integer>>(ScoredObject.SCORE_COMPARATOR, mK); for (int i = 0; i < mTrainingCategories.size(); ++i) { Integer catId = mTrainingCategories.get(i); SparseFloatVector trainingVector = mTrainingVectors.get(i); double score = mProximity.proximity(inputVector,trainingVector); queue.add(new ScoredObject<Integer>(catId,score)); } int numCats = mCategorySymbolTable.numSymbols(); double[] scores = new double[numCats]; for (ScoredObject<Integer> catScore : queue) { int key = catScore.getObject().intValue(); double score = catScore.score(); scores[key] += mWeightByProximity ? score : 1.0; } ScoredObject<String>[] categoryScores = (ScoredObject<String>[]) new ScoredObject[numCats]; for (int i = 0; i < numCats; ++i) categoryScores[i] = new ScoredObject<String>(mCategorySymbolTable.idToSymbol(i), scores[i]); return ScoredClassification.create(categoryScores); } Object writeReplace() { return new Serializer<E>(this); } /** * Compiles this k-nearest-neighbor classifier to the specified * object output stream. * * <p>This is only a convenience method. It provides exactly * the same function as standard serialization. * * @param out Output stream to which this classifier is written. * @throws IOException If there is an underlying I/O exception * during compilation. */ public void compileTo(ObjectOutput out) throws IOException { out.writeObject(writeReplace()); } static class Serializer<F> extends AbstractExternalizable { static final long serialVersionUID = 4951969636521202268L; final KnnClassifier<F> mClassifier; public Serializer() { this(null); } public Serializer(KnnClassifier<F> classifier) { mClassifier = classifier; } public void writeExternal(ObjectOutput out) throws IOException { AbstractExternalizable.serializeOrCompile(mClassifier.mFeatureExtractor,out); out.writeInt(mClassifier.mK); AbstractExternalizable.serializeOrCompile(mClassifier.mProximity,out); out.writeBoolean(mClassifier.mWeightByProximity); int numInstances = mClassifier.mTrainingCategories.size(); out.writeInt(numInstances); List<Integer> catList = mClassifier.mTrainingCategories; for (int i = 0; i < numInstances; ++i) out.writeInt(mClassifier.mTrainingCategories.get(i).intValue()); for (int i = 0; i < numInstances; ++i) AbstractExternalizable.serializeOrCompile(mClassifier.mTrainingVectors.get(i), out); AbstractExternalizable.serializeOrCompile(mClassifier.mFeatureSymbolTable,out); AbstractExternalizable.serializeOrCompile(mClassifier.mCategorySymbolTable,out); } public Object read(ObjectInput in) throws ClassNotFoundException, IOException { FeatureExtractor<? super F> featureExtractor = (FeatureExtractor<? super F>) in.readObject(); int k = in.readInt(); Proximity<Vector> proximity = (Proximity<Vector>) in.readObject(); boolean weightByProximity = in.readBoolean(); int numInstances = in.readInt(); List<Integer> categoryList = new ArrayList<Integer>(numInstances); for (int i = 0; i < numInstances; ++i) categoryList.add(new Integer(in.readInt())); List<SparseFloatVector> vectorList = new ArrayList<SparseFloatVector>(numInstances); for (int i = 0; i < numInstances; ++i) vectorList.add((SparseFloatVector) in.readObject()); MapSymbolTable featureSymbolTable = (MapSymbolTable) in.readObject(); MapSymbolTable categorySymbolTable = (MapSymbolTable) in.readObject(); return new KnnClassifier(featureExtractor,k, proximity,weightByProximity, categoryList,vectorList, featureSymbolTable,categorySymbolTable); } } static class ProximityWrapper implements Proximity<Vector>, Serializable { Distance<Vector> mDistance; public ProximityWrapper() { } public ProximityWrapper(Distance<Vector> distance) { mDistance = distance; } public double proximity(Vector v1, Vector v2) { double d = mDistance.distance(v1,v2); return (d < 0) ? Double.MAX_VALUE : (1.0/(1.0 + d)); } } static class TrainingInstance { final String mCategory; final SparseFloatVector mVector; TrainingInstance(String category, SparseFloatVector vector) { mCategory = category; mVector = vector; } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -