⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 wlsvm.java

📁 基于支持向量机的分类算法
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
	 * 
	 * @return
	 */
	public double[] getWeights() {
		return param.weight;
	}
	
	/**
	 * Sets the WLSVM classifier options
	 *  
	 */
	public void setOptions(String[] options) throws Exception {
		param = new svm_parameter();
		
		String svmtypeString = Utils.getOption('S', options);
		if (svmtypeString.length() != 0) {
			param.svm_type = Integer.parseInt(svmtypeString);
		} else {
			param.svm_type = svm_parameter.C_SVC;
		}
		
		String kerneltypeString = Utils.getOption('K', options);
		if (kerneltypeString.length() != 0) {
			param.kernel_type = Integer.parseInt(kerneltypeString);
		} else {
			param.kernel_type = svm_parameter.RBF;
		}
		
		String degreeString = Utils.getOption('D', options);
		if (degreeString.length() != 0) {
			param.degree = (new Double(degreeString)).doubleValue();
		} else {
			param.degree = 3;
		}
		
		String gammaString = Utils.getOption('G', options);
		if (gammaString.length() != 0) {
			param.gamma = (new Double(gammaString)).doubleValue();
		} else {
			param.gamma = 0;
		}
		
		String coef0String = Utils.getOption('R', options);
		if (coef0String.length() != 0) {
			param.coef0 = (new Double(coef0String)).doubleValue();
		} else {
			param.coef0 = 0;
		}
		
		String nuString = Utils.getOption('N', options);
		if (nuString.length() != 0) {
			param.nu = (new Double(nuString)).doubleValue();
		} else {
			param.nu = 0.5;
		}
		
		String cacheString = Utils.getOption('M', options);
		if (cacheString.length() != 0) {
			param.cache_size = (new Double(cacheString)).doubleValue();
		} else {
			param.cache_size = 40;
		}
		
		String costString = Utils.getOption('C', options);
		if (costString.length() != 0) {
			param.C = (new Double(costString)).doubleValue();
		} else {
			param.C = 1;
		}
		
		String epsString = Utils.getOption('E', options);
		if (epsString.length() != 0) {
			param.eps = (new Double(epsString)).doubleValue();
		} else {
			param.eps = 1e-3;
		}
		
		String normString = Utils.getOption('Z', options);
		if (normString.length() != 0) {
			normalize = Integer.parseInt(normString);
		} else {
			normalize = 0;
		}
		
		String lossString = Utils.getOption('P', options);
		if (lossString.length() != 0) {
			param.p = (new Double(lossString)).doubleValue();
		} else {
			param.p = 0.1;
		}
		
		String shrinkingString = Utils.getOption('H', options);
		if (shrinkingString.length() != 0) {
			param.shrinking = Integer.parseInt(shrinkingString);
		} else {
			param.shrinking = 1;
		}
		
		String probString = Utils.getOption('B', options);
		if (probString.length() != 0) {
			param.probability = Integer.parseInt(probString);  
		} else {
			param.probability = 0;
		}
		
		String weightsString = Utils.getOption('W', options);
		if (weightsString.length() != 0) {
			StringTokenizer st = new StringTokenizer(weightsString, " ");
			int n_classes = st.countTokens();
			param.weight_label = new int[n_classes];
			param.weight = new double[n_classes];
			
			// get array of doubles from this string                        
			int count = 0;
			while (st.hasMoreTokens()) {                
				param.weight[count++] = atof(st.nextToken());
			}
			param.nr_weight = count;
			param.weight_label[0] = -1; // label of first class
			for (int i = 1; i < count; i++)
				param.weight_label[i] = i;
		} else {
			param.nr_weight = 0;
			param.weight_label = new int[0];
			param.weight = new double[0];           
		}
	}
	
	/**
	 * Returns the current WLSVM options
	 */
	
	public String[] getOptions() {
		
		if (param == null) {
			String[] dummy = {};
			try{
				setOptions(dummy);
			} catch (Exception e) {
				e.printStackTrace();
			}
		}
		
		String[] options = new String[40];
		int current = 0;
		
		options[current++] = "-S";
		options[current++] = "" + param.svm_type;
		options[current++] = "-K";
		options[current++] = "" + param.kernel_type;
		options[current++] = "-D";
		options[current++] = "" + param.degree;
		options[current++] = "-G";
		options[current++] = "" + param.gamma;
		options[current++] = "-R";
		options[current++] = "" + param.coef0;
		options[current++] = "-N";
		options[current++] = "" + param.nu;
		options[current++] = "-M";
		options[current++] = "" + param.cache_size;
		
		options[current++] = "-C";
		options[current++] = "" + param.C;
		options[current++] = "-E";
		options[current++] = "" + param.eps;
		options[current++] = "-P";
		options[current++] = "" + param.p;
		options[current++] = "-H";
		options[current++] = "" + param.shrinking;
		options[current++] = "-B";
		options[current++] = "" + param.probability;
		options[current++] = "-Z";
		options[current++] = "" + normalize;
		
		if (param.nr_weight > 0) {
			options[current++] = "-W";
			
			String weights = new String();
			for (int i = 0; i < param.nr_weight; i++) {
				weights += " " + param.weight[i];
			}
			
			options[current++] = weights.trim();
		}
		
		while (current < options.length) {
			options[current++] = "";
		}
		return options;
	}
	
	protected static double atof(String s) {
		return Double.valueOf(s).doubleValue();
	}
	
	protected static int atoi(String s) {
		return Integer.parseInt(s);
	}
	
	/**
	 * Converts an ARFF Instance into a string in the sparse format accepted by
	 * LIBSVM
	 * 
	 * @param instance
	 * @return
	 */
	protected String InstanceToSparse(Instance instance) {
		String line = new String();
		int c = (int) instance.classValue();
		if (c == 0)
			c = -1;
		line = c + " ";
		for (int j = 1; j < instance.numAttributes(); j++) {
			if (j-1 == instance.classIndex()) {				
				continue;
			}
			if (instance.isMissing(j-1)) 
				continue;
			if (instance.value(j - 1) != 0)
				line += " " + j + ":" + instance.value(j - 1);
		}
		// System.out.println(line); 
		return (line + "\n");
	}
	
	/**
	 * converts an ARFF dataset into sparse format
	 * 
	 * @param instances
	 * @return
	 */
	protected Vector DataToSparse(Instances data) {
		Vector sparse = new Vector(data.numInstances() + 1);
		
		for (int i = 0; i < data.numInstances(); i++) { // for each instance
			sparse.add(InstanceToSparse(data.instance(i)));
		}
		return sparse;
	}
	
	
	public double[] distributionForInstance (Instance instance) throws Exception {	
		int svm_type = svm.svm_get_svm_type(model);
		int nr_class = svm.svm_get_nr_class(model);
		int[] labels = new int[nr_class];
		double[] prob_estimates = null;
		
		if (param.probability == 1) {
			if (svm_type == svm_parameter.EPSILON_SVR || svm_type == svm_parameter.NU_SVR) {
				System.err.println("Do not use distributionForInstance for regression models!");
				return null;
			} else {
				svm.svm_get_labels(model, labels);
				prob_estimates = new double[nr_class];
			}
		}
		
		if (filter != null) {
			filter.input(instance);
			filter.batchFinished();
			instance = filter.output();
		}
		
		String line = InstanceToSparse(instance);
		
		StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
		
		double target = atof(st.nextToken());
		int m = st.countTokens() / 2;
		svm_node[] x = new svm_node[m];
		for (int j = 0; j < m; j++) {
			x[j] = new svm_node();
			x[j].index = atoi(st.nextToken());
			x[j].value = atof(st.nextToken());
		}
		
		double v;
		double[] weka_probs = new double[nr_class];
		if (param.probability == 1 && (svm_type == svm_parameter.C_SVC || svm_type == svm_parameter.NU_SVC)) {           
			v = svm.svm_predict_probability(model, x, prob_estimates);
			
			// Return order of probabilities to canonical weka attribute order
			for (int k=0; k < prob_estimates.length; k++) {
				 //System.out.print(labels[k] + ":" + prob_estimates[k] + " ");
				if (labels[k] == -1) 
					labels[k] = 0;
				weka_probs[labels[k]] = prob_estimates[k];
			}
			 //System.out.println();
		} else {
			v = svm.svm_predict(model, x);
			if (v == -1) 
				v = 0;
			weka_probs[(int)v] = 1;
			// System.out.println(v);
		}
		
		return weka_probs;                
	}
	
	/**
	 * Builds the model
	 */
	public void buildClassifier(Instances insts) throws Exception {
		
		if (normalize == 1) {
			if (getDebug())
				System.err.println("Normalizing...");
			filter = new Normalize();
			filter.setInputFormat(insts);
			insts = Filter.useFilter(insts, filter);
		}
		
		if (getDebug())
			System.err.println("Converting to libsvm format...");
		Vector sparseData = DataToSparse(insts);
		Vector vy = new Vector();
		Vector vx = new Vector();
		int max_index = 0;
		
		if (getDebug())
			System.err.println("Tokenizing libsvm data...");
		for (int d = 0; d < sparseData.size(); d++) {
			String line = (String) sparseData.get(d);
			
			StringTokenizer st = new StringTokenizer(line, " \t\n\r\f:");
			
			vy.addElement(st.nextToken());
			int m = st.countTokens() / 2;
			svm_node[] x = new svm_node[m];
			for (int j = 0; j < m; j++) {
				x[j] = new svm_node();
				x[j].index = atoi(st.nextToken());
				x[j].value = atof(st.nextToken());
			}
			if (m > 0)
				max_index = Math.max(max_index, x[m - 1].index);
			vx.addElement(x);
		}
		
		prob = new svm_problem();
		prob.l = vy.size();
		prob.x = new svm_node[prob.l][];
		for (int i = 0; i < prob.l; i++)
			prob.x[i] = (svm_node[]) vx.elementAt(i);
		prob.y = new double[prob.l];
		for (int i = 0; i < prob.l; i++)
			prob.y[i] = atof((String) vy.elementAt(i));
		
		if (param.gamma == 0)
			param.gamma = 1.0 / max_index;
		
		error_msg = svm.svm_check_parameter(prob, param);
		
		if (error_msg != null) {
			System.err.print("Error: " + error_msg + "\n");
			System.exit(1);
		}
		
		if (getDebug())
			System.err.println("Training model");
		try {
			model = svm.svm_train(prob, param);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}
	
	public String toString() {
		return "WLSVM Classifier By Yasser EL-Manzalawy";
	}
	
	/**
	 * 
	 * @param argv
	 * @throws Exception
	 */
	
	public static void main(String[] argv) throws Exception {
		if (argv.length < 1) {
			System.out.println("Usage: Test <arff file>");
			System.exit(1);
		}
		String dataFile = argv[0];
		
		WLSVM lib = new WLSVM();
		
		String[] ops = { new String("-t"), 
				dataFile, 
				new String("-x"),
				new String("5"), 
				new String("-i"),
				//WLSVM options
				new String("-S"),  
				new String("0"),
				new String("-K"), 
				new String("2"),
				new String("-G"), 
				new String("1"), 
				new String("-C"),
				new String("7"),
				//new String("-B"),    
				//new String("1"),
				new String("-M"), 
				new String("100"),
				//new String("-W"), 
				//new String("1.0 1.0")
		};
		
		System.out.println(Evaluation.evaluateModel(lib, ops));
		
	}
	
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -