📄 kmeans.java
字号:
/*
*Kmeans
*Chalres
*
*
*
**/
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.Random;
public class kmeans {
private static int M=150; //No. of instance
private static int k=3; //k cluster
private static int N=4; //No. of dimension
private static String str="iris_data.txt";
int err = 0;
private static int Z=0; //No. of loop
private static float[][] arr; //data array
public String[] target_class;
float cent[][]; //centroid
float ocent[][]; //old centroid
boolean r=true; //set default r to stop the loop first
double P1,P2=0;
float class_data[];
void Normalize(){
float max[]=new float[N];
float min[]=new float[N];
int i,j;
for(i=0;i<N;i++){
max[i]=min[i]=arr[0][i];
}
for(i=0;i<M;i++)
for(j=0;j<N;j++)
{
if(max[j]<arr[i][j]) max[j]=arr[i][j];//find the max value
if(min[j]>arr[i][j]) min[j]=arr[i][j];//find the min value
}
for(i=0;i<M;i++)
for(j=0;j<N;j++)
{
if(max[j]==min[j]); //do nothing
else arr[i][j]=(arr[i][j]-min[j])/(max[j]-min[j]);
}//normailze the dataset
}
public void init(){
arr=new float[M][N+1];
cent= new float[k][N];
ocent=new float[k][N];
target_class = new String[M];
class_data = new float[M];
//store the original classification
for (int i = 0; i < 49; i++)
{
class_data[i] = 0;
}
for (int i = 49; i < 99; i++)
{
class_data[i] = 1;
}
for (int i = 99; i < M; i++)
{
class_data[i] = 2;
}
readFile(str);
int i,j;
//print raw dataset
for(i=0;i<M;i++){
System.out.print("No. "+i+"\t");
for(j=0;j<N;j++)
{
System.out.print(arr[i][j]);
System.out.print("\t");
// if(j==3)
}
System.out.print(class_data[i]);
System.out.print("\n");
}
Normalize();
calculate();
System.out.print("Cluster:"+"\n");
for(i=0;i<M;i++)
{
if (arr[i][N] == 2.0) {target_class[i] = "+";}
else if (arr[i][N] == 1.0) {target_class[i] = "-";}
else if (arr[i][N] == 0.0) {target_class[i] = "*";}
else {target_class [i] = "?";}
// System.out.print((arr[i][N])+"\n");
System.out.print((target_class[i]));
if (arr[i][N] - class_data[i] !=0) err++;
}
//System.out.print(M);
//System.out.print(N);
System.out.println();
System.out.print("Centroid: "+"\n");
for(i=0;i<cent.length;i++)//print centroid
for(j=0;j<cent[i].length;j++){
System.out.print(cent[i][j]+"\t");
if(j==cent[i].length-1)
System.out.print("\n");
}
// System.out.print("Error of Cluster: "+err);
System.out.print("No. of Loop:"+Z);
//System.out.print("\n"+P1+"\n"+P2);
}
public void readFile(String fileName){ //read from file
int count = 0,i,j;
String str[] = new String[M];
String str1 = null;
String str2[] = new String[4];
try{
FileReader fr = new FileReader(fileName);
BufferedReader br = new BufferedReader(fr);
while((str1=br.readLine())!=null){
str[count] = str1;
count++;
}
fr.close();
br.close();
}
catch(IOException e){
System.out.println("IOException");
}
for(i=0;i<M;i++){
str2=str[i].split(",");
for(j=0;j<N;j++){
Float ff = new Float(str2[j]);//convert string to float
arr[i][j]=ff.floatValue();//store into the array
}
}
//print out array, for testing
/*for(i=0;i<M;i++){
for(j=0;j<N;j++){
System.out.print(arr[i][j]+",");
}
System.out.println();
}*/
}
public void calculate(){
int j;
Random R= new Random();
for(int i=0;i<M;i++){
//randomly centroid choose
j=Math.abs(R.nextInt())%k;
arr[i][N]=j;
}
//work out centroid
calulatecent();
while(r)
{
Z++;
//No. of loop +1
distance();
calulatecent();
if(Math.abs(P1-P2)<0.0000001) break;//act as if the distance is not changed
for(int a=0;a<k;a++)
for(int i=0;i<N;i++)
{
if(cent[a][i]!=ocent[a][i])
{
r=true;
break;
//the centroid is not stable
}
else
r=false;
//try to break the loop
}
}
}
public void calulatecent(){//calculate centroid
float s[]=new float[N];//total sum for the dimension
for(int i=0;i<k;i++)//save old centroid
for(int j=0;j<N;j++){
ocent[i][j]=cent[i][j];
}
for(int i=0;i<k;i++){
for(int l=0;l<N;l++) s[l]=0;
int l=0;
for(int j=0;j<M;j++)
if(arr[j][N]==i){//centroid for every dimension
l++;
for(int z=0;z<N;z++)
s[z]+=arr[j][z];
}
for(int x=0;x<N;x++)
cent[i][x]=s[x]/l;//mean
}
}
public void distance(){
//distance
int i,j,w;
double dis,z;
P1=P2;P2=0;
for(i=0;i<M;i++)
{
dis=Double.MAX_VALUE;
//apporach to infinite value
for(j=0;j<k;j++)
{
z=0;
for(w=0;w<N;w++){
//Euclidean distance
z=z+(cent[j][w]-arr[i][w])*(cent[j][w]-arr[i][w]);
}
z=Math.sqrt(z);
if(dis>z)
{
//distance is measureble
dis=z;
arr[i][N]=j;//the cluster is selected
P2+=z;
}
}
}
}
public static void main(String []args){
kmeans k = new kmeans();
k.init();
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -