⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 backpropxor.java

📁 CroftSoft Code Library是一个开源的可移植的纯Java游戏库
💻 JAVA
📖 第 1 页 / 共 2 页
字号:
       plotGraph  ( g, samples );

       g.drawString ( "Inputs", 10 + 0 * XTAB, 10 + 0 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l0_activations.data    [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 1 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l0_activations.data    [ 1 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 2 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l0_activations.data    [ 2 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 3 * YTAB + y2 );

       g.drawString ( "Weights", 10 + 1 * XTAB, 10 + 0 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l1_weights.data        [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 1 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l1_weights.data        [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 2 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l1_weights.data        [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 3 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l1_weights.data        [ 0 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 4 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l1_weights.data        [ 1 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 5 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l1_weights.data        [ 2 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 6 * YTAB + y2 );

       g.drawString ( "Weighted Sums", 10 + 2 * XTAB, 10 + 0 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l1_weighted_sums.data  [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 1 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l1_weighted_sums.data  [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 2 * YTAB + y2 );
                                                                                     
       g.drawString ( "Hidden", 10 + 3 * XTAB, 10 + 0 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l2_inputs.data         [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 1 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l2_inputs.data         [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 2 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l2_inputs.data         [ 2 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 3 * YTAB + y2 );

       g.drawString ( "Weights", 10 + 4 * XTAB, 10 + 0 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l2_weights.data        [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 1 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l2_weights.data        [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 2 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l2_weights.data        [ 2 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 3 * YTAB + y2 );

       g.drawString ( "Weighted Sum", 10 + 5 * XTAB, 10 + 0 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l2_weighted_sums.data  [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 10 + 1 * YTAB + y2 );

       g.drawString ( "Output", 10 + 6 * XTAB, 10 + 0 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( l2_activations.data    [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 10 + 1 * YTAB + y2 );

       g.drawString ( "Output Desired", 10 + 6 * XTAB, 10 + 2 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( output_desired ), 10 + 6 * XTAB, 10 + 3 * YTAB + y2 );

       g.drawString ( "Output Error", 10 + 6 * XTAB, 10 + 4 * YTAB + y2 );
       g.drawString ( decimalFormat.format ( output_error ), 10 + 6 * XTAB, 10 + 5 * YTAB + y2 );

       g.drawString ( "Iterations", 10 + 7 * XTAB,  10 + 0 * YTAB + y2 );
       g.drawString ( "" + total_samples, 10 + 7 * XTAB,  10 + 1 * YTAB + y2 );

       g.drawString ( "Function", 10 + 7 * XTAB,  10 + 2 * YTAB + y2 );
       g.drawString ( function_String, 10 + 7 * XTAB,  10 + 3 * YTAB + y2 );

       g.drawString ( "L2 Gradient", 10 + 0 * XTAB, 0 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l2_local_gradient.data [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 1 * YTAB + y4 );

       g.drawString ( "W2 Deltas", 10 + 1 * XTAB, 0 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l2_weights_delta.data  [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 1 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l2_weights_delta.data  [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 2 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l2_weights_delta.data  [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 3 * YTAB + y4 );

       g.drawString ( "W2 Momentum", 10 + 2 * XTAB, 0 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 1 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 2 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 2 * XTAB, 3 * YTAB + y4 );

       g.drawString ( "Sum W Deltas", 10 + 3 * XTAB, 0 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 1 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 2 * YTAB + y4 );

       g.drawString ( "L1 Gradients", 10 + 4 * XTAB, 0 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 1 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 2 * YTAB + y4 );

       g.drawString ( "W1 Deltas", 10 + 5 * XTAB, 0 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 1 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 1 ] [ 0 ] ), 10 + 5 * XTAB, 2 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 2 ] [ 0 ] ), 10 + 5 * XTAB, 3 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 0 ] [ 1 ] ), 10 + 5 * XTAB, 4 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 1 ] [ 1 ] ), 10 + 5 * XTAB, 5 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_delta.data  [ 2 ] [ 1 ] ), 10 + 5 * XTAB, 6 * YTAB + y4 );

       g.drawString ( "W1 Momentum", 10 + 6 * XTAB, 0 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 1 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 6 * XTAB, 2 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 6 * XTAB, 3 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 1 ] ), 10 + 6 * XTAB, 4 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 1 ] ), 10 + 6 * XTAB, 5 * YTAB + y4 );
       g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 1 ] ), 10 + 6 * XTAB, 6 * YTAB + y4 );
     }

     public void  run ( )
     //////////////////////////////////////////////////////////////////////
     {
       try
       {

       long  lastRepaintTime = 0;

       while ( !pleaseStop )
       {
         if ( iteration == 0 )
         {
           long  currentTime = System.currentTimeMillis ( );

           if ( currentTime >= lastRepaintTime + REPAINT_PERIOD )
           {
             lastRepaintTime = currentTime;

             repaint ( );
           }
         }

         l2_weights = l2_weights.add ( l2_weights_delta );

//         l2_weights = l2_weights.clip ( -1.0, 1.0 );

         l1_weights = l1_weights.add ( l1_weights_delta );

//         l1_weights = l1_weights.clip ( -1.0, 1.0 );

         l0_activations = l0_activations.randomizeUniform ( 0.0, 1.0 );

         l0_activations.data [ 0 ] [ 0 ] = 1.0;

         samples [ iteration ] [ 0 ] = l0_activations.data [ 1 ] [ 0 ];

         samples [ iteration ] [ 1 ] = l0_activations.data [ 2 ] [ 0 ];

         l1_weighted_sums = Matrix.multiply (
           l1_weights.transpose ( ), l0_activations );

         l1_activations = l1_weighted_sums.sigmoid ( );

         l2_inputs.data [ 0 ] [ 0 ] = 1.0;

         l2_inputs.data [ 1 ] [ 0 ] = l1_activations.data [ 0 ] [ 0 ];

         l2_inputs.data [ 2 ] [ 0 ] = l1_activations.data [ 1 ] [ 0 ];

         l2_weighted_sums = Matrix.multiply (
           l2_weights.transpose ( ), l2_inputs );

         l2_activations = l2_weighted_sums.sigmoid ( );

         samples [ iteration ] [ 2 ] = l2_activations.data [ 0 ] [ 0 ];

         output_desired = target_function ( l0_activations );

         output_error = output_desired - l2_activations.data [ 0 ] [ 0 ];

         l2_local_gradient = l2_weighted_sums.sigmoidDerivative ( );

         l2_local_gradient = l2_local_gradient.multiply ( output_error );

// I'm not sure about the transpose below here.

         l2_weights_delta
           = Matrix.multiply ( l2_inputs, l2_local_gradient.transpose ( ) );

         l2_weights_delta = l2_weights_delta.multiply ( learning_rate );

         l2_weights_delta = l2_weights_delta.add ( l2_weights_momentum );

         l2_weights_momentum
           = l2_weights_delta.multiply ( momentum_constant );

// needs transpose or sum below?

         sum_weighted_deltas = l2_weights.submatrix ( 1, 2, 0, 0 );

         sum_weighted_deltas = Matrix.multiply (
           sum_weighted_deltas, l2_local_gradient );

         l1_local_gradients = new Matrix (
           l1_activations.rows, l1_activations.cols, 1.0 );

         l1_local_gradients = l1_local_gradients.subtract ( l1_activations );

         l1_local_gradients = Matrix.multiplyPairwise (
           l1_activations, l1_local_gradients );

         l1_local_gradients = Matrix.multiplyPairwise (
           l1_local_gradients, sum_weighted_deltas );

         l1_weights_delta = Matrix.multiply (
           l0_activations, l1_local_gradients.transpose ( ) );

         l1_weights_delta = l1_weights_delta.multiply ( learning_rate );

         l1_weights_delta = l1_weights_delta.add ( l1_weights_momentum );

         l1_weights_momentum
           = l1_weights_delta.multiply ( momentum_constant );

         squared_errors [ iteration ] = output_error * output_error;

         iteration++;

         total_samples++;

         if ( iteration % ITERATIONS_PER_EPOCH == 0 )
         {
           epoch++;

           if ( epoch >= epochs_max ) epoch = 1;

           iteration = 0;

           epoch_rms_error [ epoch - 1 ] = 0.0;

           for ( int index_iteration = 0;
                     index_iteration < ITERATIONS_PER_EPOCH;
                     index_iteration++ )
           {
             epoch_rms_error [ epoch - 1 ]
               += squared_errors [ index_iteration ];
           }

           epoch_rms_error [ epoch - 1 ] /= ( double ) ITERATIONS_PER_EPOCH;

           epoch_rms_error [ epoch - 1 ]
             = Math.sqrt ( epoch_rms_error [ epoch - 1 ] );
         }

         Thread.sleep ( 0 );
       }

       repaint ( );

       }
       catch ( Exception  ex )
       {
         ex.printStackTrace ( );
       }
       finally
       {
         runner = null;
       }
     }

     //////////////////////////////////////////////////////////////////////
     // private methods
     //////////////////////////////////////////////////////////////////////

     private void  plot_epochs (
       Rectangle  r, Graphics  g, double [ ]  epochs )
     //////////////////////////////////////////////////////////////////////
     {
       g.setColor ( java.awt.Color.black );
       g.fillRect ( r.x, r.y, r.width, r.height );
       g.setColor ( java.awt.Color.white );
       g.drawRect ( r.x, r.y, r.width, r.height / 2 );
//       g.clipRect ( r.x, r.y, r.width, r.height );
       for ( int index_epoch = 1;
                 index_epoch <= epoch;
                 index_epoch++ )
       {
         PlotLib.xy ( java.awt.Color.red,
           ( double ) index_epoch, epochs [ index_epoch - 1 ],
           r, g, 1.0, ( double ) epoch, 0.0, 1.0, OVAL_SIZE, true );
       }
//       g.clipRect ( 0, 0, SIZE.width, SIZE.height );
       g.setColor ( java.awt.Color.white );
       g.drawRect ( r.x, r.y, r.width, r.height );
       g.setColor ( this.getForeground ( ) );
     }

     private void  plotGraph ( Graphics  g, double [ ] [ ]  samples )
     //////////////////////////////////////////////////////////////////////
     {
       g.setColor ( Color.black );

       g.fillRect ( rs.x, rs.y, rs.width, rs.height );

       for ( int index_iteration = 0;
         index_iteration < ITERATIONS_PER_EPOCH;
         index_iteration++ )
       {
         PlotLib.xy (
           samples [ index_iteration ] [ 2 ] >= 0.5
             ? Color.green : Color.red,
           samples [ index_iteration ] [ 0 ],
           samples [ index_iteration ] [ 1 ],
           rs, g, 0.0, 1.0, 0.0, 1.0, OVAL_SIZE );
       }

       g.setColor ( Color.white );

       g.drawRect ( rs.x, rs.y, rs.width, rs.height );

       g.setColor ( this.getForeground ( ) );
     }

     private void  randomize_weights ( )
     //////////////////////////////////////////////////////////////////////
     {
       l1_weights = l1_weights.randomizeUniform ( -1.0, 1.0 );

       l2_weights = l2_weights.randomizeUniform ( -1.0, 1.0 );
     }

     private double  target_function ( Matrix  inputs )
     //////////////////////////////////////////////////////////////////////
     {
       long  a = Math.round ( inputs.data [ 1 ] [ 0 ] );

       long  b = Math.round ( inputs.data [ 2 ] [ 0 ] );

       long  bit_num = 2 * b + a;

       long  mask = 1 << bit_num;

       long  masked = function_selected & mask;

       return ( masked == mask ) ? 1.0 : 0.0;
     }

     //////////////////////////////////////////////////////////////////////
     //////////////////////////////////////////////////////////////////////
     }

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -