⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ibk.java

📁 代码是一个分类器的实现,其中使用了部分weka的源代码。可以将项目导入eclipse运行
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
    // Invalidate any currently cross-validation selected k    m_kNNValid = false;  }  /**   * Adds the supplied instance to the training set   *   * @param instance the instance to add   * @throws Exception if instance could not be incorporated   * successfully   */  public void updateClassifier(Instance instance) throws Exception {    if (m_Train.equalHeaders(instance.dataset()) == false) {      throw new Exception("Incompatible instance types");    }    if (instance.classIsMissing()) {      return;    }    m_Train.add(instance);    m_NNSearch.update(instance);    m_kNNValid = false;    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {      boolean deletedInstance=false;      while (m_Train.numInstances() > m_WindowSize) {	m_Train.delete(0);        deletedInstance=true;      }      //rebuild datastructure KDTree currently can't delete      if(deletedInstance==true)        m_NNSearch.setInstances(m_Train);    }  }  /**   * Calculates the class membership probabilities for the given test instance.   *   * @param instance the instance to be classified   * @return predicted class probability distribution   * @throws Exception if an error occurred during the prediction   */  public double [] distributionForInstance(Instance instance) throws Exception {    if (m_Train.numInstances() == 0) {      throw new Exception("No training instances!");    }    if ((m_WindowSize > 0) && (m_Train.numInstances() > m_WindowSize)) {      m_kNNValid = false;      boolean deletedInstance=false;      while (m_Train.numInstances() > m_WindowSize) {	m_Train.delete(0);      }      //rebuild datastructure KDTree currently can't delete      if(deletedInstance==true)        m_NNSearch.setInstances(m_Train);    }    // Select k by cross validation    if (!m_kNNValid && (m_CrossValidate) && (m_kNNUpper >= 1)) {      crossValidate();    }    m_NNSearch.addInstanceInfo(instance);    Instances neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);    double [] distances = m_NNSearch.getDistances();    double [] distribution = makeDistribution( neighbours, distances );        //debug//    int maxIndex = Utils.maxIndex(distribution);//    if(instance.toString().startsWith("10,4,3,1,3,3,6,5,2,?")) {//      System.out.println("Target: "+instance+" "+m_Train.attribute(m_Train.classIndex()).value(maxIndex)+" found "+neighbours.numInstances()+" neighbours\n");//      for(int k=0; k<neighbours.numInstances(); k++) {//        System.out.println(instNum+", "+neighbours.instance(k)+", distance "+ //"Node: instance "+neighbours.instance(k)+", distance "+//        distances[k]);//      } instNum++;//      System.out.println("");//    }    return distribution;  }   /**   * Returns an enumeration describing the available options.   *   * @return an enumeration of all the available options.   */  public Enumeration listOptions() {    Vector newVector = new Vector(8);    newVector.addElement(new Option(	      "\tWeight neighbours by the inverse of their distance\n"+	      "\t(use when k > 1)",	      "I", 0, "-I"));    newVector.addElement(new Option(	      "\tWeight neighbours by 1 - their distance\n"+	      "\t(use when k > 1)",	      "F", 0, "-F"));    newVector.addElement(new Option(	      "\tNumber of nearest neighbours (k) used in classification.\n"+	      "\t(Default = 1)",	      "K", 1,"-K <number of neighbors>"));    newVector.addElement(new Option(              "\tMinimise mean squared error rather than mean absolute\n"+	      "\terror when using -X option with numeric prediction.",	      "E", 0,"-E"));    newVector.addElement(new Option(              "\tMaximum number of training instances maintained.\n"+	      "\tTraining instances are dropped FIFO. (Default = no window)",	      "W", 1,"-W <window size>"));    newVector.addElement(new Option(	      "\tSelect the number of nearest neighbours between 1\n"+	      "\tand the k value specified using hold-one-out evaluation\n"+	      "\ton the training data (use when k > 1)",	      "X", 0,"-X"));    newVector.addElement(new Option(	      "\tThe nearest neighbour search algorithm to use "+              "(default: LinearNN).\n",	      "A", 0, "-A"));    return newVector.elements();  }  /**   * Parses a given list of options. <p/>   *   <!-- options-start -->   * Valid options are: <p/>   *    * <pre> -I   *  Weight neighbours by the inverse of their distance   *  (use when k &gt; 1)</pre>   *    * <pre> -F   *  Weight neighbours by 1 - their distance   *  (use when k &gt; 1)</pre>   *    * <pre> -K &lt;number of neighbors&gt;   *  Number of nearest neighbours (k) used in classification.   *  (Default = 1)</pre>   *    * <pre> -E   *  Minimise mean squared error rather than mean absolute   *  error when using -X option with numeric prediction.</pre>   *    * <pre> -W &lt;window size&gt;   *  Maximum number of training instances maintained.   *  Training instances are dropped FIFO. (Default = no window)</pre>   *    * <pre> -X   *  Select the number of nearest neighbours between 1   *  and the k value specified using hold-one-out evaluation   *  on the training data (use when k &gt; 1)</pre>   *    * <pre> -A   *  The nearest neighbour search algorithm to use (default: LinearNN).   * </pre>   *    <!-- options-end -->   *   * @param options the list of options as an array of strings   * @throws Exception if an option is not supported   */  public void setOptions(String[] options) throws Exception {        String knnString = Utils.getOption('K', options);    if (knnString.length() != 0) {      setKNN(Integer.parseInt(knnString));    } else {      setKNN(1);    }    String windowString = Utils.getOption('W', options);    if (windowString.length() != 0) {      setWindowSize(Integer.parseInt(windowString));    } else {      setWindowSize(0);    }    if (Utils.getFlag('I', options)) {      setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING));    } else if (Utils.getFlag('F', options)) {      setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING));    } else {      setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING));    }    setCrossValidate(Utils.getFlag('X', options));    setMeanSquared(Utils.getFlag('E', options));    String nnSearchClass = Utils.getOption('A', options);    if(nnSearchClass.length() != 0) {      String nnSearchClassSpec[] = Utils.splitOptions(nnSearchClass);      if(nnSearchClassSpec.length == 0) {         throw new Exception("Invalid NearestNeighbourSearch algorithm " +                            "specification string.");       }      String className = nnSearchClassSpec[0];      nnSearchClassSpec[0] = "";      setNearestNeighbourSearchAlgorithm( (NearestNeighbourSearch)                  Utils.forName( NearestNeighbourSearch.class,                                  className,                                  nnSearchClassSpec)                                        );    }    else       this.setNearestNeighbourSearchAlgorithm(new LinearNN());        Utils.checkForRemainingOptions(options);  }  /**   * Gets the current settings of IBk.   *   * @return an array of strings suitable for passing to setOptions()   */  public String [] getOptions() {    String [] options = new String [11];    int current = 0;    options[current++] = "-K"; options[current++] = "" + getKNN();    options[current++] = "-W"; options[current++] = "" + m_WindowSize;    if (getCrossValidate()) {      options[current++] = "-X";    }    if (getMeanSquared()) {      options[current++] = "-E";    }    if (m_DistanceWeighting == WEIGHT_INVERSE) {      options[current++] = "-I";    } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) {      options[current++] = "-F";    }    options[current++] = "-A";    options[current++] = m_NNSearch.getClass().getName()+" "+Utils.joinOptions(m_NNSearch.getOptions());         while (current < options.length) {      options[current++] = "";    }        return options;  }  /**   * Returns a description of this classifier.   *   * @return a description of this classifier as a string.   */  public String toString() {    if (m_Train == null) {      return "IBk: No model built yet.";    }    if (!m_kNNValid && m_CrossValidate) {      crossValidate();    }    String result = "IB1 instance-based classifier\n" +      "using " + m_kNN;    switch (m_DistanceWeighting) {    case WEIGHT_INVERSE:      result += " inverse-distance-weighted";      break;    case WEIGHT_SIMILARITY:      result += " similarity-weighted";      break;    }    result += " nearest neighbour(s) for classification\n";    if (m_WindowSize != 0) {      result += "using a maximum of " 	+ m_WindowSize + " (windowed) training instances\n";    }    return result;  }  /**   * Initialise scheme variables.   */  protected void init() {    setKNN(1);    m_WindowSize = 0;    m_DistanceWeighting = WEIGHT_NONE;    m_CrossValidate = false;    m_MeanSquared = false;  }    /**   * Turn the list of nearest neighbors into a probability distribution   *   * @param neighbours the list of nearest neighboring instances   * @param distances the distances of the neighbors   * @return the probability distribution   * @throws Exception if computation goes wrong or has no class attribute   */  protected double [] makeDistribution(Instances neighbours, double[] distances)    throws Exception {    double total = 0, weight;    double [] distribution = new double [m_NumClasses];        // Set up a correction to the estimator    if (m_ClassType == Attribute.NOMINAL) {      for(int i = 0; i < m_NumClasses; i++) {	distribution[i] = 1.0 / Math.max(1,m_Train.numInstances());      }      total = (double)m_NumClasses / Math.max(1,m_Train.numInstances());    }    for(int i=0; i < neighbours.numInstances(); i++) {      // Collect class counts      Instance current = neighbours.instance(i);      distances[i] = distances[i]*distances[i];      distances[i] = Math.sqrt(distances[i]/m_NumAttributesUsed);      switch (m_DistanceWeighting) {        case WEIGHT_INVERSE:          weight = 1.0 / (distances[i] + 0.001); // to avoid div by zero          break;        case WEIGHT_SIMILARITY:          weight = 1.0 - distances[i];          break;        default:                                 // WEIGHT_NONE:          weight = 1.0;          break;      }      weight *= current.weight();      try {        switch (m_ClassType) {          case Attribute.NOMINAL:            distribution[(int)current.classValue()] += weight;            break;          case Attribute.NUMERIC:            distribution[0] += current.classValue() * weight;            break;        }      } catch (Exception ex) {        throw new Error("Data has no class attribute!");      }      total += weight;          }    // Normalise distribution    if (total > 0) {      Utils.normalize(distribution, total);    }    return distribution;  }  /**   * Select the best value for k by hold-one-out cross-validation.   * If the class attribute is nominal, classification error is   * minimised. If the class attribute is numeric, mean absolute   * error is minimised   */  protected void crossValidate() {    try {      double [] performanceStats = new double [m_kNNUpper];      double [] performanceStatsSq = new double [m_kNNUpper];      for(int i = 0; i < m_kNNUpper; i++) {	performanceStats[i] = 0;	performanceStatsSq[i] = 0;      }      m_kNN = m_kNNUpper;      Instance instance;      Instances neighbours;      double[] origDistances, convertedDistances;      for(int i = 0; i < m_Train.numInstances(); i++) {	if (m_Debug && (i % 50 == 0)) {	  System.err.print("Cross validating "			   + i + "/" + m_Train.numInstances() + "\r");	}	instance = m_Train.instance(i);	neighbours = m_NNSearch.kNearestNeighbours(instance, m_kNN);        origDistances = m_NNSearch.getDistances();        	for(int j = m_kNNUpper - 1; j >= 0; j--) {	  // Update the performance stats          convertedDistances = new double[origDistances.length];          System.arraycopy(origDistances, 0,                            convertedDistances, 0, origDistances.length);	  double [] distribution = makeDistribution(neighbours,                                                     convertedDistances);          double thisPrediction = Utils.maxIndex(distribution);	  if (m_Train.classAttribute().isNumeric()) {	    thisPrediction = distribution[0];	    double err = thisPrediction - instance.classValue();	    performanceStatsSq[j] += err * err;   // Squared error	    performanceStats[j] += Math.abs(err); // Absolute error	  } else {	    if (thisPrediction != instance.classValue()) {	      performanceStats[j] ++;             // Classification error	    }	  }	  if (j >= 1) {	    neighbours = pruneToK(neighbours, convertedDistances, j);	  }	}      }      // Display the results of the cross-validation      for(int i = 0; i < m_kNNUpper; i++) {	if (m_Debug) {	  System.err.print("Hold-one-out performance of " + (i + 1)			   + " neighbors " );	}	if (m_Train.classAttribute().isNumeric()) {	  if (m_Debug) {	    if (m_MeanSquared) {	      System.err.println("(RMSE) = "				 + Math.sqrt(performanceStatsSq[i]					     / m_Train.numInstances()));	    } else {	      System.err.println("(MAE) = "				 + performanceStats[i]				 / m_Train.numInstances());	    }	  }	} else {	  if (m_Debug) {	    System.err.println("(%ERR) = "			       + 100.0 * performanceStats[i]			       / m_Train.numInstances());	  }	}      }      // Check through the performance stats and select the best      // k value (or the lowest k if more than one best)      double [] searchStats = performanceStats;      if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {	searchStats = performanceStatsSq;      }      double bestPerformance = Double.NaN;      int bestK = 1;      for(int i = 0; i < m_kNNUpper; i++) {	if (Double.isNaN(bestPerformance)	    || (bestPerformance > searchStats[i])) {	  bestPerformance = searchStats[i];	  bestK = i + 1;	}      }      m_kNN = bestK;      if (m_Debug) {	System.err.println("Selected k = " + bestK);      }            m_kNNValid = true;    } catch (Exception ex) {      throw new Error("Couldn't optimize by cross-validation: "		      +ex.getMessage());    }  }    /**   * Prunes the list to contain the k nearest neighbors. If there are   * multiple neighbors at the k'th distance, all will be kept.   *   * @param neighbours the neighbour instances.   * @param distances the distances of the neighbours from target instance.   * @param k the number of neighbors to keep.   * @return the pruned neighbours.   */  public Instances pruneToK(Instances neighbours, double[] distances, int k) {        if(neighbours==null || distances==null || neighbours.numInstances()==0) {      return null;    }    if (k < 1) {      k = 1;    }        int currentK = 0;    double currentDist;    for(int i=0; i < neighbours.numInstances(); i++) {      currentK++;      currentDist = distances[i];      if(currentK>k && currentDist!=distances[i-1]) {        currentK--;        neighbours = new Instances(neighbours, 0, currentK);        break;      }    }    return neighbours;  }    /**   * Main method for testing this class.   *   * @param argv should contain command line options (see setOptions)   */  public static void main(String [] argv) {    runClassifier(new IBk(), argv);  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -