📄 ibkmetric.java
字号:
/** * Set the value of Debug. * * @param newDebug Value to assign to Debug. */ public void setDebug(boolean newDebug) { m_Debug = newDebug; } /** * Set the number of neighbours the learner is to use. * * @param k the number of neighbours. */ public void setKNN(int k) { m_kNN = k; m_kNNUpper = k; m_kNNValid = false; } /** * Gets the number of neighbours the learner will use. * * @return the number of neighbours. */ public int getKNN() { return m_kNN; } /** * Gets the maximum number of instances allowed in the training * pool. The addition of new instances above this value will result * in old instances being removed. A value of 0 signifies no limit * to the number of training instances. * * @return Value of WindowSize. */ public int getWindowSize() { return m_WindowSize; } /** * Sets the maximum number of instances allowed in the training * pool. The addition of new instances above this value will result * in old instances being removed. A value of 0 signifies no limit * to the number of training instances. * * @param newWindowSize Value to assign to WindowSize. */ public void setWindowSize(int newWindowSize) { m_WindowSize = newWindowSize; } /** * Gets the distance weighting method used. Will be one of * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY * * @return the distance weighting method used. */ public SelectedTag getDistanceWeighting() { return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING); } /** * Sets the distance weighting method used. Values other than * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored. * * @param newDistanceWeighting the distance weighting method to use */ public void setDistanceWeighting(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_WEIGHTING) { m_DistanceWeighting = newMethod.getSelectedTag().getID(); } } /** * Gets whether the mean squared error is used rather than mean * absolute error when doing cross-validation. * * @return true if so. */ public boolean getMeanSquared() { return m_MeanSquared; } /** * Sets whether the mean squared error is used rather than mean * absolute error when doing cross-validation. * * @param newMeanSquared true if so. */ public void setMeanSquared(boolean newMeanSquared) { m_MeanSquared = newMeanSquared; } /** * Gets whether hold-one-out cross-validation will be used * to select the best k value * * @return true if cross-validation will be used. */ public boolean getCrossValidate() { return m_CrossValidate; } /** * Sets whether hold-one-out cross-validation will be used * to select the best k value * * @param newCrossValidate true if cross-validation should be used. */ public void setCrossValidate(boolean newCrossValidate) { m_CrossValidate = newCrossValidate; } /** * Get the number of training instances the classifier is currently using */ public int getNumTraining() { return m_Train.numInstances(); } /** * Get an attributes minimum observed value */ public double getAttributeMin(int index) throws Exception { if (m_Min == null) { throw new Exception("Minimum value for attribute not available!"); } return m_Min[index]; } /** * Get an attributes maximum observed value */ public double getAttributeMax(int index) throws Exception { if (m_Max == null) { throw new Exception("Maximum value for attribute not available!"); } return m_Max[index]; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (instances.classIndex() < 0) { throw new Exception ("No class attribute assigned to instances"); } if (instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } try { m_NumClasses = instances.numClasses(); m_ClassType = instances.classAttribute().type(); } catch (Exception ex) { throw new Error("This should never be reached"); } // Throw away training instances with missing class m_Train = new Instances(instances, 0, instances.numInstances()); m_Train.deleteWithMissingClass(); // Throw away initial instances until within the specified window size if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) { m_Train = new Instances(m_Train, m_Train.numInstances()-m_WindowSize, m_WindowSize); } // Compute the number of attributes that contribute // to each prediction m_NumAttributesUsed = 0.0; for (int i = 0; i < m_Train.numAttributes(); i++) { if ((i != m_Train.classIndex()) && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) { m_NumAttributesUsed += 1.0; } } // Invalidate any currently cross-validation selected k m_kNNValid = false; // Train the distance metric m_metric.buildMetric(instances); } /** * Adds the supplied instance to the training set * * @param instance the instance to add * @exception Exception if instance could not be incorporated * successfully */ public void updateClassifier(Instance instance) throws Exception { if (m_Train.equalHeaders(instance.dataset()) == false) { throw new Exception("Incompatible instance types"); } if (instance.classIsMissing()) { return; } m_Train.add(instance); m_kNNValid = false; if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } } } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if an error occurred during the prediction */ public double [] distributionForInstance(Instance instance) throws Exception{ if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { m_kNNValid = false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } } // Select k by cross validation if (!m_kNNValid && (m_CrossValidate) && (m_kNN > 1)) { crossValidate(); } NeighborList neighborlist = findNeighbors(instance); return makeDistribution(neighborlist); } /** * Set the distance metric * * @param s the metric */ public void setMetric (Metric m) { m_metric = m; m_MetricName = m_metric.getClass().getName(); } /** * Get the distance metric * * @returns the distance metric used */ public Metric getMetric () { return m_metric; } /** * Set the distance metric * * @param metricName the name of the distance metric that should be used */ public void setMetricName (String metricName) { try { m_MetricName = metricName; m_metric = (Metric) Class.forName(m_MetricName).newInstance(); } catch (Exception e) { System.err.println("Error instantiating metric " + m_MetricName); } } /** * Get the name of the distance metric that is used * Avoid the 'get' prefix so that this doesn't show in the dialogs * * @returns the name of the distance metric */ public String metricName () { return m_MetricName; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(8); newVector.addElement(new Option( "\tWeight neighbours by the inverse of their distance\n" +"\t(use when k > 1)", "D", 0, "-D")); newVector.addElement(new Option( "\tWeight neighbours by 1 - their distance\n" +"\t(use when k > 1)", "F", 0, "-F")); newVector.addElement(new Option( "\tNumber of nearest neighbours (k) used in classification.\n" +"\t(Default = 1)", "K", 1,"-K <number of neighbors>")); newVector.addElement(new Option( "\tMinimise mean squared error rather than mean absolute\n" +"\terror when using -X option with numeric prediction.", "S", 0,"-S")); newVector.addElement(new Option( "\tMaximum number of training instances maintained.\n" +"\tTraining instances are dropped FIFO. (Default = no window)", "W", 1,"-W <window size>")); newVector.addElement(new Option( "\tSelect the number of nearest neighbours between 1\n" +"\tand the k value specified using hold-one-out evaluation\n" +"\ton the training data (use when k > 1)", "X", 0,"-X")); newVector.addElement(new Option( "\tUse a specific distance metric. (Default=WeightedDotP)\n", "M", 1, "-M")); return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -K num <br> * Set the number of nearest neighbors to use in prediction * (default 1) <p> * * -W num <br> * Set a fixed window size for incremental train/testing. As * new training instances are added, oldest instances are removed * to maintain the number of training instances at this size. * (default no window) <p> * * -D <br> * Neighbors will be weighted by the inverse of their distance * when voting. (default equal weighting) <p> * * -F <br> * Neighbors will be weighted by their similarity when voting. * (default equal weighting) <p> *
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -