📄 algorithmlbg.java
字号:
//--------------------------------------------------------------------------
// AlgorithmLBG.java 6.0 03/15/2005
// Created : Phil Trasatti Edited : Daniel May
// Last Edited : Sanjay Patil
//
// Description : Describes the LBG algorithms
// Remarks : Code unchanged since created. Created 07/15/2003
//--------------------------------------------------------------------------
//----------------------
// import java packages
//----------------------
import java.util.*;
import java.awt.*;
/**
* implements the LBG algorithm
*/
public class AlgorithmLBG extends Algorithm
{
int output_canvas_d[][];
int iterations;
int currObject;
String algo_id = "AlgorithmLBG";
Vector<MyPoint> support_vectors_d;
Vector<MyPoint> data_pool_d;
Vector<DecisionRegion> cbinary_d;
DecisionRegion decision_regions_d;
/**
* Overrides the initialize() method in the base class. Initializes
* member data and prepares for execution of first step. This method
* "resets" the algorithm.
*
* @return Returns true
*/
public boolean initialize()
{
// Debug
//
// System.out.println(algo_id + ": initialize()");
scale = output_panel_d.disp_area_d.getDisplayScale();
data_pool_d = new Vector<MyPoint>();
decision_regions_d = new DecisionRegion();
point_means_d = new Vector<MyPoint>();
support_vectors_d = new Vector<MyPoint>();
iterations = Classify.main_menu_d.iterations;
description_d = new Vector<String>();
step_count = 3;
cbinary_d = new Vector<DecisionRegion>();
currObject = 0;
// add the process description for the LBG algorithm
//
if (description_d.size() == 0)
{
String str = new String(" 0. Initialize the original data.");
description_d.addElement(str);
str = new String(" 1. Displaying the original data.");
description_d.addElement(str);
str = new String(" 2. Computing the means for each data set.");
description_d.addElement(str);
str = new String(" 3. Stepping through iterations ");
description_d.addElement(str);
str = new String("Iteration");
description_d.addElement(str);
}
// append message to process box
//
pro_box_d.appendMessage("LBG Analysis:" + "\n");
// set the data points for this algorithm
//
// set1_d = (Vector)data_points_d.dset1.clone();
// set2_d = (Vector)data_points_d.dset2.clone();
// set3_d = (Vector)data_points_d.dset3.clone();
// set4_d = (Vector)data_points_d.dset4.clone();
//
set1_d = data_points_d.dset1;
set2_d = data_points_d.dset2;
set3_d = data_points_d.dset3;
set4_d = data_points_d.dset4;
// advance to step 1
//
step_index_d = 0;
// append message to process box and scroll
//
pro_box_d.appendMessage((String)description_d.get(step_index_d));
pro_box_d.scrollPane.getVerticalScrollBar().setValue(1000000);
// exit gracefully
//
return true;
}
/**
* Displays data sets from input box in output box.
*
* @return Returns true
*/
boolean step1()
{
// Debug
//
// System.out.println(algo_id + " : step1()");
// set up progress bar
//
pro_box_d.setProgressMin(0);
pro_box_d.setProgressMax(1);
pro_box_d.setProgressCurr(0);
scaleToFitData();
// Display original data
//
output_panel_d.addOutput(set1_d, Classify.PTYPE_INPUT,
data_points_d.color_dset1);
output_panel_d.addOutput(set2_d, Classify.PTYPE_INPUT,
data_points_d.color_dset2);
output_panel_d.addOutput(set3_d, Classify.PTYPE_INPUT,
data_points_d.color_dset3);
output_panel_d.addOutput(set4_d, Classify.PTYPE_INPUT,
data_points_d.color_dset4);
// step 1 completed
//
pro_box_d.setProgressCurr(1);
output_panel_d.repaint();
return true;
}
/**
* Computes the means of each data set and displays the means graphically
* and numerically
*
* @return Returns true
*/
boolean step2()
{
// Debug
//
// System.out.println(algo_id + " : step2()");
// determine the within class scatter matrix
//
generatePool();
decision_regions_d.addRegion(data_pool_d);
//(Vector<MyPoint>)data_pool_d.clone());
Vector<MyPoint> guesses = new Vector<MyPoint>(2, 2);
MyPoint mean = MathUtil.computeClusterMean(data_pool_d);
computeMeans();
guesses.addElement(new MyPoint(mean.x, mean.y));
decision_regions_d.setGuesses(guesses);
//(Vector<MyPoint>)guesses.clone());
output_panel_d.addOutput(decision_regions_d.getGuesses(),
Classify.PTYPE_OUTPUT_LARGE,
Color.black);
// cbinary_d.addElement(decision_regions_d);
pro_box_d.setProgressCurr(20);
output_panel_d.repaint();
return true;
}
/**
* Computes the Decision Regions and the associated errors.
*
* @return Returns true
*/
boolean step3()
{
// Debug
//
// System.out.println(algo_id + " : step3()");
if (currObject < iterations)
{
pro_box_d.appendMessage(" Iteration "
+ (currObject+1) + "\n");
if ((currObject + 1) != iterations)
step_index_d--;
currObject++;
Vector<DataSet> decisionRegions = decision_regions_d.getRegions();
decision_regions_d = new DecisionRegion();
// compute the new means of the classified data
//
computeBinaryDeviates(decisionRegions);
Vector<MyPoint> guesses = decision_regions_d.getGuesses();
// classify the data based on the guesses
//
classify(guesses);
// determine the total number of clusters
//
int numclusters = decision_regions_d.getNumRegions();
// determine the classification error of each cluster
//
double error = 0.0;
double total = 0.0;
for (int i = 0; i < numclusters; i++)
{
// get the cluster associated with the region
//
Vector cluster = decision_regions_d.getRegion(i);
// determine the mean of the current cluster
//
MyPoint mean = MathUtil.computeClusterMean(cluster);
// display the mean for the current cluster
//
//double xval = setDecimal((mean.x) + Xmin, 2);
//double yval = setDecimal(Ymax - (mean.y), 2);
String message = new String(" Mean for cluster " +
i + ": "
+ MathUtil.setDecimal(mean.x,4)
+ ", "
+ MathUtil.setDecimal(mean.y,4));
pro_box_d.appendMessage(message + "\n");
// determine the covariance matrix for the cluster
//
double x[] = new double[cluster.size()];
double y[] = new double[cluster.size()];
for (int j = 0; j < cluster.size(); j++)
{
MyPoint covarPoint = (MyPoint)cluster.elementAt(j);
x[j] = (covarPoint.x);
y[j] = (covarPoint.y);
}
// declare the covariance object
//
Covariance cov = new Covariance();
// declare the covariance matrix
//
Matrix covar = new Matrix();
covar.row = covar.col = 2;
covar.Elem = new double[2][2];
// compute the covariance matrix of the first data set
//
covar.Elem = cov.computeCovariance(x, y);
// display the covariance matrix for the cluster
//
double c11 = MathUtil.setDecimal(covar.Elem[0][0], 2);
double c12 = MathUtil.setDecimal(covar.Elem[0][1], 2);
double c21 = MathUtil.setDecimal(covar.Elem[1][0], 2);
double c22 = MathUtil.setDecimal(covar.Elem[1][1], 2);
message = new String(" Covariance matrix:\n"
+ " " + c11
+ " " + c12 + "\n"
+ " " + c21
+ " " + c22);
pro_box_d.appendMessage(message + "\n");
// determine the closest of the original data sets
//
int closest = getClosestSet(mean);
// compute the classification error
//
total += (double)cluster.size();
error += (double)displayClusterError(closest,
cluster, i);
}
double err = (error / total) * 100.0;
// display the clasification error
//
String message = new String(" Overall results:\n" +
" Total number of samples: "
+ total + "\n" +
" Misclassified samples: "
+ error + "\n" +
" Classification error: "
+ MathUtil.setDecimal(err, 2) + "%");
pro_box_d.appendMessage(message + "\n");
}
outputDecisionRegion();
return true;
}
/**
* Implementation of the run function from the Runnable interface.
* Determines what the current step is and calls the appropriate method.
*/
public void run()
{
// Debug
//
// System.out.println(algo_id + " : run()");
if (step_index_d == 1)
{
disableControl();
step1();
enableControl();
}
else if (step_index_d == 2)
{
disableControl();
step2();
enableControl();
}
else if (step_index_d == 3)
{
disableControl();
step3();
enableControl();
}
// exit gracefully
//
return;
}
/**
* Collects all the data points together
*/
public void generatePool()
{
// determine the pool of points only once
//
if (data_pool_d.size() > 0)
{
return;
}
// add all points from the first data set
//
for (int i = 0; i < set1_d.size(); i++)
{
data_pool_d.addElement((MyPoint)set1_d.elementAt(i));
}
// add all points from the second data set
//
for (int i = 0; i < set2_d.size(); i++)
{
data_pool_d.addElement((MyPoint)set2_d.elementAt(i));
}
// add all points from the third data set
//
for (int i = 0; i < set3_d.size(); i++)
{
data_pool_d.addElement((MyPoint)set3_d.elementAt(i));
}
// add all points from the fourth data set
//
for (int i = 0; i < set4_d.size(); i++)
{
data_pool_d.addElement((MyPoint)set4_d.elementAt(i));
}
}
/**
* Determines the closest data sets to the cluster
*
* @param mean mean point of the cluster
*
* @return closest data set to the cluster
*/
public int getClosestSet(MyPoint mean)
{
// declare local variables
//
double dist1 = 0.0;
double dist2 = 0.0;
double dist3 = 0.0;
double dist4 = 0.0;
MyPoint mean1;
MyPoint mean2;
MyPoint mean3;
MyPoint mean4;
int i = 0;
// determine the distance for the cluster to each data set
//
if (set1_d.size() > 0)
{
mean1 = (MyPoint)point_means_d.elementAt(i);
i++;
dist1 = MathUtil.distance(mean.x, mean.y,
mean1.x, mean1.y);
}
else
dist1 = Float.MAX_VALUE;
if (set2_d.size() > 0)
{
mean2 = (MyPoint)point_means_d.elementAt(i);
i++;
dist2 = MathUtil.distance(mean.x, mean.y,
mean2.x, mean2.y);
}
else
dist2 = Float.MAX_VALUE;
if (set3_d.size() > 0)
{
mean3 = (MyPoint)point_means_d.elementAt(i);
i++;
dist3 = MathUtil.distance(mean.x, mean.y,
mean3.x, mean3.y);
}
else
dist3 = Float.MAX_VALUE;
if (set4_d.size() > 0)
{
mean4 = (MyPoint)point_means_d.elementAt(i);
i++;
dist4 = MathUtil.distance(mean.x, mean.y,
mean4.x, mean4.y);
}
else
dist4 = Float.MAX_VALUE;
// the first data set is the closest one
//
if (dist1 < dist2 && dist1 < dist3 && dist1 < dist4)
{
return 1;
}
// the second data set is the closest one
//
if (dist2 < dist1 && dist2 < dist3 && dist2 < dist4)
{
return 2;
}
// the third data set is the closest one
//
if (dist3 < dist1 && dist3 < dist2 && dist3 < dist4)
{
return 3;
}
// the forth data set is the closest one
//
return 4;
}
/**
* Finds the datapoints in error, for all datasets
*
* @param closest Variable can be int values 1-4. Marks which
* set of data is closest
*
* @param cluster Stores the points of a cluster
* @param id ID number
*
* @return Returns the error number of the misclassified samples
*/
public int displayClusterError(int closest, Vector cluster, int id)
{
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -