⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 ibkmetric.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
     * -X <br>     * Select the number of neighbors to use by hold-one-out cross     * validation, with an upper limit given by the -K option. <p>     *     * -S <br>     * When k is selected by cross-validation for numeric class attributes,     * minimize mean-squared error. (default mean absolute error) <p>     *     * @param options the list of options as an array of strings     * @exception Exception if an option is not supported     */    public void setOptions(String[] options) throws Exception {    	String knnString = Utils.getOption('K', options);	if (knnString.length() != 0) {	    setKNN(Integer.parseInt(knnString));	} else {	    setKNN(1);	}	String windowString = Utils.getOption('W', options);	if (windowString.length() != 0) {	    setWindowSize(Integer.parseInt(windowString));	} else {	    setWindowSize(0);	}	if (Utils.getFlag('D', options)) {	    setDistanceWeighting(new SelectedTag(WEIGHT_INVERSE, TAGS_WEIGHTING));	} else if (Utils.getFlag('F', options)) {	    setDistanceWeighting(new SelectedTag(WEIGHT_SIMILARITY, TAGS_WEIGHTING));	} else {	    setDistanceWeighting(new SelectedTag(WEIGHT_NONE, TAGS_WEIGHTING));	}	setCrossValidate(Utils.getFlag('X', options));	setMeanSquared(Utils.getFlag('S', options));	String metricString = Utils.getOption('M', options);	if (metricString.length() != 0) {	  String[] metricSpec = Utils.splitOptions(metricString);	  String metricName = metricSpec[0]; 	  metricSpec[0] = "";	  System.out.println("Metric name: " + metricName + "\nMetric parameters: " + concatStringArray(metricSpec));	  setMetric(Metric.forName(metricName, metricSpec));	}	Utils.checkForRemainingOptions(options);    }  /**   * Gets the classifier specification string, which contains the class name of   * the classifier and any options to the classifier   *   * @return the classifier string.   */  protected String getMetricSpec() {    if (m_metric instanceof OptionHandler) {      return m_metric.getClass().getName() + " "	+ Utils.joinOptions(((OptionHandler)m_metric).getOptions());    }    return m_metric.getClass().getName();  }  /**   * Gets the current settings of IBkMetric.   *   * @return an array of strings suitable for passing to setOptions()   */  public String [] getOptions() {    String [] options = new String [70];    int current = 0;    options[current++] = "-K"; options[current++] = "" + getKNN();    options[current++] = "-W"; options[current++] = "" + m_WindowSize;    if (getCrossValidate()) {      options[current++] = "-X";    }    if (getMeanSquared()) {      options[current++] = "-S";    }    if (m_DistanceWeighting == WEIGHT_INVERSE) {      options[current++] = "-D";    } else if (m_DistanceWeighting == WEIGHT_SIMILARITY) {      options[current++] = "-F";    }    options[current++] = "-M";    options[current++] = Utils.removeSubstring(m_metric.getClass().getName(), "weka.core.metrics.");    if (m_metric instanceof OptionHandler) {      String[] metricOptions = ((OptionHandler)m_metric).getOptions();      for (int i = 0; i < metricOptions.length; i++) {	options[current++] = metricOptions[i];      }    }        while (current < options.length) {      options[current++] = "";    }    return options;  }    /**     * Returns a description of this classifier.     *     * @return a description of this classifier as a string.     */    public String toString() {	if (m_Train == null) {	    return "IBkMetric: No model built yet.";	}	if (!m_kNNValid && m_CrossValidate) {	    crossValidate();	}	String result = "IB1 instance-based classifier\n" +	    "using " + m_kNN;	switch (m_DistanceWeighting) {	case WEIGHT_INVERSE:	    result += " inverse-distance-weighted";	    break;	case WEIGHT_SIMILARITY:	    result += " similarity-weighted";	    break;	}	result += " nearest neighbour(s) for classification\n";	if (m_WindowSize != 0) {	    result += "using a maximum of " 		+ m_WindowSize + " (windowed) training instances\n";	}	return result;    }    /**     * Initialise scheme variables.     */    protected void init() {	setKNN(1);	m_WindowSize = 0;	m_DistanceWeighting = WEIGHT_NONE;	m_CrossValidate = false;	m_MeanSquared = false;    }        /**     * Build the list of nearest k neighbors to the given test instance.     *     * @param instance the instance to search for neighbours of     * @return a list of neighbors     */    protected NeighborList findNeighbors(Instance instance) throws Exception {	double distance;	NeighborList neighborlist = new NeighborList(m_kNN);	Enumeration enum = m_Train.enumerateInstances();	int i = 0;	while (enum.hasMoreElements()) {	    Instance trainInstance = (Instance) enum.nextElement();	    if (instance != trainInstance) { // for hold-one-out cross-validation		distance = m_metric.distance(instance, trainInstance);		if (neighborlist.isEmpty() || (i < m_kNN) || 		    (distance <= neighborlist.m_Last.m_Distance)) {		    neighborlist.insertSorted(distance, trainInstance);		}		i++;	    }	}	return neighborlist;    }    /**     * Turn the list of nearest neighbors into a probability distribution     *     * @param neighborlist the list of nearest neighboring instances     * @return the probability distribution     */    protected double [] makeDistribution(NeighborList neighborlist) 	throws Exception {	double total = 0, weight;	double [] distribution = new double [m_NumClasses];    	// Set up a Laplacian correction to the estimator	if (m_ClassType == Attribute.NOMINAL) {	    for(int i = 0; i < m_NumClasses; i++) {		distribution[i] = 1.0 / Math.max(1,m_Train.numInstances());	    }	    total = (double)m_NumClasses / Math.max(1,m_Train.numInstances());	}	if (!neighborlist.isEmpty()) {	    // Collect class counts	    NeighborNode current = neighborlist.m_First;	    while (current != null) {		switch (m_DistanceWeighting) {		case WEIGHT_INVERSE:		    weight = 1.0 / (current.m_Distance + m_EPSILON); // to avoid div by zero		    break;		case WEIGHT_SIMILARITY:		    weight = 1.0 - current.m_Distance;		    break;		default:                                       // WEIGHT_NONE:		    weight = 1.0;		    break;		}		weight *= current.m_Instance.weight();		try {		    switch (m_ClassType) {		    case Attribute.NOMINAL:			distribution[(int)current.m_Instance.classValue()] += weight;			break;		    case Attribute.NUMERIC:			distribution[0] += current.m_Instance.classValue() * weight;			break;		    }		} catch (Exception ex) {		    throw new Error("Data has no class attribute!");		}		total += weight;		current = current.m_Next;	    }	}	// Normalise distribution	if (total > 0) {	    Utils.normalize(distribution, total);	}	return distribution;    }    /**     * Select the best value for k by hold-one-out cross-validation.     * If the class attribute is nominal, classification error is     * minimised. If the class attribute is numeric, mean absolute     * error is minimised     */    protected void crossValidate() {	try {	    double [] performanceStats = new double [m_kNNUpper];	    double [] performanceStatsSq = new double [m_kNNUpper];	    for(int i = 0; i < m_kNNUpper; i++) {		performanceStats[i] = 0;		performanceStatsSq[i] = 0;	    }	    m_kNN = m_kNNUpper;	    Instance instance;	    NeighborList neighborlist;	    for(int i = 0; i < m_Train.numInstances(); i++) {		if (m_Debug && (i % 50 == 0)) {		    System.err.print("Cross validating "				     + i + "/" + m_Train.numInstances() + "\r");		}		instance = m_Train.instance(i);		neighborlist = findNeighbors(instance);		for(int j = m_kNNUpper - 1; j >= 0; j--) {		    // Update the performance stats		    double [] distribution = makeDistribution(neighborlist);		    double thisPrediction = Utils.maxIndex(distribution);		    if (m_Train.classAttribute().isNumeric()) {			double err = thisPrediction - instance.classValue();			performanceStatsSq[j] += err * err;   // Squared error			performanceStats[j] += Math.abs(err); // Absolute error		    } else {			if (thisPrediction != instance.classValue()) {			    performanceStats[j] ++;             // Classification error			}		    }		    if (j >= 1) {			neighborlist.pruneToK(j);		    }		}	    }	    // Display the results of the cross-validation	    for(int i = 0; i < m_kNNUpper; i++) {		if (m_Debug) {		    System.err.print("Hold-one-out performance of " + (i + 1)				     + " neighbors " );		}		if (m_Train.classAttribute().isNumeric()) {		    if (m_Debug) {			if (m_MeanSquared) {			    System.err.println("(RMSE) = "					       + Math.sqrt(performanceStatsSq[i]							   / m_Train.numInstances()));			} else {			    System.err.println("(MAE) = "					       + performanceStats[i]					       / m_Train.numInstances());			}		    }		} else {		    if (m_Debug) {			System.err.println("(%ERR) = "					   + 100.0 * performanceStats[i]					   / m_Train.numInstances());		    }		}	    }	    // Check through the performance stats and select the best	    // k value (or the lowest k if more than one best)	    double [] searchStats = performanceStats;	    if (m_Train.classAttribute().isNumeric() && m_MeanSquared) {		searchStats = performanceStatsSq;	    }	    double bestPerformance = Double.NaN;	    int bestK = 1;	    for(int i = 0; i < m_kNNUpper; i++) {		if (Double.isNaN(bestPerformance)		    || (bestPerformance > searchStats[i])) {		    bestPerformance = searchStats[i];		    bestK = i + 1;		}	    }	    m_kNN = bestK;	    if (m_Debug) {		System.err.println("Selected k = " + bestK);	    }      	    m_kNNValid = true;	} catch (Exception ex) {	    throw new Error("Couldn't optimize by cross-validation: "			    +ex.getMessage());	}    }  /** A little helper to create a single String from an array of Strings   * @param strings an array of strings   * @returns a single concatenated string, separated by commas   */  public static String concatStringArray(String[] strings) {    String result = new String();    for (int i = 0; i < strings.length; i++) {      result = result + "\"" + strings[i] + "\" ";    }    return result;  }     /**     * Main method for testing this class.     *     * @param argv should contain command line options (see setOptions)     */    public static void main(String [] argv) {	try {	    System.out.println(Evaluation.evaluateModel(new IBkMetric(), argv));	} catch (Exception e) {	    e.printStackTrace();	    System.err.println(e.getMessage());	}    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -