📄 ibk.java
字号:
*/ public int getWindowSize() { return m_WindowSize; } /** * Sets the maximum number of instances allowed in the training * pool. The addition of new instances above this value will result * in old instances being removed. A value of 0 signifies no limit * to the number of training instances. * * @param newWindowSize Value to assign to WindowSize. */ public void setWindowSize(int newWindowSize) { m_WindowSize = newWindowSize; } /** * Gets the distance weighting method used. Will be one of * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY * * @return the distance weighting method used. */ public SelectedTag getDistanceWeighting() { return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING); } /** * Sets the distance weighting method used. Values other than * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored. * * @param newDistanceWeighting the distance weighting method to use */ public void setDistanceWeighting(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_WEIGHTING) { m_DistanceWeighting = newMethod.getSelectedTag().getID(); } } /** * Gets whether the mean squared error is used rather than mean * absolute error when doing cross-validation. * * @return true if so. */ public boolean getMeanSquared() { return m_MeanSquared; } /** * Sets whether the mean squared error is used rather than mean * absolute error when doing cross-validation. * * @param newMeanSquared true if so. */ public void setMeanSquared(boolean newMeanSquared) { m_MeanSquared = newMeanSquared; } /** * Gets whether hold-one-out cross-validation will be used * to select the best k value * * @return true if cross-validation will be used. */ public boolean getCrossValidate() { return m_CrossValidate; } /** * Sets whether hold-one-out cross-validation will be used * to select the best k value * * @param newCrossValidate true if cross-validation should be used. */ public void setCrossValidate(boolean newCrossValidate) { m_CrossValidate = newCrossValidate; } /** * Get the number of training instances the classifier is currently using */ public int getNumTraining() { return m_Train.numInstances(); } /** * Get an attributes minimum observed value */ public double getAttributeMin(int index) throws Exception { if (m_Min == null) { throw new Exception("Minimum value for attribute not available!"); } return m_Min[index]; } /** * Get an attributes maximum observed value */ public double getAttributeMax(int index) throws Exception { if (m_Max == null) { throw new Exception("Maximum value for attribute not available!"); } return m_Max[index]; } /** * Gets whether normalization is turned off. * @return Value of DontNormalize. */ public boolean getNoNormalization() { return m_DontNormalize; } /** * Set whether normalization is turned off. * @param v Value to assign to DontNormalize. */ public void setNoNormalization(boolean v) { m_DontNormalize = v; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (instances.classIndex() < 0) { throw new Exception ("No class attribute assigned to instances"); } if (instances.checkForStringAttributes()) { throw new Exception("Can't handle string attributes!"); } try { m_NumClasses = instances.numClasses(); m_ClassType = instances.classAttribute().type(); } catch (Exception ex) { throw new Error("This should never be reached"); } // Throw away training instances with missing class m_Train = new Instances(instances, 0, instances.numInstances()); m_Train.deleteWithMissingClass(); // Throw away initial instances until within the specified window size if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) { m_Train = new Instances(m_Train, m_Train.numInstances()-m_WindowSize, m_WindowSize); } // Calculate the minimum and maximum values if (m_DontNormalize) { m_Min = null; m_Max = null; } else { m_Min = new double [m_Train.numAttributes()]; m_Max = new double [m_Train.numAttributes()]; for (int i = 0; i < m_Train.numAttributes(); i++) { m_Min[i] = m_Max[i] = Double.NaN; } Enumeration enum = m_Train.enumerateInstances(); while (enum.hasMoreElements()) { updateMinMax((Instance) enum.nextElement()); } } // Compute the number of attributes that contribute // to each prediction m_NumAttributesUsed = 0.0; for (int i = 0; i < m_Train.numAttributes(); i++) { if ((i != m_Train.classIndex()) && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) { m_NumAttributesUsed += 1.0; } } // Invalidate any currently cross-validation selected k m_kNNValid = false; } /** * Adds the supplied instance to the training set * * @param instance the instance to add * @exception Exception if instance could not be incorporated * successfully */ public void updateClassifier(Instance instance) throws Exception { if (m_Train.equalHeaders(instance.dataset()) == false) { throw new Exception("Incompatible instance types"); } if (instance.classIsMissing()) { return; } if (!m_DontNormalize) { updateMinMax(instance); } m_Train.add(instance); m_kNNValid = false; if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } } } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if an error occurred during the prediction */ public double [] distributionForInstance(Instance instance) throws Exception{ if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { m_kNNValid = false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } } // Select k by cross validation if (!m_kNNValid && (m_CrossValidate) && (m_kNN > 1)) { crossValidate(); } if (!m_DontNormalize) { updateMinMax(instance); } NeighborList neighborlist = findNeighbors(instance); return makeDistribution(neighborlist); } /** * Returns an enumeration describing the available options * * @return an enumeration of all the available options */ public Enumeration listOptions() { Vector newVector = new Vector(8); newVector.addElement(new Option( "\tWeight neighbours by the inverse of their distance\n" +"\t(use when k > 1)", "D", 0, "-D")); newVector.addElement(new Option( "\tWeight neighbours by 1 - their distance\n" +"\t(use when k > 1)", "F", 0, "-F")); newVector.addElement(new Option( "\tNumber of nearest neighbours (k) used in classification.\n" +"\t(Default = 1)", "K", 1,"-K <number of neighbors>")); newVector.addElement(new Option( "\tMinimise mean squared error rather than mean absolute\n" +"\terror when using -X option with numeric prediction.", "S", 0,"-S")); newVector.addElement(new Option( "\tMaximum number of training instances maintained.\n" +"\tTraining instances are dropped FIFO. (Default = no window)", "W", 1,"-W <window size>")); newVector.addElement(new Option( "\tSelect the number of nearest neighbours between 1\n" +"\tand the k value specified using hold-one-out evaluation\n" +"\ton the training data (use when k > 1)", "X", 0,"-X")); newVector.addElement(new Option( "\tDon't normalize the data.\n", "N", 0, "-N")); return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -K num <br> * Set the number of nearest neighbors to use in prediction * (default 1) <p> * * -W num <br> * Set a fixed window size for incremental train/testing. As * new training instances are added, oldest instances are removed * to maintain the number of training instances at this size. * (default no window) <p> * * -D <br> * Neighbors will be weighted by the inverse of their distance * when voting. (default equal weighting) <p> * * -F <br> * Neighbors will be weighted by their similarity when voting. * (default equal weighting) <p> * * -X <br> * Select the number of neighbors to use by hold-one-out cross * validation, with an upper limit given by the -K option. <p> * * -S <br> * When k is selected by cross-validation for numeric class attributes, * minimize mean-squared error. (default mean absolute error) <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String knnString = Utils.getOption('K', options); if (knnString.length() != 0) { setKNN(Integer.parseInt(knnString)); } else { setKNN(1); } String windowString = Utils.getOption('W', options); if (windowString.length() != 0) { setWindowSize(Integer.parseInt(windowString)); } else { setWindowSize(0); } if (Utils.getFlag('D', options)) { setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING)); } else if (Utils.getFlag('F', options)) { setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING)); } else { setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING)); } setCrossValidate(Utils.getFlag('X', options)); setMeanSquared(Utils.getFlag('S', options)); setNoNormalization(Utils.getFlag('N', options)); Utils.checkForRemainingOptions(options); } /** * Gets the current settings of IBk. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [11]; int current = 0; options[current++] = "-K"; options[current++] = "" + getKNN(); options[current++] = "-W"; options[current++] = "" + m_WindowSize; if (getCrossValidate()) { options[current++] = "-X"; } if (getMeanSquared()) { options[current++] = "-S"; } if (m_DistanceWeighting == WEIGHT_INVERSE) { options[current++] = "-D"; } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) { options[current++] = "-F"; } if (m_DontNormalize) { options[current++] = "-N"; } while (current < options.length) { options[current++] = ""; } return options; } /** * Returns a description of this classifier. * * @return a description of this classifier as a string. */ public String toString() {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -