⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 javamlp.java

📁 Java Source Code for MLP in NN
💻 JAVA
字号:
/*MLP neural network in Javaby Phil Brierleywww.philbrierley.comThis code may be freely used and modified at willTanh hidden neuronsLinear output neuronTo include an input bias create anextra input in the training dataand set to 1Routines included:calcNet()WeightChangesHO()WeightChangesIH()initWeights()initData()tanh(double x)displayResults()calcOverallError()compiled and tested onSymantec Cafe Lite*//* * To change this template, choose Tools | Templates * and open the template in the editor. */package javaapplication8;/** * * @author Maninder */public class JavaMLP {   //user defineable variables public static int numEpochs = 500; //number of training cycles public static int numInputs  = 3; //number of inputs - this includes the input bias public static int numHidden  = 4; //number of hidden units public static int numPatterns = 4; //number of training patterns public static double LR_IH = 0.7; //learning rate public static double LR_HO = 0.07; //learning rate //process variables public static int patNum; public static double errThisPat; public static double outPred; public static double RMSerror; //training data public static double[][] trainInputs  = new double[numPatterns][numInputs]; public static double[] trainOutput = new double[numPatterns]; //the outputs of the hidden neurons public static double[] hiddenVal  = new double[numHidden]; //the weights public static double[][] weightsIH = new double[numInputs][numHidden]; public static double[] weightsHO = new double[numHidden];//==============================================================//********** THIS IS THE MAIN PROGRAM **************************//============================================================== public static void main(String[] args) {  //initiate the weights  initWeights();  //load in the data  initData();  //train the network    for(int j = 0;j <= numEpochs;j++)    {        for(int i = 0;i<numPatterns;i++)        {            //select a pattern at random            patNum = (int)((Math.random()*numPatterns)-0.001);            //calculate the current network output            //and error for this pattern            calcNet();            //change network weights            WeightChangesHO();            WeightChangesIH();        }        //display the overall network error        //after each epoch        calcOverallError();        System.out.println("epoch = " + j + "  RMS Error = " + RMSerror);    }    //training has finished    //display the results    displayResults(); } //============================================================//********** END OF THE MAIN PROGRAM **************************//============================================================= //************************************public static void calcNet() {    //calculate the outputs of the hidden neurons    //the hidden neurons are tanh    for(int i = 0;i<numHidden;i++)    {	hiddenVal[i] = 0.0;        for(int j = 0;j<numInputs;j++)        hiddenVal[i] = hiddenVal[i] + (trainInputs[patNum][j] * weightsIH[j][i]);        hiddenVal[i] = tanh(hiddenVal[i]);    }   //calculate the output of the network   //the output neuron is linear   outPred = 0.0;   for(int i = 0;i<numHidden;i++)    outPred = outPred + hiddenVal[i] * weightsHO[i];    //calculate the error    errThisPat = outPred - trainOutput[patNum]; }//************************************ public static void WeightChangesHO() //adjust the weights hidden-output {   for(int k = 0;k<numHidden;k++)   {    double weightChange = LR_HO * errThisPat * hiddenVal[k];    weightsHO[k] = weightsHO[k] - weightChange;    //regularisation on the output weights    if (weightsHO[k] < -5)        weightsHO[k] = -5;    else if (weightsHO[k] > 5)        weightsHO[k] = 5;   } }//************************************ public static void WeightChangesIH() //adjust the weights input-hidden {  for(int i = 0;i<numHidden;i++)  {   for(int k = 0;k<numInputs;k++)   {    double x = 1 - (hiddenVal[i] * hiddenVal[i]);    x = x * weightsHO[i] * errThisPat * LR_IH;    x = x * trainInputs[patNum][k];    double weightChange = x;    weightsIH[k][i] = weightsIH[k][i] - weightChange;   }  } }//************************************ public static void initWeights() {  for(int j = 0;j<numHidden;j++)  {    weightsHO[j] = (Math.random() - 0.5)/2;    for(int i = 0;i<numInputs;i++)    weightsIH[i][j] = (Math.random() - 0.5)/5;  } }//************************************ public static void initData() {    System.out.println("initialising data");    // the data here is the XOR data    // it has been rescaled to the range    // [-1][1]    // an extra input valued 1 is also added    // to act as the bias    trainInputs[0][0]  = 1;    trainInputs[0][1]  = -1;    trainInputs[0][2]  = 1;//bias    trainOutput[0] = 1;    trainInputs[1][0]  = -1;    trainInputs[1][1]  = 1;    trainInputs[1][2]  = 1;//bias    trainOutput[1] = 1;    trainInputs[2][0]  = 1;    trainInputs[2][1]  = 1;    trainInputs[2][2]  = 1;//bias    trainOutput[2] = -1;    trainInputs[3][0]  = -1;    trainInputs[3][1]  = -1;    trainInputs[3][2]  = 1;//bias    trainOutput[3] = -1; }//************************************ public static double tanh(double x) {    if (x > 20)        return 1;    else if (x < -20)        return -1;    else        {        double a = Math.exp(x);        double b = Math.exp(-x);        return (a-b)/(a+b);        } }//************************************ public static void displayResults()    {     for(int i = 0;i<numPatterns;i++)        {        patNum = i;        calcNet();        System.out.println("pat = " + (patNum+1) + " actual = " + trainOutput[patNum] + " neural model = " + outPred);        }    }//************************************public static void calcOverallError()    {     RMSerror = 0.0;     for(int i = 0;i<numPatterns;i++)        {        patNum = i;        calcNet();        RMSerror = RMSerror + (errThisPat * errThisPat);        }     RMSerror = RMSerror/numPatterns;     RMSerror = java.lang.Math.sqrt(RMSerror);    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -