gdmetriclearner.java

来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 479 行 · 第 1/2 页

JAVA
479
字号
  /** Set the learning rate   * @param learningRate the gradient update coefficient   */  public void setLearningRate(double learningRate) {    m_learningRate = learningRate;  }  /** Get the learning rate   * @return the gradient update coefficient   */  public double getLearningRate() {    return m_learningRate;  }  /** Set the maximum number of update iterations rate   * @param maxIterations the maximum number of gradient updates   */  public void setMaxIterations(int maxIterations) {    m_maxIterations = maxIterations;  }  /** Get the maximum number of update iterations rate   * @return the maximum number of gradient updates   */  public int getMaxIterations() {    return m_maxIterations;  }   /** Set the number of same-class training pairs   * @param numPosPairs the number of same-class training pairs to create for training   */  public void setNumPosPairs(int numPosPairs) {    m_numPosPairs = numPosPairs;  }  /** Get  the number of same-class training pairs   * @return the number of same-class training pairs to create for training   */  public int getNumPosPairs() {    return m_numPosPairs;  }   /** Set the number of different-class training pairs   * @param numNegPairs the number of different-class training pairs to create for training   */  public void setNumNegPairs(int numNegPairs) {    m_numNegPairs = numNegPairs;  }  /** Get  the number of different-class training pairs   * @return the number of different-class training pairs to create for training   */  public int getNumNegPairs() {    return m_numNegPairs;  }  /** Set the pairwise selector   * @param selector the selector for training pairs   */  public void setSelector (PairwiseSelector selector) {    m_selector = selector;  }     /** Get the pairwise selector   * @return the selector for training pairs   */  public PairwiseSelector getSelector() {    return m_selector;  }     /**   * Gets the current settings of WeightedDotP.   *   * @return an array of strings suitable for passing to setOptions()   */  public String [] getOptions() {    String [] options = new String [25];    int current = 0;    options[current++] = "-e";    options[current++] = "" + m_epsilon;    options[current++] = "-p";    options[current++] = "" + m_numPosPairs;    options[current++] = "-n";    options[current++] = "" + m_numNegPairs;    options[current++] = "-i";    options[current++] = "" + m_maxIterations;    options[current++] = "-l";    options[current++] = "" + m_learningRate;    options[current++] = "-S";    options[current++] = m_selector.getClass().getName();    while (current < options.length) {      options[current++] = "";    }    return options;  }  /**   * Parses a given list of options. Valid options are:<p>   *   * -B classifierstring   */  public void setOptions(String[] options) throws Exception {  }  /**   * Gets a string containing current date and time.   *   * @return a string containing the date and time.   */  protected static String getTimestamp() {    return (new SimpleDateFormat("HH:mm:ss:")).format(new Date());  }  /**   * Returns an enumeration describing the available options.   *   * @return an enumeration of all the available options.   */  public Enumeration listOptions() {     Vector newVector = new Vector(0);    return newVector.elements();  }  /** Obtain a textual description of the metriclearner   * @return a textual description of the metric learner   */  public String toString() {    return new String("GDMetricLearner " + concatStringArray(getOptions()));  }     /** A little helper to create a single String from an array of Strings   * @param strings an array of strings   * @returns a single concatenated string, separated by commas   */  public static String concatStringArray(String[] strings) {    String result = new String();    for (int i = 0; i < strings.length; i++) {      result = result + "\"" + strings[i] + "\" ";    }    return result;  }    /** Create a lists of pairs of two kinds: pairs of instances belonging to same class,   * and pairs of instances belonging to different classes.   *   */  protected ArrayList createPairList(Instances instances, int numPosPairs, int numNegPairs) {    ArrayList pairList = new ArrayList();    // A hashmap where each class will be mapped to a list of instnaces belonging to it    HashMap classInstanceMap = new HashMap();    // A list of classes, each element is the double value of the class attribute    ArrayList classValueList = new ArrayList();    // go through all instances, hashing them into lists corresponding to each class    Enumeration enum = instances.enumerateInstances();    while (enum.hasMoreElements()) {      Instance instance = (Instance) enum.nextElement();      if (instance.classIsMissing()) {	System.err.println("Instance has missing class!!!");	continue;      }      Double classValue = new Double(instance.classValue());      // if this class has been seen, add instance to its list      if (classInstanceMap.containsKey(classValue)) {	ArrayList classInstanceList = (ArrayList) classInstanceMap.get(classValue);	classInstanceList.add(instance);      } else {  // create a new list of instances for a previously unseen class	ArrayList classInstanceList = new ArrayList();	classInstanceList.add(instance);	classInstanceMap.put(classValue, classInstanceList);	classValueList.add(classValue);      }    }    // Create the desired number of random positive instances    int numClasses = classInstanceMap.size();    Random random = new Random();    for (int i = 0; i < numPosPairs; i++) {      // select a random class... TODO: probability must be proportional to the number of instances      int class1 = random.nextInt(numClasses);      ArrayList list = (ArrayList) classInstanceMap.get(classValueList.get(class1));      int idx1 = random.nextInt(list.size());      int idx2;      do { 	idx2 = random.nextInt(list.size());      } while (idx1 == idx2);      Instance instance1 = (Instance) list.get(idx1);      Instance instance2 = (Instance) list.get(idx2);            TrainingPair posPair = new TrainingPair(instance1, instance2, true, 0);      pairList.add(posPair);    }    // Create negative diff-instances    if (numClasses > 1) {      random = new Random();      for (int i = 0; i < numNegPairs; i++) {	// select two random distinct classes	int class1 = random.nextInt(numClasses);	int class2 = random.nextInt(numClasses);	while (class2 == class1) {	  class2 = random.nextInt(numClasses);	}	ArrayList list1 = (ArrayList) classInstanceMap.get(classValueList.get(class1));	Instance instance1 = (Instance) list1.get(random.nextInt(list1.size()));	ArrayList list2 = (ArrayList) classInstanceMap.get(classValueList.get(class2));	Instance instance2 = (Instance) list2.get(random.nextInt(list2.size()));	TrainingPair negPair = new TrainingPair(instance1, instance2, false, 0);	pairList.add(negPair);      }    }    return pairList;  }  /** Print the heaviest-weighted attributes for a given set of weights   */  public void printTopAttributes(double[] weights, int n, int iteration) {    // Print top weights - to be moved out into a separate function    System.out.println(iteration + " top components:");    int[] sortedIndeces = Utils.sort(weights);    for (int i = sortedIndeces.length-1; i >  sortedIndeces.length-n && i >=0; i--) {      int idx = sortedIndeces[i];      System.out.println((sortedIndeces.length-1-i) + ": " +			 idx + ":" + m_instances.attribute(idx).name()  + "(" + weights[idx] + ")");    }  }}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?