📄 rbfnetwork.java
字号:
switch((int)Idata.getPoint()[q].DefaultClassify) {
case 1:
if(CentralPoint2[0].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[0]))
CentralPoint2[0].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[0]);
CentralPoint2[0].EspectedOutputSum=CentralPoint2[0].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
case 2:
if(CentralPoint2[1].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[1]))
CentralPoint2[1].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[1]);
CentralPoint2[1].EspectedOutputSum=CentralPoint2[1].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
case 3:
if(CentralPoint2[2].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[2]))
CentralPoint2[2].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[2]);
CentralPoint2[2].EspectedOutputSum=CentralPoint2[2].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
case 4:
if(CentralPoint2[3].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[3]))
CentralPoint2[3].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[3]);
CentralPoint2[3].EspectedOutputSum=CentralPoint2[3].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
}
}
for(int f=0;f<Idata.getMaxClusterNum();f++) {
hnode[f].center=CentralPoint[f];
hnode[f].sigma=CentralPoint[f].MaxDistance;
switch((int)(CentralPoint[f].EspectedOutputSum/CentralPoint[f].NCenterCluster)) {
case 1:
hnode[f].weight0=0;
hnode[f].weight1=0;
break;
case 2:
hnode[f].weight0=1;
hnode[f].weight1=0;
break;
case 3:
hnode[f].weight0=0;
hnode[f].weight1=1;
break;
case 4:
hnode[f].weight0=1;
hnode[f].weight1=1;
break;
}
// hnode[f].weight0=;
}
}*/
public double F(Point g) {
double Sum=0;
for(int e=0;e<Idata.getMaxClusterNum();e++) {
Sum=Sum+(hnode[e].weight*(basis(g,hnode[e])));
}
Sum+=sita;
// System.out.println("sum:"+Sum);
return Sum;
}
public double basis(Point g,HiddenNode h) {
return Math.exp(
-(Math.pow((distance(g,h.center)),2) / (2*h.sigma*h.sigma) )
);
}
public void WeightUpdate() {
}
public void SitaUpdate() {
}
public void training() {
int count=0;
double F=0;
double miss=Double.MAX_VALUE;
while(count<Times && miss>tolerance ) {
miss=0.0;
for(int g=0;g<Idata.getTrainingCount();g++) { //training every point
F=F(Idata.getPoint()[g]);
miss=miss+
( ((Idata.getPoint()[g].DefaultClassify-F)*(Idata.getPoint()[g].DefaultClassify-F))/2 );
System.out.println("F_points["+g+"]"+F+"| |"+"default:"+Idata.getPoint()[g].DefaultClassify);
output.append("F_points["+g+"]"+F+"| |"+"default:"+Idata.getPoint()[g].DefaultClassify+"\n");
for(int d=0;d<Idata.getMaxClusterNum();d++) {
System.out.println("weight:"+hnode[d].weight);
//updata weight
hnode[d].weight=hnode[d].weight+
(LearningRate*(Idata.getPoint()[g].DefaultClassify-F)*basis(Idata.getPoint()[g],hnode[d]) );
/* //updata center
hnode[d].center=hnode[d].center.sum(
hnode[d].center,
hnode[d].center.multiply(hnode[d].center.substract(Idata.getPoint()[g],hnode[d].center),
(LearningRate*(Idata.getPoint()[g].DefaultClassify-F)*basis(Idata.getPoint()[g],hnode[d]))
)
)
;
*/
/* hnode[d].weight0=;
hnode[d].weight1=;
hnode[d].sigma=;
hnode[d].sita=;
*/
}
sita=sita+(LearningRate* (Idata.getPoint()[g].DefaultClassify-F) );
}
miss=miss/Idata.getTrainingCount();
count++;
}
}
public void testing() {
double F=0;
for(int g=0;g<Idata.getTestingCount();g++) { //training every point
F=F(Idata.getTestPoint()[g]);
System.out.println("test_points["+g+"]"+F+"| |"+"default:"+Idata.getTestPoint()[g].FinalClassify);
}
}
public void training2() {
int count=0;
double F=0;
double miss=Double.MAX_VALUE;
while(count<Times && miss>tolerance ) {
miss=0.0;
for(int g=0;g<Idata.getTrainingCount();g++) { //training every point
F=F(Idata.getPoint()[g]);
miss=miss+
( ((Idata.getPoint()[g].FinalClassify-F)*(Idata.getPoint()[g].FinalClassify-F))/2 );
System.out.println("F_points["+g+"]"+F+"| |"+"default:"+Idata.getPoint()[g].FinalClassify);
output.append("F_points["+g+"]"+F+"| |"+"default:"+Idata.getPoint()[g].FinalClassify+"\n");
for(int d=0;d<Idata.getMaxClusterNum();d++) {
System.out.println("weight:"+hnode[d].weight);
//updata weight
hnode[d].weight=hnode[d].weight+
(LearningRate*(Idata.getPoint()[g].FinalClassify-F)*basis(Idata.getPoint()[g],hnode[d]) );
/* //updata center
hnode[d].center=hnode[d].center.sum(
hnode[d].center,
hnode[d].center.multiply(hnode[d].center.substract(Idata.getPoint()[g],hnode[d].center),
(LearningRate*(Idata.getPoint()[g].DefaultClassify-F)*basis(Idata.getPoint()[g],hnode[d]))
)
)
;
*/
/* hnode[d].weight0=;
hnode[d].weight1=;
hnode[d].sigma=;
hnode[d].sita=;
*/
}
sita=sita+(LearningRate* (Idata.getPoint()[g].FinalClassify-F) );
}
miss=miss/Idata.getTrainingCount();
count++;
}
}
public void Init2() {
//find center point
System.out.println("init finish");
for(int s=0;s<Idata.getTrainingCount();s++) {
int tmp=(int)Idata.getPoint()[s].DefaultClassify;
System.out.println("tmp:"+tmp);
switch(tmp) {
case 1:
CentralPoint2[0].SumOfVector=CentralPoint2[0].sum(CentralPoint2[0].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[0].NCenterCluster++;
System.out.println("1 problem:"+CentralPoint2[0].NCenterCluster);
break;
case 2:
CentralPoint2[1].SumOfVector=CentralPoint2[1].sum(CentralPoint2[1].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[1].NCenterCluster++;
System.out.println("2 problem:"+CentralPoint2[1].NCenterCluster);
break;
case 3:
CentralPoint2[2].SumOfVector=CentralPoint2[2].sum(CentralPoint2[2].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[2].NCenterCluster++;
System.out.println("3 problem:"+CentralPoint2[2].NCenterCluster);
System.out.println("Cluster number"+Idata.getMaxClusterNum());
break;
case 4:
CentralPoint2[3].SumOfVector=CentralPoint2[3].sum(CentralPoint2[3].SumOfVector,
Idata.getPoint()[s]);
CentralPoint2[3].NCenterCluster++;
System.out.println("4 problem:"+CentralPoint2[3].NCenterCluster);
break;
default :
System.out.println("Default problem!");
}
}
System.out.println("init finish");
for(int w=0;w<Idata.getMaxClusterNum();w++) {
System.out.println("NCenterCluster:"+CentralPoint2[w].NCenterCluster);
CentralPoint2[w].Divide(CentralPoint2[w], CentralPoint2[w].NCenterCluster);
}
//
//find sigma ,W
System.out.println("init finish");
for(int q=0;q<Idata.getTrainingCount();q++) {
switch((int)Idata.getPoint()[q].DefaultClassify) {
case 1:
if(CentralPoint2[0].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[0]))
CentralPoint2[0].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[0]);
CentralPoint2[0].EspectedOutputSum=CentralPoint2[0].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
case 2:
if(CentralPoint2[1].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[1]))
CentralPoint2[1].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[1]);
CentralPoint2[1].EspectedOutputSum=CentralPoint2[1].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
case 3:
if(CentralPoint2[2].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[2]))
CentralPoint2[2].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[2]);
CentralPoint2[2].EspectedOutputSum=CentralPoint2[2].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
case 4:
if(CentralPoint2[3].MaxDistance< distance(Idata.getPoint()[q],CentralPoint2[3]))
CentralPoint2[3].MaxDistance=distance(Idata.getPoint()[q],CentralPoint2[3]);
CentralPoint2[3].EspectedOutputSum=CentralPoint2[3].EspectedOutputSum+Idata.getPoint()[q].DefaultClassify;
break;
}
Idata.getPoint()[q].FinalClassify=(Idata.getPoint()[q].DefaultClassify-1)/(Idata.getMaxClusterNum()-1);
}
System.out.println("init finish");
for(int f=0;f<Idata.getMaxClusterNum();f++) {
System.out.println("Update_Weight_NCenterCluster:"+CentralPoint2[f].NCenterCluster);
hnode[f].center=CentralPoint2[f];
System.out.println("Center_point"+f+":"+CentralPoint2[f].position[0]+","+CentralPoint2[f].position[1]+","
+CentralPoint2[f].position[2]+","+CentralPoint2[f].position[3]);
hnode[f].sigma=CentralPoint2[f].MaxDistance;
hnode[f].weight=(CentralPoint2[f].EspectedOutputSum/CentralPoint2[f].NCenterCluster);
hnode[f].weight=(hnode[f].weight-1)/(Idata.getMaxClusterNum()-1);
}
System.out.println("init finish");
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -