📄 libsvm.java
字号:
String result; int i; result = ""; for (i = 0; i < m_Weight.length; i++) { if (i > 0) result += " "; result += Double.toString(m_Weight[i]); } return result; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String weightsTipText() { return "The weights to use for the classes, if empty 1 is used by default."; } /** * Returns whether probability estimates are generated instead of -1/+1 for * classification problems. * * @param value whether to predict probabilities */ public void setProbabilityEstimates(boolean value) { m_ProbabilityEstimates = value; } /** * Sets whether to generate probability estimates instead of -1/+1 for * classification problems. * * @return true, if probability estimates should be returned */ public boolean getProbabilityEstimates() { return m_ProbabilityEstimates; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String probabilityEstimatesTipText() { return "Whether to generate probability estimates instead of -1/+1 for classification problems."; } /** * sets the specified field * * @param o the object to set the field for * @param name the name of the field * @param value the new value of the field */ protected void setField(Object o, String name, Object value) { Field f; try { f = o.getClass().getField(name); f.set(o, value); } catch (Exception e) { e.printStackTrace(); } } /** * sets the specified field in an array * * @param o the object to set the field for * @param name the name of the field * @param index the index in the array * @param value the new value of the field */ protected void setField(Object o, String name, int index, Object value) { Field f; try { f = o.getClass().getField(name); Array.set(f.get(o), index, value); } catch (Exception e) { e.printStackTrace(); } } /** * returns the current value of the specified field * * @param o the object the field is member of * @param name the name of the field * @return the value */ protected Object getField(Object o, String name) { Field f; Object result; try { f = o.getClass().getField(name); result = f.get(o); } catch (Exception e) { e.printStackTrace(); result = null; } return result; } /** * sets a new array for the field * * @param o the object to set the array for * @param name the name of the field * @param type the type of the array * @param length the length of the one-dimensional array */ protected void newArray(Object o, String name, Class type, int length) { newArray(o, name, type, new int[]{length}); } /** * sets a new array for the field * * @param o the object to set the array for * @param name the name of the field * @param type the type of the array * @param dimensions the dimensions of the array */ protected void newArray(Object o, String name, Class type, int[] dimensions) { Field f; try { f = o.getClass().getField(name); f.set(o, Array.newInstance(type, dimensions)); } catch (Exception e) { e.printStackTrace(); } } /** * executes the specified method and returns the result, if any * * @param o the object the method should be called from * @param name the name of the method * @param paramClasses the classes of the parameters * @param paramValues the values of the parameters * @return the return value of the method, if any (in that case null) */ protected Object invokeMethod(Object o, String name, Class[] paramClasses, Object[] paramValues) { Method m; Object result; result = null; try { m = o.getClass().getMethod(name, paramClasses); result = m.invoke(o, paramValues); } catch (Exception e) { e.printStackTrace(); result = null; } return result; } /** * transfers the local variables into a svm_parameter object * * @return the configured svm_parameter object */ protected Object getParameters() { Object result; int i; try { result = Class.forName(CLASS_SVMPARAMETER).newInstance(); setField(result, "svm_type", new Integer(m_SVMType)); setField(result, "kernel_type", new Integer(m_KernelType)); setField(result, "degree", new Integer(m_Degree)); setField(result, "gamma", new Double(m_GammaActual)); setField(result, "coef0", new Double(m_Coef0)); setField(result, "nu", new Double(m_nu)); setField(result, "cache_size", new Double(m_CacheSize)); setField(result, "C", new Double(m_Cost)); setField(result, "eps", new Double(m_eps)); setField(result, "p", new Double(m_Loss)); setField(result, "shrinking", new Integer(m_Shrinking ? 1 : 0)); setField(result, "nr_weight", new Integer(m_Weight.length)); setField(result, "probability", new Integer(m_ProbabilityEstimates ? 1 : 0)); newArray(result, "weight", Double.TYPE, m_Weight.length); newArray(result, "weight_label", Integer.TYPE, m_Weight.length); for (i = 0; i < m_Weight.length; i++) { setField(result, "weight", i, new Double(m_Weight[i])); setField(result, "weight_label", i, new Integer(m_WeightLabel[i])); } } catch (Exception e) { e.printStackTrace(); result = null; } return result; } /** * returns the svm_problem * * @param vx the x values * @param vy the y values * @return the svm_problem object */ protected Object getProblem(Vector vx, Vector vy) { Object result; try { result = Class.forName(CLASS_SVMPROBLEM).newInstance(); setField(result, "l", new Integer(vy.size())); newArray(result, "x", Class.forName(CLASS_SVMNODE), new int[]{vy.size(), 0}); for (int i = 0; i < vy.size(); i++) setField(result, "x", i, vx.elementAt(i)); newArray(result, "y", Double.TYPE, vy.size()); for (int i = 0; i < vy.size(); i++) setField(result, "y", i, vy.elementAt(i)); } catch (Exception e) { e.printStackTrace(); result = null; } return result; } /** * returns an instance into a sparse libsvm array * * @param instance the instance to work on * @return the libsvm array * @throws Exception if setup of array fails */ protected Object instanceToArray(Instance instance) throws Exception { int index; int count; int i; Object result; // determine number of non-zero attributes count = 0; for (i = 0; i < instance.numAttributes(); i++) { if (i == instance.classIndex()) continue; if (instance.value(i) != 0) count++; } // fill array result = Array.newInstance(Class.forName(CLASS_SVMNODE), count); index = 0; for (i = 0; i < instance.numAttributes(); i++) { if (i == instance.classIndex()) continue; if (instance.value(i) == 0) continue; Array.set(result, index, Class.forName(CLASS_SVMNODE).newInstance()); setField(Array.get(result, index), "index", new Integer(i + 1)); setField(Array.get(result, index), "value", new Double(instance.value(i))); index++; } return result; } /** * Computes the distribution for a given instance. * In case of 1-class classification, 1 is returned at index 0 if libsvm * returns 1 and NaN (= missing) if libsvm returns -1. * * @param instance the instance for which distribution is computed * @return the distribution * @throws Exception if the distribution can't be computed successfully */ public double[] distributionForInstance (Instance instance) throws Exception { int[] labels = new int[instance.numClasses()]; double[] prob_estimates = null; if (m_ProbabilityEstimates) { invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_get_labels", new Class[]{ Class.forName(CLASS_SVMMODEL), Array.newInstance(Integer.TYPE, instance.numClasses()).getClass()}, new Object[]{ m_Model, labels}); prob_estimates = new double[instance.numClasses()]; } if (m_Filter != null) { m_Filter.input(instance); m_Filter.batchFinished(); instance = m_Filter.output(); } Object x = instanceToArray(instance); double v; double[] result = new double[instance.numClasses()]; if ( m_ProbabilityEstimates && ((m_SVMType == SVMTYPE_C_SVC) || (m_SVMType == SVMTYPE_NU_SVC)) ) { v = ((Double) invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_predict_probability", new Class[]{ Class.forName(CLASS_SVMMODEL), Array.newInstance(Class.forName(CLASS_SVMNODE), Array.getLength(x)).getClass(), Array.newInstance(Double.TYPE, prob_estimates.length).getClass()}, new Object[]{ m_Model, x, prob_estimates})).doubleValue(); // Return order of probabilities to canonical weka attribute order for (int k = 0; k < prob_estimates.length; k++) { result[labels[k]] = prob_estimates[k]; } } else { v = ((Double) invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_predict", new Class[]{ Class.forName(CLASS_SVMMODEL), Array.newInstance(Class.forName(CLASS_SVMNODE), Array.getLength(x)).getClass()}, new Object[]{ m_Model, x})).doubleValue(); if (instance.classAttribute().isNominal()) { if (m_SVMType == SVMTYPE_ONE_CLASS_SVM) { if (v > 0) result[0] = 1; else result[0] = Double.NaN; // outlier } else { result[(int) v] = 1; } } else { result[0] = v; } } return result; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); // class result.enableDependency(Capability.UNARY_CLASS); result.enableDependency(Capability.NOMINAL_CLASS); result.enableDependency(Capability.NUMERIC_CLASS); result.enableDependency(Capability.DATE_CLASS); switch (m_SVMType) { case SVMTYPE_C_SVC: case SVMTYPE_NU_SVC: result.enable(Capability.NOMINAL_CLASS); break; case SVMTYPE_ONE_CLASS_SVM: result.enable(Capability.UNARY_CLASS); break; case SVMTYPE_EPSILON_SVR: case SVMTYPE_NU_SVR: result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); break; default: throw new IllegalArgumentException("SVMType " + m_SVMType + " is not supported!"); } result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * builds the classifier * * @param insts the training instances * @throws Exception if libsvm classes not in classpath or libsvm * encountered a problem */ public void buildClassifier(Instances insts) throws Exception { if (!isPresent()) throw new Exception("libsvm classes not in CLASSPATH!"); // can classifier handle the data? getCapabilities().testWithFail(insts); // remove instances with missing class insts = new Instances(insts); insts.deleteWithMissingClass(); if (getNormalize()) { m_Filter = new Normalize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } Vector vy = new Vector(); Vector vx = new Vector(); int max_index = 0; for (int d = 0; d < insts.numInstances(); d++) { Instance inst = insts.instance(d); Object x = instanceToArray(inst); int m = Array.getLength(x); if (m > 0) max_index = Math.max(max_index, ((Integer) getField(Array.get(x, m - 1), "index")).intValue()); vx.addElement(x); vy.addElement(new Double(inst.classValue())); } // calculate actual gamma if (getGamma() == 0) m_GammaActual = 1.0 / max_index; else m_GammaActual = m_Gamma; // check parameter String error_msg = (String) invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_check_parameter", new Class[]{ Class.forName(CLASS_SVMPROBLEM), Class.forName(CLASS_SVMPARAMETER)}, new Object[]{ getProblem(vx, vy), getParameters()}); if (error_msg != null) throw new Exception("Error: " + error_msg); // train model m_Model = invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_train", new Class[]{ Class.forName(CLASS_SVMPROBLEM), Class.forName(CLASS_SVMPARAMETER)}, new Object[]{ getProblem(vx, vy), getParameters()}); } /** * returns a string representation * * @return a string representation */ public String toString() { return "LibSVM wrapper, original code by Yasser EL-Manzalawy (= WLSVM)"; } /** * Main method for testing this class. * * @param args the options */ public static void main(String[] args) { runClassifier(new LibSVM(), args); }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -