📄 main form's logic.cs
字号:
using System;
using System.IO;
using System.IO.IsolatedStorage;
using System.Text;
using XOR_ANN.ANN;
using XOR_ANN.DataStructures;
using System.Collections;
namespace XOR_ANN.GUI
{
/// <summary>
/// Summary description for XOR_Demo_NonGui.
/// </summary>
public class MainFormLogic
{
private XOR_ANN.DataStructures.ANN_Data ann_Data;
public bool train(out int pEpcohs, out int pSeconds)
{
XOR_ANN.ANN.ExecutionEngine annEngine = new XOR_ANN.ANN.ExecutionEngine(
ref this.ann_Data,
SingletonGlobalParameters.instance().LEARNING_RATE,
SingletonGlobalParameters.instance().MOMENTUM,
SingletonGlobalParameters.instance().OUTPUT_TOLERANCE,
SingletonGlobalParameters.instance().MAXIMUM_EPOCHS);
// Lets try to train this thing
DateTime startTime = DateTime.Now;
bool success = annEngine.train(out pEpcohs);
TimeSpan totaltime = DateTime.Now - startTime;
pSeconds = (totaltime.Minutes*60) + totaltime.Seconds;
return success;
}
public void trainAndSaveEpochData()
{
// Prevent message boxes from popping up
SingletonGlobalParameters.instance().BATCH_MODE = true;
// Make space for the output file
File.Delete("C:\\ANN_EpochData.csv");
StreamWriter fileWriter = new StreamWriter(@"C:\ANN_EpochData.csv");
// Write Column Headers
fileWriter.WriteLine("Epoch No, Converged % on Training Data, Best Guess % on Training Data,Converged % on Test Data,Best Guess % on Test Data");
fileWriter.Close();
// For all Epochs
for (int count=0; count<SingletonGlobalParameters.instance().MAXIMUM_EPOCHS;count++)
{
// Train for one Epoch
XOR_ANN.ANN.ExecutionEngine annEngine = new XOR_ANN.ANN.ExecutionEngine(
ref this.ann_Data,
SingletonGlobalParameters.instance().LEARNING_RATE,
SingletonGlobalParameters.instance().MOMENTUM,
SingletonGlobalParameters.instance().OUTPUT_TOLERANCE,
1);
int dummy = 0;
annEngine.train(out dummy);
// Test with test data (loads test samples)
float convergedOnTestData = 0;
float guessOnTestData = 0;
this.TestWeightsWithADataFile(@"C:\temp\16 by 16\last 800 Samples.dat",
out convergedOnTestData, out guessOnTestData);
// Test with training data (has the side-effect of reloading correct data)
float convergedOnTrainingData = 0;
float guessOnTrainingData = 0;
this.TestWeightsWithADataFile(@"C:\temp\16 by 16\first 4000 Samples.dat",
out convergedOnTrainingData, out guessOnTrainingData);
// Write results of Epoch to the file
fileWriter = new StreamWriter(@"C:\ANN_EpochData.csv",true); // append the file
fileWriter.WriteLine(count + "," + convergedOnTrainingData + "," + guessOnTrainingData +
"," + convergedOnTestData + "," + guessOnTestData);
fileWriter.Close();
}
SingletonGlobalParameters.instance().BATCH_MODE = false;
}
public string GetSampleDisplayData(int pSample, out string expectedResult)
{
if (pSample >= this.ann_Data.sampleCount)
{
System.Windows.Forms.MessageBox.Show ("**ERROR ** That sample is out of range");
expectedResult = "";
return "";
}
int INPUT_HEIGHT = XOR_ANN.ANN.SingletonGlobalParameters.instance().INPUT_HEIGHT;
int INPUT_WIDTH = XOR_ANN.ANN.SingletonGlobalParameters.instance().INPUT_WIDTH;
string output = "";
// INPUT_WIDTH * INPUT_HEIGHT array
for (int i=0;i<INPUT_HEIGHT;i++)
{
for (int j=0;j<INPUT_WIDTH;j++)
{
int index = i*INPUT_WIDTH+j;
output += this.ann_Data.sampleInput[pSample,index];
}
output += "\r\n";
}
output = output.Replace('0',' ');
output = output.Replace('1','*');
expectedResult = "";
for (int i=0;i<10;i++)
{
expectedResult += this.ann_Data.expectedOutput[pSample,i];
}
return output;
}
public string GetDisplayResult(int pSample)
{
float[] sampleInput = new float[ann_Data.inputNodes];
// get the input for the requested sample
for (int i=0; i<ann_Data.inputNodes; i++)
sampleInput[i] = ann_Data.sampleInput[pSample,i];
XOR_ANN.ANN.ExecutionEngine annEngine = new XOR_ANN.ANN.ExecutionEngine(
ref this.ann_Data,
SingletonGlobalParameters.instance().LEARNING_RATE,
SingletonGlobalParameters.instance().MOMENTUM,
SingletonGlobalParameters.instance().OUTPUT_TOLERANCE,
SingletonGlobalParameters.instance().MAXIMUM_EPOCHS);
float[] sampleOutput = annEngine.calculate(sampleInput);
// Make a string to display the output
string output = "";
for (int i=0; i<ann_Data.outputNodes; i++)
{
float timesTen = sampleOutput[i]*10F;
int displayedInteger = (int)(timesTen);
output += displayedInteger;
}
return output;
}
public void LoadSampleData(string pFileName)
{
float[,] newSampleInput;
float[,] newExpectedResults;
// Get the data from the file (parsing etc)
this.readDataFromFile(pFileName, out newSampleInput, out newExpectedResults);
if (newSampleInput != null)
{
// if ANN_Data is not initialise then create it now
if (this.ann_Data == null)
{
int hiddenNodes = XOR_ANN.ANN.SingletonGlobalParameters.instance().HIDDEN_NODES;
this.ann_Data = new XOR_ANN.DataStructures.ANN_Data(newSampleInput.GetLength(1),hiddenNodes,newExpectedResults.GetLength(1),newExpectedResults.GetLength(0));
}
else
{
this.ann_Data.sampleCount = newSampleInput.GetLength(0);
this.ann_Data.sampleInput = newSampleInput;
this.ann_Data.expectedOutput = newExpectedResults;
}
// Update ANN_Data with the data
this.ann_Data.sampleInput = newSampleInput;
this.ann_Data.expectedOutput = newExpectedResults;
// Check how samples were returned
if (!XOR_ANN.ANN.SingletonGlobalParameters.instance().BATCH_MODE)
System.Windows.Forms.MessageBox.Show("Read " + this.ann_Data.sampleCount + " samples");
}
}
private void readDataFromFile(string pFileName, out float[,] pSampleInput, out float[,] pExpectedOutput)
{
ArrayList sampleData = new ArrayList();
ArrayList targetOutput = new ArrayList();
int INPUT_HEIGHT = SingletonGlobalParameters.instance().INPUT_HEIGHT;
int INPUT_WIDTH = SingletonGlobalParameters.instance().INPUT_WIDTH;
int INPUT_COUNT = INPUT_HEIGHT*INPUT_WIDTH;
if (pFileName.Length <= 0)
{
// Get the user to choose the file
System.Windows.Forms.OpenFileDialog fileDialog = new System.Windows.Forms.OpenFileDialog();
fileDialog.Filter = "Neural Network Data files (*.dat)|*.dat";
fileDialog.InitialDirectory = @"C:\temp";
fileDialog.ShowDialog();
pFileName = fileDialog.FileName;
}
if (pFileName.Length > 0)
{
// Read in the whole file to memory
StreamReader fileReader = new StreamReader(pFileName);
bool endOfFile = false;
while (!endOfFile)
endOfFile = readOneSampleFromFile(fileReader, sampleData, targetOutput);
fileReader.Close(); // Close file
// Convert the ArrayLists to 2D arrays
int numberOfSamples = sampleData.Count;
pSampleInput = new float[numberOfSamples,INPUT_COUNT]; // initialises elements to zero
pExpectedOutput = new float[numberOfSamples,10]; // initialises elements to zero
// the expected output
for (int i=0; i<numberOfSamples; i++)
{
int expectedResult = (int)targetOutput[i];
pExpectedOutput[i,expectedResult] = 1; // All others will have been initialised to zero
}
// the sample data
for (int sample=0; sample<numberOfSamples; sample++)
{
string sampleLineOfData = sampleData[sample].ToString();
for (int inputNode=0; inputNode<INPUT_COUNT; inputNode++)
{
if (sampleLineOfData[inputNode] == '*')
pSampleInput[sample,inputNode] = 1;
}
}
}
else
{
pSampleInput = null;
pExpectedOutput = null;
}
}
private bool readOneSampleFromFile (StreamReader pFileReader, ArrayList pSampleData, ArrayList pTargetOutput)
{
int INPUT_HEIGHT = SingletonGlobalParameters.instance().INPUT_HEIGHT;
int INPUT_WIDTH = SingletonGlobalParameters.instance().INPUT_WIDTH;
bool endOfFile = false;
int expectedResult = Convert.ToInt32(pFileReader.ReadLine());
if (pFileReader.Peek() != -1) // Is there any more data?
{
// MessageBox.Show("Result = " + expectedResult);
// Read the INPUT_HEIGHT lines of the sample data into the
string sampleString = "";
// string displayString = "";
for (int i=0; i<INPUT_HEIGHT; i++)
{
// pad to INPUT_WIDTH characters
string readLine = pFileReader.ReadLine();
// MessageBox.Show(">>" + readLine + "<<");
readLine.PadRight(INPUT_WIDTH,' ');
sampleString += readLine;
// displayString += ">> " + readLine + "<<"+ "\r\n";
}
pSampleData.Add(sampleString);
pTargetOutput.Add(expectedResult);
// Clipboard.SetDataObject(displayString); MessageBox.Show(displayString);
}
else
endOfFile = true;
// Read the line number
return (endOfFile);
}
public void loadData()
{
XOR_ANN.GUI.FileCommands.loadSerializationData(ref this.ann_Data);
}
public void saveData(string pFile)
{
XOR_ANN.GUI.FileCommands.saveSerializationData(this.ann_Data, pFile);
}
public void ResetANNData()
{
this.ann_Data = null;
}
public void TestWeightsWithADataFile(string pFile, out float pConvergedPercentage, out float pBestguessPercentage)
{
// Make the use select a file to test with
this.LoadSampleData(pFile);
int correctCalculations = 0;
int correctGuesses = 0;
float[] sampleInput = new float[ann_Data.inputNodes];
string badSamples = "";
long totalCpuTicks = 0;
// Create an instance of the Execution Engine to calculate the results
XOR_ANN.ANN.ExecutionEngine annEngine = new XOR_ANN.ANN.ExecutionEngine(
ref this.ann_Data,
SingletonGlobalParameters.instance().LEARNING_RATE,
SingletonGlobalParameters.instance().MOMENTUM,
SingletonGlobalParameters.instance().OUTPUT_TOLERANCE,
SingletonGlobalParameters.instance().MAXIMUM_EPOCHS);
// Now test each sample in turn
for (int sample=0; sample<this.ann_Data.sampleCount; sample++)
{
// get the input for the requested sample
for (int i=0; i<ann_Data.inputNodes; i++)
sampleInput[i] = ann_Data.sampleInput[sample,i];
// Forward Propagate and time the calculation
long startTick = DateTime.Now.Ticks;
float[] sampleOutput = annEngine.calculate(sampleInput);
// How many Ticks did it take?
long endTick = DateTime.Now.Ticks;
totalCpuTicks += endTick - startTick;
// Pre-processing to make the the highest output 1 and the rest zero
float highestOutput = 0;
int indexForHighest = 0;
float[] guessCorrectOutput = new float[sampleOutput.Length];
for (int i=0; i<sampleOutput.Length; i++)
{
if (sampleOutput[i] > highestOutput)
{
highestOutput = sampleOutput[i];
indexForHighest = i;
}
}
// Now set to 0s and the 1
for (int i=0; i<guessCorrectOutput.Length; i++)
{
guessCorrectOutput[i] = 0;
if (i==indexForHighest)
guessCorrectOutput[i] = 1;
}
// Does the output match the expected output
bool match = true;
for (int i=0; i<ann_Data.outputNodes; i++)
match = match &
((Math.Abs(ann_Data.expectedOutput[sample,i] - sampleOutput[i])) <= SingletonGlobalParameters.instance().OUTPUT_TOLERANCE);
// Did it match for all output nodes?
if (match)
correctCalculations++;
// Does the output match the Guessed output
bool guessMatch = true;
for (int i=0; i<guessCorrectOutput.Length; i++)
guessMatch = guessMatch &
(guessCorrectOutput[i] == ann_Data.expectedOutput[sample,i]);
// Did it match for all output nodes?
if (guessMatch)
correctGuesses++;
else
badSamples += sample.ToString() + ",";
}
// Calcuate the % of corect samples
pConvergedPercentage = (correctCalculations * 100) / (float)this.ann_Data.sampleCount;
pBestguessPercentage = (correctGuesses * 100) / (float)this.ann_Data.sampleCount;
// Calculate the average time for calcluate the Forward Propagations
float averageTimeToCalculate = (totalCpuTicks / this.ann_Data.sampleCount);
averageTimeToCalculate = averageTimeToCalculate /10000; // Convert to us. 1ms = 1000us
// Convert to two decimal places
averageTimeToCalculate = (float)Decimal.Round(Convert.ToDecimal(averageTimeToCalculate),2);
// System.Windows.Forms.Clipboard.SetDataObject(badSamples);
if (!XOR_ANN.ANN.SingletonGlobalParameters.instance().BATCH_MODE)
System.Windows.Forms.MessageBox.Show("Converged Percentage % = " +
pConvergedPercentage + "\r\n" +
"Best Guess Percentage % = " + pBestguessPercentage + "\r\n" +
"Average Forward Propagation Time = " + averageTimeToCalculate + " microseconds");
}
private float[] calculate(float[] pInput)
{
XOR_ANN.ANN.ExecutionEngine annEngine = new XOR_ANN.ANN.ExecutionEngine(
ref this.ann_Data,
SingletonGlobalParameters.instance().LEARNING_RATE,
SingletonGlobalParameters.instance().MOMENTUM,
SingletonGlobalParameters.instance().OUTPUT_TOLERANCE,
SingletonGlobalParameters.instance().MAXIMUM_EPOCHS);
// Use the ANN to calculate the result for the given input
return (annEngine.calculate(pInput));
}
}
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -