📄 ibkmetric.java
字号:
* -X <br> * Select the number of neighbors to use by hold-one-out cross * validation, with an upper limit given by the -K option. <p> * * -S <br> * When k is selected by cross-validation for numeric class attributes, * minimize mean-squared error. (default mean absolute error) <p> * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { String knnString = Utils.getOption('K', options); if (knnString.length() != 0) { setKNN(Integer.parseInt(knnString)); } else { setKNN(1); } String windowString = Utils.getOption('W', options); if (windowString.length() != 0) { setWindowSize(Integer.parseInt(windowString)); } else { setWindowSize(0); } if (Utils.getFlag('D', options)) { setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING)); } else if (Utils.getFlag('F', options)) { setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING)); } else { setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING)); } setCrossValidate(Utils.getFlag('X', options)); setMeanSquared(Utils.getFlag('S', options)); String metricString = Utils.getOption('M', options); if (metricString.length() != 0) { String[] metricSpec = Utils.splitOptions(metricString); String metricName = metricSpec[0]; metricSpec[0] = ""; System.out.println("Metric name: " + metricName + "\nMetric parameters: " + concatStringArray(metricSpec)); setMetric(Metric.forName(metricName, metricSpec)); } Utils.checkForRemainingOptions(options); } /** * Gets the classifier specification string, which contains the class name of * the classifier and any options to the classifier * * @return the classifier string. */ protected String getMetricSpec() { if (m_metric instanceof OptionHandler) { return m_metric.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)m_metric).getOptions()); } return m_metric.getClass().getName(); } /** * Gets the current settings of IBkMetric. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [70]; int current = 0; options[current++] = "-K"; options[current++] = "" + getKNN(); options[current++] = "-W"; options[current++] = "" + m_WindowSize; if (getCrossValidate()) { options[current++] = "-X"; } if (getMeanSquared()) { options[current++] = "-S"; } if (m_DistanceWeighting == WEIGHT_INVERSE) { options[current++] = "-D"; } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) { options[current++] = "-F"; } options[current++] = "-M"; options[current++] = Utils.removeSubstring(m_metric.getClass().getName(), "weka.core.metrics."); if (m_metric instanceof OptionHandler) { String[] metricOptions = ((OptionHandler)m_metric).getOptions(); for (int i = 0; i < metricOptions.length; i++) { options[current++] = metricOptions[i]; } } while (current < options.length) { options[current++] = ""; } return options; } /** * Returns a description of this classifier. * * @return a description of this classifier as a string. */ public String toString() { if (m_Train == null) { return "IBkMetric: No model built yet."; } if (!m_kNNValid && m_CrossValidate) { crossValidate(); } String result = "IB1 instance-based classifier\n" + "using " + m_kNN; switch (m_DistanceWeighting) { case WEIGHT_INVERSE: result += " inverse-distance-weighted"; break; case WEIGHT_SIMILARITY: result += " similarity-weighted"; break; } result += " nearest neighbour(s) for classification\n"; if (m_WindowSize != 0) { result += "using a maximum of " + m_WindowSize + " (windowed) training instances\n"; } return result; } /** * Initialise scheme variables. */ protected void init() { setKNN(1); m_WindowSize = 0; m_DistanceWeighting = WEIGHT_NONE; m_CrossValidate = false; m_MeanSquared = false; } /** * Build the list of nearest k neighbors to the given test instance. * * @param instance the instance to search for neighbours of * @return a list of neighbors */ protected NeighborList findNeighbors(Instance instance) throws Exception { double distance; NeighborList neighborlist = new NeighborList(m_kNN); Enumeration enum = m_Train.enumerateInstances(); int i = 0; while (enum.hasMoreElements()) { Instance trainInstance = (Instance) enum.nextElement(); if (instance != trainInstance) { // for hold-one-out cross-validation distance = m_metric.distance(instance, trainInstance); if (neighborlist.isEmpty() || (i < m_kNN) || (distance <= neighborlist.m_Last.m_Distance)) { neighborlist.insertSorted(distance, trainInstance); } i++; } } return neighborlist; } /** * Turn the list of nearest neighbors into a probability distribution * * @param neighborlist the list of nearest neighboring instances * @return the probability distribution */ protected double [] makeDistribution(NeighborList neighborlist) throws Exception { double total = 0, weight; double [] distribution = new double [m_NumClasses]; // Set up a Laplacian correction to the estimator if (m_ClassType == Attribute.NOMINAL) { for(int i = 0; i < m_NumClasses; i++) { distribution[i] = 1.0 / Math.max(1,m_Train.numInstances()); } total = (double)m_NumClasses / Math.max(1,m_Train.numInstances()); } if (!neighborlist.isEmpty()) { // Collect class counts NeighborNode current = neighborlist.m_First; while (current != null) { switch (m_DistanceWeighting) { case WEIGHT_INVERSE: weight = 1.0 / (current.m_Distance + m_EPSILON); // to avoid div by zero break; case WEIGHT_SIMILARITY: weight = 1.0 - current.m_Distance; break; default: // WEIGHT_NONE: weight = 1.0; break; } weight *= current.m_Instance.weight(); try { switch (m_ClassType) { case Attribute.NOMINAL: distribution[(int)current.m_Instance.classValue()] += weight; break; case Attribute.NUMERIC: distribution[0] += current.m_Instance.classValue() * weight; break; } } catch (Exception ex) { throw new Error("Data has no class attribute!"); } total += weight; current = current.m_Next; } } // Normalise distribution if (total > 0) { Utils.normalize(distribution, total); } return distribution; } /** * Select the best value for k by hold-one-out cross-validation. * If the class attribute is nominal, classification error is * minimised. If the class attribute is numeric, mean absolute * error is minimised */ protected void crossValidate() { try { double [] performanceStats = new double [m_kNNUpper]; double [] performanceStatsSq = new double [m_kNNUpper]; for(int i = 0; i < m_kNNUpper; i++) { performanceStats[i] = 0; performanceStatsSq[i] = 0; } m_kNN = m_kNNUpper; Instance instance; NeighborList neighborlist; for(int i = 0; i < m_Train.numInstances(); i++) { if (m_Debug && (i % 50 == 0)) { System.err.print("Cross validating " + i + "/" + m_Train.numInstances() + "\r"); } instance = m_Train.instance(i); neighborlist = findNeighbors(instance); for(int j = m_kNNUpper - 1; j >= 0; j--) { // Update the performance stats double [] distribution = makeDistribution(neighborlist); double thisPrediction = Utils.maxIndex(distribution); if (m_Train.classAttribute().isNumeric()) { double err = thisPrediction - instance.classValue(); performanceStatsSq[j] += err * err; // Squared error performanceStats[j] += Math.abs(err); // Absolute error } else { if (thisPrediction != instance.classValue()) { performanceStats[j] ++; // Classification error } } if (j >= 1) { neighborlist.pruneToK(j); } } } // Display the results of the cross-validation for(int i = 0; i < m_kNNUpper; i++) { if (m_Debug) { System.err.print("Hold-one-out performance of " + (i + 1) + " neighbors " ); } if (m_Train.classAttribute().isNumeric()) { if (m_Debug) { if (m_MeanSquared) { System.err.println("(RMSE) = " + Math.sqrt(performanceStatsSq[i] / m_Train.numInstances())); } else { System.err.println("(MAE) = " + performanceStats[i] / m_Train.numInstances()); } } } else { if (m_Debug) { System.err.println("(%ERR) = " + 100.0 * performanceStats[i] / m_Train.numInstances()); } } } // Check through the performance stats and select the best // k value (or the lowest k if more than one best) double [] searchStats = performanceStats; if (m_Train.classAttribute().isNumeric() && m_MeanSquared) { searchStats = performanceStatsSq; } double bestPerformance = Double.NaN; int bestK = 1; for(int i = 0; i < m_kNNUpper; i++) { if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) { bestPerformance = searchStats[i]; bestK = i + 1; } } m_kNN = bestK; if (m_Debug) { System.err.println("Selected k = " + bestK); } m_kNNValid = true; } catch (Exception ex) { throw new Error("Couldn't optimize by cross-validation: " +ex.getMessage()); } } /** A little helper to create a single String from an array of Strings * @param strings an array of strings * @returns a single concatenated string, separated by commas */ public static String concatStringArray(String[] strings) { String result = new String(); for (int i = 0; i < strings.length; i++) { result = result + "\"" + strings[i] + "\" "; } return result; } /** * Main method for testing this class. * * @param argv should contain command line options (see setOptions) */ public static void main(String [] argv) { try { System.out.println(Evaluation.evaluateModel(new IBkMetric(), argv)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -