📄 kmean.cpp
字号:
#include "kmean.h"
void Kmean::initClusters()
{
for(int i = 0; i < this->patterns.numOfCluster; i++)
{
Kmean::Cluster cluster;
cluster.setCenter(this->patterns.getPattern(i));
cluster.addPattern(i);
this->clusters.push_back(cluster);
}
}
void Kmean::classifyPattern()
{
for(int i = 0; i < this->patterns.numOfCluster; i++)
{
this->clusters[i].clear();
}
for(int i = 0; i < this->patterns.numOfPatterns; i++)
{
int indexOfCluster = this->getClosestCluster(i);
this->clusters[indexOfCluster].addPattern(i);
}
}
int Kmean::getClosestCluster(int i)
{
double dist = 9.9e+99;
int clustId = -1;
Pattern pattern = this->patterns.getPattern(i);
for(int j = 0; j < this->patterns.numOfCluster; j++)
{
double d = this->patterns.caluDistance(this->clusters[j].getCenter(),pattern);
if(d < dist)
{
clustId = j;
dist = d;
}
}
return clustId;
}
bool Kmean::caluCenter()
{
bool classfied = true;
Pattern tempPattern;
for(int i = 0; i < this->patterns.numOfCluster; i++)
{
tempPattern.clear();
for(int j = 0;j < this->patterns.numOfDim; j++)
{
tempPattern.push_back(0);
}
Kmean::Cluster cluster = this->clusters[i];
if(cluster.numOfPatterns<=0)
{
cluster.setCenter(tempPattern);
}
else
{
for(int j = 0; j < cluster.numOfPatterns; j++)
{
for(int k = 0; k < this->patterns.numOfDim; k++)
{
Pattern pattern = this->patterns.getPattern(cluster.getPattern(j));
tempPattern[k] += pattern[k];
}
}
for(int j = 0; j < this->patterns.numOfDim; j++)
{
tempPattern[j] = tempPattern[j] / cluster.numOfPatterns;
double min = tempPattern[j] - cluster.getCenter()[j];
if((tempPattern[j] != cluster.getCenter()[j]))
classfied = false;
}
this->clusters[i].setCenter(tempPattern);
}
}
return classfied;
}
void Kmean::saveClusters()
{
char ifsave;
cout<<"the sample has been classfied...."<<endl;
cout<<"save the clusters?[y/n]: ";
cin>>ifsave;
if(ifsave=='y' || ifsave=='Y')
{
cout<<"input the target file name: ";
string fileName;
cin>>fileName;
ofstream fout;
fout.open(fileName.c_str(), ios::out);
fout<<this->patterns.numOfPatterns<<endl;
fout<<this->patterns.numOfDim<<endl;
fout<<this->patterns.numOfCluster<<endl;
for(int i = 0; i < this->patterns.numOfCluster; i++)
{
Kmean::Cluster cluster = this->clusters[i];
/*for(int j = 0; j < cluster.numOfPatterns; j++)
{
Pattern pattern = this->patterns.getPattern(cluster.getPattern(j));
for(int k = 0; k < this->patterns.numOfDim; k++)
{
fout<<pattern[k]<<" ";
}
fout<<i+1<<endl;
}*/
for(int j = 0; j < this->patterns.numOfDim; j++)
{
for(int k = 0; k < cluster.numOfPatterns-1; k++)
{
Pattern pattern = this->patterns.getPattern(cluster.getPattern(k));
fout<<pattern[j]<<endl;
}
fout<<"*************************************"<<endl;
}
}
fout.close();
}
}
void Kmean::display()
{
for(int i = 0; i < this->patterns.numOfCluster; i++)
{
Kmean::Cluster cluster = this->clusters[i];
Pattern center = cluster.getCenter();
cout<<"center: "<<center[0]<<" "<<center[1]<<endl;
for(int j = 0; j < cluster.numOfPatterns; j++)
{
Pattern pattern = this->patterns.getPattern(cluster.getPattern(j));
for(int k = 0; k < this->patterns.numOfDim; k++)
{
cout<<pattern[k]<<" ";
}
cout<<i+1<<endl;
}
}
cout<<endl<<endl;
}
void Kmean::run()
{
bool classfied = false;
this->patterns.load(KMEAN);
this->initClusters();
this->patterns.display();
while(!classfied)
{
this->classifyPattern();
classfied = this->caluCenter();
}
this->saveClusters();
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -