interpolate.cpp

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 452 行

CPP
452
字号
// --------------------------------------------------------
// This demo file is dedicated to the Public Domain. See:
// http://creativecommons.org/licenses/publicdomain
// --------------------------------------------------------

#ifndef WIN32
#include <unistd.h>
#endif // !WIN32
#include "Interpolate.h"
#include "../GClasses/GMacros.h"
#include "../GClasses/GTime.h"
#include "../GClasses/GArff.h"
#include "../GClasses/GNeuralNet.h"
#include "../GClasses/GMath.h"
#include "../GClasses/GTime.h"
#include "../GClasses/GSearch.h"
#include "../GClasses/GGreedySearch.h"
#include "../GClasses/GParticleSwarm.h"
#include "../GClasses/GConfSearch.h"
#include "../GClasses/GGenetic.h"
#include "../GClasses/GStabSearch.h"
#include "../GClasses/GKNN.h"
#include "../GClasses/GBag.h"
#include "../GClasses/GPCTree.h"

#define BACKGROUND_COLOR 0xffaaaa33
#define SEARCH_ALGORITHM_COUNT 6

class InterpolateDialog : public GWidgetDialog
{
protected:
	InterpolateController* m_pController;
	GWidgetImageButton* m_pButtonManual;
	int m_nValidationTechnique;
	double m_dTrainingPercent;

public:
	InterpolateDialog(InterpolateController* pController, int w, int h);
	virtual ~InterpolateDialog();

	virtual void OnReleaseImageButton(GWidgetImageButton* pButton);
	virtual void OnReleaseTextButton(GWidgetTextButton* pButton);
};







InterpolateDialog::InterpolateDialog(InterpolateController* pController, int w, int h)
 : GWidgetDialog(w, h, BACKGROUND_COLOR)
{
	m_pController = pController;

	// Make the manual button
	GImage* pManualImage = ControllerBase::GetManualImage();
	m_pButtonManual = new GWidgetImageButton(this, w - pManualImage->GetWidth() / 2 - 5, 5, pManualImage);

	new GWidgetTextLabel(this, 200, 10, 200, 20, "Backpropagation (on-line)", 0xff442266);
	new GWidgetTextLabel(this, 400, 10, 200, 20, "Stochastic Greedy (Batch)", 0xff442266);
	new GWidgetTextLabel(this, 600, 10, 200, 20, "Momentum Greedy (Batch)", 0xff442266);
	new GWidgetTextLabel(this, 200, 210, 200, 20, "Particle Swarm (Batch)", 0xff442266);
	new GWidgetTextLabel(this, 400, 210, 200, 20, "Evolutionary Search (Batch)", 0xff442266);
	new GWidgetTextLabel(this, 600, 210, 200, 20, "Stab Search (Batch)", 0xff442266);
	new GWidgetTextLabel(this, 5, 410, 200, 20, "KNN (k=4)", 0xff442266);
	new GWidgetTextLabel(this, 200, 410, 200, 20, "Axis Aligned Forest (30 trees)", 0xff442266);
	new GWidgetTextLabel(this, 400, 410, 200, 20, "Arbitrary Arboretum (30 trees)", 0xff442266);
	new GWidgetTextLabel(this, 600, 410, 200, 20, "PC Forest (30 trees)", 0xff442266);
}

/*virtual*/ InterpolateDialog::~InterpolateDialog()
{
}

/*virtual*/ void InterpolateDialog::OnReleaseImageButton(GWidgetImageButton* pButton)
{
	if(pButton == m_pButtonManual)
		OpenAppFile("../doc/Waffles/Interpolate.html");
	else
		GAssert(false, "unrecognized image button");
}

/*virtual*/ void InterpolateDialog::OnReleaseTextButton(GWidgetTextButton* pButton)
{
}



// -------------------------------------------------------------------------------

//#define BATCH_SIZE 5

class NeuralNetTrainingCritic : public GRealVectorCritic
{
protected:
	GNeuralNet* m_pNN;
	GArffRelation* m_pRel;
	GArffData* m_pData;
//	GArffData* m_pDataBatch;

public:
	NeuralNetTrainingCritic(GNeuralNet* pNN, GArffData* pData)
	: GRealVectorCritic(pNN->GetWeightCount())
	{
		m_pNN = pNN;
		m_pRel = pNN->GetRelation();
		m_pData = pData;
//		m_pDataBatch = new GArffData(BATCH_SIZE);
	}

	virtual ~NeuralNetTrainingCritic()
	{
//		m_pDataBatch->DropAllVectors();
//		delete(m_pDataBatch);
	}

protected:
	virtual double ComputeError(double* pVector)
	{
		m_pNN->SetWeights(pVector);

//		m_pDataBatch->DropAllVectors();
//		int i;
//		int nVectors = m_pData->GetSize();
//		for(i = 0; i < BATCH_SIZE; i++)
//			m_pDataBatch->AddVector(m_pData->GetVector(rand() % nVectors));
//		return m_pNN->MeasureMeanSquaredError(m_pDataBatch);

		return m_pNN->MeasureMeanSquaredError(m_pData);
	}
};

// -------------------------------------------------------------------------------


class InterpolateView : public ViewBase
{
protected:
	InterpolateDialog* m_pDialog;
	GImage* m_pImageIn;
	GImage* m_pImageOut[SEARCH_ALGORITHM_COUNT + 4];
	GArffRelation* m_pRelation;
	GArffData* m_pTrainingData;
	GNeuralNet* m_pNN1;
	GNeuralNet* m_pNN2;
	NeuralNetTrainingCritic* m_pCritics[SEARCH_ALGORITHM_COUNT];
	GRealVectorSearch* m_pSearchAlgs[SEARCH_ALGORITHM_COUNT];
	int m_nCurrentAlgorithm;
	double m_dRunningTime[SEARCH_ALGORITHM_COUNT];
	double m_dThisStartTime;
	double m_dNextStopTime;

public:
	InterpolateView(InterpolateController* pController);
	virtual ~InterpolateView();

	virtual void OnChar(char c);
	virtual void OnMouseDown(int nButton, int x, int y);
	virtual void OnMouseUp(int nButton, int x, int y);
	virtual bool OnMousePos(int x, int y);
	bool DoSomeTraining();

protected:
	virtual void Draw(SDL_Surface *pScreen);
	void UpdateImage(GSupervisedLearner* pNN, GImage* pImage);
};





InterpolateView::InterpolateView(InterpolateController* pController)
: ViewBase()
{
	m_pDialog = new InterpolateDialog(pController, m_screenRect.w, m_screenRect.h);
	m_pImageIn = new GImage();
	char szPath[256];
	strcpy(szPath, ControllerBase::GetAppPath());
	strcat(szPath, "w.png");
	if(!m_pImageIn->LoadPNGFile(szPath))
		GAssert(false, "failed to load input image");
	int n;
	for(n = 0; n < SEARCH_ALGORITHM_COUNT + 4; n++)
	{
		m_pImageOut[n] = new GImage();
		m_pImageOut[n]->SetSize(m_pImageIn->GetWidth() * 8, m_pImageIn->GetHeight() * 8);
	}

	// Create the relation
	m_pRelation = new GArffRelation();
	m_pRelation->AddAttribute(new GArffAttribute(true, 0, NULL)); // X position
	m_pRelation->AddAttribute(new GArffAttribute(true, 0, NULL)); // Y position
	m_pRelation->AddAttribute(new GArffAttribute(false, 0, NULL)); // Pixel value

	// Extract the training data from the image
	m_pTrainingData = new GArffData(m_pImageIn->GetHeight() * m_pImageIn->GetWidth());
	int x, y;
	GColor c;
	for(y = 0; y < (int)m_pImageIn->GetHeight(); y++)
	{
		for(x = 0; x < (int)m_pImageIn->GetWidth(); x++)
		{
			c = m_pImageIn->GetPixel(x, y);
			double* pRow = new double[3];
			pRow[0] = GMath::digitalToAnalog(x, m_pImageIn->GetWidth());
			pRow[1] = GMath::digitalToAnalog(y, m_pImageIn->GetHeight());
			pRow[2] = GMath::digitalToAnalog(gGreen(c), 256);
			m_pTrainingData->AddVector(pRow);
		}
	}
	m_pTrainingData->Shuffle();

	// Backpropagation gets its own neural net
	m_pNN1 = new GNeuralNet(m_pRelation);
	m_pNN1->AddLayer(10);
	m_pNN1->AddLayer(10);
	m_pNN1->TrainInit(m_pTrainingData, m_pTrainingData);

	// All the other search algorithms share this one
	m_pNN2 = new GNeuralNet(m_pRelation);
	m_pNN2->AddLayer(10);
	m_pNN2->AddLayer(10);
	m_pNN2->TrainInit(m_pTrainingData, m_pTrainingData);

	// Make the critics
	for(n = 1; n < SEARCH_ALGORITHM_COUNT; n++)
		m_pCritics[n] = new NeuralNetTrainingCritic(m_pNN2, m_pTrainingData);

	// Make the search algorithms
	if(SEARCH_ALGORITHM_COUNT > 1)
		m_pSearchAlgs[1] = new GStochasticGreedySearch(m_pCritics[1], -.5, 1);
	if(SEARCH_ALGORITHM_COUNT > 2)
		m_pSearchAlgs[2] = new GMomentumGreedySearch(m_pCritics[2]);
	if(SEARCH_ALGORITHM_COUNT > 3)
		m_pSearchAlgs[3] = new GParticleSwarm(m_pCritics[3], 80, -1, 2);
	if(SEARCH_ALGORITHM_COUNT > 4)
		m_pSearchAlgs[4] = new GEvolutionarySearch(m_pCritics[4], 30, 12);
	if(SEARCH_ALGORITHM_COUNT > 5)
		m_pSearchAlgs[5] = new GStabSearch(m_pCritics[5], -5, 10);

	// Init the algorithm cycle system
	for(n = 0; n < SEARCH_ALGORITHM_COUNT; n++)
		m_dRunningTime[n] = 0;
	m_nCurrentAlgorithm = SEARCH_ALGORITHM_COUNT - 1;
	m_dNextStopTime = 0;
	m_dThisStartTime = 0;

	// Train the extra 4 fast algorithm images
	{
		GKNN learner(m_pRelation, 4, false);
		learner.Train(m_pTrainingData);
		UpdateImage(&learner, m_pImageOut[SEARCH_ALGORITHM_COUNT]);
	}
	{
		GBag bag(m_pRelation, 30);
		int i;
		for(i = 0; i < 30; i++)
			bag.AddLearner(new GArbitraryTree(m_pRelation, true));
		bag.Train(m_pTrainingData);
		UpdateImage(&bag, m_pImageOut[SEARCH_ALGORITHM_COUNT + 1]);
	}
	{
		GBag bag(m_pRelation, 30);
		int i;
		for(i = 0; i < 30; i++)
			bag.AddLearner(new GArbitraryTree(m_pRelation, false));
		bag.Train(m_pTrainingData);
		UpdateImage(&bag, m_pImageOut[SEARCH_ALGORITHM_COUNT + 2]);
	}
	{
		GBag bag(m_pRelation, 30);
		int i;
		for(i = 0; i < 30; i++)
			bag.AddLearner(new GPCTree(m_pRelation));
		bag.Train(m_pTrainingData);
		UpdateImage(&bag, m_pImageOut[SEARCH_ALGORITHM_COUNT + 3]);
	}
}

InterpolateView::~InterpolateView()
{
	delete(m_pDialog);
	delete(m_pImageIn);
	int n;
	for(n = 0; n < SEARCH_ALGORITHM_COUNT + 4; n++)
		delete(m_pImageOut[n]);
	delete(m_pRelation);
	delete(m_pTrainingData);
	delete(m_pNN1);
	delete(m_pNN2);
	for(n = 1; n < SEARCH_ALGORITHM_COUNT; n++)
	{
		delete(m_pCritics[n]);
		delete(m_pSearchAlgs[n]);
	}
}

/*virtual*/ void InterpolateView::Draw(SDL_Surface *pScreen)
{
	// Clear the screen
	SDL_FillRect(pScreen, NULL/*&r*/, 0x000000);

	// Draw the dialog
	BlitImage(pScreen, m_screenRect.x, m_screenRect.y, m_pDialog->GetImage());

	// Draw the small image
	BlitImage(pScreen, m_screenRect.x + 150, m_screenRect.y + 25, m_pImageIn);

	// Draw the big interpolated images
	BlitImage(pScreen, m_screenRect.x + 200, m_screenRect.y + 25, m_pImageOut[0]);
	if(SEARCH_ALGORITHM_COUNT > 1)
	{
		BlitImage(pScreen, m_screenRect.x + 400, m_screenRect.y + 25, m_pImageOut[1]);
		BlitImage(pScreen, m_screenRect.x + 600, m_screenRect.y + 25, m_pImageOut[2]);
		BlitImage(pScreen, m_screenRect.x + 200, m_screenRect.y + 225, m_pImageOut[3]);
		BlitImage(pScreen, m_screenRect.x + 400, m_screenRect.y + 225, m_pImageOut[4]);
		BlitImage(pScreen, m_screenRect.x + 600, m_screenRect.y + 225, m_pImageOut[5]);
		BlitImage(pScreen, m_screenRect.x + 10, m_screenRect.y + 425, m_pImageOut[6]);
		BlitImage(pScreen, m_screenRect.x + 200, m_screenRect.y + 425, m_pImageOut[7]);
		BlitImage(pScreen, m_screenRect.x + 400, m_screenRect.y + 425, m_pImageOut[8]);
		BlitImage(pScreen, m_screenRect.x + 600, m_screenRect.y + 425, m_pImageOut[9]);
	}
}

void InterpolateView::OnChar(char c)
{
	m_pDialog->HandleChar(c);
}

void InterpolateView::OnMouseDown(int nButton, int x, int y)
{
	x -= m_screenRect.x;
	y -= m_screenRect.y;
	GWidgetAtomic* pNewWidget = m_pDialog->FindAtomicWidget(x, y);
	m_pDialog->GrabWidget(pNewWidget, nButton, x, y);
}

void InterpolateView::OnMouseUp(int nButton, int x, int y)
{
	m_pDialog->ReleaseWidget(nButton);
}

bool InterpolateView::OnMousePos(int x, int y)
{
	return m_pDialog->HandleMousePos(x - m_screenRect.x, y - m_screenRect.y);
}

void InterpolateView::UpdateImage(GSupervisedLearner* pNN, GImage* pImage)
{
	double sample[3];
	int width = pImage->GetWidth();
	int height = pImage->GetHeight();
	int x, y;
	for(y = 0; y < height; y++)
	{
		for(x = 0; x < width; x++)
		{
			sample[0] = GMath::digitalToAnalog((double)x / 8, m_pImageIn->GetWidth());
			sample[1] = GMath::digitalToAnalog((double)y / 8, m_pImageIn->GetHeight());
			pNN->Eval(sample);
			int n = GMath::analogToDigital(sample[2], 256);
			if(n < 0)
				n = 0;
			if(n > 255)
				n = 255;
			pImage->SetPixel(x, y, gRGB(n, n, n));
		}
	}
}

bool InterpolateView::DoSomeTraining()
{
	// Cycle the algorithm
	double dTime;
	if(m_dThisStartTime == 0)
	{
		m_nCurrentAlgorithm++;
		if(m_nCurrentAlgorithm >= SEARCH_ALGORITHM_COUNT)
		{
			m_nCurrentAlgorithm = 0;
			m_dNextStopTime += 1;
		}
		m_dThisStartTime = GTime::GetTime();
		dTime = 0;
	}
	else
		dTime = GTime::GetTime() - m_dThisStartTime;
	dTime += m_dRunningTime[m_nCurrentAlgorithm];
	if(dTime >= m_dNextStopTime)
	{
		m_dRunningTime[m_nCurrentAlgorithm] = dTime;
		m_dThisStartTime = 0;
		if(m_nCurrentAlgorithm == 0)
			UpdateImage(m_pNN1, m_pImageOut[m_nCurrentAlgorithm]);
		else
		{
			m_pNN2->SetWeights(m_pCritics[m_nCurrentAlgorithm]->GetBestYet());
			UpdateImage(m_pNN2, m_pImageOut[m_nCurrentAlgorithm]);
		}
		return true;
	}

	// Train
	if(m_nCurrentAlgorithm == 0)
	{
		m_pNN1->TrainEpoch();
		//printf("Epochs: %d, Error %f\n", m_pNN1->GetEpochs(), m_pNN1->GetError());
	}
	else
		m_pSearchAlgs[m_nCurrentAlgorithm]->Iterate();

	return false;
}


// -------------------------------------------------------------------------------



InterpolateController::InterpolateController()
: ControllerBase()
{
	m_pView = new InterpolateView(this);
}

InterpolateController::~InterpolateController()
{
	delete(m_pView);
}

void InterpolateController::RunModal()
{
	double timeOld = GTime::GetTime();
	double time;
	m_pView->Update();
	while(m_bKeepRunning)
	{
		time = GTime::GetTime();
		if(HandleEvents(time - timeOld) ||
			((InterpolateView*)m_pView)->DoSomeTraining())
		{
			m_pView->Update();
		}
/*		else
			GThread::sleep(10);
*/
		timeOld = time;
	}
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?