📄 ibk.java
字号:
} if (getKDTree() != null) { options[current++] = "-E"; options[current++] = "" + getKDTreeSpec(); } while (current < options.length) { options[current++] = ""; } return options; } /** * Returns a description of this classifier. * * @return a description of this classifier as a string. */ public String toString() { if (m_Train == null) { return "IBk: No model built yet."; } if (!m_kNNValid && m_CrossValidate) { crossValidate(); } String result = "IB1 instance-based classifier\n" + "using " + m_kNN; switch (m_DistanceWeighting) { case WEIGHT_INVERSE: result += " inverse-distance-weighted"; break; case WEIGHT_SIMILARITY: result += " similarity-weighted"; break; } result += " nearest neighbour(s) for classification\n"; if (m_WindowSize != 0) { result += "using a maximum of " + m_WindowSize + " (windowed) training instances\n"; } return result; } /** * Initialise scheme variables. */ private void init() { setKNN(1); m_WindowSize = 0; m_DistanceWeighting = WEIGHT_NONE; m_CrossValidate = false; m_MeanSquared = false; m_DontNormalize = false; } /** * Calculates the distance between two instances * * @param test the first instance * @param train the second instance * @return the distance between the two given instances, between 0 and 1 */ private double distance(Instance first, Instance second) { if (!Instances.inRanges(first,m_Ranges)) OOPS("Not in ranges"); if (!Instances.inRanges(second,m_Ranges)) OOPS("Not in ranges"); double distance = 0; int firstI, secondI; for (int p1 = 0, p2 = 0; p1 < first.numValues() || p2 < second.numValues();) { if (p1 >= first.numValues()) { firstI = m_Train.numAttributes(); } else { firstI = first.index(p1); } if (p2 >= second.numValues()) { secondI = m_Train.numAttributes(); } else { secondI = second.index(p2); } if (firstI == m_Train.classIndex()) { p1++; continue; } if (secondI == m_Train.classIndex()) { p2++; continue; } double diff; if (firstI == secondI) { diff = difference(firstI, first.valueSparse(p1), second.valueSparse(p2)); p1++; p2++; } else if (firstI > secondI) { diff = difference(secondI, 0, second.valueSparse(p2)); p2++; } else { diff = difference(firstI, first.valueSparse(p1), 0); p1++; } distance += diff * diff; } distance = Math.sqrt(distance / m_NumAttributesUsed); return distance; } /** * Computes the difference between two given attribute * values. */ private double difference(int index, double val1, double val2) { switch (m_Train.attribute(index).type()) { case Attribute.NOMINAL: // If attribute is nominal if (Instance.isMissingValue(val1) || Instance.isMissingValue(val2) || ((int)val1 != (int)val2)) { return 1; } else { return 0; } case Attribute.NUMERIC: // If attribute is numeric if (Instance.isMissingValue(val1) || Instance.isMissingValue(val2)) { if (Instance.isMissingValue(val1) && Instance.isMissingValue(val2)) { return 1; } else { double diff; if (Instance.isMissingValue(val2)) { diff = norm(val1, index); } else { diff = norm(val2, index); } if (diff < 0.5) { diff = 1.0 - diff; } return diff; } } else { return norm(val1, index) - norm(val2, index); } default: return 0; } } /** * Normalizes a given value of a numeric attribute. * * @param x the value to be normalized * @param i the attribute's index */ private double norm(double x, int i) { if (m_DontNormalize) { return x; } else if (Double.isNaN(m_Ranges[i][R_MIN]) || Utils.eq(m_Ranges[i][R_MAX],m_Ranges[i][R_MIN])) { return 0; } else { return (x - m_Ranges[i][R_MIN]) / (m_Ranges[i][R_MAX] - m_Ranges[i][R_MIN]); } } /** * Updates the minimum and maximum values for all the attributes * based on a new instance. * * @param instance the new instance */ private void updateMinMax(Instance instance) { for (int j = 0;j < m_Train.numAttributes(); j++) { if (!instance.isMissing(j)) { if (Double.isNaN(m_Ranges[j][R_MIN])) { m_Ranges[j][R_MIN] = instance.value(j); m_Ranges[j][R_MAX] = instance.value(j); } else { if (instance.value(j) < m_Ranges[j][R_MIN]) { m_Ranges[j][R_MIN] = instance.value(j); } else { if (instance.value(j) > m_Ranges[j][R_MAX]) { m_Ranges[j][R_MAX] = instance.value(j); } } } } } } /** * Build the list of nearest k neighbours to the given test instance. * * @param instance the instance to search for neighbours of * @return a list of neighbours */ private NeighbourList findNeighbours(Instance instance) throws Exception { double distance; NeighbourList neighbourlist = new NeighbourList(m_kNN); // dont work with kdtree if (m_KDTree == null) { Enumeration enum = m_Train.enumerateInstances(); int i = 0; while (enum.hasMoreElements()) { Instance trainInstance = (Instance) enum.nextElement(); if (instance != trainInstance) { // for hold-one-out cross-validation distance = distance(instance, trainInstance); if (neighbourlist.isEmpty() || (i < m_kNN) || (distance <= neighbourlist.m_Last.m_Distance)) { neighbourlist.insertSorted(distance, trainInstance); } i++; } } } else { // work with KDTree double[] distanceList = new double[m_KDTree.numInstances()]; int[] instanceList = new int[m_KDTree.numInstances()]; int numOfNearest = m_KDTree.findKNearestNeighbour(instance, m_kNN, instanceList, distanceList); for (int i = 0; i < numOfNearest; i++) { neighbourlist.insertSorted(distanceList[i], m_KDTree.getInstances().instance(instanceList[i])); } } //debug //OOPS("Target: "+instance+" found "+neighbourlist.currentLength() + " neighbours\n"); //neighbourlist.printList(); return neighbourlist; } /** * Turn the list of nearest neighbours into a probability distribution * * @param neighbourlist the list of nearest neighbouring instances * @return the probability distribution */ private double [] makeDistribution(NeighbourList neighbourlist) throws Exception { double total = 0, weight; double [] distribution = new double [m_NumClasses]; // Set up a correction to the estimator if (m_ClassType == Attribute.NOMINAL) { for(int i = 0; i < m_NumClasses; i++) { distribution[i] = 1.0 / Math.max(1,m_Train.numInstances()); } total = (double)m_NumClasses / Math.max(1,m_Train.numInstances()); } if (!neighbourlist.isEmpty()) { // Collect class counts NeighbourNode current = neighbourlist.m_First; while (current != null) { switch (m_DistanceWeighting) { case WEIGHT_INVERSE: weight = 1.0 / (current.m_Distance + 0.001); // to avoid div by zero break; case WEIGHT_SIMILARITY: weight = 1.0 - current.m_Distance; break; default: // WEIGHT_NONE: weight = 1.0; break; } weight *= current.m_Instance.weight(); try { switch (m_ClassType) { case Attribute.NOMINAL: distribution[(int)current.m_Instance.classValue()] += weight; break; case Attribute.NUMERIC: distribution[0] += current.m_Instance.classValue() * weight; break; } } catch (Exception ex) { throw new Error("Data has no class attribute!"); } total += weight; current = current.m_Next; } } // Normalise distribution if (total > 0) { Utils.normalize(distribution, total); } // double [] distribution = new double [m_NumClasses]; return distribution; } /** * Select the best value for k by hold-one-out cross-validation. * If the class attribute is nominal, classification error is * minimised. If the class attribute is numeric, mean absolute * error is minimised */ private void crossValidate() { try { double [] performanceStats = new double [m_kNNUpper]; double [] performanceStatsSq = new double [m_kNNUpper]; for(int i = 0; i < m_kNNUpper; i++) { performanceStats[i] = 0; performanceStatsSq[i] = 0; } m_kNN = m_kNNUpper; Instance instance; NeighbourList neighbourlist; for(int i = 0; i < m_Train.numInstances(); i++) { if (m_Debug && (i % 50 == 0)) { System.err.print("Cross validating " + i + "/" + m_Train.numInstances() + "\r"); } instance = m_Train.instance(i); neighbourlist = findNeighbours(instance); for(int j = m_kNNUpper - 1; j >= 0; j--) { // Update the performance stats double [] distribution = makeDistribution(neighbourlist); double thisPrediction = Utils.maxIndex(distribution); if (m_Train.classAttribute().isNumeric()) { double err = thisPrediction - instance.classValue(); performanceStatsSq[j] += err * err; // Squared error performanceStats[j] += Math.abs(err); // Absolute error } else { if (thisPrediction != instance.classValue()) { performanceStats[j] ++; // Classification error } } if (j >= 1) { neighbourlist.pruneToK(j); } } } // Display the results of the cross-validation for(int i = 0; i < m_kNNUpper; i++) { if (m_Debug) { System.err.print("Hold-one-out performance of " + (i + 1) + " neighbours " ); } if (m_Train.classAttribute().isNumeric()) { if (m_Debug) { if (m_MeanSquared) { System.err.println("(RMSE) = " + Math.sqrt(performanceStatsSq[i] / m_Train.numInstances())); } else { System.err.println("(MAE) = " + performanceStats[i] / m_Train.numInstances()); } } } else { if (m_Debug) { System.err.println("(%ERR) = " + 100.0 * performanceStats[i] / m_Train.numInstances()); } } } // Check through the performance stats and select the best // k value (or the lowest k if more than one best) double [] searchStats = performanceStats; if (m_Train.classAttribute().isNumeric() && m_MeanSquared) { searchStats = performanceStatsSq; } double bestPerformance = Double.NaN; int bestK = 1; for(int i = 0; i < m_kNNUpper; i++) { if (Double.isNaN(bestPerformance) || (bestPerformance > searchStats[i])) { bestPerformance = searchStats[i]; bestK = i + 1; } } m_kNN = bestK; if (m_Debug) { System.err.println("Selected k = " + bestK); } m_kNNValid = true; } catch (Exception ex) { throw new Error("Couldn't optimize by cross-validation: " +ex.getMessage()); } } /** * Used for debug println's. * @param output string that is printed */ private void OOPS(String output) { System.out.println(output); } /** * Main method for testing this class. * * @param argv should contain command line options (see setOptions) */ public static void main(String [] argv) { try { System.out.println(Evaluation.evaluateModel(new IBk(), argv)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -