matlabmetriclearner.java

来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 231 行

JAVA
231
字号
/* *    This program is free software; you can redistribute it and/or modify *    it under the terms of the GNU General Public License as published by *    the Free Software Foundation; either version 2 of the License, or *    (at your option) any later version. * *    This program is distributed in the hope that it will be useful, *    but WITHOUT ANY WARRANTY; without even the implied warranty of *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the *    GNU General Public License for more details. * *    You should have received a copy of the GNU General Public License *    along with this program; if not, write to the Free Software *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. *//* *    MetricLearner.java *    Copyright (C) 2002 Mikhail Bilenko * */package weka.core.metrics;import java.util.*;import java.io.*;import java.text.SimpleDateFormat;import weka.core.*;/**  * MatlabMetricLearner - learns metric parameters by constructing * "difference instances" and then learning weights that classify same-class * instances as positive, and different-class instances as negative using an * external Matlab program. * * @author Mikhail Bilenko (mbilenko@cs.utexas.edu) * @version $Revision: 1.1 $ */public class MatlabMetricLearner extends  MetricLearner implements Serializable {  /** Matlab program  that is used for learning metric weights */  protected String m_scriptFilename  = new String("/tmp/matlab1.m");  /** Name of the temporary file where the matrix representing the same-class diff. instances is going to be */  protected String m_posMatrixFilename = new String("/tmp/posMatrix.txt");  /** Name of the temporary file where the matrix representing the diff-class diff. instances is going to be */  protected String m_negMatrixFilename = new String("/tmp/negMatrix.txt");  /** Name of the temporary file where the weights will be stored by Matlab after calculation */  protected String m_weightsFilename = new String("/tmp/weights.txt");  /** Debugging output */  protected boolean m_debug = true;  /** Create a new matlab metric learner   */  public MatlabMetricLearner() {  }       /**   * Train a given metric using given training instances   *   * @param metric the metric to train   * @param instances data to train the metric on   * @exception Exception if training has gone bad.   */  public void trainMetric(LearnableMetric metric, Instances instances) throws Exception {    // If the data doesn't have a class attribute, bail    if (instances.classIndex() < 0) {	      return;    }    // First, create positive and negative diff-instances    ArrayList[] diffInstanceLists = createDiffInstanceLists(instances, metric,							    metric.getNumPosDiffInstances(), metric.getPosNegDiffInstanceRatio());    ArrayList posDiffInstanceList = diffInstanceLists[0];    ArrayList negDiffInstanceList = diffInstanceLists[1];    prepareMatlabScript();    dumpInstanceList(posDiffInstanceList, m_posMatrixFilename);    dumpInstanceList(negDiffInstanceList, m_negMatrixFilename);    runMatlab(m_scriptFilename, "matlab.out");    double[] coefficients = readVector(m_weightsFilename);    if (m_debug) System.out.println(getTimestamp() + " Read " + coefficients.length + " coefficients");    for (int i = 0; i < coefficients.length; i++) {      //      coefficients[i] = (coefficients[i]+1)/2;    }     metric.setWeights(coefficients);  }    /** Create matlab m-file for PCA   * @param filename file where matlab script is created   */  public void prepareMatlabScript() {    try{      PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(m_scriptFilename)));      //      writer.println("function w = fitMetricWeights()                                               ");      writer.println("S = load('" + m_posMatrixFilename + "');                                      ");      writer.println("D = load('" + m_negMatrixFilename + "');                                      ");      writer.println("[mD,n] = size(D);                                                             ");      writer.println("[mS,n] = size(S);                                                             ");      writer.println("");      writer.println("lb = zeros(n, 1);                                                             ");      writer.println("ub = ones(n, 1);                                                              ");      writer.println("x0 = ones(n, 1)/sqrt(n);                                                              ");      writer.println("");      writer.println("b = 2* norm(S*x0)/mS * ones(mD, 1);                                              ");      writer.println("w = fmincon(inline('1/norm(S*x)', 'x', 'S'), x0, D, b, [], [], lb, ub, [],[],S);");      writer.println("w = w/norm(w)");      writer.println("save " + m_weightsFilename + " w -ASCII -DOUBLE;");      writer.close();    }     catch (Exception e) {      System.err.println("Could not create matlab file: " + e);    }  }  /** Run matlab in command line with a given argument   * @param inFile file to be input to Matlab   * @param outFile file where results are stored   */  public void runMatlab(String inFile, String outFile) {    // call matlab to do the dirty work    try {      int exitValue;      do {	if (m_debug) System.out.println(getTimestamp() + " starting Matlab");	Process proc = Runtime.getRuntime().exec("matlab -tty < " + inFile + " > " + outFile);	exitValue = proc.waitFor();	if (exitValue != 0) {	  System.err.println(getTimestamp() + " WARNING!!!!!  Matlab returned exit value 1, trying again later!");	  Thread.sleep(300000);	}      } while (exitValue != 0);      if (m_debug) System.out.println(getTimestamp() + " Matlab done");    } catch (Exception e) {      System.err.println("Problems running matlab: " + e);    }  }   /**   * Gets a string containing current date and time.   *   * @return a string containing the date and time.   */  protected static String getTimestamp() {    return (new SimpleDateFormat("HH:mm:ss:")).format(new Date());  }     /** Read a column vector from a text file   * @param name file name   * @returns double[] array corresponding to a vector   */  public double[] readVector(String name) throws Exception {     BufferedReader r = new BufferedReader(new FileReader(name));     int numAttributes = -1;          ArrayList vectorList = new ArrayList();     String s;     while ((s = r.readLine()) != null) {       try { 	 vectorList.add(new Double(s));       } catch (Exception e) {	 System.err.println("Couldn't parse " + s + " as double");       }     }     int length = vectorList.size();     double [] vector = new double[length];     for (int i = 0; i < length; i++) {       vector[i] = ((Double) vectorList.get(i)).doubleValue();     }      return vector;  }  /** Dump a list of instances as a matrix of attribute values   * @param instanceList a list of instances   * @param filename name of the file where the matrix is saved   */  public void dumpInstanceList(ArrayList instanceList, String filename) {    try {       PrintWriter writer = new PrintWriter(new BufferedOutputStream(new FileOutputStream(filename)));      int numInstances = instanceList.size();      for (int i = 0; i < numInstances; i++) {	Instance instance = (Instance) instanceList.get(i);	int numAttributes = instance.numAttributes();	int classIdx = instance.classIndex();	for (int j = 0; j < numAttributes; j++) {	  if (j != classIdx) {	    writer.print(instance.value(j) + " ");	  }	}	writer.println();      }      writer.close();    } catch (Exception e) {      System.err.println("Could not create a temporary file for dumping the instance list: " + e);    }  }  /**   * Use Matlab for an estimation of similarity   * @param instance1 first instance of a pair   * @param instance2 second instance of a pair   * @returns sim an approximate similarity obtained from the classifier   */  public double getSimilarity(Instance instance1, Instance instance2) throws Exception{    throw new Exception("MatlabMetricLearner cannot be used as an external distance metric!");  }  /**   * Use Matlab for an estimation of distance   * @param instance1 first instance of a pair   * @param instance2 second instance of a pair   * @returns sim an approximate distance obtained from the classifier   */  public double getDistance(Instance instance1, Instance instance2) throws Exception{    throw new Exception("MatlabMetricLearner cannot be used as an external distance metric!");  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?