chart.cpp

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 453 行

CPP
453
字号
// --------------------------------------------------------
// This demo file is dedicated to the Public Domain. See:
// http://creativecommons.org/licenses/publicdomain
// --------------------------------------------------------

#include "Chart.h"
#ifdef WIN32
#else // WIN32
#include <unistd.h>
#endif // !WIN32
#include "../GClasses/GArff.h"
#include "../GClasses/GBits.h"
#include "../GClasses/GFile.h"
#include "../GClasses/GTime.h"
#include "../GClasses/GMacros.h"
#include "../GClasses/GThread.h"
#include "../GClasses/GGreedySearch.h"
#include <math.h>

class GetFileDialog : public GWidgetDialog
{
protected:
	ChartController* m_pController;
	GWidgetImageButton* m_pButtonManual;
	GWidgetFileSystemBrowser* m_pFileSystemBrowser;

public:
	GetFileDialog(ChartController* pController, int w, int h)
	: GWidgetDialog(w, h, 0xff338822)
	{
		m_pController = pController;

		// Make the manual button
		GImage* pManualImage = ControllerBase::GetManualImage();
		m_pButtonManual = new GWidgetImageButton(this, w - pManualImage->GetWidth() / 2 - 5, 5, pManualImage);

		GString s;
		s.Copy(L"Please select a data file to chart:");
		new GWidgetTextLabel(this, 5, 50, w - 10, 20, &s, 0xffffffff, 0);
		m_pFileSystemBrowser = new GWidgetFileSystemBrowser(this, 5, 70, w - 10, h - 75, ".arff");
	}

	virtual ~GetFileDialog()
	{
	}

	virtual void OnReleaseImageButton(GWidgetImageButton* pButton)
	{
		if(pButton == m_pButtonManual)
			OpenAppFile("../doc/Waffles/Chart.html");
		else
			GAssert(false, "unrecognized image button");
	}

	virtual void OnSelectFilename(GWidgetFileSystemBrowser* pBrowser, const char* szFilename)
	{
		m_pController->OnSelectFile(szFilename);
	}
};


class ChartDialog : public GWidgetDialog
{
protected:
	ChartController* m_pController;
	GWidgetImageButton* m_pButtonManual;
	GWidgetRelation* m_pRelation;
	GWidgetTextButton* m_pLineGraphButton;
	GWidgetTextButton* m_pLogLineGraphButton;
	GWidgetTextButton* m_pComputeBigOButton;

public:
	ChartDialog(ChartController* pController, GArffRelation* pRelation, int w, int h)
	: GWidgetDialog(w, h, 0xff338822)
	{
		m_pController = pController;

		// Make the manual button
		GImage* pManualImage = ControllerBase::GetManualImage();
		m_pButtonManual = new GWidgetImageButton(this, w - pManualImage->GetWidth() / 2 - 5, 5, pManualImage);

		GString s;
		s.Copy(L"Select inputs and outputs:");
		new GWidgetTextLabel(this, 5, 35, w - 10, 20, &s, 0xffffffff, 0);
		m_pRelation = new GWidgetRelation(pRelation, this, 5, 55, w - 10, 150);
		s.Copy(L"Then choose a graph type:");
		new GWidgetTextLabel(this, 5, 230, w - 10, 20, &s, 0xffffffff, 0);
		s.Copy(L"Line Graph");
		m_pLineGraphButton = new GWidgetTextButton(this, 100, 400, 150, 24, &s);
		s.Copy(L"Log Line Graph");
		m_pLogLineGraphButton = new GWidgetTextButton(this, 100, 450, 150, 24, &s);
		s.Copy(L"Compute Big-O");
		m_pComputeBigOButton = new GWidgetTextButton(this, 300, 400, 150, 24, &s);
	}

	virtual ~ChartDialog()
	{
	}

	virtual void OnReleaseImageButton(GWidgetImageButton* pButton)
	{
		if(pButton == m_pButtonManual)
			OpenAppFile("doc/Waffles/Chart.html");
		else
			GAssert(false, "unrecognized image button");
	}

	virtual void OnReleaseTextButton(GWidgetTextButton* pButton)
	{
		if(pButton == m_pLineGraphButton)
			m_pController->MakeLineGraphChart(false);
		else if(pButton == m_pLogLineGraphButton)
			m_pController->MakeLineGraphChart(true);
		else if(pButton == m_pComputeBigOButton)
			m_pController->ComputeBigO();
	}
};


// ----------------------------------------------------------------------------

class ChartView : public ViewBase
{
friend class ChartController;
protected:
	bool m_bGotFile;
	GetFileDialog* m_pGetFileDialog;
	ChartDialog* m_pDialog;
	GImage* m_pImage;

public:
	ChartView(ChartController* pController)
	: ViewBase()
	{
		m_bGotFile = false;
		m_pGetFileDialog = new GetFileDialog(pController, m_screenRect.w, m_screenRect.h);
		m_pDialog = NULL;
		m_pImage = NULL;
	}

	~ChartView()
	{
		delete(m_pDialog);
		delete(m_pGetFileDialog);
	}

	void OnOpenFile(ChartController* pController, GArffRelation* pRelation)
	{
		delete(m_pDialog);
		m_pDialog = new ChartDialog(pController, pRelation, m_screenRect.w, m_screenRect.h);
		m_bGotFile = true;
	}

	void OnChar(char c)
	{
		GWidgetDialog* pDialog = (m_bGotFile ? (GWidgetDialog*)m_pDialog : (GWidgetDialog*)m_pGetFileDialog);
		pDialog->HandleChar(c);
	}

	void OnMouseDown(int nButton, int x, int y)
	{
		x -= m_screenRect.x;
		y -= m_screenRect.y;
		GWidgetDialog* pDialog = (m_bGotFile ? (GWidgetDialog*)m_pDialog : (GWidgetDialog*)m_pGetFileDialog);
		GWidgetAtomic* pNewWidget = pDialog->FindAtomicWidget(x, y);
		pDialog->GrabWidget(pNewWidget, nButton, x, y);
	}

	void OnMouseUp(int nButton, int x, int y)
	{
		GWidgetDialog* pDialog = (m_bGotFile ? (GWidgetDialog*)m_pDialog : (GWidgetDialog*)m_pGetFileDialog);
		pDialog->ReleaseWidget(nButton);
	}
	
	bool OnMousePos(int x, int y)
	{
		GWidgetDialog* pDialog = (m_bGotFile ? (GWidgetDialog*)m_pDialog : (GWidgetDialog*)m_pGetFileDialog);
		return pDialog->HandleMousePos(x - m_screenRect.x, y - m_screenRect.y);
	}
	
	void SetImage(GImage* pImage)
	{
		m_pImage = pImage;
	}

protected:
	virtual void Draw(SDL_Surface *pScreen)
	{
		GImage* pCanvas;
		if(m_bGotFile)
			pCanvas = m_pDialog->GetImage();
		else
			pCanvas = m_pGetFileDialog->GetImage();
		BlitImage(pScreen, m_screenRect.x, m_screenRect.y, pCanvas);
	}
};








// -------------------------------------------------------------------------------


ChartController::ChartController()
: ControllerBase()
{
	m_pView = new ChartView(this);
	m_pData = NULL;
	m_pRelation = NULL;
}

ChartController::~ChartController()
{
	delete(m_pView);
	delete(m_pRelation);
	delete(m_pData);
}

void ChartController::OnSelectFile(const char* szFilename)
{
	GArffRelation::LoadArffFile(&m_pRelation, &m_pData, szFilename);
	m_pRelation->GetAttribute(m_pRelation->GetAttributeCount() - 1)->SetIsInput(false);
	((ChartView*)m_pView)->OnOpenFile(this, m_pRelation);
}

void ChartController::RunModal()
{
	double timeOld = GTime::GetTime();
	double time;
	double timeUpdate = 0;
	m_pView->Update();
	while(m_bKeepRunning)
	{
		time = GTime::GetTime();
		if(HandleEvents(time - timeOld)) // HandleEvents returns true if it thinks the view needs to be updated
		{
			m_pView->Update();
			timeUpdate = time;
		}
		else
		{
			GThread::sleep(10);
		}
		timeOld = time;
	}
}

/*
void ChartController::Make3DBarGraph()
{
	
}
*/

void ChartController::MakeLineGraphChart(bool bLogarithmic)
{
	// Compute the size of the image
	int nBorderSize = 20;
	int nChartWidth = 800;
	int nChartHeight = 800;
	if(m_pRelation->GetOutputCount() <= 0)
	{
		GAssert(m_pRelation->GetOutputCount() > 0, "There are no output values to chart");
		throw "There are no output values to chart";
	}
	GImage image;
	image.SetSize(2 * nBorderSize + nChartWidth, nBorderSize + m_pRelation->GetInputCount() * (nChartHeight + nBorderSize));
	image.Clear(0xffaaaaaa);

	// Make a chart for each input attribute
	int nInput, nOutput, i;
	for(nInput = 0; nInput < m_pRelation->GetInputCount(); nInput++)
	{
		// Compute the chart ranges
		double dInputMin, dInputRange, dOutputMin, dOutputRange, dTmpMin, dTmpRange;
		m_pData->GetMinAndRange(m_pRelation->GetInputIndex(nInput), &dInputMin, &dInputRange);
		m_pData->GetMinAndRange(m_pRelation->GetOutputIndex(0), &dOutputMin, &dOutputRange);
		for(nOutput = 1; nOutput < m_pRelation->GetOutputCount(); nOutput++)
		{
			m_pData->GetMinAndRange(m_pRelation->GetOutputIndex(nOutput), &dTmpMin, &dTmpRange);
			if(dTmpMin < dOutputMin)
			{
				dOutputRange += (dOutputMin - dTmpMin);
				dOutputMin = dTmpMin;
			}
			if(dTmpRange > dOutputRange)
				dOutputRange = dTmpRange;
		}
		if(bLogarithmic)
		{
			dInputRange = log(dInputMin + dInputRange);
			dInputMin = log(dInputMin);
			dInputRange -= dInputMin;
			dOutputRange = log(dOutputMin + dOutputRange);
			dOutputMin = log(dOutputMin);
			dOutputRange -= dOutputMin;
		}

		// Clear the chart background
		int left = nBorderSize;
		int bottom = nBorderSize + nInput * (nChartHeight + nBorderSize) + nChartHeight - 1;
		image.FillBox(left, bottom - nChartHeight, nChartWidth, nChartHeight, 0xffffffff);

		// Plot the grid lines
		double x, y, gridbase;
		int xx, yy;
		if(bLogarithmic)
		{
			int n = GBits::RoundDown((float)(dInputMin / log((double)10)));
			bool bDone = false;
			while(!bDone)
			{
				gridbase = pow((double)10, n);
				for(i = 1; i < 10; i++)
				{
					x = log(gridbase * i);
					xx = left + (int)((x - dInputMin) * .9999999 * nChartWidth / dInputRange);
					if(xx < left)
						continue;
					if(xx >= left + nChartWidth)
					{
						bDone = true;
						break;
					}
					image.DrawLine(xx, bottom, xx, bottom - nChartHeight + 1, 0xff888888);
				}
				n++;
			}
			n = GBits::RoundDown((float)(dOutputMin / (double)log((double)10)));
			bDone = false;
			while(!bDone)
			{
				gridbase = pow((double)10, n);
				for(i = 1; i < 10; i++)
				{
					y = log(gridbase * i);
					yy = bottom - (int)((y - dOutputMin) * .9999999 * nChartHeight / dOutputRange);
					if(yy > bottom)
						continue;
					if(yy <= bottom - nChartHeight)
					{
						bDone = true;
						break;
					}
					image.DrawLine(left, yy, left + nChartWidth - 1, yy, 0xff888888);
				}
				n++;
			}
		}
		else
		{
			// todo: write me
		}

		// todo: Label the grid lines

		// Plot the data
		int xxPrev = 0, yyPrev = 0;
		for(nOutput = 0; nOutput < m_pRelation->GetOutputCount(); nOutput++)
		{
			GColor color = GetSpectrumColor((float)nOutput / m_pRelation->GetOutputCount() + (float).75);
			double* pVector;
			for(i = 0; i < m_pData->GetSize(); i++)
			{
				// Compute chart coordinates
				pVector = m_pData->GetVector(i);
				x = pVector[m_pRelation->GetInputIndex(nInput)];
				y = pVector[m_pRelation->GetOutputIndex(nOutput)];
				if(bLogarithmic)
				{
					x = log(x);
					y = log(y);
				}
				xx = left + (int)((x - dInputMin) * .9999999 * nChartWidth / dInputRange);
				yy = bottom - (int)((y - dOutputMin) * .9999999 * nChartHeight / dOutputRange);

				// Plot the point
				image.DrawCircle(xx, yy, 4, color);
				if(i > 0)
					image.DrawLine(xxPrev, yyPrev, xx, yy, color);

				// Store the previous point
				xxPrev = xx;
				yyPrev = yy;
			}
		}
	}

#ifdef WIN32
	image.SaveBMPFile("chart.bmp");
	OpenFile("chart.bmp");
#else // WIN32
	image.SavePNGFile("chart.png");
	OpenFile("chart.png");
#endif // !WIN32
}

class MyRegressCritic : public GArffDataRegressCritic
{
	public:
		MyRegressCritic(GArffData* pData, int nVariables, int nAttrX, int nAttrY)
	: GArffDataRegressCritic(pData, nVariables, nAttrX, nAttrY)
		{
		}

		virtual ~MyRegressCritic() {}

	protected:
		virtual double ApplyVariables(double dX, double* pVariables)
		{
			return pVariables[0] * pow(dX, pVariables[1]) + pVariables[2];
		}
};

// defined in Test.cpp
bool IsPrettyClose(double a, double b);

void ChartController::ComputeBigO()
{
int j;
for(j = 1; j < m_pRelation->GetAttributeCount(); j++)
	m_pRelation->GetAttribute(j)->SetIsInput(false);

	// Regress it
	MyRegressCritic critic(m_pData, /*numVars*/3, /*attrX*/m_pRelation->GetInputIndex(0), /*attrY*/m_pRelation->GetOutputIndex(0));
	GMomentumGreedySearch search(&critic);
	double* pVariables;
	double a = -1;
	double b = -1;
	double c = -1;
	int i;
	while(true)
	{
		for(i = 0; i < 50000; i++)
			search.Iterate();
		pVariables = critic.GetBestYet();
		printf("Crunching... (t=A(n^B)+c   A=%f  B=%f  C=%f)\n", a, b, c);
		if(ABS(pVariables[0] - a) < .0000001 &&
			ABS(pVariables[1] - b) < .0000001 &&
			ABS(pVariables[2] - c) < .0000001)
			break;
		a = pVariables[0];
		b = pVariables[1];
		c = pVariables[2];
	}
	printf("t=A(n^B)+c   A=%f  B=%f  C=%f\n", a, b, c);
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?