⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 knnclassifier.java

📁 一个自然语言处理的Java开源工具包。LingPipe目前已有很丰富的功能
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
     *     * @param featureExtractor Feature extractor for training and     * classification instances.     * @param k Maximum number of neighbors to use during     * classification.     * @param distance Distance function to use to compare examples.     */    public KnnClassifier(FeatureExtractor<? super E> featureExtractor,                         int k,                         Distance<Vector> distance) {        this(featureExtractor,k,new ProximityWrapper(distance),false);    }    /**     * Construct a k-nearest-neighbor classifier based on the     * specified feature extractor, specified maximum number of     * neighbors, specified proximity function, and boolean flag     * indicating whether or not to weight nearest neighbors by proximity     * during classification.     *     * @param featureExtractor Feature extractor for training and     * classification instances.     * @param k Maximum number of neighbors to use during     * classification.     * @param proximity Proximity function to compare examples.     * @param weightByProximity Flag indicating whether to weight     * neighbors by proximity during classification.     */    public KnnClassifier(FeatureExtractor<? super E> featureExtractor,                         int k,                         Proximity<Vector> proximity,                         boolean weightByProximity) {        mFeatureExtractor = featureExtractor;        mK = k;        mProximity = proximity;        mWeightByProximity = weightByProximity;        mFeatureSymbolTable = new MapSymbolTable();        mCategorySymbolTable = new MapSymbolTable();        mTrainingCategories = new ArrayList<Integer>();        mTrainingVectors = new ArrayList<SparseFloatVector>();    }    /**     * Handle the specified classified training instance.  The     * training instance is converted to a feature vector using the     * feature extractor, and then stored as a sparse vector relative     * to a feature symbol table.     *     * @param trainingInstance Object being classified during training.     * @param classification Classification for specified object.     */    public void handle(E trainingInstance, Classification classification) {        String category = classification.bestCategory();        Map<String,? extends Number> featureMap            = mFeatureExtractor.features(trainingInstance);        SparseFloatVector vector            = PerceptronClassifier            .toVectorAddSymbols(featureMap,                                mFeatureSymbolTable,                                Integer.MAX_VALUE-1);        mTrainingCategories.add(mCategorySymbolTable.getOrAddSymbolInteger(category));        mTrainingVectors.add(vector);    }    /**     * Return the k-nearest-neighbor classification result for the     * specified input object.  The resulting classification will have     * all of the categories defined, though those with no support in     * the nearest neighbors will have scores of zero.     *     * <p>If this classifier does not weight by proximity, the     * resulting score for a category will be the number of nearest     * neighbors of the specified category.  That is, it will be a     * straight vote.     *     * <p>If the classifier does weight by proximity, the resulting     * score for a category will be the sum of the proximity scores     * for the nearest neighbors of a given category.  Instances with     * no near neighbors will be scored zero (<code>0</code>).  Thus     * proximities should be configured to return positive values.     *     * @param in Object to classify.     * @return Scored classification for the specified object.     */    public ScoredClassification classify(E in) {        Map<String,? extends Number> featureMap            = mFeatureExtractor.features(in);        SparseFloatVector inputVector            = PerceptronClassifier            .toVector(featureMap,                      mFeatureSymbolTable,                      Integer.MAX_VALUE-1);        BoundedPriorityQueue<ScoredObject<Integer>> queue            = new BoundedPriorityQueue<ScoredObject<Integer>>(ScoredObject.SCORE_COMPARATOR,                                                              mK);        for (int i = 0; i < mTrainingCategories.size(); ++i) {            Integer catId = mTrainingCategories.get(i);            SparseFloatVector trainingVector = mTrainingVectors.get(i);            double score = mProximity.proximity(inputVector,trainingVector);            queue.add(new ScoredObject<Integer>(catId,score));        }        int numCats = mCategorySymbolTable.numSymbols();        double[] scores = new double[numCats];        for (ScoredObject<Integer> catScore : queue) {            int key = catScore.getObject().intValue();            double score = catScore.score();            scores[key] += mWeightByProximity ? score : 1.0;        }        ScoredObject<String>[] categoryScores            = (ScoredObject<String>[]) new ScoredObject[numCats];        for (int i = 0; i < numCats; ++i)            categoryScores[i]                = new ScoredObject<String>(mCategorySymbolTable.idToSymbol(i),                                           scores[i]);        return ScoredClassification.create(categoryScores);    }    Object writeReplace() {        return new Serializer<E>(this);    }    /**     * Compiles this k-nearest-neighbor classifier to the specified     * object output stream.     *     * <p>This is only a convenience method.  It provides exactly     * the same function as standard serialization.     *     * @param out Output stream to which this classifier is written.     * @throws IOException If there is an underlying I/O exception     * during compilation.     */    public void compileTo(ObjectOutput out) throws IOException {        out.writeObject(writeReplace());    }    static class Serializer<F> extends AbstractExternalizable {        static final long serialVersionUID = 4951969636521202268L;        final KnnClassifier<F> mClassifier;        public Serializer() {            this(null);        }        public Serializer(KnnClassifier<F> classifier) {            mClassifier = classifier;        }        public void writeExternal(ObjectOutput out) throws IOException {            AbstractExternalizable.serializeOrCompile(mClassifier.mFeatureExtractor,out);            out.writeInt(mClassifier.mK);            AbstractExternalizable.serializeOrCompile(mClassifier.mProximity,out);            out.writeBoolean(mClassifier.mWeightByProximity);            int numInstances = mClassifier.mTrainingCategories.size();            out.writeInt(numInstances);            List<Integer> catList = mClassifier.mTrainingCategories;            for (int i = 0; i < numInstances; ++i)                out.writeInt(mClassifier.mTrainingCategories.get(i).intValue());            for (int i = 0; i < numInstances; ++i)                AbstractExternalizable.serializeOrCompile(mClassifier.mTrainingVectors.get(i),                                                          out);            AbstractExternalizable.serializeOrCompile(mClassifier.mFeatureSymbolTable,out);            AbstractExternalizable.serializeOrCompile(mClassifier.mCategorySymbolTable,out);        }        public Object read(ObjectInput in)            throws ClassNotFoundException, IOException {            FeatureExtractor<? super F> featureExtractor                = (FeatureExtractor<? super F>) in.readObject();            int k = in.readInt();            Proximity<Vector> proximity                = (Proximity<Vector>) in.readObject();            boolean weightByProximity = in.readBoolean();            int numInstances = in.readInt();            List<Integer> categoryList = new ArrayList<Integer>(numInstances);            for (int i = 0; i < numInstances; ++i)                categoryList.add(new Integer(in.readInt()));            List<SparseFloatVector> vectorList                = new ArrayList<SparseFloatVector>(numInstances);            for (int i = 0; i < numInstances; ++i)                vectorList.add((SparseFloatVector) in.readObject());            MapSymbolTable featureSymbolTable                = (MapSymbolTable) in.readObject();            MapSymbolTable categorySymbolTable                = (MapSymbolTable) in.readObject();            return new KnnClassifier(featureExtractor,k,                                     proximity,weightByProximity,                                     categoryList,vectorList,                                     featureSymbolTable,categorySymbolTable);        }    }    static class ProximityWrapper        implements Proximity<Vector>, Serializable {        Distance<Vector> mDistance;        public ProximityWrapper() { }        public ProximityWrapper(Distance<Vector> distance) {            mDistance = distance;        }        public double proximity(Vector v1, Vector v2) {            double d = mDistance.distance(v1,v2);            return (d < 0) ? Double.MAX_VALUE : (1.0/(1.0 + d));        }    }    static class TrainingInstance {        final String mCategory;        final SparseFloatVector mVector;        TrainingInstance(String category, SparseFloatVector vector) {            mCategory = category;            mVector = vector;        }    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -