📄 libsvm.java
字号:
} } /** * Gets the parameters C of class i to weight[i]*C, for C-SVC (default 1). * Blank separated doubles. * * @return the weights (doubles separated by blanks) */ public String getWeights() { String result; int i; result = ""; for (i = 0; i < m_Weight.length; i++) { if (i > 0) result += " "; result += Double.toString(m_Weight[i]); } return result; } /** * Returns the tip text for this property * * @return tip text for this property suitable for * displaying in the explorer/experimenter gui */ public String weightsTipText() { return "The weights to use for the classes, if empty 1 is used by default."; } /** * sets the specified field * * @param o the object to set the field for * @param name the name of the field * @param value the new value of the field */ protected void setField(Object o, String name, Object value) { Field f; try { f = o.getClass().getField(name); f.set(o, value); } catch (Exception e) { e.printStackTrace(); } } /** * sets the specified field in an array * * @param o the object to set the field for * @param name the name of the field * @param index the index in the array * @param value the new value of the field */ protected void setField(Object o, String name, int index, Object value) { Field f; try { f = o.getClass().getField(name); Array.set(f.get(o), index, value); } catch (Exception e) { e.printStackTrace(); } } /** * returns the current value of the specified field * * @param o the object the field is member of * @param name the name of the field * @return the value */ protected Object getField(Object o, String name) { Field f; Object result; try { f = o.getClass().getField(name); result = f.get(o); } catch (Exception e) { e.printStackTrace(); result = null; } return result; } /** * sets a new array for the field * * @param o the object to set the array for * @param name the name of the field * @param type the type of the array * @param length the length of the one-dimensional array */ protected void newArray(Object o, String name, Class type, int length) { newArray(o, name, type, new int[]{length}); } /** * sets a new array for the field * * @param o the object to set the array for * @param name the name of the field * @param type the type of the array * @param dimensions the dimensions of the array */ protected void newArray(Object o, String name, Class type, int[] dimensions) { Field f; try { f = o.getClass().getField(name); f.set(o, Array.newInstance(type, dimensions)); } catch (Exception e) { e.printStackTrace(); } } /** * executes the specified method and returns the result, if any * * @param o the object the method should be called from * @param name the name of the method * @param paramClasses the classes of the parameters * @param paramValues the values of the parameters * @return the return value of the method, if any (in that case null) */ protected Object invokeMethod(Object o, String name, Class[] paramClasses, Object[] paramValues) { Method m; Object result; result = null; try { m = o.getClass().getMethod(name, paramClasses); result = m.invoke(o, paramValues); } catch (Exception e) { e.printStackTrace(); result = null; } return result; } /** * transfers the local variables into a svm_parameter object * * @return the configured svm_parameter object */ protected Object getParameters() { Object result; int i; try { result = Class.forName(CLASS_SVMPARAMETER).newInstance(); setField(result, "svm_type", new Integer(m_SVMType)); setField(result, "kernel_type", new Integer(m_KernelType)); setField(result, "degree", new Integer(m_Degree)); setField(result, "gamma", new Double(m_GammaActual)); setField(result, "coef0", new Double(m_Coef0)); setField(result, "nu", new Double(m_nu)); setField(result, "cache_size", new Double(m_CacheSize)); setField(result, "C", new Double(m_Cost)); setField(result, "eps", new Double(m_eps)); setField(result, "p", new Double(m_Loss)); setField(result, "shrinking", new Integer(m_Shrinking ? 1 : 0)); setField(result, "nr_weight", new Integer(m_Weight.length)); newArray(result, "weight", Double.TYPE, m_Weight.length); newArray(result, "weight_label", Integer.TYPE, m_Weight.length); for (i = 0; i < m_Weight.length; i++) { setField(result, "weight", i, new Double(m_Weight[i])); setField(result, "weight_label", i, new Integer(m_WeightLabel[i])); } } catch (Exception e) { e.printStackTrace(); result = null; } return result; } /** * returns the svm_problem * * @param vx the x values * @param vy the y values * @return the svm_problem object */ protected Object getProblem(Vector vx, Vector vy) { Object result; try { result = Class.forName(CLASS_SVMPROBLEM).newInstance(); setField(result, "l", new Integer(vy.size())); newArray(result, "x", Class.forName(CLASS_SVMNODE), new int[]{vy.size(), 0}); for (int i = 0; i < vy.size(); i++) setField(result, "x", i, vx.elementAt(i)); newArray(result, "y", Double.TYPE, vy.size()); for (int i = 0; i < vy.size(); i++) setField(result, "y", i, new Double((String) vy.elementAt(i))); } catch (Exception e) { e.printStackTrace(); result = null; } return result; } /** * Converts and ARFF Instance into a string in the sparse format accepted by * LIBSVM * * @param instance the instance to turn into sparse format * @return the sparse String representation */ protected String instanceToSparse(Instance instance) { String line = new String(); int c = (int) instance.classValue(); if (c == 0) c = -1; line = c + " "; for (int j = 1; j < instance.numAttributes(); j++) { if (instance.value(j - 1) != 0) line += " " + j + ":" + instance.value(j - 1); } return (line + "\n"); } /** * converts an ARFF dataset into sparse format * * @param data the dataset to process * @return the processed data */ protected Vector dataToSparse(Instances data) { Vector sparse = new Vector(data.numInstances() + 1); for (int i = 0; i < data.numInstances(); i++) sparse.add(instanceToSparse(data.instance(i))); return sparse; } /** * classifies the given instance * * @param instance the instance to classify * @return the class label * @throws Exception if an error occurs */ public double classifyInstance(Instance instance) throws Exception { int[] labels = new int[instance.numClasses()]; double[] prob_estimates = null; // FracPete: the following block is NOT tested! if (m_predict_probability) { if ( (m_SVMType == SVMTYPE_EPSILON_SVR) || (m_SVMType == SVMTYPE_NU_SVR) ) { double prob = ((Double) invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_get_svr_probability", new Class[]{Class.forName(CLASS_SVMMODEL)}, new Object[]{m_Model})).doubleValue(); System.out.print( "Prob. model for test data: target value = predicted value + z,\n" + "z: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=" + prob + "\n"); } else { invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_get_labels", new Class[]{ Class.forName(CLASS_SVMMODEL), Array.newInstance(Integer.TYPE, instance.numClasses()).getClass()}, new Object[]{ m_Model, labels}); prob_estimates = new double[instance.numClasses()]; } } if (m_Filter != null) { m_Filter.input(instance); m_Filter.batchFinished(); instance = m_Filter.output(); } String line = instanceToSparse(instance); StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); st.nextToken(); // skip class int m = st.countTokens() / 2; Object x = Array.newInstance(Class.forName(CLASS_SVMNODE), m); for (int j = 0; j < m; j++) { Array.set(x, j, Class.forName(CLASS_SVMNODE).newInstance()); setField(Array.get(x, j), "index", new Integer(st.nextToken())); setField(Array.get(x, j), "value", new Double(st.nextToken())); } double v; if ( m_predict_probability && ( (m_SVMType == SVMTYPE_C_SVC) || (m_SVMType == SVMTYPE_NU_SVC) ) ) { v = ((Double) invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_predict_probability", new Class[]{ Class.forName(CLASS_SVMMODEL), Array.newInstance(Class.forName(CLASS_SVMNODE), Array.getLength(x)).getClass(), Array.newInstance(Double.TYPE, prob_estimates.length).getClass()}, new Object[]{ m_Model, x, prob_estimates})).doubleValue(); } else { v = ((Double) invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_predict", new Class[]{ Class.forName(CLASS_SVMMODEL), Array.newInstance(Class.forName(CLASS_SVMNODE), Array.getLength(x)).getClass()}, new Object[]{ m_Model, x})).doubleValue(); } // transform frist class label into Weka format if (v == -1) v = 0; return v; } /** * Returns default capabilities of the classifier. * * @return the capabilities of this classifier */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); // attributes result.enable(Capability.NOMINAL_ATTRIBUTES); result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.DATE_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); // class result.enable(Capability.NOMINAL_CLASS); result.enable(Capability.NUMERIC_CLASS); result.enable(Capability.DATE_CLASS); result.enable(Capability.MISSING_CLASS_VALUES); return result; } /** * builds the classifier * * @param insts the training instances * @throws Exception if libsvm classes not in classpath or libsvm * encountered a problem */ public void buildClassifier(Instances insts) throws Exception { if (!isPresent()) throw new Exception("libsvm classes not in CLASSPATH!"); // can classifier handle the data? getCapabilities().testWithFail(insts); // remove instances with missing class insts = new Instances(insts); insts.deleteWithMissingClass(); if (getNormalize()) { m_Filter = new Normalize(); m_Filter.setInputFormat(insts); insts = Filter.useFilter(insts, m_Filter); } Vector sparseData = dataToSparse(insts); Vector vy = new Vector(); Vector vx = new Vector(); int max_index = 0; for (int d = 0; d < sparseData.size(); d++) { String line = (String) sparseData.get(d); StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:"); vy.addElement(st.nextToken()); int m = st.countTokens() / 2; Object x = Array.newInstance(Class.forName(CLASS_SVMNODE), m); for (int j = 0; j < m; j++) { Array.set(x, j, Class.forName(CLASS_SVMNODE).newInstance()); setField(Array.get(x, j), "index", new Integer(st.nextToken())); setField(Array.get(x, j), "value", new Double(st.nextToken())); } if (m > 0) max_index = Math.max(max_index, ((Integer) getField(Array.get(x, m - 1), "index")).intValue()); vx.addElement(x); } // calculate actual gamma if (getGamma() == 0) m_GammaActual = 1.0 / max_index; else m_GammaActual = m_Gamma; // check parameter String error_msg = (String) invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_check_parameter", new Class[]{ Class.forName(CLASS_SVMPROBLEM), Class.forName(CLASS_SVMPARAMETER)}, new Object[]{ getProblem(vx, vy), getParameters()}); if (error_msg != null) throw new Exception("Error: " + error_msg); // train model m_Model = invokeMethod( Class.forName(CLASS_SVM).newInstance(), "svm_train", new Class[]{ Class.forName(CLASS_SVMPROBLEM), Class.forName(CLASS_SVMPARAMETER)}, new Object[]{ getProblem(vx, vy), getParameters()}); } /** * returns a string representation * * @return a string representation */ public String toString() { return "LibSVM wrapper, original code by Yasser EL-Manzalawy (= WLSVM)"; } /** * Main method for testing this class. * * @param args the options */ public static void main(String[] args) { try { System.out.println(Evaluation.evaluateModel(new LibSVM(), args)); } catch (Exception e) { e.printStackTrace(); System.err.println(e.getMessage()); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -