⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 smo.java

📁 Java下面的支持向量机源程序。。可应用于各种领域
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
				{
					tar = Float.valueOf(v.lastElement().toString()).intValue();
					target.add(new Integer (tar));      

					v.remove(v.size()-1);
					n = v.size();
				}
				if (is_sparse_data && is_binary ) 
				{
					sparse_binary_vector x  =  new sparse_binary_vector();
      
					for (int i=0; i<n; i++) 
					{
						if (object2float(v.elementAt(i)) < 1 || object2float(v.elementAt(i)) > d) 
						{
     
							int line2 = n_lines +1;
							System.out.println("error: line " + line2 + ": attribute index "+ (int)object2float(v.elementAt(i))+ " out of range.\n");
							System.exit(1);
						}            
						x.id.add(new Integer((int)object2float(v.elementAt(i)) -1));			
					}
					sparse_binary_points.add(x);
				}
				else if (is_sparse_data && !is_binary) 
				{
					sparse_vector x = new sparse_vector();
    
					if (this.is_libsvm_file)
					{
						for (int i=0; i<n; i+=2) 
						{
							if (object2float(v.elementAt(i)) < 1 || object2float(v.elementAt(i)) > d) 
							{
								int line3 = n_lines +1;
								System.out.println("data file error: line " + line3 + ": attribute index " + (int)object2float(v.elementAt(i)) + " out of range.\n");
								System.exit(1);
							}
							int id = (int)object2float(v.elementAt(i)) -1;
							float value = (float)object2float(v.elementAt(i+1));      
							x.id.add(new Integer(id));
							x.val.add(new Float(value));      
						}
						sparse_points.add(x);
					}
        
					else
					{
						for (int i=0; i<n; i+=2) 
						{
							if (object2float(v.elementAt(i)) < 1 || object2float(v.elementAt(i)) > d) 
							{
								int line3 = n_lines +1;
								System.out.println("data file error: line " + line3 + ": attribute index " + (int)object2float(v.elementAt(i)) + " out of range.\n");
								System.exit(1);
							}
							int id = (int)object2float(v.elementAt(i)) -1;
							float value = (float)object2float(v.elementAt(i+1));      
							x.id.add(new Integer(id));
							x.val.add(new Float(value));      
						}
						sparse_points.add(x);
					}
				}
				else 
				{
					if (v.size() != d) 
					{
          
						int line4 = n_lines +1;            
						System.out.println("Data file error: line "+line4+ " has "+ v.size() +" attributes; should be d=" + d);
						System.exit(1);
					}
					for ( int i=0; i<d; i++)
					{
						dense_points[N][i] = object2float(v.elementAt(i));	
					}		
					N= N+1;
				}
			}
		}

		catch(Exception e)
		{
			e.printStackTrace();
		}
	


		return n_lines;
  
	}

	void write_svm(PrintStream os) 
	{
		os.println(d);
		os.println(is_sparse_data);
		os.println(is_binary);
		os.println(is_linear_kernel);
		os.println(b);
		if ( is_linear_kernel)
		{
			for( int i=0; i<d; i++)
				os.println(object2float(w.elementAt(i)));
		}
		else
		{
			os.println(two_sigma_squared);
			int n_support_vectors =0;
			for ( int i=0; i< end_support_i; i++)
				if ( object2float(alph.elementAt(i)) > 0)
					n_support_vectors++;
			os.println(n_support_vectors);
				
			for ( int i=0; i< end_support_i; i++)
				if ( object2float(alph.elementAt(i)) >0)
					os.println(object2float(alph.elementAt(i)));
		
			
			for (int i=0; i<end_support_i; i++)
				if (object2float(alph.elementAt(i)) > 0) 
				{
					if (is_sparse_data && is_binary) 
					{
						os.print(object2int( target.elementAt(i)));
						os.print(" ");
						for (int j=0; j<((sparse_binary_vector)sparse_binary_points.elementAt(i)).id.size(); j++)
                        
						{
							os.print(object2int(((sparse_binary_vector)sparse_binary_points.elementAt(i)).id.elementAt(j)) +1);
							os.print(" ");
						}
					}
					else if (is_sparse_data && !is_binary) 
					{
						os.print(object2int( target.elementAt(i)));
						os.print(" ");
						for (int j=0; j<((sparse_vector)sparse_points.elementAt(i)).id.size(); j++)
                        
						{
							int id = object2int(((sparse_vector)sparse_points.elementAt(i)).id.elementAt(j)) +1;
							float value = object2float(((sparse_vector)sparse_points.elementAt(i)).val.elementAt(j));
							os.print(id+ " "+value+ " ");
                                                	
						}
					}
                
					else 
					{
						for (int j=0; j<d; j++)
						{
							os.print(dense_points[i][j]);
							os.print(" ");
						}
						os.print(object2int( target.elementAt(i)));
                	
					}
              	 
					os.print("\n");
				}
		}
 
	}

	int read_svm(DataInputStream is) 
	{

		try
		{

			d = Integer.valueOf(is.readLine().toString()).intValue();
			is_sparse_data = Boolean.valueOf(is.readLine().toString()).booleanValue();
			is_binary = Boolean.valueOf(is.readLine().toString()).booleanValue();
			is_linear_kernel = Boolean.valueOf(is.readLine().toString()).booleanValue();
			b = Float.valueOf(is.readLine().toString()).floatValue();
	
			/*System.out.println("Finnished reading first few flags ...");
			 System.out.println("d = " + this.d);
			 System.out.println("is_sparse_data = " + this.is_sparse_data);
			 System.out.println("is_binary = " + this.is_binary);
			 System.out.println("is_linear_kernel = " + this.is_linear_kernel);*/
		}
		catch (Exception e)
		{
			e.printStackTrace();
		}
		if (is_linear_kernel)
		{
	
			resize(w,d,2);
			for ( int i=0; i<d; i++)
			{
				try{
				 float weight =	Float.valueOf(is.readLine().toString()).floatValue();		
				 w.set(i,new Float(weight));	
			 }
			 catch (Exception e)
			 {
				 e.printStackTrace();
			 }		
			}
		}
		else 
		{
		
			try
			{
				two_sigma_squared = Float.valueOf(is.readLine().toString()).floatValue();
				int n_support_vectors =0;
				n_support_vectors = Integer.valueOf(is.readLine().toString()).intValue();
				
		
				resize(alph,n_support_vectors,2);
				
				for (int i =0; i< n_support_vectors;i++)
				{
					float value = Float.valueOf(is.readLine().toString()).floatValue();
					alph.set(i,new Float(value));
				}
			
			}
			catch (Exception e)
			{
				e.printStackTrace();
			}
		
			return read_data(is);
		}
		return 0;
	
	}

	float
		error_rate()
	{
		int n_total = 0;
		int n_error = 0;
		for (int i=first_test_i; i<N; i++) 
		{
			if ((learned_func(i,learned_func_flag) > 0) != (object2int(target.elementAt(i)) > 0))
				n_error++;
			n_total++;
		}
		return (float)n_error/(float)n_total;
	}


	float dot_product_func(int i,int j,int flag)
	{
		float result=0;
		if (flag == 1)
			result = dot_product_sparse_binary(i,j);
		else if (flag == 2)
			result = dot_product_sparse_nonbinary(i,j);
		else if (flag ==3)
			result = dot_product_dense(i,j);	
		
		return result;
	}
	
	float learned_func(int i, int flag)
	{
		float result =0;
		if (flag == 1)
			result =learned_func_linear_sparse_binary(i);
		else if (flag == 2)
			result = learned_func_linear_sparse_nonbinary(i);
		else if (flag ==3)
			result =learned_func_linear_dense(i);	
		else if (flag == 4)
			result =learned_func_nonlinear(i);		
		return result;
	}		


	float kernel_func(int i, int j, int flag)
	{
		float result =0;
		if (flag == 1)
			result = dot_product_func(i,j,this.dot_product_flag);
		else if (flag == 2)
			result = rbf_kernel(i,j);
		return result;
	}

	void resize(Vector v, int newSize, int type)
	{
		int original = v.size();
		if ( original > newSize)
		{
			v.setSize(newSize);
			return;
		}
		for ( int i = original; i< newSize; i++)
		{
			if ( type == 1)
				v.add(new Integer(0));
			else if ( type ==2)
				v.add(new Float(0));
			
		}
	}
	void reserve (Vector v, int size, int type)
	{
		for ( int i=0; i<size; i++)
		{
			if ( type ==1)
				v.add(i,new Integer(0));
			else if ( type ==2)
				v.add(i,new Float(0));
		}
	}
	
	void reserveSparse(Vector v, int size)
	{
	
		for ( int i=0; i<size; i++)
		{
			
			v.add(i,new sparse_vector());			
		}
	}

	void reserveSparseBinary(Vector v, int size)
	{
		for ( int i=0; i<size; i++)
			v.add(i,new sparse_binary_vector());
	
	}
	
	void reserve (float[][] array, int size)
	{
		for ( int i=0; i< size; i++)
			for ( int j=0; j< d;j++)
				array[i][j] = 0;
	}	

	public static void main(String[] args)
	{
		long time, newTime;
		time = System.currentTimeMillis();
		try
		{		
			String data_file_name = "java-svm.data";
			String svm_file_name = "java-svm.model";
			String output_file_name = "java-svm.output";
			Smo my = new Smo();
			int numChanged =0;
			int examineAll =0;

			{
				GetOpt go =  new GetOpt(args,"n:d:c:t:e:p:f:m:o:r:lsbai");
				go.optErr= true;
				int ch = -1;
				int errflg = 0;
				while ((ch = go.getopt()) != go.optEOF) 
					switch (ch)
					{
						case 'n':
							my.N = go.processArg(go.optArgGet(),my.N);
							break;
						case 'd':
     
							my.d = go.processArg(go.optArgGet(),my.d);
							break;
						case 'c':
							my.C = go.processArg(go.optArgGet(),my.C);
							break;
						case 't':
							my.tolerance = go.processArg(go.optArgGet(),my.tolerance);
							break;
						case 'e':
							my.eps = go.processArg(go.optArgGet(),my.eps);
							break;
						case 'p':
							my.two_sigma_squared = go.processArg(go.optArgGet(),my.two_sigma_squared);
							break;
						case 'f':
							data_file_name = go.optArgGet();
							break;
						case 'm':
							svm_file_name = go.optArgGet();
							break;
						case 'o':
							output_file_name = go.optArgGet();
							break;
						case 'r':
							System.out.println("Random");
							break;
						case 'l':
							my.is_linear_kernel = true;
							break;
						case 's':
							my.is_sparse_data = true;
							break;
						case 'b':
							my.is_binary = true;
							my.is_sparse_data =true;
							break;
						case 'a':
							my.is_test_only = true;
							break;
						case 'i':
							my.is_libsvm_file = true;
							break;  
						case '?':
							errflg++;
					}
				if (errflg >0 )
				{
					System.out.println("usage: " + args[0] + " " +
						"\n-f  data_file_name\n" +
						"-m  svm_file_name\n"  +
						"-o  output_file_name\n" +
						"-n  N\n" +
						"-d  d\n" +
						"-c  C\n" +
						"-t  tolerance\n" +
						"-e  epsilon\n" +
						"-p  two_sigma_squared\n" +
						//"-r  random_seed\n" +
						"-l  (is_linear_kernel)\n"+
						"-s  (is_sparse_data)\n" +
						"-b  (is_binary)\n" +
						"-a  (is_test_only)\n" );
					//  "-i  (is_libsvm_file)\n");
					System.exit(2);
				}
			}


			{
				int n =0;
				if (my.is_test_only) 
				{
     
					try
					{
						FileInputStream svm = new FileInputStream(svm_file_name);
						DataInputStream svm_file = new DataInputStream(svm);
						my.end_support_i = my.first_test_i = n = my.read_svm(svm_file);
						// my.N += n;
					}
					catch (Exception e)
					{
						e.printStackTrace();
					}
				}
      
				if (my.N > 0) 
				{   
					my.reserve(my.target,my.N,1);      
    
					if (my.is_sparse_data && my.is_binary)
						my.reserveSparseBinary(my.sparse_binary_points,my.N);
					else if (my.is_sparse_data && !my.is_binary)
					{		
						my.reserveSparse(my.sparse_points,my.N);
        
					}
					else
						my.reserve(my.dense_points,my.N);
				}
      
				System.out.println(data_file_name);
   	
				FileInputStream data = new FileInputStream(data_file_name);
				DataInputStream data_file = new DataInputStream (data);
				n = my.read_data(data_file);
 
				if (my.is_test_only) 
				{
					my.N = my.first_test_i + n;
				}
				else 
				{
					my.N = n;
					my.first_test_i = 0;
					my.end_support_i = my.N;
				}
			}

			if (!my.is_test_only) 
			{
      
				my.resize(my.alph,my.end_support_i,2);
				my.b = 0;
				my.resize(my.error_cache,my.N,2);
				if (my.is_linear_kernel)
					my.resize(my.w,my.d,2);
			}
  
			if (my.is_linear_kernel && my.is_sparse_data && my.is_binary)
				my.learned_func_flag = 1; 
			if (my.is_linear_kernel && my.is_sparse_data && !my.is_binary)
				my.learned_func_flag = 2; 
			if (my.is_linear_kernel && !my.is_sparse_data)
				my.learned_func_flag = 3;           
			if (!my.is_linear_kernel)
				my.learned_func_flag = 4;    
     
			if (my.is_sparse_data && my.is_binary)
				my.dot_product_flag = 1;
			if (my.is_sparse_data && !my.is_binary)
				my.dot_product_flag = 2;          
			if (!my.is_sparse_data)
				my.dot_product_flag = 3;
  
			if (my.is_linear_kernel)
				my.kernel_flag = 1;
          
			if (!my.is_linear_kernel)
				my.kernel_flag = 2;
			/***************************************************************************/
			//	System.out.println("All flags " + "dot flag "+ my.dot_product_flag + ",kernel flag " +my.kernel_flag+ ",learn flag " + my.learned_func_flag);
			/***************************************************************************/
			if (!my.is_linear_kernel) 
			{
				my.resize(my.precomputed_self_dot_product,my.N,2);
   
				for (int i=0; i<my.N; i++)
					for (int j=0; j<my.N; j++)
					{
						if (i != j)
							my.precomputed_dot_product[i][j] = my.dot_product_func(i,j,my.dot_product_flag);
						else
						{
							float temp = my.dot_product_func(i,i,my.dot_product_flag);
							my.precomputed_self_dot_product.set(i,new Float(temp));
							my.precomputed_dot_product[i][i] =temp;
						}
					}
			}
			if (!my.is_test_only) 
			{
				numChanged = 0;
				examineAll = 1;
				while (numChanged > 0 || examineAll >0) 
				{
					numChanged = 0;
					if (examineAll>0) 
					{ 
    
						for (int k = 0; k < my.N; k++)
							numChanged += my.examineExample (k);
       
					}
					else 
					{ 
						for (int k = 0; k < my.N; k++)
						{
							if (my.object2float(my.alph.elementAt(k)) != 0 && my.object2float(my.alph.elementAt(k)) != my.C)
								numChanged += my.examineExample (k);
						}
					}
					if (examineAll == 1)
						examineAll = 0;
					else if (numChanged == 0)
						examineAll = 1;                    
					{
						int non_bound_support =0;
						int bound_support =0;
						for (int i=0; i<my.N; i++)
							if (my.object2float(my.alph.elementAt(i)) > 0) 
							{
								if (my.object2float(my.alph.elementAt(i)) < my.C)
								{non_bound_support++;}
								else
									bound_support++;
							}
						System.out.println("non_bound= " +non_bound_support+"\t"+"bound_support= "+bound_support);
					}
				}
      
				{
					if (!my.is_test_only && svm_file_name != null) 
					{
						try
						{
							PrintStream svm_file = new PrintStream(new FileOutputStream(svm_file_name));
							my.write_svm(svm_file);
						}
						catch(Exception e)
						{
							e.printStackTrace();
						}
          
					}
				}
    
				System.out.println("Threshold=" + my.b);
			}
			System.out.println("Error_rate="+my.error_rate());
			newTime = System.currentTimeMillis();
			{
				try
				{
					PrintStream svm_file = new PrintStream(new FileOutputStream(output_file_name));
					for (int i=my.first_test_i; i<my.N; i++)
						svm_file.println(my.learned_func(i,my.learned_func_flag));
				}
				catch(Exception e)
				{
					e.printStackTrace();
				}
			}
			System.out.println("Time cost = "+(newTime - time)*1.0/1000);
		}  
		catch(Exception e)
		{
			e.printStackTrace();
		}
	}

}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -