📄 .#weighteddotp.java.1.13
字号:
// /**// * Create an instance with features corresponding to dot-product components of the two given instances// * @param instance1 first sparse instance// * @param instance2 second non-sparse instance// */// protected Instance createDiffInstanceSparseNonSparse (SparseInstance instance1, Instance instance2) {// double[] values2 = instance2.toDoubleArray();// int numAttributes = values2.length;// // create an extra attribute if there was no class index originally// int classIndex = m_classIndex;// if (classIndex < 0) {// classIndex = numAttributes;// numAttributes++;// }// double[] diffInstanceValues = new double[numAttributes]; // // iterate through the attributes that are present in the sparse instance// for (int i = 0; i < instance1.numValues(); i++) {// Attribute attribute = instance1.attributeSparse(i);// int attrIdx = attribute.index();// if (attrIdx != classIndex) {// diffInstanceValues[attrIdx] = m_attrWeights[attrIdx] * instance1.value(attrIdx) * values2[attrIdx];// } // }// Instance diffInstance = new Instance(1.0, diffInstanceValues);// diffInstance.setDataset(instance1.dataset());// return diffInstance;// } /** The computation of a metric can be either based on distance, or on similarity * @returns false because dot product fundamentally computes similarity */ public boolean isDistanceBased() { return false; } /** * Given a cluster of instances, return the centroid of that cluster * @param instances objects belonging to a cluster * @param fastMode whether fast mode should be used for SparseInstances * @param normalized normalize centroids for SPKMeans * @return a centroid instance for the given cluster */ public Instance getCentroidInstance(Instances instances, boolean fastMode, boolean normalized) { double [] values = new double[instances.numAttributes()]; if (fastMode) { values = meanOrMode(instances); // uses fast meanOrMode } else { for (int j = 0; j < instances.numAttributes(); j++) { values[j] = instances.meanOrMode(j); // uses usual meanOrMode } } Instance centroid = new Instance(1.0, values); // cluster centroids are dense in SPKMeans if (normalized) { try { normalizeInstanceWeighted(centroid); } catch (Exception e) { e.printStackTrace(); } } return centroid; } /** Get the values of the partial derivates for the metric components * for a particular instance pair @param instance1 the first instance @param instance2 the first instance */ public double[] getGradients(Instance instance1, Instance instance2) throws Exception { double[] gradients = new double[m_numAttributes]; double length1 = lengthWeighted(instance1); double length2 = lengthWeighted(instance2); double l1l2 = length1 * length2; double l1sq = length1 * length1; double l2sq = length2 * length2; double dotp = similarityInternal(instance1, instance2); // take care of SparseInstances by enumerating over the values of the first instance int numAttrs = Math.min(m_numAttributes, instance1.numValues()); for (int i = 0; i < numAttrs; i++) { // get the values double val1 = instance1.valueSparse(i); Attribute attr = instance1.attributeSparse(i); double val2 = instance2.value(attr); int attrIdx = attr.index(); if (attrIdx != m_classIndex) { gradients[attrIdx] = val1 * val2 / l1l2 - dotp / l1l2 * (l1sq * val2 * val2 + l2sq * val1 * val1) / (2*l1l2); } } return gradients; } /** Get the norm-2 length of an instance assuming all attributes are numeric * and utilizing the attribute weights * @returns norm-2 length of an instance */ public double lengthWeighted(Instance instance) { int classIndex = instance.classIndex(); double length = 0; if (instance instanceof SparseInstance) { for (int i = 0; i < instance.numValues(); i++) { int attrIdx = instance.index(i); if (attrIdx != classIndex) { double value = instance.valueSparse(i); length += m_attrWeights[attrIdx] * value * value; } } } else { // non-sparse instance double[] values = instance.toDoubleArray(); for (int i = 0; i < values.length; i++) { if (i != classIndex) { // length += values[i] * values[i]; length += m_attrWeights[i] * values[i] * values[i]; } } } return Math.sqrt(length); } /** * Parses a given list of options. Valid options are:<p> * * -N <br> * Normalize the dot product by vectors lengths * * -E <br> * Use exponential conversion from similarity to distance * (default laplacian conversion) <p> * * -U <br> * Use unit conversion from similarity to distance (dist=1-sim) * (default laplacian conversion) <p> * * -R <br> * The metric is trainable and will be trained using the current MetricLearner * (default non-trainable) * * @param options the list of options as an array of strings * @exception Exception if an option is not supported */ public void setOptions(String[] options) throws Exception { if (Utils.getFlag('E', options)) { setConversionType(new SelectedTag(CONVERSION_EXPONENTIAL, TAGS_CONVERSION)); } else if (Utils.getFlag('U', options)) { setConversionType(new SelectedTag(CONVERSION_UNIT, TAGS_CONVERSION)); } else { setConversionType(new SelectedTag(CONVERSION_LAPLACIAN, TAGS_CONVERSION)); } setLengthNormalized(Utils.getFlag('N', options)); if (Utils.getFlag('R', options)) { setTrainable(Utils.getFlag('R', options)); setExternal(Utils.getFlag('X', options)); String metricLearnerString = Utils.getOption('L', options); if (metricLearnerString.length() != 0) { String [] metricLearnerSpec = Utils.splitOptions(metricLearnerString); String metricLearnerName = metricLearnerSpec[0]; metricLearnerSpec[0] = ""; System.out.println("Got metric learner spec: " + metricLearnerSpec); setMetricLearner(MetricLearner.forName(metricLearnerName, metricLearnerSpec)); } } Utils.checkForRemainingOptions(options); } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(4); newVector.addElement(new Option("\tNormalize the dot product by vectors lengths\n", "N", 0, "-N")); newVector.addElement(new Option("\tUse exponential conversion from similarity to distance\n", "E", 0, "-E")); newVector.addElement(new Option("\tUse unit conversion from similarity to distance\n", "U", 0, "-U")); newVector.addElement(new Option("\tTrain the metric\n", "R", 0, "-R")); newVector.addElement(new Option("\tUse the metric learner for similarity calculations(\"external\")", "X", 0, "-X")); newVector.addElement(new Option( "\tFull class name of metric learner to use, followed\n" + "\tby scheme options. (required)\n" + "\teg: \"weka.core.metrics.ClassifierMetricLearner -B weka.classifiers.function.SMO\"", "L", 1, "-L <classifier specification>")); return newVector.elements(); } /** * Gets the current settings of WeightedDotP. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [45]; int current = 0; if (getLengthNormalized()) { options[current++] = "-N"; } if (m_conversionType == CONVERSION_EXPONENTIAL) { options[current++] = "-E"; } else if (m_conversionType == CONVERSION_UNIT) { options[current++] = "-U"; } if (m_trainable) { options[current++] = "-R"; if (m_external) { options[current++] = "-X"; } options[current++] = "-L"; options[current++] = Utils.removeSubstring(m_metricLearner.getClass().getName(), "weka.core.metrics."); if (m_metricLearner instanceof OptionHandler) { String[] metricOptions = ((OptionHandler)m_metricLearner).getOptions(); for (int i = 0; i < metricOptions.length; i++) { options[current++] = metricOptions[i]; } } } while (current < options.length) { options[current++] = ""; } return options; } public static void main(String[] args) { try { // Create numeric attributes "length" and "weight" Attribute length = new Attribute("length"); Attribute weight = new Attribute("weight"); Attribute height = new Attribute("height"); // Create vector to hold nominal values "first", "second", "third" FastVector my_nominal_values = new FastVector(3); my_nominal_values.addElement("first"); my_nominal_values.addElement("second"); my_nominal_values.addElement("third"); // Create nominal attribute "position" Attribute position = new Attribute("position", my_nominal_values); // Create vector of the above attributes FastVector attributes = new FastVector(3); attributes.addElement(length); attributes.addElement(weight); attributes.addElement(height); attributes.addElement(position); // Create the empty dataset "race" with above attributes Instances race = new Instances("race", attributes, 0); // Make position the class attribute race.setClassIndex(position.index()); // Create a sparse instance with three attribute values SparseInstance s_inst1 = new SparseInstance(1, new double[0], new int[0], 4); s_inst1.setValue(length, 2); s_inst1.setValue(weight, 1); s_inst1.setValue(position, "third"); // Create a sparse instance with three attribute values SparseInstance s_inst2 = new SparseInstance(1, new double[0], new int[0], 4); s_inst2.setValue(length, 1); s_inst2.setValue(height, 5); s_inst2.setValue(position, "second"); // Create a non-sparse instance with all attribute values Instance inst1 = new Instance(4); inst1.setValue(length, 3); inst1.setValue(weight, 4); inst1.setValue(height, 5); inst1.setValue(position, "first"); // Create a sparse instance with three attribute values Instance inst2 = new Instance(4); inst2.setValue(length, 2); inst2.setValue(weight, 2); inst2.setValue(height, 2); inst2.setValue(position, "second"); // Set instances' dataset to be the dataset "race" s_inst1.setDataset(race); s_inst2.setDataset(race); inst1.setDataset(race); inst2.setDataset(race); // Print the instances System.out.println("Sparse instance S1: " + s_inst1); System.out.println("Sparse instance S2: " + s_inst2); System.out.println("Non-sparse instance NS1: " + inst1); System.out.println("Non-sparse instance NS2: " + inst2); // Print the class attribute System.out.println("Class attribute: " + s_inst1.classAttribute()); // Print the class index System.out.println("Class index: " + s_inst1.classIndex()); // Create a new metric and print the distances WeightedDotP metric = new WeightedDotP(3); metric.setClassIndex(position.index()); System.out.println("Similarity is length-normalized? " + metric.getLengthNormalized() + "\n"); System.out.println("Similarity between S1 and S2: " + metric.similarity(s_inst1, s_inst2)); System.out.println("Similarity between S1 and NS1: " + metric.similarity(s_inst1, inst1)); System.out.println("Similarity between NS1 and S1: " + metric.similarity(inst1, s_inst1)); System.out.println("Similarity between S1 and S2: " + metric.similarity(inst1, inst2)); System.out.println("\nSimilarity-distance conversion type: " + metric.getConversionType().getSelectedTag().getReadable()); System.out.println("Distance between S1 and S2: " + metric.distance(s_inst1, s_inst2)); System.out.println("Distance between S1 and NS1: " + metric.distance(s_inst1, inst1)); System.out.println("Distance between NS1 and S1: " + metric.distance(inst1, s_inst1)); System.out.println("Distance between S1 and S2: " + metric.distance(inst1, inst2)); System.out.println(); System.out.println("Difference instance S1-S2: " + metric.createDiffInstance(s_inst1, s_inst2)); System.out.println("Difference instance S1-NS1: " + metric.createDiffInstance(s_inst1, inst1)); System.out.println("Difference instance NS1-S1: " + metric.createDiffInstance(inst1, s_inst1)); System.out.println("Difference instance NS1-NS2: " + metric.createDiffInstance(inst1, inst2)); System.out.println(); double[] weights = {0.2, 0.3, 0.4}; metric.setWeights(weights); System.out.println("Weights: 0.2 0.3 0.4"); System.out.print("NS1: " + inst1); metric.normalizeInstanceWeighted(inst1); System.out.println("\tnormalized: " + inst1); System.out.print("S1: " + s_inst1); metric.normalizeInstanceWeighted(s_inst1); System.out.println("\tnormalized: " + s_inst1); } catch (Exception e) { e.printStackTrace(); } } }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -