⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 rbf2.hpp

📁 现提供一个径向基网络的C++源程序
💻 HPP
📖 第 1 页 / 共 2 页
字号:
#if !defined(__RBF2_HPP)
#define __RBF2_HPP

#include <time.h>

#if !defined(__ANN2_HPP)
#include "ANN2.HPP"
#endif

#ifndef __OLS_HPP
#include "OLS.HPP"
#endif

#ifndef __KJUL_HPP
#include "kjul.hpp"
#endif

class GAU_NODE : public ANN_NODE   //Gaussian function node
{
      protected:
         double Sum_input;
      public:
         GAU_NODE()
           :ANN_NODE(){};
         GAU_NODE(short in_num,short td_num);
         ~GAU_NODE(){};
       virtual double Work();
       virtual void Adjustl(double deri_out,double w);
 };


class REG_NODE :public ANN_NODE   //Regressive function node
{
         double Sum_input;
      public:
         REG_NODE()
           :ANN_NODE(){};
         REG_NODE(short in_num,short td_num);
         ~REG_NODE(){};
       virtual double Work();
       virtual void Adjustl(double deri_out,double w);
 };



class RBF_NET : public ANN_NET
{     public:
        RBF_NET(short lay_num,short lay0_num,short lay1_num,short lay2_num);
        RBF_NET(short lay_num);
        ~RBF_NET(){ };
        virtual short  Learn(double * sam_in,double * sam_out,double * w,
                           const short sam_number,const short sam_inNum,
                           const short sam_outNum,double & diff,short maxinum,
                           short erro_type);
        virtual short InitNet(double sam_in[],double sam_out[],short number,
                            short in_num,short out_num,short type);


};
        void k_jul(double * sample,double * jul,short * cla,short sample_num,
                   short sample_dim,short ju_num);


class LSE_RBF_NET : public RBF_NET
{
      public:
        LSE_RBF_NET(short lay_num,short lay0_num,short lay1_num,short lay2_num);
        LSE_RBF_NET(short lay_num);
        ~LSE_RBF_NET(){ };
        virtual short  Learn(double * sam_in,double * sam_out,double * w,
                           const short sam_number,const short sam_inNum,
                           const short sam_outNum,double & diff,short maxinum,
                           short erro_type);
        virtual void For_pro(double * sam_in,double * sam_out,const short sam_num,
                     const short sam_inNum,const short sam_outNum);
        virtual double Back_pro(double * sam_in,double * sam_out,double * w,
                        const short sam_number,const short sam_inNum,
                        const short sam_outNum,short erro_type);
        virtual short InitNet(double sam_in[],double sam_out[],short number,
                            short in_num,short out_num,short type);


};
	void seq_method(double * xi,double * ai,double * bi,double * si,
                        const short m,const short n)  ;
        void lse(double * x,double * a,double * b, short m,short n,short p);

//....................................................................
// .................GAU_NODE.................

GAU_NODE :: GAU_NODE(short in_num,short td_num=1) // use Gaussian function
           : ANN_NODE(in_num,in_num,td_num)
           {
             Threshold[0]=1;
             Sum_input=0;
             strcpy(Lable,"Gaussian Function Node");
           }


double
GAU_NODE :: Work()
{
	   Sum_input=0;
	   for(short i=0;i<InNum;i++)
	     Sum_input+=pow((Inputs[i]-Weight[i]),2);
	   Sum_input/=2*pow(Threshold[0],2);
	   Output=exp(-Sum_input);
           return Output;
}

void
GAU_NODE :: Adjustl(double deri_out,double w=1.0)
{
           for(short i=0;i<WtNum;i++)
              Learn[i]+=w*deri_out*Output*(Inputs[i]-Weight[i])/pow(Threshold[0],2);
           Learn[2*WtNum]+=2*w*deri_out*Output*Sum_input/Threshold[0];
           for(i=0;i<InNum;i++)
              Deri_in[i]=-deri_out*Output*(Inputs[i]-Weight[i])/pow(Threshold[0],2);
}

//................................................................
//...................REG_NODE...............

REG_NODE :: REG_NODE(short in_num,short td_num=0) // use regressive function
           : ANN_NODE(in_num,in_num,td_num)
           {
             strcpy(Lable,"Regressive Function Node");
           }

double
REG_NODE :: Work()
{
	   Sum_input=0;
           Output=0;
	   for(short i=0;i<InNum;i++)
	      { Output+=Inputs[i]*Weight[i];
                Sum_input+=Inputs[i];
              }
	   Output/=Sum_input;
           return Output;
}



void
REG_NODE :: Adjustl(double deri_out,double w=1.0)
{
           for(short i=0;i<WtNum;i++)
              Learn[i]+=w*deri_out*Inputs[i]/Sum_input;
           for(i=0;i<InNum;i++)
              Deri_in[i]=deri_out*(Weight[i]-Output)/Sum_input;
}

//........................................................................
//...................RBF_NET..............................

RBF_NET :: RBF_NET(short lay_num,short lay0_num,short lay1_num,short lay2_num)
           :ANN_NET(lay_num) //lay_num==3
	  {
           Layer_node_num[0]=lay0_num;
           Layer_node_num[1]=lay1_num;
           Layer_node_num[2]=lay2_num;
           for(short j=0;j<lay0_num;j++)
              Layer[0].Lay_node[j]=new IN_NODE(1);
           for(j=0;j<lay1_num;j++)
              Layer[1].Lay_node[j]=new GAU_NODE(lay0_num);
           for(j=0;j<lay2_num;j++)
              Layer[2].Lay_node[j]=new REG_NODE(lay1_num);
           for(short i=0;i<Layer_num;i++)
              Layer_out[i]=new double[Layer_node_num[i]];
	   Net_inNum=lay0_num;
	   Net_outNum=lay2_num;
	   Inputs=new double[Net_inNum];
	   Outputs=new double[Net_outNum];
           strcpy(Lable,"Radical Basis Function  Nenual Network");
	  }
RBF_NET :: RBF_NET(short lay_num)
           :ANN_NET(lay_num) //lay_num==3
          {
            strcpy(Lable,"Radical Basis Function  Nenual Network");
          }






short
RBF_NET :: Learn(double * sam_in,double * sam_out,double * w,const short sam_number,
                 const short sam_inNum,const short sam_outNum,double & diff,
                 short maxinum=1500,short erro_type=1)
          {
            PARA1=0.01;PARA2=0.0;
            short count=0;
            count=BpLearn(sam_in,sam_out,w,sam_number,sam_inNum,sam_outNum,
                             diff,maxinum,erro_type);
            return count;
          }
short
RBF_NET :: InitNet(double sam_in[],double sam_out[],short number,
                   short in_num,short out_num,short type=1)
{   short i=0,j=0,lable;
    short lay2_num=Layer_node_num[1];
    short w_num=Layer[1].Lay_node[0]->GetWtNum();
    short t_num=Layer[1].Lay_node[0]->GetTdNum();

    if(number<lay2_num)
      {
        cout<<"\n samples are not enough";
        exit(-1);
      }
    if(in_num!=w_num||t_num!=1)
      { cout<<"\n Data is not equivalance";
        exit(0);
      }
    double * jul=new double[lay2_num*in_num];
    short    * cla=new short[number];
    double div[Layer_node_max],weight[Layer_node_max],t[Layer_node_max],erro=0.0;
    switch(type)
          {
            case 0:
               k_jul(sam_in,jul,cla,number,in_num,lay2_num);
               break;
            case 1:
               double * P=new double[number*number];
               for(i=0;i<number;i++)
                  { double min=1000000.0;
	            for(j=0;j<number;j++)
	               { double dl=0.0;
	                 if(j!=i)
		           { for(short k=0;k<in_num;k++)
		             dl+=pow((sam_in[j*in_num+k]-sam_in[i*in_num+k]),2);
		             if(min>dl)
		               min=dl;
		           }
	               }
	            div[i]=sqrt(min)/sqrt(2);
                  }
               for(i=0;i<number;i++)
                  { for(j=0;j<number;j++)
                       { double dl=0.0;
                         for(short k=0;k<in_num;k++)
                           dl+=pow(sam_in[j*in_num+k]-sam_in[i*in_num+k],2);
                         dl/=-2*pow(div[i],2);
                         P[j*number+i]=exp(dl);
                       }
                  }
               short count=ols(cla,P,sam_out,erro,number,number,lay2_num);

               if(count!=lay2_num)
                  count=lay2_num;
               for(i=0;i<count;i++)
                  for(j=0;j<in_num;j++)
                     jul[i*in_num+j]=sam_in[cla[i]*in_num+j];

               delete [] P;
               break;
          }

    for(i=0;i<lay2_num;i++)
       { double min=1000000.0;
	 for(j=0;j<lay2_num;j++)
	    { double dl=0.0;
	      if(j!=i)
		{ for(short k=0;k<in_num;k++)
		     dl+=pow((jul[j*in_num+k]-jul[i*in_num+k]),2);
		 if(min>dl)
		   min=dl;
		}
	    }
	 div[i]=sqrt(min)/sqrt(2);

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -