📄 algorithmkmeans.java,v
字号:
head 1.11;access;symbols;locks; strict;comment @# @;1.11date 2005.06.10.16.47.10; author rirwin; state Exp;branches;next 1.10;1.10date 2005.06.10.16.33.25; author rirwin; state Exp;branches;next 1.9;1.9date 2005.06.10.15.55.32; author rirwin; state Exp;branches;next 1.8;1.8date 2005.05.23.16.41.07; author rirwin; state Exp;branches;next 1.7;1.7date 2005.04.11.21.54.37; author patil; state Exp;branches;next 1.6;1.6date 2005.03.16.18.21.34; author patil; state Exp;branches;next 1.5;1.5date 2005.03.16.18.17.59; author patil; state Exp;branches;next 1.4;1.4date 2005.03.15.01.18.08; author patil; state Exp;branches;next 1.3;1.3date 2005.03.08.00.59.27; author patil; state Exp;branches;next 1.2;1.2date 2005.01.19.22.03.04; author patil; state Exp;branches;next 1.1;1.1date 2004.12.28.00.04.32; author patil; state Exp;branches;next ;desc@No changes made@1.11log@autoscroll on init, works not great though.@text@//--------------------------------------------------------------------------
// AlgorithmKMeans.java 6.0 03/15/2005
// Created : Phil Trasatti
// Edited : Sanjay Patil
// Last Edited : Ryan Irwin
//
// Description : Describes the K nearest neighbor algorithms
// Remarks : Code unchanged since created. Created 07/15/2003
//--------------------------------------------------------------------------
//----------------------
// import java packages
//----------------------
import java.util.*;
import java.awt.*;
/**
* implements the K nearest neighbor algorithms
*/
public class AlgorithmKMeans extends Algorithm
{
//-------------------------------------------------------------------
//
// Public Data Members
//
//-------------------------------------------------------------------
/**
* Current object index
*/
int currObject;
/**
* Stores points used in Algorithm
*/
Vector<MyPoint> guesses;
/**
* Stores points used in Algorithm
*/
Vector<MyPoint> data_pool_d;
/**
* Stores points of various decision regions computed
*/
DecisionRegion decision_regions_d;
/**
* The random number generator
*/
public final static long RAND_SEED = 20022000;
/**
* Random number initialized
*/
Random random = new Random(RAND_SEED);
/**
* ID used for message appending. Initialized to "AlgorithmKMeans".
*/
String algo_id = "AlgorithmKMeans";
/**
* Number of iterations. Initialized to 10.
*/
int iterations = 10;
/**
* Number of guess. Initialized to 4.
*/
int numguesses = 4;
//-------------------------------------------------------------------
//
// Public methods
//
//-------------------------------------------------------------------
/**
* Overrides the initialize() method in the base class. Initializes
* member data and prepares for execution of first step. This method
* "resets" the algorithm.
*
* @@return Returns false
*/
public boolean initialize()
{
// Debug
//
// System.out.println(algo_id + ": initialize()");
scale = output_panel_d.disp_area_d.getDisplayScale();
data_pool_d = new Vector<MyPoint>();
decision_regions_d = new DecisionRegion();
point_means_d = new Vector<MyPoint>();
// support_vectors_d = new Vector();
iterations = Classify.main_menu_d.iterations;
numguesses = Classify.main_menu_d.clusters;
description_d = new Vector<String>();
step_count = 3;
currObject = 0;
// add the process description for the LBG algorithm
//
if (description_d.size() == 0)
{
String str = new String(" 0. Initialize the original data.");
description_d.addElement(str);
str = new String(" 1. Displaying the original data.");
description_d.addElement(str);
str = new String(" 2. Computing the means for each data set.");
description_d.addElement(str);
str = new String(" 3. Stepping through iterations ");
description_d.addElement(str);
str = new String("Iteration");
description_d.addElement(str);
}
// append message to process box
//
pro_box_d.appendMessage("KMeans Analysis:" + "\n");
// set the data points for this algorithm
//
// set1_d = (Vector)data_points_d.dset1.clone();
// set2_d = (Vector)data_points_d.dset2.clone();
// set3_d = (Vector)data_points_d.dset3.clone();
// set4_d = (Vector)data_points_d.dset4.clone();
//
set1_d = data_points_d.dset1;
set2_d = data_points_d.dset2;
set3_d = data_points_d.dset3;
set4_d = data_points_d.dset4;
// advance to step 1
//
step_index_d = 0;
// append message to process box and scroll
//
pro_box_d.appendMessage((String)description_d.get(step_index_d));
pro_box_d.scrollPane.getVerticalScrollBar().setValue(1000000);
// exit gracefully
//
return false;
}
/**
* Displays data sets from input box in output box.
*
* @@return Returns true
*/
boolean step1()
{
// Debug
//
// System.out.println(algo_id + ": step1()");
// set up progress bar
//
pro_box_d.setProgressMin(0);
pro_box_d.setProgressMax(1);
pro_box_d.setProgressCurr(0);
scaleToFitData();
// Display original data
//
output_panel_d.addOutput(set1_d, Classify.PTYPE_INPUT,
data_points_d.color_dset1);
output_panel_d.addOutput(set2_d, Classify.PTYPE_INPUT,
data_points_d.color_dset2);
output_panel_d.addOutput(set3_d, Classify.PTYPE_INPUT,
data_points_d.color_dset3);
output_panel_d.addOutput(set4_d, Classify.PTYPE_INPUT,
data_points_d.color_dset4);
// step 1 completed
//
pro_box_d.setProgressCurr(1);
output_panel_d.repaint();
return true;
}
/**
* Computes the means of each data set and displays the means graphically
* and numerically
*
* @@return Returns true
*/
boolean step2()
{
// Debug
//
// System.out.println(algo_id + ": step2()");
computeMeans();
generatePool();
generateMeans(numguesses);
initializeKmeans();
output_panel_d.addOutput(guesses, Classify.PTYPE_OUTPUT_LARGE,
Color.black);
output_panel_d.repaint();
return true;
}
/**
* Computes the Decision Regions and the associated errors.
*
* @@return Returns true
*/
boolean step3()
{
// Debug
//
// System.out.println(algo_id + ": step3()");
if (currObject < iterations)
{
pro_box_d.appendMessage(" Iteration "
+ (currObject+1) + "\n");
if ((currObject + 1) != iterations)
step_index_d--;
currObject++;
decision_regions_d = new DecisionRegion();
// classify the data based on the guesses
//
classify(decision_regions_d);
// compute the new means of the classified data
//
computeMeans(decision_regions_d);
// determine the total number of clusters
//
int numclusters = decision_regions_d.getNumRegions();
// determine the classification error of each cluster
//
double error = 0.0;
double total = 0.0;
for (int i = 0; i < numclusters; i++)
{
// get the cluster associated with the region
//
Vector cluster = decision_regions_d.getRegion(i);
// determine the mean of the current cluster
//
MyPoint mean = MathUtil.computeClusterMean(cluster);
// display the mean for the current cluster
//
double xval = MathUtil.setDecimal(mean.x, 2);
double yval = MathUtil.setDecimal(mean.y, 2);
String message = new String(" Mean for cluster " +
i + ": "
+ xval + ", " + yval);
pro_box_d.appendMessage(message + "\n");
// determine the covariance matrix for the cluster
//
double x[] = new double[cluster.size()];
double y[] = new double[cluster.size()];
for (int j = 0; j < cluster.size(); j++)
{
MyPoint covarPoint = (MyPoint)cluster.elementAt(j);
x[j] = covarPoint.x;
y[j] = covarPoint.y;
}
// declare the covariance object
//
Covariance cov = new Covariance();
// declare the covariance matrix
//
Matrix covar = new Matrix();
covar.row = covar.col = 2;
covar.Elem = new double[2][2];
// compute the covariance matrix of the first data set
//
covar.Elem = cov.computeCovariance(x, y);
// display the covariance matrix for the cluster
//
double c11 = MathUtil.setDecimal(covar.Elem[0][0], 2);
double c12 = MathUtil.setDecimal(covar.Elem[0][1], 2);
double c21 = MathUtil.setDecimal(covar.Elem[1][0], 2);
double c22 = MathUtil.setDecimal(covar.Elem[1][1], 2);
message = new String(" Covariance matrix:\n" +
" " + c11
+ " " + c12 + "\n" +
" " + c21
+ " " + c22);
pro_box_d.appendMessage(message + "\n");
// determine the closest of the original data sets
//
int closest = getClosestSet(mean);
// compute the classification error
//
total += (double)cluster.size();
error += (double)displayClusterError(closest, cluster, i);
}
double err = (error / total) * 100.0;
// display the clasification error
//
String message = new String(" Overall results:\n" +
" Total number of samples: "
+ total + "\n" +
" Misclassified samples: "
+ error + "\n" +
" Classification error: " +
MathUtil.setDecimal(err, 2) + "%");
pro_box_d.appendMessage(message + "\n");
// update the output canvas
//
}
outputDecisionRegion();
return true;
}
/**
* Implementation of the run function from the Runnable interface.
* Determines what the current step is and calls the appropriate method.
*/
public void run()
{
// Debug
//
// System.out.println(algo_id + ": run()");
if (step_index_d == 1)
{
disableControl();
step1();
enableControl();
}
else if (step_index_d == 2)
{
disableControl();
step2();
enableControl();
}
else if (step_index_d == 3)
{
disableControl();
step3();
enableControl();
}
// exit gracefully
//
return;
}
/**
* Collects all the data points of all the data sets
*/
public void generatePool()
{
// determine the pool of points only once
//
if (data_pool_d.size() > 0)
{
return;
}
// add all points from the first data set
//
for (int i = 0; i < set1_d.size(); i++)
{
data_pool_d.addElement((MyPoint)set1_d.elementAt(i));
}
// add all points from the second data set
//
for (int i = 0; i < set2_d.size(); i++)
{
data_pool_d.addElement((MyPoint)set2_d.elementAt(i));
}
// add all points from the third data set
//
for (int i = 0; i < set3_d.size(); i++)
{
data_pool_d.addElement((MyPoint)set3_d.elementAt(i));
}
// add all points from the fourth data set
//
for (int i = 0; i < set4_d.size(); i++)
{
data_pool_d.addElement((MyPoint)set4_d.elementAt(i));
}
}
/**
* Generates random initial guesses (means) for the data set
*
* @@param numMeans number of mean points
*/
public void generateMeans(int numMeans)
{
// declare local variables
//
double xval = 0.0;
double yval = 0.0;
// allocate memory to the guesses vector
//
guesses = new Vector<MyPoint>(numMeans, 10);
// generate the required number of initial guesses
//
for (int i=0; i<numMeans; i++) {
xval = scale.xmin + (random.nextDouble() *
(scale.xmax-scale.xmin));
yval = scale.ymin + (random.nextDouble() *
(scale.ymax-scale.ymin));
guesses.addElement(new MyPoint(xval, yval));
}
}
/**
* Initializes the kmean array with the original data sets
*/
public void initializeKmeans()
{
// add the first data set to the region
//
if (set1_d.size() > 0)
{
decision_regions_d.addRegion(set1_d);
}
// add the second data set to the region
//
if (set2_d.size() > 0)
{
decision_regions_d.addRegion(set2_d);
}
// add the third data set to the region
//
if (set3_d.size() > 0)
{
decision_regions_d.addRegion(set3_d);
}
// add the forth data set to the region
//
if (set4_d.size() > 0)
{
decision_regions_d.addRegion(set4_d);
}
// add the initial guesses to the region
//
decision_regions_d.setGuesses(guesses);
// add the region to the kmean vector
//
// kmeans.addElement(region);
}
/**
* Classifies the data sets based on the k-means iterative algorithm
*
* @@param region - stored data sets from the classification
* @@see DecisionRegion
*/
public void classify(DecisionRegion region)
{
// iterate over all points in the current data pool
//
for (int i = 0; i < data_pool_d.size(); i++)
{
// varables to determine the closest point to the mean
//
double dist = 0.0;
int associated = 0;
double smallestSoFar = Double.MAX_VALUE;
// retrieve a point from the pool
//
MyPoint point = (MyPoint)data_pool_d.elementAt(i);
// iterate over all guesses - means
//
for (int j = 0; j < guesses.size(); j++)
{
// retrieve one of the guesses
//
MyPoint mean = (MyPoint)guesses.elementAt(j);
// determine the distance of the point from the mean
//
dist = MathUtil.distance(point.x, point.y, mean.x, mean.y);
// store the smallest distance
//
if(dist < smallestSoFar)
{
associated = j;
smallestSoFar = dist;
}
}
// store the point based on the classification
//
String name = new String("cluster");
name = name + associated;
region.addPoint(point, name);
}
}
/**
* Computes the means of the data sets after each iteraion
*
* @@param region - classified data sets
* @@see DecisionRegion
*/
public void computeMeans(DecisionRegion region)
{
// determine the number of classified data sets
//
int numsets = region.getNumRegions();
// remove all previous guesses - mean points
//
guesses.removeAllElements();
// iterate over the data sets to determine the new means
//
for (int i = 0; i < numsets; i++)
{
// retrieve the classified region
//
Vector dataset = (Vector)region.getRegion(i);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -