⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ibk.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    }    if (getKDTree() != null) {      options[current++] = "-E";      options[current++] = "" + getKDTreeSpec();    }    while (current < options.length) {      options[current++] = "";    }    return options;  }  /**   * Returns a description of this classifier.   *   * @return a description of this classifier as a string.   */  public String toString() {    if (m_Train == null) {      return "IBk: No model built yet.";    }    if (!m_kNNValid && m_CrossValidate) {      crossValidate();    }    String result = "IB1 instance-based classifier\n" +      "using " + m_kNN;    switch (m_DistanceWeighting) {    case WEIGHT_INVERSE:      result += " inverse-distance-weighted";      break;    case WEIGHT_SIMILARITY:      result += " similarity-weighted";      break;    }    result += " nearest neighbour(s) for classification\n";    if (m_WindowSize != 0) {      result += "using a maximum of " 	+ m_WindowSize + " (windowed) training instances\n";    }    return result;  }  /**   * Initialise scheme variables.   */  private void init() {    setKNN(1);    m_WindowSize = 0;    m_DistanceWeighting = WEIGHT_NONE;    m_CrossValidate = false;    m_MeanSquared = false;    m_DontNormalize = false;  }  /**   * Calculates the distance between two instances   *   * @param test the first instance   * @param train the second instance   * @return the distance between the two given instances, between 0 and 1   */            private double distance(Instance first, Instance second) {      if (!Instances.inRanges(first,m_Ranges))	OOPS("Not in ranges");    if (!Instances.inRanges(second,m_Ranges))	OOPS("Not in ranges");    double distance = 0;    int firstI, secondI;    for (int p1 = 0, p2 = 0; 	 p1 < first.numValues() || p2 < second.numValues();) {      if (p1 >= first.numValues()) {	firstI = m_Train.numAttributes();      } else {	firstI = first.index(p1);       }      if (p2 >= second.numValues()) {	secondI = m_Train.numAttributes();      } else {	secondI = second.index(p2);      }      if (firstI == m_Train.classIndex()) {	p1++; continue;      }       if (secondI == m_Train.classIndex()) {	p2++; continue;      }       double diff;      if (firstI == secondI) {	diff = difference(firstI, 			  first.valueSparse(p1),			  second.valueSparse(p2));	p1++; p2++;      } else if (firstI > secondI) {	diff = difference(secondI, 			  0, second.valueSparse(p2));	p2++;      } else {	diff = difference(firstI, 			  first.valueSparse(p1), 0);	p1++;      }      distance += diff * diff;    }    distance = Math.sqrt(distance / m_NumAttributesUsed);    return distance;  }     /**   * Computes the difference between two given attribute   * values.   */  private double difference(int index, double val1, double val2) {    switch (m_Train.attribute(index).type()) {    case Attribute.NOMINAL:            // If attribute is nominal      if (Instance.isMissingValue(val1) || 	  Instance.isMissingValue(val2) ||	  ((int)val1 != (int)val2)) {	return 1;      } else {	return 0;      }    case Attribute.NUMERIC:      // If attribute is numeric      if (Instance.isMissingValue(val1) || 	  Instance.isMissingValue(val2)) {	if (Instance.isMissingValue(val1) && 	    Instance.isMissingValue(val2)) {	  return 1;	} else {	  double diff;	  if (Instance.isMissingValue(val2)) {	    diff = norm(val1, index);	  } else {	    diff = norm(val2, index);	  }	  if (diff < 0.5) {	    diff = 1.0 - diff;	  }	  return diff;	}      } else {	return norm(val1, index) - norm(val2, index);      }    default:      return 0;    }  }  /**   * Normalizes a given value of a numeric attribute.   *   * @param x the value to be normalized   * @param i the attribute's index   */  private double norm(double x, int i) {    if (m_DontNormalize) {      return x;    } else if (Double.isNaN(m_Ranges[i][R_MIN]) || 			    Utils.eq(m_Ranges[i][R_MAX],m_Ranges[i][R_MIN])) {      return 0;    } else {      return (x - m_Ranges[i][R_MIN]) / (m_Ranges[i][R_MAX] - m_Ranges[i][R_MIN]);    }  }                        /**   * Updates the minimum and maximum values for all the attributes   * based on a new instance.   *   * @param instance the new instance   */  private void updateMinMax(Instance instance) {      for (int j = 0;j < m_Train.numAttributes(); j++) {      if (!instance.isMissing(j)) {	if (Double.isNaN(m_Ranges[j][R_MIN])) {	  m_Ranges[j][R_MIN] = instance.value(j);	  m_Ranges[j][R_MAX] = instance.value(j);	} else {	  if (instance.value(j) < m_Ranges[j][R_MIN]) {	    m_Ranges[j][R_MIN] = instance.value(j);	  } else {	    if (instance.value(j) > m_Ranges[j][R_MAX]) {	      m_Ranges[j][R_MAX] = instance.value(j);	    }	  }	}      }    }  }      /**   * Build the list of nearest k neighbours to the given test instance.   *   * @param instance the instance to search for neighbours of   * @return a list of neighbours   */  private NeighbourList findNeighbours(Instance instance) throws Exception {    double distance;    NeighbourList neighbourlist = new NeighbourList(m_kNN);    // dont work with kdtree    if (m_KDTree == null) {      Enumeration enum = m_Train.enumerateInstances();      int i = 0;            while (enum.hasMoreElements()) {	Instance trainInstance = (Instance) enum.nextElement();	if (instance != trainInstance) { // for hold-one-out cross-validation 	  distance = distance(instance, trainInstance);	  if (neighbourlist.isEmpty() || (i < m_kNN) || 	      (distance <= neighbourlist.m_Last.m_Distance)) {	    neighbourlist.insertSorted(distance, trainInstance);	  }	  i++;	}      }    }    else {      // work with KDTree      double[] distanceList = new double[m_KDTree.numInstances()];      int[] instanceList = new int[m_KDTree.numInstances()];      int numOfNearest = m_KDTree.findKNearestNeighbour(instance, m_kNN,							instanceList, distanceList);      for (int i = 0; i < numOfNearest; i++) {	neighbourlist.insertSorted(distanceList[i], 				   m_KDTree.getInstances().instance(instanceList[i]));      }    }    //debug    //OOPS("Target: "+instance+" found "+neighbourlist.currentLength() + " neighbours\n");    //neighbourlist.printList();      return neighbourlist;  }  /**   * Turn the list of nearest neighbours into a probability distribution   *   * @param neighbourlist the list of nearest neighbouring instances   * @return the probability distribution   */  private double [] makeDistribution(NeighbourList neighbourlist)     throws Exception {    double total = 0, weight;    double [] distribution = new double [m_NumClasses];        // Set up a correction to the estimator    if (m_ClassType == Attribute.NOMINAL) {      for(int i = 0; i < m_NumClasses; i++) {	distribution[i] = 1.0 / Math.max(1,m_Train.numInstances());      }      total = (double)m_NumClasses / Math.max(1,m_Train.numInstances());    }    if (!neighbourlist.isEmpty()) {      // Collect class counts      NeighbourNode current = neighbourlist.m_First;      while (current != null) {	switch (m_DistanceWeighting) {	case WEIGHT_INVERSE:	  weight = 1.0 / (current.m_Distance + 0.001); // to avoid div by zero	  break;	case WEIGHT_SIMILARITY:	  weight = 1.0 - current.m_Distance;	  break;	default:                                       // WEIGHT_NONE:	  weight = 1.0;	  break;	}	weight *= current.m_Instance.weight();	try {	  switch (m_ClassType) {	  case Attribute.NOMINAL:	    distribution[(int)current.m_Instance.classValue()] += weight;	    break;	  case Attribute.NUMERIC:	    distribution[0] += current.m_Instance.classValue() * weight;	    break;	  }	} catch (Exception ex) {	  throw new Error("Data has no class attribute!");	}	total += weight;	current = current.m_Next;      }    }    // Normalise distribution    if (total > 0) {      Utils.normalize(distribution, total);    }    //    double [] distribution = new double [m_NumClasses];    return distribution;  }  /**   * Select the best value for k by hold-one-out cross-validation.   * If the class attribute is nominal, classification error is   * minimised. If the class attribute is numeric, mean absolute   * error is minimised   */  private void crossValidate() {    try {      double [] performanceStats = new double [m_kNNUpper];      double [] performanceStatsSq = new double [m_kNNUpper];      for(int i = 0; i < m_kNNUpper; i++) {	performanceStats[i] = 0;	performanceStatsSq[i] = 0;      }      m_kNN = m_kNNUpper;      Instance instance;      NeighbourList neighbourlist;      for(int i = 0; i < m_Train.numInstances(); i++) {	if (m_Debug && (i % 50 == 0)) {	  System.err.print("Cross validating "			   + i + "/" + m_Train.numInstances() + "\r");	}	instance = m_Train.instance(i);	neighbourlist = findNeighbours(instance);	for(int j = m_kNNUpper - 1; j >= 0; j--) {	  // Update the performance stats	  double [] distribution = makeDistribution(neighbourlist);	  double thisPrediction = Utils.maxIndex(distribution);	  if (m_Train.classAttribute().isNumeric()) {	    double err = thisPrediction - instance.classValue();	    performanceStatsSq[j] += err * err;   // Squared error	    performanceStats[j] += Math.abs(err); // Absolute error	  } else {	    if (thisPrediction != instance.classValue()) {	      performanceStats[j] ++;             // Classification error	    }	  }	  if (j >= 1) {	    neighbourlist.pruneToK(j);	  }	}      }      // Display the results of the cross-validation      for(int i = 0; i < m_kNNUpper; i++) {	if (m_Debug) {	  System.err.print("Hold-one-out performance of " + (i + 1)			   + " neighbours " );	}	if (m_Train.classAttribute().isNumeric()) {	  if (m_Debug) {	    if (m_MeanSquared) {	      System.err.println("(RMSE) = "				 + Math.sqrt(performanceStatsSq[i]					     / m_Train.numInstances()));	    } else {	      System.err.println("(MAE) = "				 + performanceStats[i]				 / m_Train.numInstances());	    }	  }	} else {	  if (m_Debug) {	    System.err.println("(%ERR) = "			       + 100.0 * performanceStats[i]			       / m_Train.numInstances());	  }	}      }      // Check through the performance stats and select the best      // k value (or the lowest k if more than one best)      double [] searchStats = performanceStats;      if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {	searchStats = performanceStatsSq;      }      double bestPerformance = Double.NaN;      int bestK = 1;      for(int i = 0; i < m_kNNUpper; i++) {	if (Double.isNaN(bestPerformance)	    || (bestPerformance > searchStats[i])) {	  bestPerformance = searchStats[i];	  bestK = i + 1;	}      }      m_kNN = bestK;      if (m_Debug) {	System.err.println("Selected k = " + bestK);      }            m_kNNValid = true;    } catch (Exception ex) {      throw new Error("Couldn't optimize by cross-validation: "		      +ex.getMessage());    }  }  /**   * Used for debug println's.   * @param output string that is printed   */  private void OOPS(String output) {    System.out.println(output);  }  /**   * Main method for testing this class.   *   * @param argv should contain command line options (see setOptions)   */  public static void main(String [] argv) {    try {      System.out.println(Evaluation.evaluateModel(new IBk(), argv));    } catch (Exception e) {      e.printStackTrace();      System.err.println(e.getMessage());    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -