📄 ibk.java
字号:
/** * Gets the maximum number of instances allowed in the training * pool. The addition of new instances above this value will result * in old instances being removed. A value of 0 signifies no limit * to the number of training instances. * * @return Value of WindowSize */ public int getWindowSize() { return m_WindowSize; } /** * Sets the maximum number of instances allowed in the training * pool. The addition of new instances above this value will result * in old instances being removed. A value of 0 signifies no limit * to the number of training instances. * * @param newWindowSize Value to assign to WindowSize. */ public void setWindowSize(int newWindowSize) { m_WindowSize = newWindowSize; } /** * Gets the distance weighting method used. Will be one of * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY * * @return the distance weighting method used. */ public SelectedTag getDistanceWeighting() { return new SelectedTag(m_DistanceWeighting, TAGS_WEIGHTING); } /** * Sets the distance weighting method used. Values other than * WEIGHT_NONE, WEIGHT_INVERSE, or WEIGHT_SIMILARITY will be ignored. * * @param newDistanceWeighting the distance weighting method to use */ public void setDistanceWeighting(SelectedTag newMethod) { if (newMethod.getTags() == TAGS_WEIGHTING) { m_DistanceWeighting = newMethod.getSelectedTag().getID(); } } /** * Gets whether the mean squared error is used rather than mean * absolute error when doing cross-validation. * * @return true if so. */ public boolean getMeanSquared() { return m_MeanSquared; } /** * Sets whether the mean squared error is used rather than mean * absolute error when doing cross-validation. * * @param newMeanSquared true if so. */ public void setMeanSquared(boolean newMeanSquared) { m_MeanSquared = newMeanSquared; } /** * Gets whether hold-one-out cross-validation will be used * to select the best k value * * @return true if cross-validation will be used. */ public boolean getCrossValidate() { return m_CrossValidate; } /** * Sets whether hold-one-out cross-validation will be used * to select the best k value * * @param newCrossValidate true if cross-validation should be used. */ public void setCrossValidate(boolean newCrossValidate) { m_CrossValidate = newCrossValidate; } /** * Get the number of training instances the classifier is currently using */ public int getNumTraining() { return m_Train.numInstances(); } /** * Get an attributes minimum observed value */ public double getAttributeMin(int index) throws Exception { if (m_Ranges == null) { throw new Exception("Minimum value for attribute not available!"); } return m_Ranges[index][R_MIN]; } /** * Get an attributes maximum observed value */ public double getAttributeMax(int index) throws Exception { if (m_Ranges == null) { throw new Exception("Maximum value for attribute not available!"); } return m_Ranges[index][R_MAX]; } /** * Gets whether normalization is turned off. * @return Value of DontNormalize. */ public boolean getNoNormalization() { return m_DontNormalize; } /** * Set whether normalization is turned off. * @param v Value to assign to DontNormalize. */ public void setNoNormalization(boolean v) { m_DontNormalize = v; } /** * Generates the classifier. * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { if (instances.classIndex() < 0) { throw new Exception ("No class attribute assigned to instances"); } if (instances.checkForStringAttributes()) { throw new UnsupportedAttributeTypeException("Cannot handle string attributes!"); } try { m_NumClasses = instances.numClasses(); m_ClassType = instances.classAttribute().type(); } catch (Exception ex) { throw new Error("This should never be reached"); } // Throw away training instances with missing class m_Train = new Instances(instances, 0, instances.numInstances()); m_Train.deleteWithMissingClass(); // Throw away initial instances until within the specified window size if ((m_WindowSize > 0) && (instances.numInstances() > m_WindowSize)) { m_Train = new Instances(m_Train, m_Train.numInstances()-m_WindowSize, m_WindowSize); } // make ranges if needed for normalization and/or for the KDTree if ((!m_DontNormalize) || (m_KDTree != null)) { // Initializes and calculates the ranges for the training instances m_Ranges = m_Train.initializeRanges(); // Instances.printRanges(m_Ranges); } // if already some instances here, then build KDTree if ((m_KDTree != null) && (m_Train.numInstances() > 0)) { m_KDTree.buildKDTree(m_Train); OOPS("KDTree build in buildclassifier"); OOPS(" " + m_KDTree.toString()); } // Compute the number of attributes that contribute // to each prediction m_NumAttributesUsed = 0.0; for (int i = 0; i < m_Train.numAttributes(); i++) { if ((i != m_Train.classIndex()) && (m_Train.attribute(i).isNominal() || m_Train.attribute(i).isNumeric())) { m_NumAttributesUsed += 1.0; } } // Invalidate any currently cross-validation selected k m_kNNValid = false; } /** * Adds the supplied instance to the training set * * @param instance the instance to add * @exception Exception if instance could not be incorporated * successfully */ public void updateClassifier(Instance instance) throws Exception { if (m_Train.equalHeaders(instance.dataset()) == false) { throw new Exception("Incompatible instance types"); } if (instance.classIsMissing()) { return; } // update ranges // but only if normalize flag is on or KDTree is chosen if ((!m_DontNormalize) || (m_KDTree != null)) { m_Ranges = Instances.updateRanges(instance, m_Ranges); } // add instance to training set m_Train.add(instance); // update KDTree if (m_KDTree != null) { if (m_KDTree.isValid() && (m_KDTree.numInstances() > 0)) m_KDTree.updateKDTree(instance); } m_kNNValid = false; if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); if (m_KDTree != null) m_KDTree.setValid(false); } } } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if an error occurred during the prediction */ public double [] distributionForInstance(Instance instance) throws Exception { if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } // cut instances to windowsize if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { m_kNNValid = false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); m_KDTree.setValid(false); } } if ((m_KDTree != null) && (!m_KDTree.isValid())) { m_KDTree.buildKDTree(m_Train); //OOPS("KDTree build in distributionForInstance"); //OOPS(" " + m_KDTree.toString()); } // Select k by cross validation if (!m_kNNValid && (m_CrossValidate) && (m_kNN > 1)) { crossValidate(); } // update ranges - for norm()-method if (!m_DontNormalize) { m_Ranges = Instances.updateRanges(instance, m_Ranges); } // update ranges for norm()-methode in Distance class of KDTree if (m_KDTree != null) { m_KDTree.addLooslyInstance(instance); } // find neighbours and make distribution NeighbourList neighbourlist = findNeighbours(instance); return makeDistribution(neighbourlist); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(9); newVector.addElement(new Option( "\tWeight neighbours by the inverse of their distance\n" +"\t(use when k > 1)", "D", 0, "-D")); newVector.addElement(new Option( "\tWeight neighbours by 1 - their distance\n" +"\t(use when k > 1)", "F", 0, "-F")); newVector.addElement(new Option( "\tNumber of nearest neighbours (k) used in classification.\n" +"\t(Default = 1)", "K", 1,"-K <number of neighbours>")); newVector.addElement(new Option( "\tMinimise mean squared error rather than mean absolute\n" +"\terror when using -X option with numeric prediction.", "S", 0,"-S")); newVector.addElement(new Option( "\tMaximum number of training instances maintained.\n" +"\tTraining instances are dropped FIFO. (Default = no window)", "W", 1,"-W <window size>")); newVector.addElement(new Option( "\tSelect the number of nearest neighbours between 1\n" +"\tand the k value specified using hold-one-out evaluation\n" +"\ton the training data (use when k > 1)", "X", 0,"-X")); newVector.addElement(new Option( "\tDon't normalize the data.\n", "N", 0, "-N")); newVector.addElement(new Option( "\tFull class name of KDTree class to use, followed\n" + "\tby scheme options.\n" + "\teg: \"weka.core.KDTree -P\"\n" + "(default = no KDTree class used).", "E", 1, "-E <KDTree class specification>")); return newVector.elements(); } /** * Parses a given list of options. Valid options are:<p> * * -K num <br> * Set the number of nearest neighbours to use in prediction * (default 1) <p> * * -W num <br> * Set a fixed window size for incremental train/testing. As * new training instances are added, oldest instances are removed * to maintain the number of training instances at this size. * (default no window) <p> * * -D <br> * Neighbours will be weighted by the inverse of their distance * when voting. (default equal weighting) <p> * * -F <br> * Neighbours will be weighted by their similarity when voting. * (default equal weighting) <p> * * -X <br> * Select the number of neighbours to use by hold-one-out cross * validation, with an upper limit given by the -K option. <p> * * -S <br> * When k is selected by cross-validation for numeric class attributes, * minimize mean-squared error. (default mean absolute error) <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String knnString = Utils.getOption('K', options); if (knnString.length() != 0) { setKNN(Integer.parseInt(knnString)); } else { setKNN(1); } String windowString = Utils.getOption('W', options); if (windowString.length() != 0) { setWindowSize(Integer.parseInt(windowString)); } else { setWindowSize(0); } if (Utils.getFlag('D', options)) { setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING)); } else if (Utils.getFlag('F', options)) { setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING)); } else { setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING)); } setCrossValidate(Utils.getFlag('X', options)); setMeanSquared(Utils.getFlag('S', options)); setNoNormalization(Utils.getFlag('N', options)); String funcString = Utils.getOption('E', options); if (funcString.length() != 0) { String [] funcSpec = Utils.splitOptions(funcString); if (funcSpec.length == 0) { throw new Exception("Invalid function specification string"); } String funcName = funcSpec[0]; funcSpec[0] = ""; Class cl = KDTree.class; setKDTree((KDTree) Utils.forName(KDTree.class, funcName, funcSpec)); } Utils.checkForRemainingOptions(options); } /** * Gets the current settings of IBk. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [11]; int current = 0; options[current++] = "-K"; options[current++] = "" + getKNN(); options[current++] = "-W"; options[current++] = "" + m_WindowSize; if (getCrossValidate()) { options[current++] = "-X"; } if (getMeanSquared()) { options[current++] = "-S"; } if (m_DistanceWeighting == WEIGHT_INVERSE) { options[current++] = "-D"; } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) { options[current++] = "-F"; } if (m_DontNormalize) { options[current++] = "-N";
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -