📄 ibk.java
字号:
// Invalidate any currently cross-validation selected k m_kNNValid = false; } /** * Adds the supplied instance to the training set * * @param instance the instance to add * @throws Exception if instance could not be incorporated * successfully */ public void updateClassifier(Instance instance) throws Exception { if (m_Train.equalHeaders(instance.dataset()) == false) { throw new Exception("Incompatible instance types"); } if (instance.classIsMissing()) { return; } m_Train.add(instance); m_NNSearch.update(instance); m_kNNValid = false; if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { boolean deletedInstance=false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); deletedInstance=true; } //rebuild datastructure KDTree currently can't delete if(deletedInstance==true) m_NNSearch.setInstances(m_Train); } } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @throws Exception if an error occurred during the prediction */ public double [] distributionForInstance(Instance instance) throws Exception { if (m_Train.numInstances() == 0) { throw new Exception("No training instances!"); } if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) { m_kNNValid = false; boolean deletedInstance=false; while (m_Train.numInstances() > m_WindowSize) { m_Train.delete(0); } //rebuild datastructure KDTree currently can't delete if(deletedInstance==true) m_NNSearch.setInstances(m_Train); } // Select k by cross validation if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) { crossValidate(); } m_NNSearch.addInstanceInfo(instance); Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); double [] distances = m_NNSearch.getDistances(); double [] distribution = makeDistribution( neighbours, distances ); //debug// int maxIndex = Utils.maxIndex(distribution);// if(instance.toString().startsWith("10,4,3,1,3,3,6,5,2,?")) {// System.out.println("Target: "+instance+" "+m_Train.attribute(m_Train.classIndex()).value(maxIndex)+" found "+neighbours.numInstances()+" neighbours\n");// for(int k=0; k<neighbours.numInstances(); k++) {// System.out.println(instNum+", "+neighbours.instance(k)+", distance "+ //"Node: instance "+neighbours.instance(k)+", distance "+// distances[k]);// } instNum++;// System.out.println("");// } return distribution; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(8); newVector.addElement(new Option( "\tWeight neighbours by the inverse of their distance\n"+ "\t(use when k > 1)", "I", 0, "-I")); newVector.addElement(new Option( "\tWeight neighbours by 1 - their distance\n"+ "\t(use when k > 1)", "F", 0, "-F")); newVector.addElement(new Option( "\tNumber of nearest neighbours (k) used in classification.\n"+ "\t(Default = 1)", "K", 1,"-K <number of neighbors>")); newVector.addElement(new Option( "\tMinimise mean squared error rather than mean absolute\n"+ "\terror when using -X option with numeric prediction.", "E", 0,"-E")); newVector.addElement(new Option( "\tMaximum number of training instances maintained.\n"+ "\tTraining instances are dropped FIFO. (Default = no window)", "W", 1,"-W <window size>")); newVector.addElement(new Option( "\tSelect the number of nearest neighbours between 1\n"+ "\tand the k value specified using hold-one-out evaluation\n"+ "\ton the training data (use when k > 1)", "X", 0,"-X")); newVector.addElement(new Option( "\tThe nearest neighbour search algorithm to use "+ "(default: LinearNN).\n", "A", 0, "-A")); return newVector.elements(); } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -I * Weight neighbours by the inverse of their distance * (use when k > 1)</pre> * * <pre> -F * Weight neighbours by 1 - their distance * (use when k > 1)</pre> * * <pre> -K <number of neighbors> * Number of nearest neighbours (k) used in classification. * (Default = 1)</pre> * * <pre> -E * Minimise mean squared error rather than mean absolute * error when using -X option with numeric prediction.</pre> * * <pre> -W <window size> * Maximum number of training instances maintained. * Training instances are dropped FIFO. (Default = no window)</pre> * * <pre> -X * Select the number of nearest neighbours between 1 * and the k value specified using hold-one-out evaluation * on the training data (use when k > 1)</pre> * * <pre> -A * The nearest neighbour search algorithm to use (default: LinearNN). * </pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String knnString = Utils.getOption('K', options); if (knnString.length() != 0) { setKNN(Integer.parseInt(knnString)); } else { setKNN(1); } String windowString = Utils.getOption('W', options); if (windowString.length() != 0) { setWindowSize(Integer.parseInt(windowString)); } else { setWindowSize(0); } if (Utils.getFlag('I', options)) { setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING)); } else if (Utils.getFlag('F', options)) { setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING)); } else { setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING)); } setCrossValidate(Utils.getFlag('X', options)); setMeanSquared(Utils.getFlag('E', options)); String nnSearchClass = Utils.getOption('A', options); if(nnSearchClass.length() != 0) { String nnSearchClassSpec[] = Utils.splitOptions(nnSearchClass); if(nnSearchClassSpec.length == 0) { throw new Exception("Invalid NearestNeighbourSearch algorithm " + "specification string."); } String className = nnSearchClassSpec[0]; nnSearchClassSpec[0] = ""; setNearestNeighbourSearchAlgorithm( (NearestNeighbourSearch) Utils.forName( NearestNeighbourSearch.class, className, nnSearchClassSpec) ); } else this.setNearestNeighbourSearchAlgorithm(new LinearNN()); Utils.checkForRemainingOptions(options); } /** * Gets the current settings of IBk. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [11]; int current = 0; options[current++] = "-K"; options[current++] = "" + getKNN(); options[current++] = "-W"; options[current++] = "" + m_WindowSize; if (getCrossValidate()) { options[current++] = "-X"; } if (getMeanSquared()) { options[current++] = "-E"; } if (m_DistanceWeighting == WEIGHT_INVERSE) { options[current++] = "-I"; } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) { options[current++] = "-F"; } options[current++] = "-A"; options[current++] = m_NNSearch.getClass().getName()+" "+Utils.joinOptions(m_NNSearch.getOptions()); while (current < options.length) { options[current++] = ""; } return options; } /** * Returns a description of this classifier. * * @return a description of this classifier as a string. */ public String toString() { if (m_Train == null) { return "IBk: No model built yet."; } if (!m_kNNValid && m_CrossValidate) { crossValidate(); } String result = "IB1 instance-based classifier\n" + "using " + m_kNN; switch (m_DistanceWeighting) { case WEIGHT_INVERSE: result += " inverse-distance-weighted"; break; case WEIGHT_SIMILARITY: result += " similarity-weighted"; break; } result += " nearest neighbour(s) for classification\n"; if (m_WindowSize != 0) { result += "using a maximum of " + m_WindowSize + " (windowed) training instances\n"; } return result; } /** * Initialise scheme variables. */ protected void init() { setKNN(1); m_WindowSize = 0; m_DistanceWeighting = WEIGHT_NONE; m_CrossValidate = false; m_MeanSquared = false; } /** * Turn the list of nearest neighbors into a probability distribution * * @param neighbours the list of nearest neighboring instances * @param distances the distances of the neighbors * @return the probability distribution * @throws Exception if computation goes wrong or has no class attribute */ protected double [] makeDistribution(Instances neighbours, double[] distances) throws Exception { double total = 0, weight; double [] distribution = new double [m_NumClasses]; // Set up a correction to the estimator if (m_ClassType == Attribute.NOMINAL) { for(int i = 0; i < m_NumClasses; i++) { distribution[i] = 1.0 / Math.max(1,m_Train.numInstances()); } total = (double)m_NumClasses / Math.max(1,m_Train.numInstances()); } for(int i=0; i < neighbours.numInstances(); i++) { // Collect class counts Instance current = neighbours.instance(i); distances[i] = distances[i]*distances[i]; distances[i] = Math.sqrt(distances[i]/m_NumAttributesUsed); switch (m_DistanceWeighting) { case WEIGHT_INVERSE: weight = 1.0 / (distances[i] + 0.001); // to avoid div by zero break; case WEIGHT_SIMILARITY: weight = 1.0 - distances[i]; break; default: // WEIGHT_NONE: weight = 1.0; break; } weight *= current.weight(); try { switch (m_ClassType) { case Attribute.NOMINAL: distribution[(int)current.classValue()] += weight; break; case Attribute.NUMERIC: distribution[0] += current.classValue() * weight; break; } } catch (Exception ex) { throw new Error("Data has no class attribute!"); } total += weight; } // Normalise distribution if (total > 0) { Utils.normalize(distribution, total); } return distribution; } /** * Select the best value for k by hold-one-out cross-validation. * If the class attribute is nominal, classification error is * minimised. If the class attribute is numeric, mean absolute * error is minimised */ protected void crossValidate() { try { double [] performanceStats = new double [m_kNNUpper]; double [] performanceStatsSq = new double [m_kNNUpper]; for(int i = 0; i < m_kNNUpper; i++) { performanceStats[i] = 0; performanceStatsSq[i] = 0; } m_kNN = m_kNNUpper; Instance instance; Instances neighbours; double[] origDistances, convertedDistances; for(int i = 0; i < m_Train.numInstances(); i++) { if (m_Debug && (i % 50 == 0)) { System.err.print("Cross validating " + i + "/" + m_Train.numInstances() + "\r"); } instance = m_Train.instance(i); neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN); origDistances = m_NNSearch.getDistances(); for(int j = m_kNNUpper - 1; j >= 0; j--) { // Update the performance stats convertedDistances = new double[origDistances.length]; System.arraycopy(origDistances, 0, convertedDistances, 0, origDistances.length); double [] distribution = makeDistribution(neighbours, convertedDistances); double thisPrediction = Utils.maxIndex(distribution); if (m_Train.classAttribute().isNumeric()) { thisPrediction = distribution[0]; double err = thisPrediction - instance.classValue(); performanceStatsSq[j] += err * err; // Squared error performanceStats[j] += Math.abs(err); // Absolute error } else { if (thisPrediction != instance.classValue()) { performanceStats[j] ++; // Classification error } } if (j >= 1) { neighbours = pruneToK(neighbours, convertedDistances, j); } } } // Display the results of the cross-validation for(int i = 0; i < m_kNNUpper; i++) { if (m_Debug) { System.err.print("Hold-one-out performance of " + (i + 1) + " neighbors " ); } if (m_Train.classAttribute().isNumeric()) { if (m_Debug) { if (m_MeanSquared) { System.err.println("(RMSE) = " + Math.sqrt(performanceStatsSq[i] / m_Train.numInstances())); } else { System.err.println("(MAE) = " + performanceStats[i] / m_Train.numInstances()); } } } else { if (m_Debug) { System.err.println("(%ERR) = " + 100.0 * performanceStats[i] / m_Train.numInstances()); } } } // Check through the performance stats and select the best // k value (or the lowest k if more than one best) double [] searchStats = performanceStats; if (m_Train.classAttribute().isNumeric() && m_MeanSquared) { searchStats = performanceStatsSq; } double bestPerformance = Double.NaN; int bestK = 1; for(int i = 0; i < m_kNNUpper; i++) { if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) { bestPerformance = searchStats[i]; bestK = i + 1; } } m_kNN = bestK; if (m_Debug) { System.err.println("Selected k = " + bestK); } m_kNNValid = true; } catch (Exception ex) { throw new Error("Couldn't optimize by cross-validation: " +ex.getMessage()); } } /** * Prunes the list to contain the k nearest neighbors. If there are * multiple neighbors at the k'th distance, all will be kept. * * @param neighbours the neighbour instances. * @param distances the distances of the neighbours from target instance. * @param k the number of neighbors to keep. * @return the pruned neighbours. */ public Instances pruneToK(Instances neighbours, double[] distances, int k) { if(neighbours==null || distances==null || neighbours.numInstances()==0) { return null; } if (k < 1) { k = 1; } int currentK = 0; double currentDist; for(int i=0; i < neighbours.numInstances(); i++) { currentK++; currentDist = distances[i]; if(currentK>k && currentDist!=distances[i-1]) { currentK--; neighbours = new Instances(neighbours, 0, currentK); break; } } return neighbours; } /** * Main method for testing this class. * * @param argv should contain command line options (see setOptions) */ public static void main(String [] argv) { runClassifier(new IBk(), argv); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -