⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 svm.java

📁 一个java程序编写的svm支持向量机小程序
💻 JAVA
📖 第 1 页 / 共 3 页
字号:
package edu.udo.cs.mySVMdb.SVM;import edu.udo.cs.mySVMdb.Optimizer.*;import edu.udo.cs.mySVMdb.Container.*;import edu.udo.cs.mySVMdb.Kernel.*;import edu.udo.cs.mySVMdb.Util.*;// import java.util.BitSet;import java.lang.Integer;import java.lang.Double;import java.util.Random;import java.lang.Math;public abstract class SVM{    /**     * Abstract base class for all SVMs     * @author Stefan R黳ing     * @version 1.0     */        protected Kernel the_kernel;    protected JDBCDatabaseContainer the_container;    protected int examples_total;    protected int verbosity;    protected int working_set_size;    protected int parameters_working_set_size; // wss set in parameters    protected int target_count;    protected double convergence_epsilon;    protected double is_zero;    protected int shrink_const;    protected double lambda_factor;    protected int[] at_bound;    protected double[] sum;    protected boolean[] which_alpha;    protected int[] working_set;    protected double[] primal;    protected double Cpos;    protected double Cneg;    protected double sum_alpha;    protected double lambda_eq;    protected double epsilon_pos;    protected double epsilon_neg;    protected int to_shrink;    protected double feasible_epsilon;    protected double lambda_WS;    protected boolean quadraticLossPos;    protected boolean quadraticLossNeg;    protected double descend;    boolean shrinked;    MinHeap heap_min;    MaxHeap heap_max;    protected quadraticProblem qp;    /**     * class constructor     */    public SVM()    {    };            /**     * Init the SVM     * @param Kernel new kernel function.     * @param JDBCDatabaseContainer the data container     * @exception Exception on any error     */    public void init(Kernel new_kernel, JDBCDatabaseContainer new_container)    {	String dummy;	the_kernel = new_kernel;	the_container = new_container;	examples_total = the_container.count_examples();	try{	  verbosity = (new Integer(the_container.get_param("verbosity"))).intValue();	}	catch(Exception e){	  verbosity = 3;	};	try{	  working_set_size = (new Integer(the_container.get_param("working_set_size"))).intValue();	  if(working_set_size < 2){	      working_set_size = 2;	  };	}	catch(Exception e){	    working_set_size = 10; // !!! has to be identical to JDBCDatabaseContainer::prepareKisStatement	};	parameters_working_set_size = working_set_size;	try{	  is_zero = (new Double(the_container.get_param("is_zero"))).doubleValue();	  if(is_zero <= 0){	      is_zero = 1e-10;	  };	}	catch(Exception e){	  is_zero = 1e-10;	};	try{	  convergence_epsilon = (new Double(the_container.get_param("convergence_epsilon"))).doubleValue();	  if(convergence_epsilon <= 0){	      convergence_epsilon = 1e-3;	  };	}	catch(Exception e){	  convergence_epsilon = 1e-3;	};	try{	  shrink_const = (new Integer(the_container.get_param("shrink_const"))).intValue();	  if(shrink_const <= 0){	      shrink_const = 50;	  };	}	catch(Exception e){	  shrink_const = 50;	};	try{	  descend = (new Double(the_container.get_param("descend"))).doubleValue();	  if(descend < 0){	      descend = 1e-15;	  };	}	catch(Exception e){	  descend = 1e-15;	};	quadraticLossPos = false;	try{	  dummy = the_container.get_param("quadraticLossPos");	  if(dummy.equals("true")){	      quadraticLossPos = true;	  };	}	catch(Exception e){	};	quadraticLossNeg = false;	try{	  dummy = the_container.get_param("quadraticLossNeg");	  if(dummy.equals("true")){	      quadraticLossNeg = true;	  };	}	catch(Exception e){	};	try{	  Cpos = (new Double(the_container.get_param("C"))).doubleValue();	  Cneg = Cpos;	}	catch(Exception e){	    Cpos = 1.0;	    Cneg = 1.0;	};	try{	  Cpos = (new Double(the_container.get_param("Cpos"))).doubleValue();	}	catch(Exception e){	};	try{	  Cneg = (new Double(the_container.get_param("Cneg"))).doubleValue();	}	catch(Exception e){	};	// better in subclass:	try{	  epsilon_pos = (new Double(the_container.get_param("epsilon"))).doubleValue();	  epsilon_neg = epsilon_pos;	}	catch(Exception e){	    epsilon_pos = 0.0;	    epsilon_neg = 0.0;	};	try{	  epsilon_pos = (new Double(the_container.get_param("epsilon_pos"))).doubleValue();	}	catch(Exception e){	};	try{	  epsilon_neg = (new Double(the_container.get_param("epsilon_neg"))).doubleValue();	}	catch(Exception e){	};	try{	    dummy = the_container.get_param("balance_cost");	    if(dummy.equals("true")){		Cpos *= ((double)the_container.count_pos_examples())/((double)the_container.count_examples());		Cneg *= ((double)(the_container.count_examples()-the_container.count_pos_examples()))/((double)the_container.count_examples());	    };	}	catch(Exception e){	};	//	System.out.println("Cpos "+Cpos);	//	System.out.println("Cneg "+Cneg);	//	System.out.println("epsilon_pos "+epsilon_pos);	//	System.out.println("epsilon_neg "+epsilon_neg);	lambda_factor = 1.0;	lambda_eq=0;	target_count=0;	sum_alpha = 0;	feasible_epsilon = convergence_epsilon;	at_bound = new int[examples_total];	sum = new double[examples_total];	which_alpha = new boolean[examples_total];	primal = new double[working_set_size];    };            /**     * Train the SVM     * @exception Exception on any error     */    public void train()	throws Exception    {	target_count = 0;	shrinked = false;	init_optimizer();	init_working_set();	int iteration = 0;	int max_iterations;	try{	  max_iterations = (new Integer(the_container.get_param("max_iterations"))).intValue();	}	catch(Exception e){	    max_iterations=30000; 	};	boolean converged=false;	//long time_train_loop = System.currentTimeMillis();	//long time_dummy = 0;	//long time_resetshrink = 0;	M:while(iteration < max_iterations){	    iteration++;	    logln(4,"optimizer iteration "+iteration);	    log(4,".");	    optimize(); 	    put_optimizer_values();	    converged = convergence();	    if(converged){		logln(4,"");  // dots		project_to_constraint();		if(shrinked){		    // check convergence for all alphas  		    logln(2,"***** Checking convergence for all variables");		    //		    time_resetshrink -= System.currentTimeMillis();      		    reset_shrinked();		    //		    time_resetshrink += System.currentTimeMillis();		    converged = convergence();		};				if(converged){		    logln(1,"*** Convergence");		    break M;		};				// set variables free again		shrink_const += 10;		target_count = 0;		for(int i=0;i<examples_total;i++){		    at_bound[i]=0;		};	    };	    shrink();	    calculate_working_set();	    update_working_set();	};	//time_train_loop = (System.currentTimeMillis() - time_train_loop)/1000;	int i;	if((iteration >= max_iterations) && (! converged)){	  logln(1,"*** No convergence: Time up.");	  if(shrinked){	    // set sums for all variables for statistics	    //time_resetshrink -= System.currentTimeMillis();	    reset_shrinked();	    //time_resetshrink += System.currentTimeMillis();	  };	};		// calculate b	double new_b=0;	int new_b_count=0;	double[] my_sum = sum;	double[] my_y = the_container.get_ys();	double[] my_alphas = the_container.get_alphas();	for(i=0;i<examples_total;i++){	  if((my_alphas[i]-Cneg < -is_zero) && 	     (my_alphas[i] > is_zero)){	    new_b +=  my_y[i] - my_sum[i]-epsilon_neg;	    new_b_count++;	  }	  else if((my_alphas[i]+Cpos > is_zero) && 		  (my_alphas[i] < -is_zero)){	    new_b +=  my_y[i] - my_sum[i]+epsilon_pos;	    new_b_count++;	  };	};		if(new_b_count>0){	  the_container.set_b(new_b/((double)new_b_count));	}	else{	  // unlikely	  for(i=0;i<examples_total;i++){	    if((my_alphas[i]<is_zero) && 	       (my_alphas[i]>-is_zero)) {	      new_b += my_y[i] - my_sum[i];	      new_b_count++;	    };	  };	  if(new_b_count>0){	    the_container.set_b(new_b/((double)new_b_count));	  }	  else{	    // even unlikelier	    for(i=0;i<examples_total;i++){	      new_b += my_y[i] - my_sum[i];	      new_b_count++;	    };	    the_container.set_b(new_b/((double)new_b_count));	  };	};		if(verbosity>= 2){	  logln(2,"Done training: "+iteration+" iterations.");	  if(verbosity>= 3){	    double now_target=0;	    double now_target_dummy=0;	    for(i=0;i<examples_total;i++){	      now_target_dummy=sum[i]/2-the_container.get_y(i);	      if(is_alpha_neg(i)){		now_target_dummy+= epsilon_pos;	      }	      else{		now_target_dummy-= epsilon_neg;	      };	      now_target+=the_container.get_alpha(i)*now_target_dummy;	    };	    logln(3,"Target function: "+now_target);	  };	};		print_statistics();		exit_optimizer();	//	System.out.println("Time in resetshrink: "+(time_resetshrink/1000)+"s");	//	System.out.println("Time in train loop: "+time_train_loop+"s");    };        /**     * print statistics about result     */    protected void print_statistics()	throws Exception    {      int dim = the_container.get_dim();      int i,j;      double alpha;      double[] x;      int svs=0;      int bsv = 0;      double mae=0;      double mse = 0;      int countpos = 0;      int countneg = 0;      double y;      double prediction;      double min_lambda = Double.MAX_VALUE;      double b = the_container.get_b();      for(i=0;i<examples_total;i++){	  if(lambda(i) < min_lambda){	      min_lambda = lambda(i);	  };	  y = the_container.get_y(i);	  prediction = sum[i]+b;	  mae += Math.abs(y-prediction);	  mse += (y-prediction)*(y-prediction);	  alpha = the_container.get_alpha(i);	  if(y < prediction-epsilon_pos){	      countpos++;	  }	  else if(y > prediction+epsilon_neg){	      countneg++;	  };	  if(alpha != 0){	      svs++;	      if((alpha == Cpos) || (alpha == -Cneg)){		  bsv++;	      };	  };      };      mae /= (double)examples_total;      mse /= (double)examples_total;      min_lambda = -min_lambda;      logln(1,"Error on KKT is "+min_lambda);      logln(1,svs+" SVs");      logln(1,bsv+" BSVs");      logln(1,"MAE "+mae);      logln(1,"MSE "+mse);      logln(1,countpos+" pos loss");      logln(1,countneg+" neg loss");      if(verbosity >= 2){	  // print hyperplane	  double[] w = new double[dim];	  for(j=0;j<dim;j++) w[j] = 0;	  for(i=0;i<examples_total;i++){	      x = the_container.get_example(i);	      alpha = the_container.get_alpha(i);	      for(j=0;j<dim;j++){		  w[j] += alpha*x[j];	      };	  };	  double[] Exp = the_container.Exp;	  double[] Dev = the_container.Dev;	  if(Exp != null){	      for(j=0;j<dim;j++){		  if(Dev[j] != 0){		      w[j] /= Dev[j];		  };		  if(0 != Dev[dim]){		      w[j] *= Dev[dim];		  };		  b -= w[j]*Exp[j];	      };	      b += Exp[dim];	  };	  logln(2," ");	  for(j=0;j<dim;j++){	      logln(2,"w["+j+"] = "+w[j]);	  };	  logln(2,"b = "+b);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -