interpolate.cpp
来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 452 行
CPP
452 行
// --------------------------------------------------------
// This demo file is dedicated to the Public Domain. See:
// http://creativecommons.org/licenses/publicdomain
// --------------------------------------------------------
#ifndef WIN32
#include <unistd.h>
#endif // !WIN32
#include "Interpolate.h"
#include "../GClasses/GMacros.h"
#include "../GClasses/GTime.h"
#include "../GClasses/GArff.h"
#include "../GClasses/GNeuralNet.h"
#include "../GClasses/GMath.h"
#include "../GClasses/GTime.h"
#include "../GClasses/GSearch.h"
#include "../GClasses/GGreedySearch.h"
#include "../GClasses/GParticleSwarm.h"
#include "../GClasses/GConfSearch.h"
#include "../GClasses/GGenetic.h"
#include "../GClasses/GStabSearch.h"
#include "../GClasses/GKNN.h"
#include "../GClasses/GBag.h"
#include "../GClasses/GPCTree.h"
#define BACKGROUND_COLOR 0xffaaaa33
#define SEARCH_ALGORITHM_COUNT 6
class InterpolateDialog : public GWidgetDialog
{
protected:
InterpolateController* m_pController;
GWidgetImageButton* m_pButtonManual;
int m_nValidationTechnique;
double m_dTrainingPercent;
public:
InterpolateDialog(InterpolateController* pController, int w, int h);
virtual ~InterpolateDialog();
virtual void OnReleaseImageButton(GWidgetImageButton* pButton);
virtual void OnReleaseTextButton(GWidgetTextButton* pButton);
};
InterpolateDialog::InterpolateDialog(InterpolateController* pController, int w, int h)
: GWidgetDialog(w, h, BACKGROUND_COLOR)
{
m_pController = pController;
// Make the manual button
GImage* pManualImage = ControllerBase::GetManualImage();
m_pButtonManual = new GWidgetImageButton(this, w - pManualImage->GetWidth() / 2 - 5, 5, pManualImage);
new GWidgetTextLabel(this, 200, 10, 200, 20, "Backpropagation (on-line)", 0xff442266);
new GWidgetTextLabel(this, 400, 10, 200, 20, "Stochastic Greedy (Batch)", 0xff442266);
new GWidgetTextLabel(this, 600, 10, 200, 20, "Momentum Greedy (Batch)", 0xff442266);
new GWidgetTextLabel(this, 200, 210, 200, 20, "Particle Swarm (Batch)", 0xff442266);
new GWidgetTextLabel(this, 400, 210, 200, 20, "Evolutionary Search (Batch)", 0xff442266);
new GWidgetTextLabel(this, 600, 210, 200, 20, "Stab Search (Batch)", 0xff442266);
new GWidgetTextLabel(this, 5, 410, 200, 20, "KNN (k=4)", 0xff442266);
new GWidgetTextLabel(this, 200, 410, 200, 20, "Axis Aligned Forest (30 trees)", 0xff442266);
new GWidgetTextLabel(this, 400, 410, 200, 20, "Arbitrary Arboretum (30 trees)", 0xff442266);
new GWidgetTextLabel(this, 600, 410, 200, 20, "PC Forest (30 trees)", 0xff442266);
}
/*virtual*/ InterpolateDialog::~InterpolateDialog()
{
}
/*virtual*/ void InterpolateDialog::OnReleaseImageButton(GWidgetImageButton* pButton)
{
if(pButton == m_pButtonManual)
OpenAppFile("../doc/Waffles/Interpolate.html");
else
GAssert(false, "unrecognized image button");
}
/*virtual*/ void InterpolateDialog::OnReleaseTextButton(GWidgetTextButton* pButton)
{
}
// -------------------------------------------------------------------------------
//#define BATCH_SIZE 5
class NeuralNetTrainingCritic : public GRealVectorCritic
{
protected:
GNeuralNet* m_pNN;
GArffRelation* m_pRel;
GArffData* m_pData;
// GArffData* m_pDataBatch;
public:
NeuralNetTrainingCritic(GNeuralNet* pNN, GArffData* pData)
: GRealVectorCritic(pNN->GetWeightCount())
{
m_pNN = pNN;
m_pRel = pNN->GetRelation();
m_pData = pData;
// m_pDataBatch = new GArffData(BATCH_SIZE);
}
virtual ~NeuralNetTrainingCritic()
{
// m_pDataBatch->DropAllVectors();
// delete(m_pDataBatch);
}
protected:
virtual double ComputeError(double* pVector)
{
m_pNN->SetWeights(pVector);
// m_pDataBatch->DropAllVectors();
// int i;
// int nVectors = m_pData->GetSize();
// for(i = 0; i < BATCH_SIZE; i++)
// m_pDataBatch->AddVector(m_pData->GetVector(rand() % nVectors));
// return m_pNN->MeasureMeanSquaredError(m_pDataBatch);
return m_pNN->MeasureMeanSquaredError(m_pData);
}
};
// -------------------------------------------------------------------------------
class InterpolateView : public ViewBase
{
protected:
InterpolateDialog* m_pDialog;
GImage* m_pImageIn;
GImage* m_pImageOut[SEARCH_ALGORITHM_COUNT + 4];
GArffRelation* m_pRelation;
GArffData* m_pTrainingData;
GNeuralNet* m_pNN1;
GNeuralNet* m_pNN2;
NeuralNetTrainingCritic* m_pCritics[SEARCH_ALGORITHM_COUNT];
GRealVectorSearch* m_pSearchAlgs[SEARCH_ALGORITHM_COUNT];
int m_nCurrentAlgorithm;
double m_dRunningTime[SEARCH_ALGORITHM_COUNT];
double m_dThisStartTime;
double m_dNextStopTime;
public:
InterpolateView(InterpolateController* pController);
virtual ~InterpolateView();
virtual void OnChar(char c);
virtual void OnMouseDown(int nButton, int x, int y);
virtual void OnMouseUp(int nButton, int x, int y);
virtual bool OnMousePos(int x, int y);
bool DoSomeTraining();
protected:
virtual void Draw(SDL_Surface *pScreen);
void UpdateImage(GSupervisedLearner* pNN, GImage* pImage);
};
InterpolateView::InterpolateView(InterpolateController* pController)
: ViewBase()
{
m_pDialog = new InterpolateDialog(pController, m_screenRect.w, m_screenRect.h);
m_pImageIn = new GImage();
char szPath[256];
strcpy(szPath, ControllerBase::GetAppPath());
strcat(szPath, "w.png");
if(!m_pImageIn->LoadPNGFile(szPath))
GAssert(false, "failed to load input image");
int n;
for(n = 0; n < SEARCH_ALGORITHM_COUNT + 4; n++)
{
m_pImageOut[n] = new GImage();
m_pImageOut[n]->SetSize(m_pImageIn->GetWidth() * 8, m_pImageIn->GetHeight() * 8);
}
// Create the relation
m_pRelation = new GArffRelation();
m_pRelation->AddAttribute(new GArffAttribute(true, 0, NULL)); // X position
m_pRelation->AddAttribute(new GArffAttribute(true, 0, NULL)); // Y position
m_pRelation->AddAttribute(new GArffAttribute(false, 0, NULL)); // Pixel value
// Extract the training data from the image
m_pTrainingData = new GArffData(m_pImageIn->GetHeight() * m_pImageIn->GetWidth());
int x, y;
GColor c;
for(y = 0; y < (int)m_pImageIn->GetHeight(); y++)
{
for(x = 0; x < (int)m_pImageIn->GetWidth(); x++)
{
c = m_pImageIn->GetPixel(x, y);
double* pRow = new double[3];
pRow[0] = GMath::digitalToAnalog(x, m_pImageIn->GetWidth());
pRow[1] = GMath::digitalToAnalog(y, m_pImageIn->GetHeight());
pRow[2] = GMath::digitalToAnalog(gGreen(c), 256);
m_pTrainingData->AddVector(pRow);
}
}
m_pTrainingData->Shuffle();
// Backpropagation gets its own neural net
m_pNN1 = new GNeuralNet(m_pRelation);
m_pNN1->AddLayer(10);
m_pNN1->AddLayer(10);
m_pNN1->TrainInit(m_pTrainingData, m_pTrainingData);
// All the other search algorithms share this one
m_pNN2 = new GNeuralNet(m_pRelation);
m_pNN2->AddLayer(10);
m_pNN2->AddLayer(10);
m_pNN2->TrainInit(m_pTrainingData, m_pTrainingData);
// Make the critics
for(n = 1; n < SEARCH_ALGORITHM_COUNT; n++)
m_pCritics[n] = new NeuralNetTrainingCritic(m_pNN2, m_pTrainingData);
// Make the search algorithms
if(SEARCH_ALGORITHM_COUNT > 1)
m_pSearchAlgs[1] = new GStochasticGreedySearch(m_pCritics[1], -.5, 1);
if(SEARCH_ALGORITHM_COUNT > 2)
m_pSearchAlgs[2] = new GMomentumGreedySearch(m_pCritics[2]);
if(SEARCH_ALGORITHM_COUNT > 3)
m_pSearchAlgs[3] = new GParticleSwarm(m_pCritics[3], 80, -1, 2);
if(SEARCH_ALGORITHM_COUNT > 4)
m_pSearchAlgs[4] = new GEvolutionarySearch(m_pCritics[4], 30, 12);
if(SEARCH_ALGORITHM_COUNT > 5)
m_pSearchAlgs[5] = new GStabSearch(m_pCritics[5], -5, 10);
// Init the algorithm cycle system
for(n = 0; n < SEARCH_ALGORITHM_COUNT; n++)
m_dRunningTime[n] = 0;
m_nCurrentAlgorithm = SEARCH_ALGORITHM_COUNT - 1;
m_dNextStopTime = 0;
m_dThisStartTime = 0;
// Train the extra 4 fast algorithm images
{
GKNN learner(m_pRelation, 4, false);
learner.Train(m_pTrainingData);
UpdateImage(&learner, m_pImageOut[SEARCH_ALGORITHM_COUNT]);
}
{
GBag bag(m_pRelation, 30);
int i;
for(i = 0; i < 30; i++)
bag.AddLearner(new GArbitraryTree(m_pRelation, true));
bag.Train(m_pTrainingData);
UpdateImage(&bag, m_pImageOut[SEARCH_ALGORITHM_COUNT + 1]);
}
{
GBag bag(m_pRelation, 30);
int i;
for(i = 0; i < 30; i++)
bag.AddLearner(new GArbitraryTree(m_pRelation, false));
bag.Train(m_pTrainingData);
UpdateImage(&bag, m_pImageOut[SEARCH_ALGORITHM_COUNT + 2]);
}
{
GBag bag(m_pRelation, 30);
int i;
for(i = 0; i < 30; i++)
bag.AddLearner(new GPCTree(m_pRelation));
bag.Train(m_pTrainingData);
UpdateImage(&bag, m_pImageOut[SEARCH_ALGORITHM_COUNT + 3]);
}
}
InterpolateView::~InterpolateView()
{
delete(m_pDialog);
delete(m_pImageIn);
int n;
for(n = 0; n < SEARCH_ALGORITHM_COUNT + 4; n++)
delete(m_pImageOut[n]);
delete(m_pRelation);
delete(m_pTrainingData);
delete(m_pNN1);
delete(m_pNN2);
for(n = 1; n < SEARCH_ALGORITHM_COUNT; n++)
{
delete(m_pCritics[n]);
delete(m_pSearchAlgs[n]);
}
}
/*virtual*/ void InterpolateView::Draw(SDL_Surface *pScreen)
{
// Clear the screen
SDL_FillRect(pScreen, NULL/*&r*/, 0x000000);
// Draw the dialog
BlitImage(pScreen, m_screenRect.x, m_screenRect.y, m_pDialog->GetImage());
// Draw the small image
BlitImage(pScreen, m_screenRect.x + 150, m_screenRect.y + 25, m_pImageIn);
// Draw the big interpolated images
BlitImage(pScreen, m_screenRect.x + 200, m_screenRect.y + 25, m_pImageOut[0]);
if(SEARCH_ALGORITHM_COUNT > 1)
{
BlitImage(pScreen, m_screenRect.x + 400, m_screenRect.y + 25, m_pImageOut[1]);
BlitImage(pScreen, m_screenRect.x + 600, m_screenRect.y + 25, m_pImageOut[2]);
BlitImage(pScreen, m_screenRect.x + 200, m_screenRect.y + 225, m_pImageOut[3]);
BlitImage(pScreen, m_screenRect.x + 400, m_screenRect.y + 225, m_pImageOut[4]);
BlitImage(pScreen, m_screenRect.x + 600, m_screenRect.y + 225, m_pImageOut[5]);
BlitImage(pScreen, m_screenRect.x + 10, m_screenRect.y + 425, m_pImageOut[6]);
BlitImage(pScreen, m_screenRect.x + 200, m_screenRect.y + 425, m_pImageOut[7]);
BlitImage(pScreen, m_screenRect.x + 400, m_screenRect.y + 425, m_pImageOut[8]);
BlitImage(pScreen, m_screenRect.x + 600, m_screenRect.y + 425, m_pImageOut[9]);
}
}
void InterpolateView::OnChar(char c)
{
m_pDialog->HandleChar(c);
}
void InterpolateView::OnMouseDown(int nButton, int x, int y)
{
x -= m_screenRect.x;
y -= m_screenRect.y;
GWidgetAtomic* pNewWidget = m_pDialog->FindAtomicWidget(x, y);
m_pDialog->GrabWidget(pNewWidget, nButton, x, y);
}
void InterpolateView::OnMouseUp(int nButton, int x, int y)
{
m_pDialog->ReleaseWidget(nButton);
}
bool InterpolateView::OnMousePos(int x, int y)
{
return m_pDialog->HandleMousePos(x - m_screenRect.x, y - m_screenRect.y);
}
void InterpolateView::UpdateImage(GSupervisedLearner* pNN, GImage* pImage)
{
double sample[3];
int width = pImage->GetWidth();
int height = pImage->GetHeight();
int x, y;
for(y = 0; y < height; y++)
{
for(x = 0; x < width; x++)
{
sample[0] = GMath::digitalToAnalog((double)x / 8, m_pImageIn->GetWidth());
sample[1] = GMath::digitalToAnalog((double)y / 8, m_pImageIn->GetHeight());
pNN->Eval(sample);
int n = GMath::analogToDigital(sample[2], 256);
if(n < 0)
n = 0;
if(n > 255)
n = 255;
pImage->SetPixel(x, y, gRGB(n, n, n));
}
}
}
bool InterpolateView::DoSomeTraining()
{
// Cycle the algorithm
double dTime;
if(m_dThisStartTime == 0)
{
m_nCurrentAlgorithm++;
if(m_nCurrentAlgorithm >= SEARCH_ALGORITHM_COUNT)
{
m_nCurrentAlgorithm = 0;
m_dNextStopTime += 1;
}
m_dThisStartTime = GTime::GetTime();
dTime = 0;
}
else
dTime = GTime::GetTime() - m_dThisStartTime;
dTime += m_dRunningTime[m_nCurrentAlgorithm];
if(dTime >= m_dNextStopTime)
{
m_dRunningTime[m_nCurrentAlgorithm] = dTime;
m_dThisStartTime = 0;
if(m_nCurrentAlgorithm == 0)
UpdateImage(m_pNN1, m_pImageOut[m_nCurrentAlgorithm]);
else
{
m_pNN2->SetWeights(m_pCritics[m_nCurrentAlgorithm]->GetBestYet());
UpdateImage(m_pNN2, m_pImageOut[m_nCurrentAlgorithm]);
}
return true;
}
// Train
if(m_nCurrentAlgorithm == 0)
{
m_pNN1->TrainEpoch();
//printf("Epochs: %d, Error %f\n", m_pNN1->GetEpochs(), m_pNN1->GetError());
}
else
m_pSearchAlgs[m_nCurrentAlgorithm]->Iterate();
return false;
}
// -------------------------------------------------------------------------------
InterpolateController::InterpolateController()
: ControllerBase()
{
m_pView = new InterpolateView(this);
}
InterpolateController::~InterpolateController()
{
delete(m_pView);
}
void InterpolateController::RunModal()
{
double timeOld = GTime::GetTime();
double time;
m_pView->Update();
while(m_bKeepRunning)
{
time = GTime::GetTime();
if(HandleEvents(time - timeOld) ||
((InterpolateView*)m_pView)->DoSomeTraining())
{
m_pView->Update();
}
/* else
GThread::sleep(10);
*/
timeOld = time;
}
}
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?