📄 kmeans.java
字号:
import java.util.Vector;import java.awt.*;/************************************************************* * This class implements the K-means algorithm */public class Kmeans{ final int PtSize = 10; int nk; // number of the clusters Database db; // object containing the database values Vector means; // vector containing the means Vector clusters; // vector to store the members of individual clusters // each member will be a vector of points // with the index telling the cluster number Color [] cl = { Color.gray, Color.blue, Color.orange, Color.pink, Color.cyan}; // color for printout double error; // stores the current error value boolean firststep; // bool to store if the first step in the algorithm PointGenerator pgen; // references to a random point generator /** Creates new Kmeans algoirthmic object * sets the number of clusters to 1 * @param DB a database storing the points * @param p a random point generator */ public Kmeans(Database DB,PointGenerator p) { db = DB; means = new Vector(); error = 0.0; firststep = true; pgen = p; clusters = new Vector(); setK(1); // by default the number of clusters is 1 } /** reset the object - randomize means and start the process again with the * previously stores k-value */ public void reset() { setK(nk); } /** The method is used to set the number of clusters * It initializes the related objects as well * @param kval the number of clusters */ public void setK(int kval) { nk = kval; randomizeMeans(); // reset state firststep = true; error = 0.0; // clear the clusters clusters.removeAllElements(); for (int i=0; i < nk; i++) { clusters.add(new Vector()); } } /** This method is used to create random means */ void randomizeMeans() { means.removeAllElements(); for (int i = 0; i < nk; i++) { means.add(pgen.RandomPoint()); } } /** This method performs a single step of the k-means algorithm */ public void kmeansStep() { double minDist, dist; int clusterindex; // reset the cluster clusters.removeAllElements(); for (int i=0; i < nk;i++) clusters.add(new Vector()); for (int i=0;i<db.nPoints();i++) { // get the point Point currp = db.getPoint(i); // find min distance minDist = currp.EuclideanDistance((Point) means.elementAt(0)); clusterindex = 0; for (int j=1; j < nk; j++) { dist = currp.EuclideanDistance((Point) means.elementAt(j)); if (dist < minDist) { minDist = dist; clusterindex = j; } } //System.out.println("minDist = "+minDist+" cluster = "+clusterindex); // categorize to a cluster Vector clust = (Vector) clusters.elementAt(clusterindex); clust.add(currp); } // calculate new means means.removeAllElements(); for (int i = 0 ; i < nk; i++) { Vector clust = (Vector) clusters.elementAt(i); if (clust.size() != 0) means.add(Point.mean(clust)); else means.add(pgen.RandomPoint()); // reinitialize mean if no cluster members } // calculate the new Error dist = 0.0; for (int i = 0; i < nk ; i++ ) { // get the cluster points Vector clust = (Vector) clusters.elementAt(i); int size = clust.size(); Point mean = (Point) means.elementAt(i); // add the distance to the means for (int j=0; j < size ; j++) { Point p = (Point) clust.elementAt(j); dist += p.EuclideanDistance(mean); } } // dist is the new error error = dist; firststep = false; //System.out.println("out"); } /** This method is used to print the current clusters * @param g The graphics object to draw to */ public void paint(Graphics g) { // paint the means g.clearRect(0, 0, 600, 300); if (firststep) db.paint(g); else { for(int i = 0; i < nk; i++) { g.setColor(cl[i]); Point p = (Point) means.elementAt(i); //System.out.println("mean = "+p); g.fillRect( (int) (p.xVal()*600)-PtSize/2,(int) (p.yVal()*300)-PtSize/2,PtSize,PtSize); // print the cluster Vector cluster = (Vector) clusters.elementAt(i); int size = cluster.size(); for (int j = 0; j < size ; j++ ) { p = (Point) cluster.elementAt(j); //g.setColor(Color.green); //g.fillRect( (int) (p.xVal()*600)-PtSize/4,(int) (p.yVal()*300)-PtSize/4,PtSize/2,PtSize/2); //g.setColor(cl[i]); g.fillRect( (int) (p.xVal()*600)-PtSize/4,(int) (p.yVal()*300)-PtSize/4,PtSize/2,PtSize/2); //g.drawString(i+"",(int)p.xVal(), (int)p.yVal()); } } } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -