📄 interpolate.cpp
字号:
// ********************************************************// This is demo code. You may derive from, use, modify, and// distribute it without limitation for any purpose.// Obviously you don't get a warranty or an assurance of// fitness for a particular purpose with this code. Your// welcome to remove this header and claim original// authorship. I really don't care.// ********************************************************#ifndef WIN32#include <unistd.h>#endif // !WIN32#include "Interpolate.h"#include "../GClasses/GMacros.h"#include "../GClasses/GTime.h"#include "../GClasses/GArff.h"#include "../GClasses/GNeuralNet.h"#include "../GClasses/GMath.h"#include "../GClasses/GTime.h"#include "../GClasses/GSearch.h"#include "../GClasses/GGreedySearch.h"#include "../GClasses/GParticleSwarm.h"#include "../GClasses/GConfSearch.h"#include "../GClasses/GGenetic.h"#include "../GClasses/GStabSearch.h"#define BACKGROUND_COLOR 0xffaaaa33InterpolateDialog::InterpolateDialog(InterpolateController* pController, int w, int h) : GWidgetDialog(w, h, BACKGROUND_COLOR){ m_pController = pController; GString sCancel(L"Cancel"); m_pCancelButton = new GWidgetTextButton(this, 25, 25, 100, 25, &sCancel); GString s; s.Copy(L"Backpropagation"); new GWidgetTextLabel(this, 200, 0, 200, 25, &s, 0xff442266); s.Copy(L"Stochastic Greedy"); new GWidgetTextLabel(this, 400, 0, 200, 25, &s, 0xff442266); s.Copy(L"Momentum Greedy"); new GWidgetTextLabel(this, 600, 0, 200, 25, &s, 0xff442266); s.Copy(L"Particle Swarm"); new GWidgetTextLabel(this, 200, 200, 200, 25, &s, 0xff442266); s.Copy(L"Evolutionary Search"); new GWidgetTextLabel(this, 400, 200, 200, 25, &s, 0xff442266); s.Copy(L"Stab Search"); new GWidgetTextLabel(this, 600, 200, 200, 25, &s, 0xff442266); s.Copy(L"(Be patient, this demo takes several hours.)"); new GWidgetTextLabel(this, 100, 500, 500, 25, &s, 0xff442266);}/*virtual*/ InterpolateDialog::~InterpolateDialog(){}/*virtual*/ void InterpolateDialog::OnReleaseTextButton(GWidgetTextButton* pButton){ if(pButton == m_pCancelButton) { m_pController->Quit(); }}// -------------------------------------------------------------------------------class NeuralNetTrainingCritic : public GRealVectorCritic{protected: GNeuralNet* m_pNN; GArffRelation* m_pRel; GArffData* m_pData;public: NeuralNetTrainingCritic(GNeuralNet* pNN, GArffData* pData) : GRealVectorCritic(pNN->GetWeightCount()) { m_pNN = pNN; m_pRel = pNN->GetRelation(); m_pData = pData; } virtual ~NeuralNetTrainingCritic() { }protected: virtual double ComputeError(double* pVector) { m_pNN->SetWeights(pVector); return m_pNN->MeasureMeanSquaredError(m_pData); }};// -------------------------------------------------------------------------------InterpolateView::InterpolateView(InterpolateController* pController): ViewBase(){ m_pDialog = new InterpolateDialog(pController, 790, 590); m_pImageIn = new GImage(); char szPath[256]; strcpy(szPath, ControllerBase::GetAppPath()); strcat(szPath, "w.png"); if(!m_pImageIn->LoadPNGFile(szPath)) GAssert(false, "failed to load input image"); int n; for(n = 0; n < SEARCH_ALGORITHM_COUNT; n++) { m_pImageOut[n] = new GImage(); m_pImageOut[n]->SetSize(m_pImageIn->GetWidth() * 8, m_pImageIn->GetHeight() * 8); } // Create the relation m_pRelation = new GArffRelation(); m_pRelation->AddAttribute(new GArffAttribute(true, 0, NULL)); // X position m_pRelation->AddAttribute(new GArffAttribute(true, 0, NULL)); // Y position m_pRelation->AddAttribute(new GArffAttribute(false, 0, NULL)); // Pixel value // Extract the training data from the image m_pTrainingData = new GArffData(m_pImageIn->GetHeight() * m_pImageIn->GetWidth()); int x, y; GColor c; for(y = 0; y < (int)m_pImageIn->GetHeight(); y++) { for(x = 0; x < (int)m_pImageIn->GetWidth(); x++) { c = m_pImageIn->GetPixel(x, y); double* pRow = new double[3]; pRow[0] = GMath::digitalToAnalog(x, m_pImageIn->GetWidth()); pRow[1] = GMath::digitalToAnalog(y, m_pImageIn->GetHeight()); pRow[2] = GMath::digitalToAnalog(gGreen(c), 256); m_pTrainingData->AddVector(pRow); } } m_pTrainingData->Shuffle(); // Backpropagation gets its own neural net m_pNN1 = new GNeuralNet(m_pRelation); m_pNN1->AddLayer(10); m_pNN1->AddLayer(10); m_pNN1->SetLearningRate(.3); m_pNN1->SetMomentum(.9); m_pNN1->TrainInit(m_pTrainingData, m_pTrainingData); // All the other search algorithms share this one m_pNN2 = new GNeuralNet(m_pRelation); m_pNN2->AddLayer(10); m_pNN2->AddLayer(10); m_pNN2->TrainInit(m_pTrainingData, m_pTrainingData); // Make the critics for(n = 1; n < SEARCH_ALGORITHM_COUNT; n++) m_pCritics[n] = new NeuralNetTrainingCritic(m_pNN2, m_pTrainingData); // Make the search algorithms m_pSearchAlgs[1] = new GStochasticGreedySearch(m_pCritics[1], -.5, 1); m_pSearchAlgs[2] = new GMomentumGreedySearch(m_pCritics[2]); m_pSearchAlgs[3] = new GParticleSwarm(m_pCritics[3], 80, -1, 2); m_pSearchAlgs[4] = new GEvolutionarySearch(m_pCritics[4], 30, 12); m_pSearchAlgs[5] = new GStabSearch(m_pCritics[5], -5, 10); // Init the algorithm cycle system for(n = 0; n < SEARCH_ALGORITHM_COUNT; n++) m_dRunningTime[n] = 0; m_nCurrentAlgorithm = SEARCH_ALGORITHM_COUNT - 1; m_dNextStopTime = 0; m_dThisStartTime = 0;}InterpolateView::~InterpolateView(){ delete(m_pDialog); delete(m_pImageIn); int n; for(n = 0; n < SEARCH_ALGORITHM_COUNT; n++) delete(m_pImageOut[n]); delete(m_pRelation); delete(m_pTrainingData); delete(m_pNN1); delete(m_pNN2); for(n = 1; n < SEARCH_ALGORITHM_COUNT; n++) { delete(m_pCritics[n]); delete(m_pSearchAlgs[n]); }}/*virtual*/ void InterpolateView::Draw(SDL_Surface *pScreen){ // Clear the screen SDL_FillRect(pScreen, NULL/*&r*/, 0x000000); // Draw the dialog GRect r; BlitImage(pScreen, m_screenRect.x, m_screenRect.y, m_pDialog->GetImage(&r)); // Draw the small image BlitImage(pScreen, m_screenRect.x + 150, m_screenRect.y + 25, m_pImageIn); // Draw the big interpolated images BlitImage(pScreen, m_screenRect.x + 200, m_screenRect.y + 25, m_pImageOut[0]); BlitImage(pScreen, m_screenRect.x + 400, m_screenRect.y + 25, m_pImageOut[1]); BlitImage(pScreen, m_screenRect.x + 600, m_screenRect.y + 25, m_pImageOut[2]); BlitImage(pScreen, m_screenRect.x + 200, m_screenRect.y + 225, m_pImageOut[3]); BlitImage(pScreen, m_screenRect.x + 400, m_screenRect.y + 225, m_pImageOut[4]); BlitImage(pScreen, m_screenRect.x + 600, m_screenRect.y + 225, m_pImageOut[5]);}void InterpolateView::OnChar(char c){ m_pDialog->HandleChar(c);}void InterpolateView::OnMouseDown(int x, int y){ x -= m_screenRect.x; y -= m_screenRect.y; GWidgetAtomic* pNewWidget = m_pDialog->FindAtomicWidget(x, y); m_pDialog->GrabWidget(pNewWidget, x, y);}void InterpolateView::OnMouseUp(int x, int y){ m_pDialog->ReleaseWidget();}bool InterpolateView::OnMousePos(int x, int y){ return m_pDialog->HandleMousePos(x - m_screenRect.x, y - m_screenRect.y);}void InterpolateView::UpdateImage(GNeuralNet* pNN, GImage* pImage){ double sample[3]; int width = (int)pImage->GetWidth(); int height = (int)pImage->GetHeight(); int x, y; for(y = 0; y < height; y++) { for(x = 0; x < width; x++) { sample[0] = GMath::digitalToAnalog((double)x / 8, m_pImageIn->GetWidth()); sample[1] = GMath::digitalToAnalog((double)y / 8, m_pImageIn->GetHeight()); pNN->Eval(sample); int n = GMath::analogToDigital(sample[2], 256); if(n < 0) n = 0; if(n > 255) n = 255; pImage->SetPixel(x, y, gRGB(n, n, n)); } }}bool InterpolateView::DoSomeTraining(){ // Cycle the algorithm double dTime; if(m_dThisStartTime == 0) { m_nCurrentAlgorithm++; if(m_nCurrentAlgorithm >= SEARCH_ALGORITHM_COUNT) { m_nCurrentAlgorithm = 0; m_dNextStopTime += 4; } m_dThisStartTime = GTime::GetTime(); dTime = 0; } else dTime = GTime::GetTime() - m_dThisStartTime; dTime += m_dRunningTime[m_nCurrentAlgorithm]; if(dTime >= m_dNextStopTime) { m_dRunningTime[m_nCurrentAlgorithm] = dTime; m_dThisStartTime = 0; if(m_nCurrentAlgorithm == 0) UpdateImage(m_pNN1, m_pImageOut[m_nCurrentAlgorithm]); else { m_pNN2->SetWeights(m_pCritics[m_nCurrentAlgorithm]->GetBestYet()); UpdateImage(m_pNN2, m_pImageOut[m_nCurrentAlgorithm]); } return true; } // Train if(m_nCurrentAlgorithm == 0) m_pNN1->TrainEpoch(); else m_pSearchAlgs[m_nCurrentAlgorithm]->Iterate(); return false;}// -------------------------------------------------------------------------------InterpolateController::InterpolateController(): ControllerBase(){ m_pView = new InterpolateView(this);}InterpolateController::~InterpolateController(){ delete(m_pView);}void InterpolateController::RunModal(){ double timeOld = GTime::GetTime(); double time; m_pView->Update(); while(m_bKeepRunning) { time = GTime::GetTime(); if(HandleEvents(time - timeOld) || ((InterpolateView*)m_pView)->DoSomeTraining()) { m_pView->Update(); }/* else#ifdef WIN32 Sleep(10);#else // WIN32 usleep(10);#endif // !WIN32*/ timeOld = time; }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -