📄 ga_trainer_xor.cs
字号:
using System;
using System.Collections.Generic;
using System.Text;
namespace GA_ANN_XOR
{
#region NN_Trainer_XOR CLASS
/// <summary>
/// Provides a GA trainer for a
/// <see cref="NeuralNetwork">NeuralNetwork</see> class
/// with 2 inputs, 2 hidden, and 1 output, which is trying
/// to approximate the XOR problem
/// </summary>
public class GA_Trainer_XOR
{
#region Instance fields
private Random gen = new Random(5);
private int training_times = 10000;
private double[,] train_set =
{{0, 0},
{0, 1},
{1,0},
{1,1}};
//population size
private int POPULATION = 15;
//ANN's
private NeuralNetwork[] networks;
//Mutation
private double MUTATION = 0.5;
//Recombination
private double RECOMBINE = 0.4;
//flag to detect when we hav found good ANN
private bool foundGoodANN = false;
//number of outputs
private int trainLoop = 0;
//best configuration index
private int bestConfiguration = -1;
//acceptable overall Neural Networ error
private double acceptableNNError = 0.1;
//events for gui, generated by the GA trainer
public delegate void GAChangeHandler(Object sender, TrainerEventArgs te);
public event GAChangeHandler GAChange;
public event EventHandler GATrainingDone;
//events for gui, generated by the NeuralNetwork, but propgated up to gui
//by the GA trainer, thats why they this event is here, the gui knows nothing
//about the array of NeuralNetworks, so the event must come through trainer
public delegate void ChangeHandler(Object sender, NeuralNetworkEventArgs nne);
public event ChangeHandler NNChange;
#endregion
#region Public Properties/Methods
/// <summary>
/// Performs a microbial GA (best of last breeding cycle stays in population)
/// on an array of <see cref="NeuralNetwork"> NeuralNetworks</see> in an attempt
/// to find a solution to the XOR logix problem. The training presents the entire
/// training set to a random pair of <see cref="NeuralNetwork"> NeuralNetworks,</see>
/// and evaluates which one does best. The winners genes, and some mutation are used
/// to shape the losers genes, in the hope that the new population will be moving
/// towards a closer solution.
/// </summary>
/// <param name="training_times">the number of times to carry out the
/// training loop</param>
/// <returns>The best <see cref="NeuralNetwork"> NeuralNetworks </see>
/// configuartion found</returns>
public NeuralNetwork doTraining(int training_times)
{
int a = 0;
int b = 0;
int WINNER = 0;
int LOSER = 0;
#region Training
//loop for the trainingPeriod
for (trainLoop = 0; trainLoop < training_times; trainLoop++)
{
//fire training loop event
TrainerEventArgs te = new TrainerEventArgs(trainLoop);
On_GAChange(te);
NeuralNetwork.isInTraining = true;
//if the previous evaluation cyle, found a good ANN configuration
//quit the traning cycle, otherwise, let the breeding continue
if (foundGoodANN)
{
break;
}
//pick 2 ANN's at random, GA - SELECTION
a = (int)(gen.NextDouble() * POPULATION);
b = (int)(gen.NextDouble() * POPULATION);
//work out which was the WINNER and LOSER, GA - EVALUATION
if (evaluate(a) < evaluate(b))
{
WINNER = a;
LOSER = b;
}
else
{
WINNER = b;
LOSER = a;
}
////get the current value of the ANN weights
double[,] WINNER_i_to_h_wts = networks[WINNER].InputToHiddenWeights;
double[,] LOSER_i_to_h_wts = networks[LOSER].InputToHiddenWeights;
double[,] WINNER_h_to_o_wts = networks[WINNER].HiddenToOutputWeights;
double[,] LOSER_h_to_o_wts = networks[LOSER].HiddenToOutputWeights;
////i_to_h_wts RECOMBINATION LOOP
for (int k = 0; k < networks[WINNER].NumberOfInputs + 1; k++)
{
for (int l = 0; l < networks[WINNER].NumberOfHidden; l++)
{
//get genes from winner randomly for i_to_h_wts wieghts
if (gen.NextDouble() < RECOMBINE)
{
// set the weights to be that of the input weights from GA
LOSER_i_to_h_wts[k,l] = WINNER_i_to_h_wts[k,l];
}
}
}
//h_to_o_wts RECOMBINATION LOOP
for (int k = 0; k < networks[WINNER].NumberOfHidden + 1; k++)
{
for (int l = 0; l < networks[WINNER].NumberOfOutputs; l++)
{
//get genes from winner randomly for i_to_h_wts wieghts
if (gen.NextDouble() < RECOMBINE)
{
// set the weights to be that of the input weights from GA
LOSER_h_to_o_wts[k,l] = WINNER_h_to_o_wts[k,l];
}
}
}
//i_to_h_wts MUTATION LOOP
for (int k = 0; k < networks[WINNER].NumberOfInputs + 1; k++)
{
for (int l = 0; l < networks[WINNER].NumberOfHidden; l++)
{
//add some mutation randomly
if (gen.NextDouble() < MUTATION)
{
LOSER_i_to_h_wts[k,l] += ((gen.NextDouble() * 0.2) - 0.1);
}
}
}
//h_to_o_wts MUTATION LOOP
for (int k = 0; k < networks[WINNER].NumberOfHidden + 1; k++)
{
for (int l = 0; l < networks[WINNER].NumberOfOutputs; l++)
{
//add some mutation randomly
if (gen.NextDouble() < MUTATION)
{
LOSER_h_to_o_wts[k,l] += ((gen.NextDouble() * 0.2) - 0.1);
}
}
}
//update the losers i_to_h_wts genotype
networks[LOSER].InputToHiddenWeights = LOSER_i_to_h_wts;
//update the losers i_to_h_wts genotype
networks[LOSER].HiddenToOutputWeights = LOSER_h_to_o_wts;
}
#endregion
//AT THIS POINT ITS EITHER THE END OF TRAINING OR WE HAVE
//FOUND AN ACCEPTABLE ANN, WHICH IS BELOW THE VALUE
//tell gui that training is now done
On_GATrainingDone(new EventArgs());
NeuralNetwork.isInTraining = false;
//check to see if there was a best configuration found, may not have done
//enough training to find a good NeuralNetwork configuration, so will simply
//have to return the WINNER
if (bestConfiguration == -1)
{
bestConfiguration = WINNER;
}
//return the best Neural network
return networks[bestConfiguration];
}
/// <summary>
/// Is called after the initial training is completed.
/// Sipmly presents 1 complete set of the training set to
/// the trained network, which should hopefully get it pretty
/// correct now its trained
/// </summary>
public void doActualRun()
{
//loop through the entire training set
for (int i = 0; i <= train_set.GetUpperBound(0); i++)
{
//forward these new values through network
//forward weights through ANN
forwardWeights(bestConfiguration, getTrainSet(i));
double[] targetValues = getTargetValues(getTrainSet(i));
}
}
#endregion
#region Constructor
/// <summary>
/// Constructs a new GA_Trainer_XOR. The constructor creates
/// the population of <see cref="NeuralNetwork">NeuralNetworks</see>
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -