📄 wlsvm.java
字号:
*
* @return
*/
public double[] getWeights() {
return param.weight;
}
/**
* Sets the WLSVM classifier options
*
*/
public void setOptions(String[] options) throws Exception {
param = new svm_parameter();
String svmtypeString = Utils.getOption('S', options);
if (svmtypeString.length() != 0) {
param.svm_type = Integer.parseInt(svmtypeString);
} else {
param.svm_type = svm_parameter.C_SVC;
}
String kerneltypeString = Utils.getOption('K', options);
if (kerneltypeString.length() != 0) {
param.kernel_type = Integer.parseInt(kerneltypeString);
} else {
param.kernel_type = svm_parameter.RBF;
}
String degreeString = Utils.getOption('D', options);
if (degreeString.length() != 0) {
param.degree = (new Double(degreeString)).doubleValue();
} else {
param.degree = 3;
}
String gammaString = Utils.getOption('G', options);
if (gammaString.length() != 0) {
param.gamma = (new Double(gammaString)).doubleValue();
} else {
param.gamma = 0;
}
String coef0String = Utils.getOption('R', options);
if (coef0String.length() != 0) {
param.coef0 = (new Double(coef0String)).doubleValue();
} else {
param.coef0 = 0;
}
String nuString = Utils.getOption('N', options);
if (nuString.length() != 0) {
param.nu = (new Double(nuString)).doubleValue();
} else {
param.nu = 0.5;
}
String cacheString = Utils.getOption('M', options);
if (cacheString.length() != 0) {
param.cache_size = (new Double(cacheString)).doubleValue();
} else {
param.cache_size = 40;
}
String costString = Utils.getOption('C', options);
if (costString.length() != 0) {
param.C = (new Double(costString)).doubleValue();
} else {
param.C = 1;
}
String epsString = Utils.getOption('E', options);
if (epsString.length() != 0) {
param.eps = (new Double(epsString)).doubleValue();
} else {
param.eps = 1e-3;
}
String normString = Utils.getOption('Z', options);
if (normString.length() != 0) {
normalize = Integer.parseInt(normString);
} else {
normalize = 0;
}
String lossString = Utils.getOption('P', options);
if (lossString.length() != 0) {
param.p = (new Double(lossString)).doubleValue();
} else {
param.p = 0.1;
}
String shrinkingString = Utils.getOption('H', options);
if (shrinkingString.length() != 0) {
param.shrinking = Integer.parseInt(shrinkingString);
} else {
param.shrinking = 1;
}
String probString = Utils.getOption('B', options);
if (probString.length() != 0) {
param.probability = Integer.parseInt(probString);
} else {
param.probability = 0;
}
String weightsString = Utils.getOption('W', options);
if (weightsString.length() != 0) {
StringTokenizer st = new StringTokenizer(weightsString, " ");
int n_classes = st.countTokens();
param.weight_label = new int[n_classes];
param.weight = new double[n_classes];
// get array of doubles from this string
int count = 0;
while (st.hasMoreTokens()) {
param.weight[count++] = atof(st.nextToken());
}
param.nr_weight = count;
param.weight_label[0] = -1; // label of first class
for (int i = 1; i < count; i++)
param.weight_label[i] = i;
} else {
param.nr_weight = 0;
param.weight_label = new int[0];
param.weight = new double[0];
}
}
/**
* Returns the current WLSVM options
*/
public String[] getOptions() {
if (param == null) {
String[] dummy = {};
try{
setOptions(dummy);
} catch (Exception e) {
e.printStackTrace();
}
}
String[] options = new String[40];
int current = 0;
options[current++] = "-S";
options[current++] = "" + param.svm_type;
options[current++] = "-K";
options[current++] = "" + param.kernel_type;
options[current++] = "-D";
options[current++] = "" + param.degree;
options[current++] = "-G";
options[current++] = "" + param.gamma;
options[current++] = "-R";
options[current++] = "" + param.coef0;
options[current++] = "-N";
options[current++] = "" + param.nu;
options[current++] = "-M";
options[current++] = "" + param.cache_size;
options[current++] = "-C";
options[current++] = "" + param.C;
options[current++] = "-E";
options[current++] = "" + param.eps;
options[current++] = "-P";
options[current++] = "" + param.p;
options[current++] = "-H";
options[current++] = "" + param.shrinking;
options[current++] = "-B";
options[current++] = "" + param.probability;
options[current++] = "-Z";
options[current++] = "" + normalize;
if (param.nr_weight > 0) {
options[current++] = "-W";
String weights = new String();
for (int i = 0; i < param.nr_weight; i++) {
weights += " " + param.weight[i];
}
options[current++] = weights.trim();
}
while (current < options.length) {
options[current++] = "";
}
return options;
}
protected static double atof(String s) {
return Double.valueOf(s).doubleValue();
}
protected static int atoi(String s) {
return Integer.parseInt(s);
}
/**
* Converts an ARFF Instance into a string in the sparse format accepted by
* LIBSVM
*
* @param instance
* @return
*/
protected String InstanceToSparse(Instance instance) {
String line = new String();
int c = (int) instance.classValue();
if (c == 0)
c = -1;
line = c + " ";
for (int j = 1; j < instance.numAttributes(); j++) {
if (j-1 == instance.classIndex()) {
continue;
}
if (instance.isMissing(j-1))
continue;
if (instance.value(j - 1) != 0)
line += " " + j + ":" + instance.value(j - 1);
}
// System.out.println(line);
return (line + "\n");
}
/**
* converts an ARFF dataset into sparse format
*
* @param instances
* @return
*/
protected Vector DataToSparse(Instances data) {
Vector sparse = new Vector(data.numInstances() + 1);
for (int i = 0; i < data.numInstances(); i++) { // for each instance
sparse.add(InstanceToSparse(data.instance(i)));
}
return sparse;
}
public double[] distributionForInstance (Instance instance) throws Exception {
int svm_type = svm.svm_get_svm_type(model);
int nr_class = svm.svm_get_nr_class(model);
int[] labels = new int[nr_class];
double[] prob_estimates = null;
if (param.probability == 1) {
if (svm_type == svm_parameter.EPSILON_SVR || svm_type == svm_parameter.NU_SVR) {
System.err.println("Do not use distributionForInstance for regression models!");
return null;
} else {
svm.svm_get_labels(model, labels);
prob_estimates = new double[nr_class];
}
}
if (filter != null) {
filter.input(instance);
filter.batchFinished();
instance = filter.output();
}
String line = InstanceToSparse(instance);
StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
double target = atof(st.nextToken());
int m = st.countTokens() / 2;
svm_node[] x = new svm_node[m];
for (int j = 0; j < m; j++) {
x[j] = new svm_node();
x[j].index = atoi(st.nextToken());
x[j].value = atof(st.nextToken());
}
double v;
double[] weka_probs = new double[nr_class];
if (param.probability == 1 && (svm_type == svm_parameter.C_SVC || svm_type == svm_parameter.NU_SVC)) {
v = svm.svm_predict_probability(model, x, prob_estimates);
// Return order of probabilities to canonical weka attribute order
for (int k=0; k < prob_estimates.length; k++) {
//System.out.print(labels[k] + ":" + prob_estimates[k] + " ");
if (labels[k] == -1)
labels[k] = 0;
weka_probs[labels[k]] = prob_estimates[k];
}
//System.out.println();
} else {
v = svm.svm_predict(model, x);
if (v == -1)
v = 0;
weka_probs[(int)v] = 1;
// System.out.println(v);
}
return weka_probs;
}
/**
* Builds the model
*/
public void buildClassifier(Instances insts) throws Exception {
if (normalize == 1) {
if (getDebug())
System.err.println("Normalizing...");
filter = new Normalize();
filter.setInputFormat(insts);
insts = Filter.useFilter(insts, filter);
}
if (getDebug())
System.err.println("Converting to libsvm format...");
Vector sparseData = DataToSparse(insts);
Vector vy = new Vector();
Vector vx = new Vector();
int max_index = 0;
if (getDebug())
System.err.println("Tokenizing libsvm data...");
for (int d = 0; d < sparseData.size(); d++) {
String line = (String) sparseData.get(d);
StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
vy.addElement(st.nextToken());
int m = st.countTokens() / 2;
svm_node[] x = new svm_node[m];
for (int j = 0; j < m; j++) {
x[j] = new svm_node();
x[j].index = atoi(st.nextToken());
x[j].value = atof(st.nextToken());
}
if (m > 0)
max_index = Math.max(max_index, x[m - 1].index);
vx.addElement(x);
}
prob = new svm_problem();
prob.l = vy.size();
prob.x = new svm_node[prob.l][];
for (int i = 0; i < prob.l; i++)
prob.x[i] = (svm_node[]) vx.elementAt(i);
prob.y = new double[prob.l];
for (int i = 0; i < prob.l; i++)
prob.y[i] = atof((String) vy.elementAt(i));
if (param.gamma == 0)
param.gamma = 1.0 / max_index;
error_msg = svm.svm_check_parameter(prob, param);
if (error_msg != null) {
System.err.print("Error: " + error_msg + "\n");
System.exit(1);
}
if (getDebug())
System.err.println("Training model");
try {
model = svm.svm_train(prob, param);
} catch (Exception e) {
e.printStackTrace();
}
}
public String toString() {
return "WLSVM Classifier By Yasser EL-Manzalawy";
}
/**
*
* @param argv
* @throws Exception
*/
public static void main(String[] argv) throws Exception {
if (argv.length < 1) {
System.out.println("Usage: Test <arff file>");
System.exit(1);
}
String dataFile = argv[0];
WLSVM lib = new WLSVM();
String[] ops = { new String("-t"),
dataFile,
new String("-x"),
new String("5"),
new String("-i"),
//WLSVM options
new String("-S"),
new String("0"),
new String("-K"),
new String("2"),
new String("-G"),
new String("1"),
new String("-C"),
new String("7"),
//new String("-B"),
//new String("1"),
new String("-M"),
new String("100"),
//new String("-W"),
//new String("1.0 1.0")
};
System.out.println(Evaluation.evaluateModel(lib, ops));
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -