⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ibk.java

📁 :<<数据挖掘--实用机器学习技术及java实现>>一书的配套源程序
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
    if (m_Train == null) {      return "IBk: No model built yet.";    }    if (!m_kNNValid && m_CrossValidate) {      crossValidate();    }    String result = "IB1 instance-based classifier\n" +      "using " + m_kNN;    switch (m_DistanceWeighting) {    case WEIGHT_INVERSE:      result += " inverse-distance-weighted";      break;    case WEIGHT_SIMILARITY:      result += " similarity-weighted";      break;    }    result += " nearest neighbour(s) for classification\n";    if (m_WindowSize != 0) {      result += "using a maximum of " 	+ m_WindowSize + " (windowed) training instances\n";    }    return result;  }  /**   * Initialise scheme variables.   */  private void init() {    setKNN(1);    m_WindowSize = 0;    m_DistanceWeighting = WEIGHT_NONE;    m_CrossValidate = false;    m_MeanSquared = false;    m_DontNormalize = false;  }  /**   * Calculates the distance between two instances   *   * @param test the first instance   * @param train the second instance   * @return the distance between the two given instances, between 0 and 1   */            private double distance(Instance first, Instance second) {      double distance = 0;    int firstI, secondI;    for (int p1 = 0, p2 = 0; 	 p1 < first.numValues() || p2 < second.numValues();) {      if (p1 >= first.numValues()) {	firstI = m_Train.numAttributes();      } else {	firstI = first.index(p1);       }      if (p2 >= second.numValues()) {	secondI = m_Train.numAttributes();      } else {	secondI = second.index(p2);      }      if (firstI == m_Train.classIndex()) {	p1++; continue;      }       if (secondI == m_Train.classIndex()) {	p2++; continue;      }       double diff;      if (firstI == secondI) {	diff = difference(firstI, 			  first.valueSparse(p1),			  second.valueSparse(p2));	p1++; p2++;      } else if (firstI > secondI) {	diff = difference(secondI, 			  0, second.valueSparse(p2));	p2++;      } else {	diff = difference(firstI, 			  first.valueSparse(p1), 0);	p1++;      }      distance += diff * diff;    }        return Math.sqrt(distance / m_NumAttributesUsed);  }     /**   * Computes the difference between two given attribute   * values.   */  private double difference(int index, double val1, double val2) {    switch (m_Train.attribute(index).type()) {    case Attribute.NOMINAL:            // If attribute is nominal      if (Instance.isMissingValue(val1) || 	  Instance.isMissingValue(val2) ||	  ((int)val1 != (int)val2)) {	return 1;      } else {	return 0;      }    case Attribute.NUMERIC:      // If attribute is numeric      if (Instance.isMissingValue(val1) || 	  Instance.isMissingValue(val2)) {	if (Instance.isMissingValue(val1) && 	    Instance.isMissingValue(val2)) {	  return 1;	} else {	  double diff;	  if (Instance.isMissingValue(val2)) {	    diff = norm(val1, index);	  } else {	    diff = norm(val2, index);	  }	  if (diff < 0.5) {	    diff = 1.0 - diff;	  }	  return diff;	}      } else {	return norm(val1, index) - norm(val2, index);      }    default:      return 0;    }  }  /**   * Normalizes a given value of a numeric attribute.   *   * @param x the value to be normalized   * @param i the attribute's index   */  private double norm(double x, int i) {    if (m_DontNormalize) {      return x;    } else if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i],m_Min[i])) {      return 0;    } else {      return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);    }  }                        /**   * Updates the minimum and maximum values for all the attributes   * based on a new instance.   *   * @param instance the new instance   */  private void updateMinMax(Instance instance) {      for (int j = 0;j < m_Train.numAttributes(); j++) {      if (!instance.isMissing(j)) {	if (Double.isNaN(m_Min[j])) {	  m_Min[j] = instance.value(j);	  m_Max[j] = instance.value(j);	} else {	  if (instance.value(j) < m_Min[j]) {	    m_Min[j] = instance.value(j);	  } else {	    if (instance.value(j) > m_Max[j]) {	      m_Max[j] = instance.value(j);	    }	  }	}      }    }  }      /**   * Build the list of nearest k neighbors to the given test instance.   *   * @param instance the instance to search for neighbours of   * @return a list of neighbors   */  private NeighborList findNeighbors(Instance instance) {    double distance;    NeighborList neighborlist = new NeighborList(m_kNN);    Enumeration enum = m_Train.enumerateInstances();    int i = 0;    while (enum.hasMoreElements()) {      Instance trainInstance = (Instance) enum.nextElement();      if (instance != trainInstance) { // for hold-one-out cross-validation	distance = distance(instance, trainInstance);	if (neighborlist.isEmpty() || (i < m_kNN) || 	    (distance <= neighborlist.m_Last.m_Distance)) {	  neighborlist.insertSorted(distance, trainInstance);	}	i++;      }    }    return neighborlist;  }  /**   * Turn the list of nearest neighbors into a probability distribution   *   * @param neighborlist the list of nearest neighboring instances   * @return the probability distribution   */  private double [] makeDistribution(NeighborList neighborlist)     throws Exception {    double total = 0, weight;    double [] distribution = new double [m_NumClasses];        // Set up a correction to the estimator    if (m_ClassType == Attribute.NOMINAL) {      for(int i = 0; i < m_NumClasses; i++) {	distribution[i] = 1.0 / Math.max(1,m_Train.numInstances());      }      total = (double)m_NumClasses / Math.max(1,m_Train.numInstances());    }    if (!neighborlist.isEmpty()) {      // Collect class counts      NeighborNode current = neighborlist.m_First;      while (current != null) {	switch (m_DistanceWeighting) {	case WEIGHT_INVERSE:	  weight = 1.0 / (current.m_Distance + 0.001); // to avoid div by zero	  break;	case WEIGHT_SIMILARITY:	  weight = 1.0 - current.m_Distance;	  break;	default:                                       // WEIGHT_NONE:	  weight = 1.0;	  break;	}	weight *= current.m_Instance.weight();	try {	  switch (m_ClassType) {	  case Attribute.NOMINAL:	    distribution[(int)current.m_Instance.classValue()] += weight;	    break;	  case Attribute.NUMERIC:	    distribution[0] += current.m_Instance.classValue() * weight;	    break;	  }	} catch (Exception ex) {	  throw new Error("Data has no class attribute!");	}	total += weight;	current = current.m_Next;      }    }    // Normalise distribution    if (total > 0) {      Utils.normalize(distribution, total);    }    return distribution;  }  /**   * Select the best value for k by hold-one-out cross-validation.   * If the class attribute is nominal, classification error is   * minimised. If the class attribute is numeric, mean absolute   * error is minimised   */  private void crossValidate() {    try {      double [] performanceStats = new double [m_kNNUpper];      double [] performanceStatsSq = new double [m_kNNUpper];      for(int i = 0; i < m_kNNUpper; i++) {	performanceStats[i] = 0;	performanceStatsSq[i] = 0;      }      m_kNN = m_kNNUpper;      Instance instance;      NeighborList neighborlist;      for(int i = 0; i < m_Train.numInstances(); i++) {	if (m_Debug && (i % 50 == 0)) {	  System.err.print("Cross validating "			   + i + "/" + m_Train.numInstances() + "\r");	}	instance = m_Train.instance(i);	neighborlist = findNeighbors(instance);	for(int j = m_kNNUpper - 1; j >= 0; j--) {	  // Update the performance stats	  double [] distribution = makeDistribution(neighborlist);	  double thisPrediction = Utils.maxIndex(distribution);	  if (m_Train.classAttribute().isNumeric()) {	    double err = thisPrediction - instance.classValue();	    performanceStatsSq[j] += err * err;   // Squared error	    performanceStats[j] += Math.abs(err); // Absolute error	  } else {	    if (thisPrediction != instance.classValue()) {	      performanceStats[j] ++;             // Classification error	    }	  }	  if (j >= 1) {	    neighborlist.pruneToK(j);	  }	}      }      // Display the results of the cross-validation      for(int i = 0; i < m_kNNUpper; i++) {	if (m_Debug) {	  System.err.print("Hold-one-out performance of " + (i + 1)			   + " neighbors " );	}	if (m_Train.classAttribute().isNumeric()) {	  if (m_Debug) {	    if (m_MeanSquared) {	      System.err.println("(RMSE) = "				 + Math.sqrt(performanceStatsSq[i]					     / m_Train.numInstances()));	    } else {	      System.err.println("(MAE) = "				 + performanceStats[i]				 / m_Train.numInstances());	    }	  }	} else {	  if (m_Debug) {	    System.err.println("(%ERR) = "			       + 100.0 * performanceStats[i]			       / m_Train.numInstances());	  }	}      }      // Check through the performance stats and select the best      // k value (or the lowest k if more than one best)      double [] searchStats = performanceStats;      if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {	searchStats = performanceStatsSq;      }      double bestPerformance = Double.NaN;      int bestK = 1;      for(int i = 0; i < m_kNNUpper; i++) {	if (Double.isNaN(bestPerformance)	    || (bestPerformance > searchStats[i])) {	  bestPerformance = searchStats[i];	  bestK = i + 1;	}      }      m_kNN = bestK;      if (m_Debug) {	System.err.println("Selected k = " + bestK);      }            m_kNNValid = true;    } catch (Exception ex) {      throw new Error("Couldn't optimize by cross-validation: "		      +ex.getMessage());    }  }  /**   * Main method for testing this class.   *   * @param argv should contain command line options (see setOptions)   */  public static void main(String [] argv) {    try {      System.out.println(Evaluation.evaluateModel(new IBk(), argv));    } catch (Exception e) {      e.printStackTrace();      System.err.println(e.getMessage());    }  }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -