📄 rbfnetwork.java
字号:
import java.io.*;
import java.util.*;
import java.lang.*;
import javax.swing.JTextArea;
public class RBFNetwork{
private int Times=10;
private double sita=0.5;
private double tolerance=0.5;
private int KMeanTimes=20;
private double LearningRate=0.5;
private DataPoints Idata;
private HiddenNode hnode[];
private Point CentralPoint[];
private Point CentralPoint2[];
private JTextArea output;
// private double sum[]; //for update KMean
public RBFNetwork(DataPoints idata,String learningRate,String IterationTime,String EavTolerance,JTextArea out){
output=out;
LearningRate=Double.parseDouble(learningRate);
Times=Integer.parseInt(IterationTime);
tolerance=Double.parseDouble(EavTolerance);
Idata=idata;
hnode=new HiddenNode[Idata.getMaxClusterNum()];
CentralPoint=new Point[Idata.getMaxClusterNum()];
CentralPoint2=new Point[Idata.getMaxClusterNum()];
for(int i=0;i<hnode.length;i++) {
hnode[i]=new HiddenNode(/*idata.getDim()-1*/);
}
for(int h=0;h<Idata.getMaxClusterNum();h++) {
CentralPoint[h]=new Point();
}
for(int g=0;g<Idata.getMaxClusterNum();g++) {
CentralPoint2[g]=new Point();
CentralPoint2[g].SumOfVector=new Point();
}
}
public void Kmean() {
// sum=new double[Idata.getMaxClusterNum()];
//initialize center
Random generator2 = new Random( 80427 );
int CentralCount=0;
int random=0;
while(CentralCount<Idata.getMaxClusterNum()) {
random=generator2.nextInt(Idata.getTrainingCount());
CentralPoint[CentralCount]=Idata.getPoint()[random];
CentralCount++;
}
/*int Ndata[]=new int[Idata.getMaxClusterNum()]; //choose the Nth data
for(int j=0;j<Idata.getMaxClusterNum();j++) {
Ndata[j]= generator2.nextInt(Idata.getTrainingSet().length);
}*/
//
//find nearest center
double MinDistance=10000;
double Distance=0;
for(int a=0;a<KMeanTimes;a++){
for(int v=0;v<Idata.getTrainingCount();v++) {
MinDistance=10000;
Distance=0;
for(int u=0;u<Idata.getMaxClusterNum();u++) {
Distance=distance(Idata.getPoint()[v],CentralPoint[u]);
if(Distance< MinDistance) {
MinDistance=Distance;
Idata.getPoint()[v].center=CentralPoint[u];
}
}
}
//Update center
for(int b=0;b<Idata.getTrainingCount();b++) {
for(int y=0;y<Idata.getMaxClusterNum();y++) {
if(Idata.getPoint()[b].center==CentralPoint[y]) {
CentralPoint[y].SumOfVector=CentralPoint[y].sum(CentralPoint[y].SumOfVector,
Idata.getPoint()[b]);
CentralPoint[y].NCenterCluster=CentralPoint[y].NCenterCluster+1.0;
}
}
}
if((a+1)>=KMeanTimes) {
break;
}
for(int y=0;y<Idata.getMaxClusterNum();y++) {
CentralPoint[y].Divide(CentralPoint[y].SumOfVector, CentralPoint[y].NCenterCluster);
//Default the new CentralPoint should be initialized(position & NCenterCluster)
}
}
/*
for(int h=0;h<Idata.getTrainingSet().length;h++) {
for(int a=0;a<;a++) {
sum=sum+pow(,2);
}
}
*/
}
public double distance(Point p1,Point p2) {
double sum=0;
//double minmum=0;
for(int q=0;q<p1.position.length;q++) {
sum=sum+(
(p1.position[q]-p2.position[q])*(p1.position[q]-p2.position[q]));
}
return Math.sqrt(sum);
}
/* public void InitializeHiddenNode() {
for(int w=0;w<Idata.getTrainingCount();w++) {
//set sigma
if(Distance>CentralPoint[u].MaxDistance) {
CentralPoint[u].MaxDistance=Distance;
}
//
//set center Point
CentralPoint[]
}
for(int f=0;f<Idata.getMaxClusterNum();f++) {
hnode[f].center=CentralPoint[f];
hnode[f].sigma=CentralPoint[f].MaxDistance;
hnode[f].weight0=;
}
}*/
public void Init() {
//find center point
System.out.println("init finish");
for(int s=0;s<Idata.getTrainingCount();s++) {
int tmp=(int)Idata.getPoint()[s].DefaultClassify;
// System.out.println("tmp:"+tmp);
switch(tmp) {
case 1:
CentralPoint2[0].SumOfVector=CentralPoint2[0].sum(CentralPoint2[0].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[0].NCenterCluster++;
System.out.println("1 problem:"+CentralPoint2[0].NCenterCluster);
break;
case 2:
CentralPoint2[1].SumOfVector=CentralPoint2[1].sum(CentralPoint2[1].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[1].NCenterCluster++;
System.out.println("2 problem:"+CentralPoint2[1].NCenterCluster);
break;
case 3:
CentralPoint2[2].SumOfVector=CentralPoint2[2].sum(CentralPoint2[2].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[2].NCenterCluster++;
System.out.println("3 problem:"+CentralPoint2[2].NCenterCluster);
System.out.println("Cluster number"+Idata.getMaxClusterNum());
break;
case 4:
CentralPoint2[3].SumOfVector=CentralPoint2[3].sum(CentralPoint2[3].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[3].NCenterCluster++;
System.out.println("4 problem:"+CentralPoint2[3].NCenterCluster);
break;
default :
System.out.println("Default problem!");
}
}
System.out.println("init finish");
for(int w=0;w<Idata.getMaxClusterNum();w++) {
System.out.println("NCenterCluster:"+CentralPoint2[w].NCenterCluster);
CentralPoint2[w].Divide(CentralPoint2[w], CentralPoint2[w].NCenterCluster);
}
//
//find sigma ,W
System.out.println("init finish");
for(int q=0;q<Idata.getTrainingCount();q++) {
switch((int)Idata.getPoint()[q].DefaultClassify) {
case 1:
if(CentralPoint2[0].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[0]))
CentralPoint2[0].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[0]);
CentralPoint2[0].EspectedOutputSum=CentralPoint2[0].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
case 2:
if(CentralPoint2[1].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[1]))
CentralPoint2[1].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[1]);
CentralPoint2[1].EspectedOutputSum=CentralPoint2[1].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
case 3:
if(CentralPoint2[2].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[2]))
CentralPoint2[2].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[2]);
CentralPoint2[2].EspectedOutputSum=CentralPoint2[2].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
case 4:
if(CentralPoint2[3].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[3]))
CentralPoint2[3].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[3]);
CentralPoint2[3].EspectedOutputSum=CentralPoint2[3].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
}
}
System.out.println("init finish");
for(int f=0;f<Idata.getMaxClusterNum();f++) {
System.out.println("Update_Weight_NCenterCluster:"+CentralPoint2[f].NCenterCluster);
hnode[f].center=CentralPoint2[f];
System.out.println("Center_point"+f+":"+CentralPoint2[f].position[0]+","+CentralPoint2[f].position[1]+","
+CentralPoint2[f].position[2]+","+CentralPoint2[f].position[3]);
hnode[f].sigma=CentralPoint2[f].MaxDistance;
hnode[f].weight=(CentralPoint2[f].EspectedOutputSum/CentralPoint2[f].NCenterCluster);
}
System.out.println("init finish");
}
/*
public void InitHiddenFake() {
//find center point
for(int s=0;s<Idata.getTrainingCount();s++) {
switch((int)Idata.getPoint()[s].DefaultClassify) {
case 1:
CentralPoint2[0].SumOfVector=CentralPoint2[0].sum(CentralPoint2[0].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[0].NCenterCluster++;
break;
case 2:
CentralPoint2[1].SumOfVector=CentralPoint2[1].sum(CentralPoint2[1].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[1].NCenterCluster++;
break;
case 3:
CentralPoint2[2].SumOfVector=CentralPoint2[2].sum(CentralPoint2[2].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[2].NCenterCluster++;
break;
case 4:
CentralPoint2[3].SumOfVector=CentralPoint2[3].sum(CentralPoint2[3].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[4].NCenterCluster++;
break;
}
}
for(int w=0;w<Idata.getMaxClusterNum();w++) {
CentralPoint2[w]=CentralPoint2[w].Divide(CentralPoint2[w].SumOfVector,
CentralPoint2[w].NCenterCluster);
}
//
//find sigma ,W
for(int q=0;q<Idata.getTrainingCount();q++) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -