📄 kl.java
字号:
} /** * return the type of smoothing * @return one of SMOOTHING_UNSMOOTHED, SMOOTHING_DIRICHLET, SMOOTHING_JELINEK_MERCER */ public SelectedTag getSmoothingType() { return new SelectedTag(m_smoothingType, TAGS_SMOOTHING); } /** Set the pseudo-count value for Dirichlet smoothing * @param pseudoCountDirichlet the pseudocount value */ public void setPseudoCountDirichlet(double pseudoCountDirichlet) { m_pseudoCountDirichlet = pseudoCountDirichlet; } /** Get the pseudo-count value for Dirichlet smoothing * @return the pseudocount value */ public double getPseudoCountDirichlet() { return m_pseudoCountDirichlet; } /** Set the lambda parameter for Jelinek-Mercer smoothing * @param lambda */ public void setLambdaJM(double lambdaJM) { m_lambdaJM = lambdaJM; } /** Get the lambda parameter for Jelinek-Mercer smoothing * @return lambda */ public double getLambdaJM() { return m_lambdaJM; } /** The computation of a metric can be either based on distance, or on similarity * @returns true because euclidean metrict fundamentally computes distance */ public boolean isDistanceBased() { return true; } /** Switch between regular KL divergence and I-divergence */ public void setUseIDivergence(boolean useID) { m_useIDivergence = useID; } /** Check whether regular KL divergence or I-divergence is used */ public boolean getUseIDivergence() { return m_useIDivergence; } /** * Given a cluster of instances, return the centroid of that cluster * @param instances objects belonging to a cluster * @param fastMode whether fast mode should be used for SparseInstances * @param normalized normalize centroids for SPKMeans * @return a centroid instance for the given cluster */ public Instance getCentroidInstance(Instances instances, boolean fastMode, boolean normalized) { double [] values = new double[instances.numAttributes()]; if (fastMode) { values = meanOrMode(instances); // uses fast meanOrMode } else { for (int j = 0; j < instances.numAttributes(); j++) { values[j] = instances.meanOrMode(j); // uses usual meanOrMode } } Instance centroid = new Instance(1.0, values); // cluster centroids are dense in SPKMeans if (normalized) { try { normalizeInstanceWeighted(centroid); } catch (Exception e) { e.printStackTrace(); } } return centroid; } /** * Parses a given list of options. Valid options are:<p> * * -N <br> * Normalize the euclidean distance by vectors lengths * * -E <br> * Use exponential conversion from distance to similarity * (default laplacian conversion) <p> * * -U <br> * Use unit conversion from similarity to distance (dist=1-sim) * (default laplacian conversion) <p> * * -R <br> * The metric is trainable and will be trained using the current MetricLearner * (default non-trainable) * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { if (Utils.getFlag('E', options)) { setConversionType(new SelectedTag(CONVERSION_EXPONENTIAL, TAGS_CONVERSION)); } else if (Utils.getFlag('U', options)) { setConversionType(new SelectedTag(CONVERSION_UNIT, TAGS_CONVERSION)); } else { setConversionType(new SelectedTag(CONVERSION_LAPLACIAN, TAGS_CONVERSION)); } if (Utils.getFlag('R', options)) { setTrainable(Utils.getFlag('R', options)); setExternal(Utils.getFlag('X', options)); String metricLearnerString = Utils.getOption('L', options); if (metricLearnerString.length() != 0) { String [] metricLearnerSpec = Utils.splitOptions(metricLearnerString); String metricLearnerName = metricLearnerSpec[0]; metricLearnerSpec[0] = ""; System.out.println("Got metric learner " + metricLearnerName + " spec: " + metricLearnerSpec); setMetricLearner(MetricLearner.forName(metricLearnerName, metricLearnerSpec)); } } Utils.checkForRemainingOptions(options); } /** * Gets the classifier specification string, which contains the class name of * the classifier and any options to the classifier * * @return the classifier string. */ protected String getMetricLearnerSpec() { if (m_metricLearner instanceof OptionHandler) { return m_metricLearner.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)m_metricLearner).getOptions()); } return m_metricLearner.getClass().getName(); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(4); newVector.addElement(new Option("\tNormalize the euclidean distance by vectors lengths\n", "N", 0, "-N")); newVector.addElement(new Option("\tUse exponential conversion from similarity to distance\n", "E", 0, "-E")); newVector.addElement(new Option("\tUse unit conversion from similarity to distance\n", "U", 0, "-U")); newVector.addElement(new Option("\tTrain the metric\n", "R", 0, "-R")); newVector.addElement(new Option("\tUse the metric learner for similarity calculations(\"external\")", "X", 0, "-X")); newVector.addElement(new Option( "\tFull class name of metric learner to use, followed\n" + "\tby scheme options. (required)\n" + "\teg: \"weka.core.metrics.ClassifierMetricLearner -B weka.classifiers.function.SMO\"", "L", 1, "-L <classifier specification>")); return newVector.elements(); } /** * Gets the current settings of KLP. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [55]; int current = 0; if (m_useIDivergence) { options[current++] = "-I"; } if (m_conversionType == CONVERSION_EXPONENTIAL) { options[current++] = "-E"; } else if (m_conversionType == CONVERSION_UNIT) { options[current++] = "-U"; } if (m_smoothingType == SMOOTHING_DIRICHLET) { options[current++] = "-D"; options[current++] = "" + m_pseudoCountDirichlet; } else if (m_smoothingType == SMOOTHING_JELINEK_MERCER) { options[current++] = "-J"; options[current++] = "" + m_lambdaJM; } if (m_useSmoothing) { options[current++] = "-S"; options[current++] = "" + m_alpha; options[current++] = "-R"; options[current++] = "" + m_alphaDecayRate; } if (m_trainable) { options[current++] = "-R"; if (m_external) { options[current++] = "-X"; } options[current++] = "-L"; options[current++] = Utils.removeSubstring(m_metricLearner.getClass().getName(), "weka.core.metrics."); if (m_metricLearner instanceof OptionHandler) { String[] metricOptions = ((OptionHandler)m_metricLearner).getOptions(); for (int i = 0; i < metricOptions.length; i++) { options[current++] = metricOptions[i]; } } } while (current < options.length) { options[current++] = ""; } return options; } /** Create a copy of this metric */ public Object clone() { KL m = null; m = (KL) super.clone(); // clone the fields // for now clone a metric learner via serialization; TODO: implement proper cloning in MetricLearners System.out.println("New alpha=" + m_alpha); try { SerializedObject so = new SerializedObject(m_metricLearner); m.m_metricLearner = (MetricLearner) so.getObject(); } catch (Exception e) { System.err.println("Problems cloning m_metricLearner while cloning KL"); } return m; } public static void main(String[] args) { try {// // Create numeric attributes // Attribute attr1 = new Attribute("attr1");// Attribute attr2 = new Attribute("attr2");// Attribute attr3 = new Attribute("attr3");// Attribute attr4 = new Attribute("attr4"); // // Create vector to hold nominal values "first", "second", "third" // FastVector my_nominal_values = new FastVector(3); // my_nominal_values.addElement("first"); // my_nominal_values.addElement("second"); // my_nominal_values.addElement("third"); // // Create nominal attribute "classAttr" // Attribute classAttr = new Attribute("classAttr", my_nominal_values); // // Create vector of the above attributes // FastVector attributes = new FastVector(4);// attributes.addElement(attr1);// attributes.addElement(attr2);// attributes.addElement(attr3);// attributes.addElement(attr4);// attributes.addElement(classAttr); // // Create the empty dataset with above attributes// Instances dataset = new Instances("dataset", attributes, 0); // // Make position the class attribute// dataset.setClassIndex(classAttr.index()); // // Create a sparse instance with three attribute values// SparseInstance s_inst1 = new SparseInstance(1, new double[0], new int[0], 4);// s_inst1.setValue(attr1, 0.5);// s_inst1.setValue(attr3, 0.5);// s_inst1.setValue(classAttr, "third");// // Create a sparse instance with three attribute values// SparseInstance s_inst2 = new SparseInstance(1, new double[0], new int[0], 4);// s_inst2.setValue(attr2, 0.5);// s_inst2.setValue(attr3, 0.5);// s_inst2.setValue(classAttr,"second");// // Create a non-sparse instance with all attribute values// Instance inst1 = new Instance(5);// inst1.setValue(attr1, 3);// inst1.setValue(attr2, 4);// inst1.setValue(attr3, 5);// inst1.setValue(attr4, 2);// inst1.setValue(classAttr, "first");// // Create a sparse instance with three attribute values// Instance inst2 = new Instance(5);// inst2.setValue(attr1, 2);// inst2.setValue(attr2, 2);// inst2.setValue(attr3, 2);// inst2.setValue(attr4, 3);// inst2.setValue(classAttr, "second");// // Set instances' dataset to be the dataset "dataset"// s_inst1.setDataset(dataset);// s_inst2.setDataset(dataset);// inst1.setDataset(dataset);// inst2.setDataset(dataset); // // Print the instances// System.out.println("Sparse instance S1: " + s_inst1);// System.out.println("Sparse instance S2: " + s_inst2);// System.out.println("Non-sparse instance NS1: " + inst1);// System.out.println("Non-sparse instance NS2: " + inst2); // // Print the class attribute// System.out.println("Class attribute: " + s_inst1.classAttribute()); // // Print the class index// System.out.println("Class index: " + s_inst1.classIndex());// // Create a new metric and print the distances// KL metric = new KL(4);// metric.setClassIndex(classAttr.index());// System.out.println("Distance between S1 and S2: " + metric.distanceJS(s_inst1, s_inst2));// System.out.println("Distance between S1 and NS1: " + metric.distanceJS(s_inst1, inst1));// System.out.println("Distance between NS1 and S1: " + metric.distanceJS(inst1, s_inst1));// System.out.println("Distance between NS1 and NS2: " + metric.distanceJS(inst1, inst2));// System.out.println("\nDistance-similarity conversion type: " +// metric.getConversionType().getSelectedTag().getReadable());// System.out.println("Similarity between S1 and S2: " + metric.similarity(s_inst1, s_inst2));// System.out.println("Similarity between S1 and NS1: " + metric.similarity(s_inst1, inst1));// System.out.println("Similarity between NS1 and S1: " + metric.similarity(inst1, s_inst1));// System.out.println("Similarity between NS1 and NS2: " + metric.similarity(inst1, inst2));// System.out.println();// System.out.println("Difference instance S1-S2: " + metric.createDiffInstance(s_inst1, s_inst2));// System.out.println("Difference instance S1-NS1: " + metric.createDiffInstance(s_inst1, inst1));// System.out.println("Difference instance NS1-S1: " + metric.createDiffInstance(inst1, s_inst1));// System.out.println("Difference instance NS1-NS2: " + metric.createDiffInstance(inst1, inst2)); Instances instances = new Instances(new FileReader ("/tmp/INST.arff")); KL metric = new KL(); metric.buildMetric(instances); Instance newCentroid = instances.instance(0); Instance oldCentroid = instances.instance(1); instances.delete(0); instances.delete(0); metric.setSmoothingType(new SelectedTag(SMOOTHING_DIRICHLET, TAGS_SMOOTHING)); metric.setUseIDivergence(true); double [] values = new double[instances.numAttributes()]; for (int i = 0; i < instances.numInstances(); i++) { Instance instance = instances.instance(i); for (int j = 0; j < instance.numAttributes(); j++) { values[j] += instance.value(j); } } for (int j = 0; j < instances.numAttributes(); j++) { values[j] /= instances.numInstances(); } Instance saneCentroid = new SparseInstance(1.0, values); System.out.println("NumInstances=" + instances.numInstances()); double prevTotal=0, currTotal=0, saneTotal=0; for (int i = 0; i < instances.numInstances(); i++) { double prevPen = metric.distanceNonSparse(instances.instance(i), oldCentroid); double currPen = metric.distanceNonSparse(instances.instance(i), newCentroid); double sanePen = metric.distanceNonSparse(instances.instance(i), saneCentroid); prevTotal += prevPen; currTotal += currPen; saneTotal += sanePen; System.out.println(prevPen + " -> " + currPen + "\t" + sanePen); } System.out.println("====\n" + prevTotal + "\t" + currTotal + "\t" + saneTotal); } catch (Exception e) { e.printStackTrace(); } } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -