gdmetriclearner.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 479 行 · 第 1/2 页
JAVA
479 行
/** Set the learning rate * @param learningRate the gradient update coefficient */ public void setLearningRate(double learningRate) { m_learningRate = learningRate; } /** Get the learning rate * @return the gradient update coefficient */ public double getLearningRate() { return m_learningRate; } /** Set the maximum number of update iterations rate * @param maxIterations the maximum number of gradient updates */ public void setMaxIterations(int maxIterations) { m_maxIterations = maxIterations; } /** Get the maximum number of update iterations rate * @return the maximum number of gradient updates */ public int getMaxIterations() { return m_maxIterations; } /** Set the number of same-class training pairs * @param numPosPairs the number of same-class training pairs to create for training */ public void setNumPosPairs(int numPosPairs) { m_numPosPairs = numPosPairs; } /** Get the number of same-class training pairs * @return the number of same-class training pairs to create for training */ public int getNumPosPairs() { return m_numPosPairs; } /** Set the number of different-class training pairs * @param numNegPairs the number of different-class training pairs to create for training */ public void setNumNegPairs(int numNegPairs) { m_numNegPairs = numNegPairs; } /** Get the number of different-class training pairs * @return the number of different-class training pairs to create for training */ public int getNumNegPairs() { return m_numNegPairs; } /** Set the pairwise selector * @param selector the selector for training pairs */ public void setSelector (PairwiseSelector selector) { m_selector = selector; } /** Get the pairwise selector * @return the selector for training pairs */ public PairwiseSelector getSelector() { return m_selector; } /** * Gets the current settings of WeightedDotP. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [25]; int current = 0; options[current++] = "-e"; options[current++] = "" + m_epsilon; options[current++] = "-p"; options[current++] = "" + m_numPosPairs; options[current++] = "-n"; options[current++] = "" + m_numNegPairs; options[current++] = "-i"; options[current++] = "" + m_maxIterations; options[current++] = "-l"; options[current++] = "" + m_learningRate; options[current++] = "-S"; options[current++] = m_selector.getClass().getName(); while (current < options.length) { options[current++] = ""; } return options; } /** * Parses a given list of options. Valid options are:<p> * * -B classifierstring */ public void setOptions(String[] options) throws Exception { } /** * Gets a string containing current date and time. * * @return a string containing the date and time. */ protected static String getTimestamp() { return (new SimpleDateFormat("HH:mm:ss:")).format(new Date()); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(0); return newVector.elements(); } /** Obtain a textual description of the metriclearner * @return a textual description of the metric learner */ public String toString() { return new String("GDMetricLearner " + concatStringArray(getOptions())); } /** A little helper to create a single String from an array of Strings * @param strings an array of strings * @returns a single concatenated string, separated by commas */ public static String concatStringArray(String[] strings) { String result = new String(); for (int i = 0; i < strings.length; i++) { result = result + "\"" + strings[i] + "\" "; } return result; } /** Create a lists of pairs of two kinds: pairs of instances belonging to same class, * and pairs of instances belonging to different classes. * */ protected ArrayList createPairList(Instances instances, int numPosPairs, int numNegPairs) { ArrayList pairList = new ArrayList(); // A hashmap where each class will be mapped to a list of instnaces belonging to it HashMap classInstanceMap = new HashMap(); // A list of classes, each element is the double value of the class attribute ArrayList classValueList = new ArrayList(); // go through all instances, hashing them into lists corresponding to each class Enumeration enum = instances.enumerateInstances(); while (enum.hasMoreElements()) { Instance instance = (Instance) enum.nextElement(); if (instance.classIsMissing()) { System.err.println("Instance has missing class!!!"); continue; } Double classValue = new Double(instance.classValue()); // if this class has been seen, add instance to its list if (classInstanceMap.containsKey(classValue)) { ArrayList classInstanceList = (ArrayList) classInstanceMap.get(classValue); classInstanceList.add(instance); } else { // create a new list of instances for a previously unseen class ArrayList classInstanceList = new ArrayList(); classInstanceList.add(instance); classInstanceMap.put(classValue, classInstanceList); classValueList.add(classValue); } } // Create the desired number of random positive instances int numClasses = classInstanceMap.size(); Random random = new Random(); for (int i = 0; i < numPosPairs; i++) { // select a random class... TODO: probability must be proportional to the number of instances int class1 = random.nextInt(numClasses); ArrayList list = (ArrayList) classInstanceMap.get(classValueList.get(class1)); int idx1 = random.nextInt(list.size()); int idx2; do { idx2 = random.nextInt(list.size()); } while (idx1 == idx2); Instance instance1 = (Instance) list.get(idx1); Instance instance2 = (Instance) list.get(idx2); TrainingPair posPair = new TrainingPair(instance1, instance2, true, 0); pairList.add(posPair); } // Create negative diff-instances if (numClasses > 1) { random = new Random(); for (int i = 0; i < numNegPairs; i++) { // select two random distinct classes int class1 = random.nextInt(numClasses); int class2 = random.nextInt(numClasses); while (class2 == class1) { class2 = random.nextInt(numClasses); } ArrayList list1 = (ArrayList) classInstanceMap.get(classValueList.get(class1)); Instance instance1 = (Instance) list1.get(random.nextInt(list1.size())); ArrayList list2 = (ArrayList) classInstanceMap.get(classValueList.get(class2)); Instance instance2 = (Instance) list2.get(random.nextInt(list2.size())); TrainingPair negPair = new TrainingPair(instance1, instance2, false, 0); pairList.add(negPair); } } return pairList; } /** Print the heaviest-weighted attributes for a given set of weights */ public void printTopAttributes(double[] weights, int n, int iteration) { // Print top weights - to be moved out into a separate function System.out.println(iteration + " top components:"); int[] sortedIndeces = Utils.sort(weights); for (int i = sortedIndeces.length-1; i > sortedIndeces.length-n && i >=0; i--) { int idx = sortedIndeces[i]; System.out.println((sortedIndeces.length-1-i) + ": " + idx + ":" + m_instances.attribute(idx).name() + "(" + weights[idx] + ")"); } }}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?