⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 algorithmlbg.java

📁 包含了模式识别中常用的一些分类器设计算法
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
//--------------------------------------------------------------------------
// AlgorithmLBG.java 6.0 03/15/2005
// Created       : Phil Trasatti     Edited : Daniel May
//                              Last Edited : Sanjay Patil
//
// Description   : Describes the LBG algorithms
// Remarks       : Code unchanged since created. Created 07/15/2003
//--------------------------------------------------------------------------

//----------------------
// import java packages
//----------------------
import java.util.*;
import java.awt.*;

/**
 * implements the LBG algorithm
 */
public class AlgorithmLBG extends Algorithm
{

    int output_canvas_d[][];    
    int iterations;
    int currObject;

    String algo_id = "AlgorithmLBG"; 
    Vector<MyPoint> support_vectors_d;
    Vector<MyPoint> data_pool_d;
    Vector<DecisionRegion> cbinary_d;
    DecisionRegion decision_regions_d;

    /**
    * Overrides the initialize() method in the base class. Initializes
    * member data and prepares for execution of first step. This method
    * "resets" the algorithm.
    *
    * @return   Returns true
    */
    public boolean initialize()
    {
	// Debug
	//
	// System.out.println(algo_id + ": initialize()");
	
	scale = output_panel_d.disp_area_d.getDisplayScale();
	data_pool_d = new Vector<MyPoint>();
	decision_regions_d = new DecisionRegion();
	point_means_d = new Vector<MyPoint>();
	support_vectors_d = new Vector<MyPoint>();
	iterations = Classify.main_menu_d.iterations;
	description_d = new Vector<String>();
	step_count = 3;

	cbinary_d = new Vector<DecisionRegion>();
	currObject = 0;

	// add the process description for the LBG algorithm
	//
	if (description_d.size() == 0)
	    {
		String str = new String("   0. Initialize the original data.");
		description_d.addElement(str);
		
		str = new String("   1. Displaying the original data.");
		description_d.addElement(str);
		
		str = new String("   2. Computing the means for each data set.");
		description_d.addElement(str);
		
		str = new String("   3. Stepping through iterations ");
		description_d.addElement(str);

		str = new String("Iteration");
		description_d.addElement(str);		
	    }
	
	// append message to process box
	//
	pro_box_d.appendMessage("LBG Analysis:" + "\n");
	
	// set the data points for this algorithm
	//
	//	set1_d = (Vector)data_points_d.dset1.clone();
	//	set2_d = (Vector)data_points_d.dset2.clone();
	//	set3_d = (Vector)data_points_d.dset3.clone();
	//	set4_d = (Vector)data_points_d.dset4.clone();
	//
	set1_d = data_points_d.dset1;
	set2_d = data_points_d.dset2;
	set3_d = data_points_d.dset3;
	set4_d = data_points_d.dset4;
	
	// advance to step 1
	//
	step_index_d = 0;
	
	// append message to process box and scroll
	//
	pro_box_d.appendMessage((String)description_d.get(step_index_d));
	pro_box_d.scrollPane.getVerticalScrollBar().setValue(1000000);
	
	// exit gracefully
	//
	return true;
	
    }

    /**
    * Displays data sets from input box in output box.
    *
    * @return Returns true
    */
    boolean step1()
    {
	// Debug
	// 
	// System.out.println(algo_id + " : step1()");
	
	// set up progress bar
	//
	pro_box_d.setProgressMin(0);
	pro_box_d.setProgressMax(1);
	pro_box_d.setProgressCurr(0);
	
	scaleToFitData();

	// Display original data
	//
	output_panel_d.addOutput(set1_d, Classify.PTYPE_INPUT, 
				 data_points_d.color_dset1);
	output_panel_d.addOutput(set2_d, Classify.PTYPE_INPUT,
				 data_points_d.color_dset2);
	output_panel_d.addOutput(set3_d, Classify.PTYPE_INPUT,
				 data_points_d.color_dset3);
	output_panel_d.addOutput(set4_d, Classify.PTYPE_INPUT, 
				 data_points_d.color_dset4);
	
	// step 1 completed
	//
	pro_box_d.setProgressCurr(1);
	output_panel_d.repaint();
	
	return true;
    }

    /**
     * Computes the means of each data set and displays the means graphically
     * and numerically
     * 
     * @return   Returns true
     */
    boolean step2()
    {
	// Debug
	//
	// System.out.println(algo_id + " : step2()");
	
	// determine the within class scatter matrix
	//
	generatePool();
	decision_regions_d.addRegion(data_pool_d);
				     //(Vector<MyPoint>)data_pool_d.clone());

	Vector<MyPoint> guesses = new Vector<MyPoint>(2, 2);

	MyPoint mean = MathUtil.computeClusterMean(data_pool_d);

	computeMeans();

	guesses.addElement(new MyPoint(mean.x, mean.y));

	decision_regions_d.setGuesses(guesses);
				      //(Vector<MyPoint>)guesses.clone());


	output_panel_d.addOutput(decision_regions_d.getGuesses(), 
				 Classify.PTYPE_OUTPUT_LARGE, 
				 Color.black);

	// cbinary_d.addElement(decision_regions_d);	

	pro_box_d.setProgressCurr(20);
	output_panel_d.repaint();
	
	return true;
    }
    
    /**
     * Computes the Decision Regions and the associated errors. 
     *
     * @return Returns true
     */
    boolean step3()
    {
	// Debug
	//
	// System.out.println(algo_id + " : step3()");
	
	if (currObject < iterations)
	{
	    
	    pro_box_d.appendMessage("      Iteration " 
				    + (currObject+1) + "\n");	
	    if ((currObject + 1) != iterations)
		step_index_d--;

	    currObject++;

	    Vector<DataSet> decisionRegions = decision_regions_d.getRegions();

	    decision_regions_d = new DecisionRegion();

	    // compute the new means of the classified data
	    //
	    computeBinaryDeviates(decisionRegions);
	    
	    Vector<MyPoint> guesses = decision_regions_d.getGuesses();

	    // classify the data based on the guesses
	    //
	    classify(guesses);
	    
 	    // determine the total number of clusters
	    //
	    int numclusters = decision_regions_d.getNumRegions();
	    
	    // determine the classification error of each cluster
	    //
	    double error = 0.0;
	    double total = 0.0;   

	    for (int i = 0; i < numclusters; i++) 
	    {
		
		// get the cluster associated with the region
		//
		Vector cluster = decision_regions_d.getRegion(i);
		
		// determine the mean of the current cluster
		//
		MyPoint mean = MathUtil.computeClusterMean(cluster);
		
		// display the mean for the current cluster
		//
		//double xval = setDecimal((mean.x) + Xmin, 2);
		//double yval = setDecimal(Ymax - (mean.y), 2);
		
		String message = new String("      Mean for cluster " +
					    i + ": " 
					    + MathUtil.setDecimal(mean.x,4) 
					    + ", " 
					    + MathUtil.setDecimal(mean.y,4));
		pro_box_d.appendMessage(message + "\n");
		
		// determine the covariance matrix for the cluster
		//
		double x[] = new double[cluster.size()];
		double y[] = new double[cluster.size()];
		
		for (int j = 0; j < cluster.size(); j++) 
		{
		    MyPoint covarPoint =  (MyPoint)cluster.elementAt(j);
		    x[j] = (covarPoint.x);
		    y[j] = (covarPoint.y);
		}
		
		// declare the covariance object
		//
		Covariance cov = new Covariance();
		
		// declare the covariance matrix
		//
		Matrix covar = new Matrix();
		covar.row = covar.col = 2;
		covar.Elem = new double[2][2];
		
		// compute the covariance matrix of the first data set
		//
		covar.Elem = cov.computeCovariance(x, y);
		
		// display the covariance matrix for the cluster
		//
		double c11 = MathUtil.setDecimal(covar.Elem[0][0], 2);
		double c12 = MathUtil.setDecimal(covar.Elem[0][1], 2);
		double c21 = MathUtil.setDecimal(covar.Elem[1][0], 2);
		double c22 = MathUtil.setDecimal(covar.Elem[1][1], 2);
		
		message = new String("      Covariance matrix:\n" 
				     + "         " + c11 
				     + "    "      + c12 + "\n" 
				     + "         " + c21 
				     + "    "      + c22);
		
		pro_box_d.appendMessage(message + "\n");
		
		// determine the closest of the original data sets
		//
		int closest = getClosestSet(mean);
		
		// compute the classification error
		//
		total += (double)cluster.size();
		error += (double)displayClusterError(closest, 
						     cluster, i);
	    }
	    
	    double err = (error / total) * 100.0;
	    
	    // display the clasification error
	    //
	    String message =  new String("       Overall results:\n" +
					 "          Total number of samples: "
					 + total + "\n" +
					 "          Misclassified samples: " 
					 + error + "\n" +
					 "          Classification error: " 
					 + MathUtil.setDecimal(err, 2) + "%");
	    pro_box_d.appendMessage(message + "\n");    
	}

	outputDecisionRegion();

	return true;
    }

    /**
    * Implementation of the run function from the Runnable interface.
    * Determines what the current step is and calls the appropriate method.
    */
    public void run()
    {
	// Debug
	//
	// System.out.println(algo_id + " : run()");
	
	if (step_index_d == 1)
	{
	    disableControl();
	    step1();
	    enableControl();
	}
	
	else if (step_index_d == 2)
        {
	    disableControl();
	    step2(); 
	    enableControl();
	}
	
	else if (step_index_d == 3)
       {
	   disableControl();
	   step3();
	   enableControl();
       }
	
	// exit gracefully
	//
	return;
    }
    
    /**
     * Collects all the data points together 
     */
    public void generatePool() 
    {

	// determine the pool of points only once
	//
	if (data_pool_d.size() > 0) 
	{
	    return;
	}

	// add all points from the first data set
	//
	for (int i = 0; i < set1_d.size(); i++) 
	{	    
	    data_pool_d.addElement((MyPoint)set1_d.elementAt(i));
	}

	// add all points from the second data set
	//
	for (int i = 0; i < set2_d.size(); i++) 
	{   
	    data_pool_d.addElement((MyPoint)set2_d.elementAt(i));
	}

	// add all points from the third data set
	//
	for (int i = 0; i < set3_d.size(); i++) 
	{	    
	    data_pool_d.addElement((MyPoint)set3_d.elementAt(i));
	}

	// add all points from the fourth data set
	//
	for (int i = 0; i < set4_d.size(); i++) 
	{	    
	    data_pool_d.addElement((MyPoint)set4_d.elementAt(i));
	}
    }

    /**
    * Determines the closest data sets to the cluster
    *
    * @param  mean mean point of the cluster
    * 
    * @return closest data set to the cluster
    */
    public int getClosestSet(MyPoint mean) 
    {

	// declare local variables
	//
	double dist1 = 0.0;
	double dist2 = 0.0;
	double dist3 = 0.0;
	double dist4 = 0.0;

	
	MyPoint mean1;
	MyPoint mean2;
	MyPoint mean3;
	MyPoint mean4;
	int i = 0;

	// determine the distance for the cluster to each data set
	//
	if (set1_d.size() > 0)
	{
	    mean1 = (MyPoint)point_means_d.elementAt(i);
	    i++;
	    dist1 = MathUtil.distance(mean.x, mean.y, 
			     mean1.x, mean1.y);
	}
	else
	    dist1 = Float.MAX_VALUE;

	if (set2_d.size() > 0)
	{
	    mean2 = (MyPoint)point_means_d.elementAt(i);
	    i++;
	    dist2 = MathUtil.distance(mean.x, mean.y, 
			     mean2.x, mean2.y);
	}
	else
	    dist2 = Float.MAX_VALUE;

	if (set3_d.size() > 0)
	{
	    mean3 = (MyPoint)point_means_d.elementAt(i);
	    i++;
	    dist3 = MathUtil.distance(mean.x, mean.y, 
			     mean3.x, mean3.y);
	}
	else
	    dist3 = Float.MAX_VALUE;

	if (set4_d.size() > 0)
	{
	    mean4 = (MyPoint)point_means_d.elementAt(i);
	    i++;
	    dist4 = MathUtil.distance(mean.x, mean.y, 
			     mean4.x, mean4.y);
	}
	else
	    dist4 = Float.MAX_VALUE;

	// the first data set is the closest one
	//
	if (dist1 < dist2 && dist1 < dist3 && dist1 < dist4) 
	{
	    return 1;
	}

	// the second data set is the closest one
	//
	if (dist2 < dist1 && dist2 < dist3 && dist2 < dist4) 
	{
	    return 2;
	}

	// the third data set is the closest one
	//
	if (dist3 < dist1 && dist3 < dist2 && dist3 < dist4) 
	{
	    return 3;
	}

	// the forth data set is the closest one
	//
	return 4;
    }

    /**
     * Finds the datapoints in error, for all datasets
     *
     * @param closest Variable can be int values 1-4.  Marks which
     * set of data is closest
     *
     * @param cluster Stores the points of a cluster
     * @param id ID number
     *
     * @return Returns the error number of the misclassified samples
     */
    public int displayClusterError(int closest, Vector cluster, int id) 
    {

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -