kernelvsmetric.java
来自「wekaUT是 university texas austin 开发的基于wek」· Java 代码 · 共 702 行 · 第 1/2 页
JAVA
702 行
Classifier classifier = (Classifier) Class.forName(m_classifier.getClass().getName()).newInstance(); if (m_classifier instanceof OptionHandler) { ((OptionHandler)classifier).setOptions(((OptionHandler)m_classifier).getOptions()); } Evaluation eval = new Evaluation(instances); eval.crossValidateModel(classifier, instances, 2); writer = new PrintWriter(new BufferedOutputStream (new FileOutputStream(diffDir.getPath() + "/" + diffName + ".dat", true))); writer.println(eval.pctCorrect()); writer.close(); System.out.println("** String sanity: " + (System.currentTimeMillis() - trainTimeStart) + " ms; " + eval.pctCorrect() + "% correct\t" + eval.numFalseNegatives(0) + "(" + eval.falseNegativeRate(0) + "%) false negatives\t" + eval.numFalsePositives(0) + "(" + eval.falsePositiveRate(0) + "%) false positives\t"); } catch (Exception e) { e.printStackTrace(); System.out.println(e.toString()); } } // END SANITY CHECK System.out.println((new SimpleDateFormat("HH:mm:ss:")).format(new Date()) + weka.classifiers.sparse.IBkMetric.concatStringArray(((OptionHandler)m_classifier).getOptions())); System.out.println("Now got " + m_instances.numInstances()); m_classifier.buildClassifier(m_instances); m_trained = true; } /** Given a pair of strings and a label (same-class/different-class), * create a diff-instance */ protected SparseInstance createPairInstance(String s1, String s2) { StringReference stringRef1 = (StringReference) m_stringRefHash.get(s1); StringReference stringRef2 = (StringReference) m_stringRefHash.get(s2); double invLength = 1/(stringRef1.m_length * stringRef2.m_length); HashMapVector v1 = stringRef1.m_vector; HashMapVector v2 = stringRef2.m_vector; SparseInstance pairInstance = new SparseInstance(1, new double[0], new int[0], m_tokenHash.size()+1); // calculate all the components of the kernel Iterator mapEntries = v1.iterator(); while (mapEntries.hasNext()) { Map.Entry entry = (Map.Entry)mapEntries.next(); String token = (String)entry.getKey(); if (v2.hashMap.containsKey(token)) { Attribute attr = (Attribute) m_tokenAttrMap.get(token); double tf1 = ((Weight)entry.getValue()).getValue(); double tf2 = ((Weight)v2.hashMap.get(token)).getValue(); TokenInfo tokenInfo = (TokenInfo) m_tokenHash.get(token); // add this component unless it was killed (with idf=0) if (tokenInfo != null) { if (m_useIDF) { pairInstance.setValue(attr, tf1 * tf2 * tokenInfo.idf * tokenInfo.idf * invLength ); } else { pairInstance.setValue(attr, tf1 * tf2 * invLength ); } if (m_useIndividualWeights) { Attribute attr_s1 = (Attribute) m_tokenAttrMap.get("s1_" + token); Attribute attr_s2 = (Attribute) m_tokenAttrMap.get("s2_" + token); if (m_useIDF) { // TODO: this is not right; invLength should be different! pairInstance.setValue(attr_s1, tf1 * tokenInfo.idf * invLength); pairInstance.setValue(attr_s2, tf2 * tokenInfo.idf * invLength); } else { pairInstance.setValue(attr_s1, tf1 * invLength); pairInstance.setValue(attr_s2, tf2 * invLength); } } } } } return pairInstance; } /** Compute similarity between two strings * @param s1 first string * @param s2 second string * @returns similarity between two strings */ public double similarity(String s1, String s2) throws Exception { SparseInstance pairInstance = createPairInstance(s1, s2); pairInstance.setDataset(m_instances); double sim = 0; // if the classifier has been trained, use it. if (m_trained) { double[] res = m_classifier.distributionForInstance(pairInstance); sim = res[0]; } else { // otherwise, return the old-fashioned dot product for (int j = 0; j < pairInstance.numValues(); j++) { Attribute attribute = pairInstance.attributeSparse(j); int attrIdx = attribute.index(); sim += pairInstance.value(attrIdx); } } return sim; } /** The computation of a metric can be either based on distance, or on similarity * @returns false because dot product fundamentally computes similarity */ public boolean isDistanceBased() { return false; } /** Set the tokenizer to use * @param tokenizer the tokenizer that is used */ public void setTokenizer(Tokenizer tokenizer) { m_tokenizer = tokenizer; } /** Get the tokenizer to use * @return the tokenizer that is used */ public Tokenizer getTokenizer() { return m_tokenizer; } /** * Set the classifier * * @param classifier the classifier */ public void setClassifier (DistributionClassifier classifier) { m_classifier = classifier; } /** * Get the classifier * * @returns the classifierthat this metric employs */ public DistributionClassifier getClassifier () { return m_classifier; } /** Turn IDF weighting on/off * @param useIDF if true, all token weights will be weighted by IDF */ public void setUseIDF(boolean useIDF) { m_useIDF = useIDF; } /** check whether IDF weighting is on/off * @return if true, all token weights are weighted by IDF */ public boolean getUseIDF() { return m_useIDF; } /** Turn using individual components on/off * @param useIndividualStrings if true, individual token weghts are included in the pairwise representation */ public void setUseIndividualStrings(boolean useIndividualStrings) { m_useIndividualWeights = useIndividualStrings; } /** Turn using individual components on/off * @return true if individual token weights are included in the pairwise representation */ public boolean getUseIndividualStrings() { return m_useIndividualWeights; } /** Turn adding a special all-features example on/off * @param useAllFeaturesExample if true, a special training example will be constructed that incorporates all features */ public void setUseAllFeaturesExample(boolean useAllFeaturesExample) { m_useAllFeaturesExample = useAllFeaturesExample; } /** Check whether a special all-features example is being added * @return true if a special training example will be constructed that incorporates all features */ public boolean getUseAllFeaturesExample() { return m_useAllFeaturesExample; } /** Return the number of tokens indexed. * @return the number of tokens indexed*/ public int size() { return m_tokenHash.size(); } /** * Returns distance between two strings using the current conversion * type (CONVERSION_LAPLACIAN, CONVERSION_EXPONENTIAL, CONVERSION_UNIT, ...) * @param string1 First string. * @param string2 Second string. * @exception Exception if distance could not be estimated. */ public double distance (String string1, String string2) throws Exception { switch (m_conversionType) { case CONVERSION_LAPLACIAN: return 1 / (1 + similarity(string1, string2)); case CONVERSION_UNIT: return 2 * (1 - similarity(string1, string2)); case CONVERSION_EXPONENTIAL: return Math.exp(-similarity(string1, string2)); default: throw new Exception ("Unknown similarity to distance conversion method"); } } /** * Set the type of similarity to distance conversion. Values other * than CONVERSION_LAPLACIAN, CONVERSION_UNIT, or CONVERSION_EXPONENTIAL will be ignored * * @param type type of the similarity to distance conversion to use */ public void setConversionType(SelectedTag conversionType) { if (conversionType.getTags() == TAGS_CONVERSION) { m_conversionType = conversionType.getSelectedTag().getID(); } } /** * return the type of similarity to distance conversion * @return one of CONVERSION_LAPLACIAN, CONVERSION_UNIT, or CONVERSION_EXPONENTIAL */ public SelectedTag getConversionType() { return new SelectedTag(m_conversionType, TAGS_CONVERSION); } /** Create a copy of this metric * @return another KernelVSMetric with the same exact parameters as this metric */ public Object clone() { KernelVSMetric metric = new KernelVSMetric(); metric.setConversionType(new SelectedTag(m_conversionType, TAGS_CONVERSION)); metric.setTokenizer(m_tokenizer); metric.setUseIDF(m_useIDF); metric.setUseIndividualStrings(m_useIndividualWeights); metric.setUseAllFeaturesExample(m_useAllFeaturesExample); try { DistributionClassifier classifier = (DistributionClassifier) Class.forName(m_classifier.getClass().getName()).newInstance(); if (m_classifier instanceof OptionHandler) { ((OptionHandler)classifier).setOptions(((OptionHandler)m_classifier).getOptions()); } metric.setClassifier(classifier); } catch (Exception e) { System.err.println("Problems cloning metric " + this.getClass().getName() + ": " + e.toString()); e.printStackTrace(); System.exit(1); } return metric; } /** * Gets the current settings of NGramTokenizer. * * @return an array of strings suitable for passing to setOptions() */ public String [] getOptions() { String [] options = new String [40]; int current = 0; if (m_conversionType == CONVERSION_EXPONENTIAL) { options[current++] = "-E"; } else if (m_conversionType == CONVERSION_UNIT) { options[current++] = "-U"; } if (m_useAllFeaturesExample) { options[current++] = "-AF"; } if (m_useIDF) { options[current++] = "-I"; } if (m_useIndividualWeights) { options[current++] = "-V"; } options[current++] = "-T"; options[current++] = Utils.removeSubstring(m_tokenizer.getClass().getName(), "weka.deduping.metrics."); if (m_tokenizer instanceof OptionHandler) { String[] tokenizerOptions = ((OptionHandler)m_tokenizer).getOptions(); for (int i = 0; i < tokenizerOptions.length; i++) { options[current++] = tokenizerOptions[i]; } } options[current++] = "-C"; options[current++] = Utils.removeSubstring(m_classifier.getClass().getName(), "weka.classifiers."); if (m_classifier instanceof OptionHandler) { String[] classifierOptions = ((OptionHandler)m_classifier).getOptions(); for (int i = 0; i < classifierOptions.length; i++) { options[current++] = classifierOptions[i]; } } while (current < options.length) { options[current++] = ""; } return options; } /** * Parses a given list of options. Valid options are:<p> * * -S use stemming * -R remove stopwords * -N gram size */ public void setOptions(String[] options) throws Exception { // TODO } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. */ public Enumeration listOptions() { Vector newVector = new Vector(0); return newVector.elements(); } /** Given an instance, normalize it to be a unit vector. Destructive! * @param instance instance to be normalized */ protected void normalizeInstance(Instance instance) { double norm = 0; double values [] = instance.toDoubleArray(); for (int i=0; i < values.length; i++) { if (i != instance.classIndex()) { // don't normalize the class index norm += values[i] * values[i]; } } norm = Math.sqrt(norm); if (norm != 0) { for (int i=0; i<values.length; i++) { if (i != instance.classIndex()) { // don't normalize the class index values[i] /= norm; } } instance.setValueArray(values); } } }
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?