⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 classifiermetriclearner.java

📁 wekaUT是 university texas austin 开发的基于weka的半指导学习(semi supervised learning)的分类器
💻 JAVA
字号:
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    MetricLearner.java *    Copyright (C) 2002 Mikhail Bilenko * */package weka.core.metrics;import java.util.*;import java.io.Serializable;import java.io.*;import java.text.SimpleDateFormat;import weka.classifiers.*;import weka.classifiers.sparse.*;import weka.classifiers.functions.*;import weka.core.*;/**  * ClassifierMetricLearner - learns metric parameters by constructing * "difference instances" and then learning weights that classify same-class * instances as positive, and different-class instances as negative. * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.4 $ */public class ClassifierMetricLearner extends MetricLearner implements Serializable, OptionHandler {  /** Classifier that is used for learning metric weights */  protected Classifier m_classifier = new SVMlight();  /** Class attribute for diff-instances can be either nominal or numeric */  protected boolean m_isDiffClassNominal = true;  /** The metric that the classifier was used to learn, useful for external-calculation based metrics */  protected LearnableMetric m_metric = null;  /** The pairwise selector used by the metric */  protected PairwiseSelector m_selector = new HardPairwiseSelector();  protected int m_numPosPairs = 200;  protected int m_numNegPairs = 200;  /** Create a new classifier metric learner    */  public ClassifierMetricLearner() {  }        /**   * Train a given metric using given training instances   *   * @param metric the metric to train   * @param instances data to train the metric on   * @exception Exception if training has gone bad.   */  public void trainMetric(LearnableMetric metric, Instances instances) throws Exception {    // If the data doesn't have a class attribute, bail    if (instances.classIndex() < 0 || instances.numInstances() < 2) {      metric.m_trained = false;      System.out.println("Problem with training data");      return;    }    if (metric.getExternal()) {      m_metric = metric;    }    ArrayList pairList = m_selector.createPairList(instances, m_numPosPairs, m_numNegPairs, metric);    Instances diffInstances = createDiffInstances(pairList, metric);    if (diffInstances == null) {      metric.m_trained = false;      System.out.println("null diffInstances");      return;    }    // BEGIN SANITY CHECK    if (true) {       // dump diff-instances into a temporary file      try {	File diffDir = new File("/tmp/diff");	diffDir.mkdir();	String diffName = instances.relationName() + "." +	  Utils.removeSubstring(metric.getClass().getName(), "weka.core.metrics.") + "." +	  Utils.removeSubstring(m_selector.getClass().getName(), "weka.core.metrics.");	if (m_selector instanceof HardPairwiseSelector) {	  diffName = diffName + ((HardPairwiseSelector)m_selector).getNegativesMode().getSelectedTag().getReadable();	  diffName = diffName + ((HardPairwiseSelector)m_selector).getPositivesMode().getSelectedTag().getReadable();	} 	diffInstances.setRelationName(diffName);	PrintWriter writer = new PrintWriter(new BufferedOutputStream (new FileOutputStream(diffDir.getPath() + "/" +											    diffName + ".arff")));	writer.println(diffInstances.toString());	writer.close();	// Do a sanity check - dump out the diffInstances, and	// evaluation classification with an SVM. 	long trainTimeStart = System.currentTimeMillis();	SVMlight classifier = new SVMlight();	Evaluation eval = new Evaluation(diffInstances);	eval.crossValidateModel(classifier, diffInstances, 3);	writer = new PrintWriter(new BufferedOutputStream (new FileOutputStream(diffDir.getPath() + "/" +										diffName + ".dat", true)));	writer.println(eval.pctCorrect());	writer.close();	System.out.println("** Record Sanity:" + (System.currentTimeMillis() - trainTimeStart) + " ms; " +			   eval.pctCorrect() + "% correct\t" +			   eval.numFalseNegatives(0) + "(" + (eval.falseNegativeRate(0)*100) + "%) false negatives\t" +			   eval.numFalsePositives(0) + "(" + (eval.falsePositiveRate(0)*100) + "%) false positives\t");            } catch (Exception e) {	e.printStackTrace();	System.out.println(e.toString());       }    }    // END SANITY CHECK    System.out.println(getTimestamp()+ "  Building classifier: " +		       Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers.") + " "		       + concatStringArray(((OptionHandler)m_classifier).getOptions()));    m_classifier.buildClassifier(diffInstances);    // if we are learning coefficients, put them back into the distance metric    if (!metric.getExternal()) {      String classifierName = Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers");      if (classifierName.equals("functions.LinearRegression")) {	double[] coefficients = ((LinearRegression)m_classifier).coefficients();	System.out.println("Learned coefficients " + coefficients.length);	metric.setWeights(coefficients);      }  else if (classifierName.equals("functions.SMO")) {	FastVector weights = ((SMO)m_classifier).weights();	double[] m_sparseWeights = (double[]) weights.elementAt(0);	int[] m_sparseIndices = (int[]) weights.elementAt(1);	double[] coefficients = new double[metric.getNumAttributes()];	for (int i = 0; i < m_sparseIndices.length; i++) {	  coefficients[m_sparseIndices[i]] = m_sparseWeights[i];	} 	metric.setWeights(coefficients);      }    }     System.out.println(getTimestamp() + "  Done building " +		       Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers."));    metric.m_trained = true;  }  /**   * Set the classifier   *   * @param classifier the classifier   */  public void setClassifier (Classifier classifier) {    m_classifier = classifier;  }  /**   * Get the classifier   *   * @returns the classifierthat this metric employs   */  public Classifier getClassifier () {    return m_classifier;  }    /** Set the pairwise selector   * @param selector the selector for training pairs   */  public void setSelector (PairwiseSelector selector) {    m_selector = selector;  }     /** Get the pairwise selector   * @return the selector for training pairs   */  public PairwiseSelector getSelector() {    return m_selector;  }  /** Set the number of same-class training pairs   * @param numPosPairs the number of same-class training pairs to create for training   */  public void setNumPosPairs(int numPosPairs) {    m_numPosPairs = numPosPairs;  }  /** Get  the number of same-class training pairs   * @return the number of same-class training pairs to create for training   */  public int getNumPosPairs() {    return m_numPosPairs;  }   /** Set the number of different-class training pairs   * @param numNegPairs the number of different-class training pairs to create for training   */  public void setNumNegPairs(int numNegPairs) {    m_numNegPairs = numNegPairs;  }  /** Get  the number of different-class training pairs   * @return the number of different-class training pairs to create for training   */  public int getNumNegPairs() {    return m_numNegPairs;  }  /**   * Use the Classifier for an estimation of similarity   * @param instance1 first instance of a pair   * @param instance2 second instance of a pair   * @returns sim an approximate similarity obtained from the classifier   */  public double getSimilarity(Instance instance1, Instance instance2) throws Exception{    Instance diffInstance = m_metric.createDiffInstance(instance1, instance2);    double d = (((DistributionClassifier)m_classifier).distributionForInstance(diffInstance))[1];    return Math.exp(-d);  }  /**   * Use the Classifier for an estimation of distance   * @param instance1 first instance of a pair   * @param instance2 second instance of a pair   * @returns  an approximate distance obtained from the classifier   */  public double getDistance(Instance instance1, Instance instance2) throws Exception{    Instance diffInstance = m_metric.createDiffInstance(instance1, instance2);    double d =  (((DistributionClassifier)m_classifier).distributionForInstance(diffInstance))[1];    return d;  }  /**   * Gets the current settings of WeightedDotP.   *   * @return an array of strings suitable for passing to setOptions()   */  public String [] getOptions() {    String [] options = new String [40];        int current = 0;    options[current++] = "-p";    options[current++] = "" + m_numPosPairs;    options[current++] = "-n";    options[current++] = "" + m_numNegPairs;    options[current++] = "-S";    options[current++] = Utils.removeSubstring(m_selector.getClass().getName(), "weka.core.metrics.");    if (m_selector instanceof OptionHandler) {      String[] selectorOptions = ((OptionHandler)m_selector).getOptions();      for (int i = 0; i < selectorOptions.length; i++) {	options[current++] = selectorOptions[i];      }    }        options[current++] = "-B";    options[current++] = Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers.");    if (m_classifier instanceof OptionHandler) {      String[] classifierOptions = ((OptionHandler)m_classifier).getOptions();      for (int i = 0; i < classifierOptions.length; i++) {	options[current++] = classifierOptions[i];      }    }        while (current < options.length) {      options[current++] = "";    }    return options;  }  /**   * Parses a given list of options. Valid options are:<p>   *   * -B classifierstring   */  public void setOptions(String[] options) throws Exception {    String classifierString = Utils.getOption('B', options);    if (classifierString.length() == 0) {      throw new Exception("A classifier must be specified"			  + " with the -B option.");    }    String [] classifierSpec = Utils.splitOptions(classifierString);    if (classifierSpec.length == 0) {      throw new Exception("Invalid classifier specification string");    }    String classifierName = classifierSpec[0];    classifierSpec[0] = "";    System.out.println("Classifier name: " + classifierName + "\nClassifier parameters: " + weka.classifiers.sparse.IBkMetric.concatStringArray(classifierSpec));    setClassifier(Classifier.forName(classifierName, classifierSpec));  }  /**   * Gets a string containing current date and time.   *   * @return a string containing the date and time.   */  protected static String getTimestamp() {    return (new SimpleDateFormat("HH:mm:ss:")).format(new Date());  }  /**   * Returns an enumeration describing the available options.   *   * @return an enumeration of all the available options.   */  public Enumeration listOptions() {     Vector newVector = new Vector(1);    newVector.addElement(new Option(	      "\tFull class name of classifier to use, followed\n"	      + "\tby scheme options. (required)\n"	      + "\teg: \"weka.classifiers.bayes.NaiveBayes -D\"",	      "B", 1, "-B <classifier specification>"));        return newVector.elements();  }  /** Obtain a textual description of the metriclearner   * @return a textual description of the metric learner   */  public String toString() {    return new String("ClassifierMetricLearner " + concatStringArray(getOptions()));  }     /** A little helper to create a single String from an array of Strings   * @param strings an array of strings   * @returns a single concatenated string, separated by commas   */  public static String concatStringArray(String[] strings) {    String result = new String();    for (int i = 0; i < strings.length; i++) {      result = result + "\"" + strings[i] + "\" ";    }    return result;  } }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -