📄 algorithmnn.java
字号:
/**
* AlgorithmNN.java v6.0 03/15/2005
*
* Author Phil Trasatti, Last edited Sanjay Created: 7/15/03
*
* Nearest Neighbor algorithm. Determines the line of
* discrimination between data sets based on the points that are equal
* in distance from the nearest point of each data set.
*/
//----------------------
// import java packages
//----------------------
import java.util.*;
import java.awt.*;
public class AlgorithmNN extends Algorithm
{
//-----------------------------------------------------------------
//
// static data members
//
//-----------------------------------------------------------------
//-----------------------------------------------------------------
//
// primitive data members
//
//-----------------------------------------------------------------
int output_canvas_d[][];
//-----------------------------------------------------------------
//
// instance data members
//
//-----------------------------------------------------------------
String algo_id = "AlgorithmNN";
Vector<MyPoint> support_vectors_d;
Vector<MyPoint> decision_regions_d;
//---------------------------------------------------------------
//
// class methods
//
//---------------------------------------------------------------
/**
* Implements the initialize() method in the base class. Initializes
* member data and prepares for execution of first step. This method
* "resets" the algorithm.
*
* @return true
*
*/
public boolean initialize()
{
// Debug
//
// System.out.println(algo_id + ": initialize()");
point_means_d = new Vector<MyPoint>();
decision_regions_d = new Vector<MyPoint>();
support_vectors_d = new Vector<MyPoint>();
description_d = new Vector<String>();
step_count = 3;
// Add the process description for the NN algorithm
//
if (description_d.size() == 0)
{
String str = new String(" 0. Initialize the original data.");
description_d.addElement(str);
str = new String(" 1. Displaying the original data.");
description_d.addElement(str);
str = new String(" 2. Computing the means for each class.");
description_d.addElement(str);
str = new String(" 3. Computing the decision regions based on the Nearest Neighbor algorithm.");
description_d.addElement(str);
}
// append message to process box
//
pro_box_d.appendMessage("Nearest Neighbor Analysis:" + "\n");
// set the data points for this algorithm
//
// set1_d = (Vector)data_points_d.dset1.clone();
// set2_d = (Vector)data_points_d.dset2.clone();
// set3_d = (Vector)data_points_d.dset3.clone();
// set4_d = (Vector)data_points_d.dset4.clone();
//
set1_d = data_points_d.dset1;
set2_d = data_points_d.dset2;
set3_d = data_points_d.dset3;
set4_d = data_points_d.dset4;
// set the step index
//
step_index_d = 0;
// append message to process box
//
pro_box_d.appendMessage((String)description_d.get(step_index_d));
// exit initialize
//
return true;
}
/**
* Displays data sets from input box in output box.
*
* @return true
*
*/
boolean step1()
{
// Debug
//
// System.out.println(algo_id + ": step1()");
// set up progress bar
//
pro_box_d.setProgressMin(0);
pro_box_d.setProgressMax(1);
pro_box_d.setProgressCurr(0);
scaleToFitData();
// Display original data
//
output_panel_d.addOutput(set1_d, Classify.PTYPE_INPUT,
data_points_d.color_dset1);
output_panel_d.addOutput(set2_d, Classify.PTYPE_INPUT,
data_points_d.color_dset2);
output_panel_d.addOutput(set3_d, Classify.PTYPE_INPUT,
data_points_d.color_dset3);
output_panel_d.addOutput(set4_d, Classify.PTYPE_INPUT,
data_points_d.color_dset4);
// step 1 completed
//
pro_box_d.setProgressCurr(1);
output_panel_d.repaint();
return true;
}
/**
*
* Computes the means of each data set and displays the means graphically
* and numerically
*
* @return true
*
*/
boolean step2()
{
// Debug
//
// System.out.println(algo_id + ": step2()");
// determine the within class scatter matrix
//
computeMeans();
// display means
//
output_panel_d.addOutput(point_means_d, Classify.PTYPE_OUTPUT_LARGE,
Color.black);
// display support vectors
//
output_panel_d.addOutput( support_vectors_d, Classify.PTYPE_INPUT,
Color.black );
// display support vectors
//
pro_box_d.setProgressCurr(20);
output_panel_d.repaint();
return true;
}
/**
*
* Computes the Decision Regions and the associated errors.
*
* @return true
*
*/
boolean step3()
{
// Debug
//
// System.out.println(algo_id + ": step3()");
// compute the decision regisions
//
computeDecisionRegions();
// compute errors
//
computeErrors();
// display support vectors
//
output_panel_d.addOutput( decision_regions_d, Classify.PTYPE_INPUT,
new Color(255, 200, 0));
output_panel_d.repaint();
return true;
}
/**
*
* Implements the run function from the Runnable interface.
* Determines what the current step is and calls the appropriate method.
*
*/
public void run()
{
// Debug
//
// System.out.println(algo_id + " : run()");
if (step_index_d == 1)
{
disableControl();
step1();
enableControl();
}
if (step_index_d == 2)
{
disableControl();
step2();
enableControl();
}
if (step_index_d == 3)
{
disableControl();
step3();
enableControl();
pro_box_d.appendMessage(" Algorithm Complete");
}
// exit gracefully
//
return;
}
/**
*
* Computes the line of discrimination for nearest neighbor
*
*/
public void computeDecisionRegions()
{
// Debug
//
// System.out.println(algo_id +": computeDecisionRegions()");
// compute the line of discrimination for euclidean distance
//
DisplayScale scale = output_panel_d.disp_area_d.getDisplayScale();
double currentX = scale.xmin;
double currentY = scale.ymin;
// set precision
//
int outputWidth = output_panel_d.disp_area_d.getXPrecision();
int outputHeight = output_panel_d.disp_area_d.getYPrecision();
double incrementY = (scale.ymax - scale.ymin) / outputHeight;
double incrementX = (scale.xmax - scale.xmin) / outputWidth;
// declare a 2D array to store the class associations
//
output_canvas_d = new int[outputWidth][outputHeight];
// loop through each and every point on the pixmap and
// determine which class each pixel is associated with
//
MyPoint point;
double dist = 0.0;
int associated = 0;
double smallestSoFar = Double.MAX_VALUE;
int counter = 0;
pro_box_d.setProgressMin(0);
pro_box_d.setProgressMax(outputWidth);
pro_box_d.setProgressCurr(0);
for (int i = 0; i < outputWidth; i++)
{
currentX += incrementX;
currentY = scale.ymin;
pro_box_d.setProgressCurr(i);
for (int j = 0; j < outputHeight; counter++, j++)
{
// declare the current pixel point
//
currentY += incrementY;
MyPoint pixel = new MyPoint(currentX, currentY);
smallestSoFar = Double.MAX_VALUE;
// convert the pixel to the time domain
//
double X[][] = new double[1][2];
X[0][0] = pixel.x;
X[0][1] = pixel.y;
// find the closest point from the first class
//
for (int k = 0; k < set1_d.size(); k++)
{
point = (MyPoint)set1_d.elementAt(k);
dist = MathUtil.distance(pixel.x, pixel.y,
point.x, point.y);
if (dist < smallestSoFar)
{
associated = 0;
smallestSoFar = dist;
}
}
// find the closest point from the second class
//
for (int k = 0; k < set2_d.size(); k++)
{
point = (MyPoint)set2_d.elementAt(k);
dist = MathUtil.distance(pixel.x, pixel.y,
point.x, point.y);
if (dist < smallestSoFar)
{
associated = 1;
smallestSoFar = dist;
}
}
// find the closest point from the third class
//
for (int k = 0; k < set3_d.size(); k++)
{
point = (MyPoint)set3_d.elementAt(k);
dist = MathUtil.distance(pixel.x, pixel.y,
point.x, point.y);
if (dist < smallestSoFar)
{
associated = 2;
smallestSoFar = dist;
}
}
// find the closest point from the fourth class
//
for (int k = 0; k < set4_d.size(); k++)
{
point = (MyPoint)set4_d.elementAt(k);
dist = MathUtil.distance(pixel.x, pixel.y,
point.x, point.y);
if (dist < smallestSoFar)
{
associated = 3;
smallestSoFar = dist;
}
}
// put and entry in the output canvas array to
// indicate which class the current pixel is closest to
//
output_canvas_d[i][j] = associated;
// add a point to the vector of decision
// region points if the class that the current
// point is associated with is different for
// the class what the previous point was
// associated with i.e., a transition point
//
if (j > 0 && i > 0)
{
if (associated != output_canvas_d[i][j - 1]
|| associated != output_canvas_d[i - 1][j])
{
decision_regions_d.add(pixel);
}
}
}
}
}
/**
*
* Computes and displays the classification errors for each set
*
*/
public void computeErrors()
{
// Debug
//
// System.out.println(algo_id +": computeErrors()");
// declare local variables
//
String text;
double error;
int samples = 0;
int samples1 = 0;
int samples2 = 0;
int samples3 = 0;
int samples4 = 0;
int incorrect = 0;
int incorrect1 = 0;
int incorrect2 = 0;
int incorrect3 = 0;
int incorrect4 = 0;
DisplayScale scale = output_panel_d.disp_area_d.getDisplayScale();
// set scales
//
int outputWidth = output_panel_d.disp_area_d.getXPrecision();
int outputHeight = output_panel_d.disp_area_d.getYPrecision();
double incrementY = (scale.ymax - scale.ymin) / outputHeight;
double incrementX = (scale.xmax - scale.xmin) / outputWidth;
// compute the classification error for the first set
//
for (int i = 0; i < set1_d.size(); i++)
{
MyPoint point = (MyPoint)set1_d.elementAt(i);
samples1++;
if ((point.x > scale.xmin && point.x < scale.xmax)
&& (point.y > scale.ymin && point.y < scale.ymax))
{
if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)]
[(int)((point.y - scale.ymin) / incrementY)] != 0)
{
incorrect1++;
}
}
}
if (set1_d.size() > 0)
{
error = ((double)incorrect1 / (double)samples1) * 100.0;
text =
new String(
" Results for class 0:\n"
+ " Total number of samples: "
+ samples1
+ "\n"
+ " Misclassified samples: "
+ incorrect1
+ "\n"
+ " Classification error: "
+ MathUtil.setDecimal(error, 2)
+ "%");
pro_box_d.appendMessage(text);
}
// compute the classification error for the second set
//
for (int i = 0; i < set2_d.size(); i++)
{
MyPoint point = (MyPoint)set2_d.elementAt(i);
samples2++;
if ((point.x > scale.xmin && point.x < scale.xmax)
&& (point.y > scale.ymin && point.y < scale.ymax))
{
if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)]
[(int)((point.y - scale.ymin) / incrementY)] != 1)
{
incorrect2++;
}
}
}
if (set2_d.size() > 0)
{
error = ((double)incorrect2 / (double)samples2) * 100.0;
text =
new String(
" Results for class 1:\n"
+ " Total number of samples: "
+ samples2
+ "\n"
+ " Misclassified samples: "
+ incorrect2
+ "\n"
+ " Classification error: "
+ MathUtil.setDecimal(error, 2)
+ "%");
pro_box_d.appendMessage(text);
}
// compute the classification error for the third set
//
for (int i = 0; i < set3_d.size(); i++)
{
MyPoint point = (MyPoint)set3_d.elementAt(i);
samples3++;
if ((point.x > scale.xmin && point.x < scale.xmax)
&& (point.y > scale.ymin && point.y < scale.ymax))
{
if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)]
[(int)((point.y - scale.ymin) / incrementY)] != 2)
{
incorrect3++;
}
}
}
if (set3_d.size() > 0)
{
error = ((double)incorrect3 / (double)samples3) * 100.0;
text =
new String(
" Results for class 2:\n"
+ " Total number of samples: "
+ samples3
+ "\n"
+ " Misclassified samples: "
+ incorrect3
+ "\n"
+ " Classification error: "
+ MathUtil.setDecimal(error, 2)
+ "%");
pro_box_d.appendMessage(text);
}
// compute the classification error for the forth set
//
for (int i = 0; i < set4_d.size(); i++)
{
MyPoint point = (MyPoint)set4_d.elementAt(i);
samples4++;
if ((point.x > scale.xmin && point.x < scale.xmax)
&& (point.y > scale.ymin && point.y < scale.ymax))
{
if (output_canvas_d[(int)((point.x - scale.xmin) / incrementX)]
[(int)((point.y - scale.ymin) / incrementY)] != 3)
{
incorrect4++;
}
}
}
if (set4_d.size() > 0)
{
error = ((double)incorrect4 / (double)samples4) * 100.0;
text =
new String(
" Results for class 3:\n"
+ " Total number of samples: "
+ samples4
+ "\n"
+ " Misclassified samples: "
+ incorrect4
+ "\n"
+ " Classification error: "
+ MathUtil.setDecimal(error, 2)
+ "%");
pro_box_d.appendMessage(text);
}
// compute the overall classification error
//
samples = samples1 + samples2 + samples3 + samples4;
incorrect = incorrect1 + incorrect2 + incorrect3 + incorrect4;
error = ((double)incorrect / (double)samples) * 100.0;
text =
new String(
" Overall results:\n"
+ " Total number of samples: "
+ samples
+ "\n"
+ " Misclassified samples: "
+ incorrect
+ "\n"
+ " Classification error: "
+ MathUtil.setDecimal(error, 2)
+ "%");
pro_box_d.appendMessage(text);
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -