📄 backpropxor.java
字号:
plotGraph ( g, samples );
g.drawString ( "Inputs", 10 + 0 * XTAB, 10 + 0 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l0_activations.data [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 1 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l0_activations.data [ 1 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 2 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l0_activations.data [ 2 ] [ 0 ] ), 10 + 0 * XTAB, 10 + 3 * YTAB + y2 );
g.drawString ( "Weights", 10 + 1 * XTAB, 10 + 0 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l1_weights.data [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 1 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l1_weights.data [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 2 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l1_weights.data [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 10 + 3 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l1_weights.data [ 0 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 4 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l1_weights.data [ 1 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 5 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l1_weights.data [ 2 ] [ 1 ] ), 10 + 1 * XTAB, 10 + 6 * YTAB + y2 );
g.drawString ( "Weighted Sums", 10 + 2 * XTAB, 10 + 0 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l1_weighted_sums.data [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 1 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l1_weighted_sums.data [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 10 + 2 * YTAB + y2 );
g.drawString ( "Hidden", 10 + 3 * XTAB, 10 + 0 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l2_inputs.data [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 1 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l2_inputs.data [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 2 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l2_inputs.data [ 2 ] [ 0 ] ), 10 + 3 * XTAB, 10 + 3 * YTAB + y2 );
g.drawString ( "Weights", 10 + 4 * XTAB, 10 + 0 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l2_weights.data [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 1 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l2_weights.data [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 2 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l2_weights.data [ 2 ] [ 0 ] ), 10 + 4 * XTAB, 10 + 3 * YTAB + y2 );
g.drawString ( "Weighted Sum", 10 + 5 * XTAB, 10 + 0 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l2_weighted_sums.data [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 10 + 1 * YTAB + y2 );
g.drawString ( "Output", 10 + 6 * XTAB, 10 + 0 * YTAB + y2 );
g.drawString ( decimalFormat.format ( l2_activations.data [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 10 + 1 * YTAB + y2 );
g.drawString ( "Output Desired", 10 + 6 * XTAB, 10 + 2 * YTAB + y2 );
g.drawString ( decimalFormat.format ( output_desired ), 10 + 6 * XTAB, 10 + 3 * YTAB + y2 );
g.drawString ( "Output Error", 10 + 6 * XTAB, 10 + 4 * YTAB + y2 );
g.drawString ( decimalFormat.format ( output_error ), 10 + 6 * XTAB, 10 + 5 * YTAB + y2 );
g.drawString ( "Iterations", 10 + 7 * XTAB, 10 + 0 * YTAB + y2 );
g.drawString ( "" + total_samples, 10 + 7 * XTAB, 10 + 1 * YTAB + y2 );
g.drawString ( "Function", 10 + 7 * XTAB, 10 + 2 * YTAB + y2 );
g.drawString ( function_String, 10 + 7 * XTAB, 10 + 3 * YTAB + y2 );
g.drawString ( "L2 Gradient", 10 + 0 * XTAB, 0 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l2_local_gradient.data [ 0 ] [ 0 ] ), 10 + 0 * XTAB, 1 * YTAB + y4 );
g.drawString ( "W2 Deltas", 10 + 1 * XTAB, 0 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l2_weights_delta.data [ 0 ] [ 0 ] ), 10 + 1 * XTAB, 1 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l2_weights_delta.data [ 1 ] [ 0 ] ), 10 + 1 * XTAB, 2 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l2_weights_delta.data [ 2 ] [ 0 ] ), 10 + 1 * XTAB, 3 * YTAB + y4 );
g.drawString ( "W2 Momentum", 10 + 2 * XTAB, 0 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 2 * XTAB, 1 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 2 * XTAB, 2 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l2_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 2 * XTAB, 3 * YTAB + y4 );
g.drawString ( "Sum W Deltas", 10 + 3 * XTAB, 0 * YTAB + y4 );
g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 0 ] [ 0 ] ), 10 + 3 * XTAB, 1 * YTAB + y4 );
g.drawString ( decimalFormat.format ( sum_weighted_deltas.data [ 1 ] [ 0 ] ), 10 + 3 * XTAB, 2 * YTAB + y4 );
g.drawString ( "L1 Gradients", 10 + 4 * XTAB, 0 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 0 ] [ 0 ] ), 10 + 4 * XTAB, 1 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_local_gradients.data [ 1 ] [ 0 ] ), 10 + 4 * XTAB, 2 * YTAB + y4 );
g.drawString ( "W1 Deltas", 10 + 5 * XTAB, 0 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 0 ] [ 0 ] ), 10 + 5 * XTAB, 1 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 1 ] [ 0 ] ), 10 + 5 * XTAB, 2 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 2 ] [ 0 ] ), 10 + 5 * XTAB, 3 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 0 ] [ 1 ] ), 10 + 5 * XTAB, 4 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 1 ] [ 1 ] ), 10 + 5 * XTAB, 5 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_delta.data [ 2 ] [ 1 ] ), 10 + 5 * XTAB, 6 * YTAB + y4 );
g.drawString ( "W1 Momentum", 10 + 6 * XTAB, 0 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 0 ] ), 10 + 6 * XTAB, 1 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 0 ] ), 10 + 6 * XTAB, 2 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 0 ] ), 10 + 6 * XTAB, 3 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 0 ] [ 1 ] ), 10 + 6 * XTAB, 4 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 1 ] [ 1 ] ), 10 + 6 * XTAB, 5 * YTAB + y4 );
g.drawString ( decimalFormat.format ( l1_weights_momentum.data [ 2 ] [ 1 ] ), 10 + 6 * XTAB, 6 * YTAB + y4 );
}
public void run ( )
//////////////////////////////////////////////////////////////////////
{
try
{
long lastRepaintTime = 0;
while ( !pleaseStop )
{
if ( iteration == 0 )
{
long currentTime = System.currentTimeMillis ( );
if ( currentTime >= lastRepaintTime + REPAINT_PERIOD )
{
lastRepaintTime = currentTime;
repaint ( );
}
}
l2_weights = l2_weights.add ( l2_weights_delta );
// l2_weights = l2_weights.clip ( -1.0, 1.0 );
l1_weights = l1_weights.add ( l1_weights_delta );
// l1_weights = l1_weights.clip ( -1.0, 1.0 );
l0_activations = l0_activations.randomizeUniform ( 0.0, 1.0 );
l0_activations.data [ 0 ] [ 0 ] = 1.0;
samples [ iteration ] [ 0 ] = l0_activations.data [ 1 ] [ 0 ];
samples [ iteration ] [ 1 ] = l0_activations.data [ 2 ] [ 0 ];
l1_weighted_sums = Matrix.multiply (
l1_weights.transpose ( ), l0_activations );
l1_activations = l1_weighted_sums.sigmoid ( );
l2_inputs.data [ 0 ] [ 0 ] = 1.0;
l2_inputs.data [ 1 ] [ 0 ] = l1_activations.data [ 0 ] [ 0 ];
l2_inputs.data [ 2 ] [ 0 ] = l1_activations.data [ 1 ] [ 0 ];
l2_weighted_sums = Matrix.multiply (
l2_weights.transpose ( ), l2_inputs );
l2_activations = l2_weighted_sums.sigmoid ( );
samples [ iteration ] [ 2 ] = l2_activations.data [ 0 ] [ 0 ];
output_desired = target_function ( l0_activations );
output_error = output_desired - l2_activations.data [ 0 ] [ 0 ];
l2_local_gradient = l2_weighted_sums.sigmoidDerivative ( );
l2_local_gradient = l2_local_gradient.multiply ( output_error );
// I'm not sure about the transpose below here.
l2_weights_delta
= Matrix.multiply ( l2_inputs, l2_local_gradient.transpose ( ) );
l2_weights_delta = l2_weights_delta.multiply ( learning_rate );
l2_weights_delta = l2_weights_delta.add ( l2_weights_momentum );
l2_weights_momentum
= l2_weights_delta.multiply ( momentum_constant );
// needs transpose or sum below?
sum_weighted_deltas = l2_weights.submatrix ( 1, 2, 0, 0 );
sum_weighted_deltas = Matrix.multiply (
sum_weighted_deltas, l2_local_gradient );
l1_local_gradients = new Matrix (
l1_activations.rows, l1_activations.cols, 1.0 );
l1_local_gradients = l1_local_gradients.subtract ( l1_activations );
l1_local_gradients = Matrix.multiplyPairwise (
l1_activations, l1_local_gradients );
l1_local_gradients = Matrix.multiplyPairwise (
l1_local_gradients, sum_weighted_deltas );
l1_weights_delta = Matrix.multiply (
l0_activations, l1_local_gradients.transpose ( ) );
l1_weights_delta = l1_weights_delta.multiply ( learning_rate );
l1_weights_delta = l1_weights_delta.add ( l1_weights_momentum );
l1_weights_momentum
= l1_weights_delta.multiply ( momentum_constant );
squared_errors [ iteration ] = output_error * output_error;
iteration++;
total_samples++;
if ( iteration % ITERATIONS_PER_EPOCH == 0 )
{
epoch++;
if ( epoch >= epochs_max ) epoch = 1;
iteration = 0;
epoch_rms_error [ epoch - 1 ] = 0.0;
for ( int index_iteration = 0;
index_iteration < ITERATIONS_PER_EPOCH;
index_iteration++ )
{
epoch_rms_error [ epoch - 1 ]
+= squared_errors [ index_iteration ];
}
epoch_rms_error [ epoch - 1 ] /= ( double ) ITERATIONS_PER_EPOCH;
epoch_rms_error [ epoch - 1 ]
= Math.sqrt ( epoch_rms_error [ epoch - 1 ] );
}
Thread.sleep ( 0 );
}
repaint ( );
}
catch ( Exception ex )
{
ex.printStackTrace ( );
}
finally
{
runner = null;
}
}
//////////////////////////////////////////////////////////////////////
// private methods
//////////////////////////////////////////////////////////////////////
private void plot_epochs (
Rectangle r, Graphics g, double [ ] epochs )
//////////////////////////////////////////////////////////////////////
{
g.setColor ( java.awt.Color.black );
g.fillRect ( r.x, r.y, r.width, r.height );
g.setColor ( java.awt.Color.white );
g.drawRect ( r.x, r.y, r.width, r.height / 2 );
// g.clipRect ( r.x, r.y, r.width, r.height );
for ( int index_epoch = 1;
index_epoch <= epoch;
index_epoch++ )
{
PlotLib.xy ( java.awt.Color.red,
( double ) index_epoch, epochs [ index_epoch - 1 ],
r, g, 1.0, ( double ) epoch, 0.0, 1.0, OVAL_SIZE, true );
}
// g.clipRect ( 0, 0, SIZE.width, SIZE.height );
g.setColor ( java.awt.Color.white );
g.drawRect ( r.x, r.y, r.width, r.height );
g.setColor ( this.getForeground ( ) );
}
private void plotGraph ( Graphics g, double [ ] [ ] samples )
//////////////////////////////////////////////////////////////////////
{
g.setColor ( Color.black );
g.fillRect ( rs.x, rs.y, rs.width, rs.height );
for ( int index_iteration = 0;
index_iteration < ITERATIONS_PER_EPOCH;
index_iteration++ )
{
PlotLib.xy (
samples [ index_iteration ] [ 2 ] >= 0.5
? Color.green : Color.red,
samples [ index_iteration ] [ 0 ],
samples [ index_iteration ] [ 1 ],
rs, g, 0.0, 1.0, 0.0, 1.0, OVAL_SIZE );
}
g.setColor ( Color.white );
g.drawRect ( rs.x, rs.y, rs.width, rs.height );
g.setColor ( this.getForeground ( ) );
}
private void randomize_weights ( )
//////////////////////////////////////////////////////////////////////
{
l1_weights = l1_weights.randomizeUniform ( -1.0, 1.0 );
l2_weights = l2_weights.randomizeUniform ( -1.0, 1.0 );
}
private double target_function ( Matrix inputs )
//////////////////////////////////////////////////////////////////////
{
long a = Math.round ( inputs.data [ 1 ] [ 0 ] );
long b = Math.round ( inputs.data [ 2 ] [ 0 ] );
long bit_num = 2 * b + a;
long mask = 1 << bit_num;
long masked = function_selected & mask;
return ( masked == mask ) ? 1.0 : 0.0;
}
//////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -