⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 algorithmnn.java

📁 包含了模式识别中常用的一些分类器设计算法
💻 JAVA
字号:
/**
* AlgorithmNN.java v6.0 03/15/2005
*
* Author Phil Trasatti, Last edited Sanjay Created: 7/15/03
* 
* Nearest Neighbor algorithm. Determines the line of 
* discrimination between data sets based on the points that are equal 
* in distance from the nearest point of each data set.
*/


//----------------------
// import java packages
//----------------------
import java.util.*;
import java.awt.*;

public class AlgorithmNN extends Algorithm
{

    //-----------------------------------------------------------------
    //
    // static data members
    //
    //-----------------------------------------------------------------

    //-----------------------------------------------------------------
    //
    // primitive data members
    //
    //-----------------------------------------------------------------

    int output_canvas_d[][];   

    //-----------------------------------------------------------------
    //
    // instance data members
    //
    //-----------------------------------------------------------------
        
    String algo_id = "AlgorithmNN"; 
    Vector<MyPoint> support_vectors_d;
    Vector<MyPoint> decision_regions_d;


    //---------------------------------------------------------------
    //
    // class methods
    //
    //---------------------------------------------------------------


    /**
     * Implements the initialize() method in the base class. Initializes
     * member data and prepares for execution of first step. This method
     * "resets" the algorithm.
     *
     * @return   true
     *
    */
    public boolean initialize()
    {
	// Debug
	//
	// System.out.println(algo_id + ": initialize()");
	
	point_means_d = new Vector<MyPoint>();
	decision_regions_d = new Vector<MyPoint>();
	support_vectors_d = new Vector<MyPoint>();
	description_d = new Vector<String>();
	step_count = 3;
	
	// Add the process description for the NN algorithm
	//
	if (description_d.size() == 0)
	{
	    String str = new String("   0. Initialize the original data.");
	    description_d.addElement(str);
	    
	    str = new String("   1. Displaying the original data.");
	    description_d.addElement(str);
	    
	    str = new String("   2. Computing the means for each class.");
	    description_d.addElement(str);
	    
	    str = new String("   3. Computing the decision regions based on the Nearest Neighbor algorithm.");
	    description_d.addElement(str);
	}
	
	// append message to process box
	//
	pro_box_d.appendMessage("Nearest Neighbor Analysis:" + "\n");
	
	// set the data points for this algorithm
	//
	//	set1_d = (Vector)data_points_d.dset1.clone();
	//	set2_d = (Vector)data_points_d.dset2.clone();
	//	set3_d = (Vector)data_points_d.dset3.clone();
	//	set4_d = (Vector)data_points_d.dset4.clone();
	//
	set1_d = data_points_d.dset1;
	set2_d = data_points_d.dset2;
	set3_d = data_points_d.dset3;
	set4_d = data_points_d.dset4;
	
	// set the step index
	//
	step_index_d = 0;
	
	// append message to process box
	//
	pro_box_d.appendMessage((String)description_d.get(step_index_d));
	
	// exit initialize
	//
	return true;
    }

    
    /**
     * Displays data sets from input box in output box.
     *
     * @return   true
     *
     */
    boolean step1()
    {
	// Debug
	//
	// System.out.println(algo_id + ": step1()");

	// set up progress bar
	//
	pro_box_d.setProgressMin(0);
	pro_box_d.setProgressMax(1);
	pro_box_d.setProgressCurr(0);

	scaleToFitData();

	// Display original data
	//
	output_panel_d.addOutput(set1_d, Classify.PTYPE_INPUT, 
				 data_points_d.color_dset1);
	output_panel_d.addOutput(set2_d, Classify.PTYPE_INPUT,
				 data_points_d.color_dset2);
	output_panel_d.addOutput(set3_d, Classify.PTYPE_INPUT,
				 data_points_d.color_dset3);
	output_panel_d.addOutput(set4_d, Classify.PTYPE_INPUT, 
				 data_points_d.color_dset4);
	
	// step 1 completed
	//
	pro_box_d.setProgressCurr(1);
	output_panel_d.repaint();
	
	return true;
    }
    
    /**
     * 
     * Computes the means of each data set and displays the means graphically
     * and numerically
     *
     * @return   true
     *
     */
    boolean step2()
    {
	// Debug
	//
	// System.out.println(algo_id + ": step2()");
	
	// determine the within class scatter matrix
	//
	computeMeans();

	// display means
	//
	output_panel_d.addOutput(point_means_d, Classify.PTYPE_OUTPUT_LARGE, 
				 Color.black);
	
	// display support vectors
	//
	output_panel_d.addOutput( support_vectors_d, Classify.PTYPE_INPUT, 
				  Color.black );
	
	// display support vectors
	//
	pro_box_d.setProgressCurr(20);
	output_panel_d.repaint();	
	
	return true;
    }
    
    /**
     *
     * Computes the Decision Regions and the associated errors. 
     *
     * @return   true
     *
     */
    boolean step3()
    {
	// Debug
	//
	// System.out.println(algo_id + ": step3()");

	// compute the decision regisions
	//
	computeDecisionRegions();
	
	// compute errors
	//
	computeErrors();
	
	// display support vectors
	//
	output_panel_d.addOutput( decision_regions_d, Classify.PTYPE_INPUT, 
				  new Color(255, 200, 0));
	
	output_panel_d.repaint();

       	return true;
    }

    /**
     *
     * Implements the run function from the Runnable interface.
     * Determines what the current step is and calls the appropriate method.
     *
     */
    public void run()
    {
	// Debug
	//
	// System.out.println(algo_id + " : run()");
	
	if (step_index_d == 1)
	{
	    disableControl();
	    step1();
	    enableControl();
	}
	
	if (step_index_d == 2)
        {
	    disableControl();
	    step2();
	    enableControl(); 
	}
	
	if (step_index_d == 3)
        {
	    disableControl();
	    step3(); 
	    enableControl(); 
	    pro_box_d.appendMessage("   Algorithm Complete");
	}
	
	// exit gracefully
	//
	return;
    }
    
    /**
     *
     * Computes the line of discrimination for nearest neighbor 
     *
     */
    public void computeDecisionRegions()
    {	

	// Debug
	//
	// System.out.println(algo_id +": computeDecisionRegions()");

	// compute the line of discrimination for euclidean distance
	//
	DisplayScale scale = output_panel_d.disp_area_d.getDisplayScale();

	double currentX = scale.xmin;
	double currentY = scale.ymin;

	// set precision
	//
	int outputWidth = output_panel_d.disp_area_d.getXPrecision();
	int outputHeight = output_panel_d.disp_area_d.getYPrecision();
	
	double incrementY = (scale.ymax - scale.ymin) / outputHeight;
	double incrementX = (scale.xmax - scale.xmin) / outputWidth;

	// declare a 2D array to store the class associations
	//
	output_canvas_d = new int[outputWidth][outputHeight];
	
	// loop through each and every point on the pixmap and
	// determine which class each pixel is associated with
	//
	MyPoint point;
	double dist = 0.0;
	int associated = 0;
	double smallestSoFar = Double.MAX_VALUE;
	int counter = 0;
	
	pro_box_d.setProgressMin(0);
	pro_box_d.setProgressMax(outputWidth);
	pro_box_d.setProgressCurr(0);
	
	for (int i = 0; i < outputWidth; i++)
	{	    
	    currentX += incrementX;
	    currentY = scale.ymin;
	    pro_box_d.setProgressCurr(i);    
	    
	    for (int j = 0; j < outputHeight; counter++, j++)
	    {
		// declare the current pixel point
		//
		currentY += incrementY;
		MyPoint pixel = new MyPoint(currentX, currentY);
		smallestSoFar = Double.MAX_VALUE;
		
		// convert the pixel to the time domain
		//
		double X[][] = new double[1][2];
		X[0][0] = pixel.x;
		X[0][1] = pixel.y;
		
		// find the closest point from the first class
		//
		for (int k = 0; k < set1_d.size(); k++)
		{
		    point = (MyPoint)set1_d.elementAt(k);
		    dist = MathUtil.distance(pixel.x, pixel.y, 
					     point.x, point.y);
		    if (dist < smallestSoFar)
		    {
			associated = 0;
			smallestSoFar = dist;
		    }
		}
			    
		// find the closest point from the second class
		//
		for (int k = 0; k < set2_d.size(); k++)
		{
		    point = (MyPoint)set2_d.elementAt(k);
		    dist = MathUtil.distance(pixel.x, pixel.y, 
					     point.x, point.y);
		    if (dist < smallestSoFar)
		    {
			associated = 1;
			smallestSoFar = dist;
		    }
		}

	        // find the closest point from the third class
	        //
		for (int k = 0; k < set3_d.size(); k++)
		{
		    point = (MyPoint)set3_d.elementAt(k);
		    dist = MathUtil.distance(pixel.x, pixel.y, 
					     point.x, point.y);
		    if (dist < smallestSoFar)
		    {
			associated = 2;
			smallestSoFar = dist;
		    }
		}
		
		// find the closest point from the fourth class
		//
		for (int k = 0; k < set4_d.size(); k++)
	        {
		    point = (MyPoint)set4_d.elementAt(k);
		    dist = MathUtil.distance(pixel.x, pixel.y, 
					     point.x, point.y);
		    if (dist < smallestSoFar)
		    {
			associated = 3;
			smallestSoFar = dist;
		    }
		}

		// put and entry in the output canvas array to
		// indicate which class the current pixel is closest to
		//
		output_canvas_d[i][j] = associated;
		
		// add a point to the vector of decision
		// region points if the class that the current
		// point is associated with is different for
		// the class what the previous point was
		// associated with i.e., a transition point
		//
		if (j > 0 && i > 0)
	        {
		    if (associated != output_canvas_d[i][j - 1] 
			|| associated != output_canvas_d[i - 1][j])
		    {
			decision_regions_d.add(pixel);
		    }
		}
	    }
	}
	
    }

    /**
     *
     * Computes and displays the classification errors for each set   
     *
     */
    public void computeErrors()
    {
	
	// Debug
	//
	// System.out.println(algo_id +": computeErrors()");

	// declare local variables
	//
	String text;
	double error;
	int samples = 0;
	int samples1 = 0;
	int samples2 = 0;
	int samples3 = 0;
	int samples4 = 0;
	
	int incorrect = 0;
	int incorrect1 = 0;
	int incorrect2 = 0;
	int incorrect3 = 0;
	int incorrect4 = 0;

	DisplayScale scale = output_panel_d.disp_area_d.getDisplayScale();	
	// set scales
	//
	int outputWidth = output_panel_d.disp_area_d.getXPrecision();
	int outputHeight = output_panel_d.disp_area_d.getYPrecision();

	double incrementY = (scale.ymax - scale.ymin) / outputHeight;

	double incrementX = (scale.xmax - scale.xmin) / outputWidth;


	// compute the classification error for the first set
	//
	for (int i = 0; i < set1_d.size(); i++)
        {
	    
	    MyPoint point = (MyPoint)set1_d.elementAt(i);
	    samples1++;
	    if ((point.x > scale.xmin && point.x < scale.xmax)
		&& (point.y > scale.ymin && point.y < scale.ymax))
	    {
		if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)]
		    [(int)((point.y - scale.ymin) / incrementY)] != 0)
		{
		    incorrect1++;
		}
	    }
	}
	
	if (set1_d.size() > 0)
        {
	    
	    error = ((double)incorrect1 / (double)samples1) * 100.0;
	    
	    text =
		new String(
			   "       Results for class 0:\n"
			   + "          Total number of samples: "
			   + samples1
			   + "\n"
			   + "          Misclassified samples: "
			   + incorrect1
			   + "\n"
			   + "          Classification error: "
			   + MathUtil.setDecimal(error, 2)
			   + "%");
	    
	    pro_box_d.appendMessage(text);
	}
	
	// compute the classification error for the second set
	//
	for (int i = 0; i < set2_d.size(); i++)
        {
	    
	    MyPoint point = (MyPoint)set2_d.elementAt(i);
	    samples2++;
	    if ((point.x > scale.xmin && point.x < scale.xmax)
		&& (point.y > scale.ymin && point.y < scale.ymax))
	    {
		if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)]
		    [(int)((point.y - scale.ymin) / incrementY)] != 1)
	       {
		   incorrect2++;
	       }
	    }
	}
	
	if (set2_d.size() > 0)
        {
	    
	    error = ((double)incorrect2 / (double)samples2) * 100.0;
	    
	    text =
		new String(
			   "       Results for class 1:\n"
			   + "          Total number of samples: "
			   + samples2
			   + "\n"
			   + "          Misclassified samples: "
			   + incorrect2
			   + "\n"
			   + "          Classification error: "
			   + MathUtil.setDecimal(error, 2)
			   + "%");
	    
	    pro_box_d.appendMessage(text);
	}
	
	// compute the classification error for the third set
	//
	for (int i = 0; i < set3_d.size(); i++)
        {
	    
	    MyPoint point = (MyPoint)set3_d.elementAt(i);
	    samples3++;
	    if ((point.x > scale.xmin && point.x < scale.xmax)
		&& (point.y > scale.ymin && point.y < scale.ymax))
	    {
		if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)]
		    [(int)((point.y - scale.ymin) / incrementY)] != 2)
		{
		    incorrect3++;
		}
	    }
	}
	
	if (set3_d.size() > 0)
	{
	    
	    error = ((double)incorrect3 / (double)samples3) * 100.0;
	    
	    text =
		new String(
			   "       Results for class 2:\n"
			   + "          Total number of samples: "
			   + samples3
			   + "\n"
			   + "          Misclassified samples: "
			   + incorrect3
			   + "\n"
			   + "          Classification error: "
			   + MathUtil.setDecimal(error, 2)
			   + "%");
	    
	    pro_box_d.appendMessage(text);
	}
	
	// compute the classification error for the forth set
	//
	for (int i = 0; i < set4_d.size(); i++)
        {
	    
	    MyPoint point = (MyPoint)set4_d.elementAt(i);
	    samples4++;
	    if ((point.x > scale.xmin && point.x < scale.xmax)
		&& (point.y > scale.ymin && point.y < scale.ymax))
	    {
		if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)]
		    [(int)((point.y - scale.ymin) / incrementY)] != 3)
		{
		    incorrect4++;
		}
	    }
	}
	
	if (set4_d.size() > 0)
        {
	    
	    error = ((double)incorrect4 / (double)samples4) * 100.0;
	    
	    text =
		new String(
			   "       Results for class 3:\n"
			   + "          Total number of samples: "
			   + samples4
			   + "\n"
			   + "          Misclassified samples: "
			   + incorrect4
			   + "\n"
			   + "          Classification error: "
			   + MathUtil.setDecimal(error, 2)
			   + "%");
	    
	    pro_box_d.appendMessage(text);
	}
	
	// compute the overall classification error
	//
	samples = samples1 + samples2 + samples3 + samples4;
	incorrect = incorrect1 + incorrect2 + incorrect3 + incorrect4;
	error = ((double)incorrect / (double)samples) * 100.0;
	
	text =
	    new String(
		       "       Overall results:\n"
		       + "          Total number of samples: "
		       + samples
		       + "\n"
		       + "          Misclassified samples: "
		       + incorrect
		       + "\n"
		       + "          Classification error: "
		       + MathUtil.setDecimal(error, 2)
		       + "%");
	
	pro_box_d.appendMessage(text);
    }   
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -