ibk.java

来自「Weka」· Java 代码 · 共 1,028 行 · 第 1/3 页

JAVA
1,028
字号
    String nnSearchClass = Utils.getOption('A', options);    if(nnSearchClass.length() != 0) {      String nnSearchClassSpec[] = Utils.splitOptions(nnSearchClass);      if(nnSearchClassSpec.length == 0) {         throw new Exception("Invalid NearestNeighbourSearch algorithm " +                            "specification string.");       }      String className = nnSearchClassSpec[0];      nnSearchClassSpec[0] = "";      setNearestNeighbourSearchAlgorithm( (NearestNeighbourSearch)                  Utils.forName( NearestNeighbourSearch.class,                                  className,                                  nnSearchClassSpec)                                        );    }    else       this.setNearestNeighbourSearchAlgorithm(new LinearNNSearch());        Utils.checkForRemainingOptions(options);  }  /**   * Gets the current settings of IBk.   *   * @return an array of strings suitable for passing to setOptions()   */  public String [] getOptions() {    String [] options = new String [11];    int current = 0;    options[current++] = "-K"; options[current++] = "" + getKNN();    options[current++] = "-W"; options[current++] = "" + m_WindowSize;    if (getCrossValidate()) {      options[current++] = "-X";    }    if (getMeanSquared()) {      options[current++] = "-E";    }    if (m_DistanceWeighting == WEIGHT_INVERSE) {      options[current++] = "-I";    } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) {      options[current++] = "-F";    }    options[current++] = "-A";    options[current++] = m_NNSearch.getClass().getName()+" "+Utils.joinOptions(m_NNSearch.getOptions());         while (current < options.length) {      options[current++] = "";    }        return options;  }  /**   * Returns an enumeration of the additional measure names    * produced by the neighbour search algorithm.   * @return an enumeration of the measure names   */  public Enumeration enumerateMeasures() {    return m_NNSearch.enumerateMeasures();  }    /**   * Returns the value of the named measure from the    * neighbour search algorithm.   * @param additionalMeasureName the name of the measure to query for its value   * @return the value of the named measure   * @throws IllegalArgumentException if the named measure is not supported   */  public double getMeasure(String additionalMeasureName) {    return m_NNSearch.getMeasure(additionalMeasureName);  }      /**   * Returns a description of this classifier.   *   * @return a description of this classifier as a string.   */  public String toString() {    if (m_Train == null) {      return "IBk: No model built yet.";    }    if (!m_kNNValid && m_CrossValidate) {      crossValidate();    }    String result = "IB1 instance-based classifier\n" +      "using " + m_kNN;    switch (m_DistanceWeighting) {    case WEIGHT_INVERSE:      result += " inverse-distance-weighted";      break;    case WEIGHT_SIMILARITY:      result += " similarity-weighted";      break;    }    result += " nearest neighbour(s) for classification\n";    if (m_WindowSize != 0) {      result += "using a maximum of " 	+ m_WindowSize + " (windowed) training instances\n";    }    return result;  }  /**   * Initialise scheme variables.   */  protected void init() {    setKNN(1);    m_WindowSize = 0;    m_DistanceWeighting = WEIGHT_NONE;    m_CrossValidate = false;    m_MeanSquared = false;  }    /**   * Turn the list of nearest neighbors into a probability distribution.   *   * @param neighbours the list of nearest neighboring instances   * @param distances the distances of the neighbors   * @return the probability distribution   * @throws Exception if computation goes wrong or has no class attribute   */  protected double [] makeDistribution(Instances neighbours, double[] distances)    throws Exception {    double total = 0, weight;    double [] distribution = new double [m_NumClasses];        // Set up a correction to the estimator    if (m_ClassType == Attribute.NOMINAL) {      for(int i = 0; i < m_NumClasses; i++) {	distribution[i] = 1.0 / Math.max(1,m_Train.numInstances());      }      total = (double)m_NumClasses / Math.max(1,m_Train.numInstances());    }    for(int i=0; i < neighbours.numInstances(); i++) {      // Collect class counts      Instance current = neighbours.instance(i);      distances[i] = distances[i]*distances[i];      distances[i] = Math.sqrt(distances[i]/m_NumAttributesUsed);      switch (m_DistanceWeighting) {        case WEIGHT_INVERSE:          weight = 1.0 / (distances[i] + 0.001); // to avoid div by zero          break;        case WEIGHT_SIMILARITY:          weight = 1.0 - distances[i];          break;        default:                                 // WEIGHT_NONE:          weight = 1.0;          break;      }      weight *= current.weight();      try {        switch (m_ClassType) {          case Attribute.NOMINAL:            distribution[(int)current.classValue()] += weight;            break;          case Attribute.NUMERIC:            distribution[0] += current.classValue() * weight;            break;        }      } catch (Exception ex) {        throw new Error("Data has no class attribute!");      }      total += weight;          }    // Normalise distribution    if (total > 0) {      Utils.normalize(distribution, total);    }    return distribution;  }  /**   * Select the best value for k by hold-one-out cross-validation.   * If the class attribute is nominal, classification error is   * minimised. If the class attribute is numeric, mean absolute   * error is minimised   */  protected void crossValidate() {    try {      if (m_NNSearch instanceof weka.core.neighboursearch.CoverTree)	throw new Exception("CoverTree doesn't support hold-one-out "+			    "cross-validation. Use some other NN " +			    "method.");      double [] performanceStats = new double [m_kNNUpper];      double [] performanceStatsSq = new double [m_kNNUpper];      for(int i = 0; i < m_kNNUpper; i++) {	performanceStats[i] = 0;	performanceStatsSq[i] = 0;      }      m_kNN = m_kNNUpper;      Instance instance;      Instances neighbours;      double[] origDistances, convertedDistances;      for(int i = 0; i < m_Train.numInstances(); i++) {	if (m_Debug && (i % 50 == 0)) {	  System.err.print("Cross validating "			   + i + "/" + m_Train.numInstances() + "\r");	}	instance = m_Train.instance(i);	neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);        origDistances = m_NNSearch.getDistances();        	for(int j = m_kNNUpper - 1; j >= 0; j--) {	  // Update the performance stats          convertedDistances = new double[origDistances.length];          System.arraycopy(origDistances, 0,                            convertedDistances, 0, origDistances.length);	  double [] distribution = makeDistribution(neighbours,                                                     convertedDistances);          double thisPrediction = Utils.maxIndex(distribution);	  if (m_Train.classAttribute().isNumeric()) {	    thisPrediction = distribution[0];	    double err = thisPrediction - instance.classValue();	    performanceStatsSq[j] += err * err;   // Squared error	    performanceStats[j] += Math.abs(err); // Absolute error	  } else {	    if (thisPrediction != instance.classValue()) {	      performanceStats[j] ++;             // Classification error	    }	  }	  if (j >= 1) {	    neighbours = pruneToK(neighbours, convertedDistances, j);	  }	}      }      // Display the results of the cross-validation      for(int i = 0; i < m_kNNUpper; i++) {	if (m_Debug) {	  System.err.print("Hold-one-out performance of " + (i + 1)			   + " neighbors " );	}	if (m_Train.classAttribute().isNumeric()) {	  if (m_Debug) {	    if (m_MeanSquared) {	      System.err.println("(RMSE) = "				 + Math.sqrt(performanceStatsSq[i]					     / m_Train.numInstances()));	    } else {	      System.err.println("(MAE) = "				 + performanceStats[i]				 / m_Train.numInstances());	    }	  }	} else {	  if (m_Debug) {	    System.err.println("(%ERR) = "			       + 100.0 * performanceStats[i]			       / m_Train.numInstances());	  }	}      }      // Check through the performance stats and select the best      // k value (or the lowest k if more than one best)      double [] searchStats = performanceStats;      if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {	searchStats = performanceStatsSq;      }      double bestPerformance = Double.NaN;      int bestK = 1;      for(int i = 0; i < m_kNNUpper; i++) {	if (Double.isNaN(bestPerformance)	    || (bestPerformance > searchStats[i])) {	  bestPerformance = searchStats[i];	  bestK = i + 1;	}      }      m_kNN = bestK;      if (m_Debug) {	System.err.println("Selected k = " + bestK);      }            m_kNNValid = true;    } catch (Exception ex) {      throw new Error("Couldn't optimize by cross-validation: "		      +ex.getMessage());    }  }    /**   * Prunes the list to contain the k nearest neighbors. If there are   * multiple neighbors at the k'th distance, all will be kept.   *   * @param neighbours the neighbour instances.   * @param distances the distances of the neighbours from target instance.   * @param k the number of neighbors to keep.   * @return the pruned neighbours.   */  public Instances pruneToK(Instances neighbours, double[] distances, int k) {        if(neighbours==null || distances==null || neighbours.numInstances()==0) {      return null;    }    if (k < 1) {      k = 1;    }        int currentK = 0;    double currentDist;    for(int i=0; i < neighbours.numInstances(); i++) {      currentK++;      currentDist = distances[i];      if(currentK>k && currentDist!=distances[i-1]) {        currentK--;        neighbours = new Instances(neighbours, 0, currentK);        break;      }    }    return neighbours;  }    /**   * Main method for testing this class.   *   * @param argv should contain command line options (see setOptions)   */  public static void main(String [] argv) {    runClassifier(new IBk(), argv);  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?