predacc.cpp

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 752 行 · 第 1/2 页

CPP
752
字号
// --------------------------------------------------------
// This demo file is dedicated to the Public Domain. See:
// http://creativecommons.org/licenses/publicdomain
// --------------------------------------------------------

#include "PredAcc.h"
#ifdef WIN32
//#include <windows.h>
#else // WIN32
#include <unistd.h>
#endif // !WIN32
#include "../GClasses/GArff.h"
#include "../GClasses/GArray.h"
#include "../GClasses/GBag.h"
#include "../GClasses/GFile.h"
#include "../GClasses/GDecisionTree.h"
#include "../GClasses/GKNN.h"
#include "../GClasses/GMacros.h"
#include "../GClasses/GNeuralNet.h"
#include "../GClasses/GNaiveBayes.h"
#include "../GClasses/GPCTree.h"
#include "../GClasses/GTime.h"
#include "../GClasses/GThread.h"
#include "../GClasses/GManifold.h"



#define BACKGROUND_COLOR 0

PredAccDialog::PredAccDialog(PredAccController* pController, int w, int h)
 : GWidgetDialog(w, h, BACKGROUND_COLOR)
{
	m_pController = pController;

	// Make the manual button
	GImage* pManualImage = ControllerBase::GetManualImage();
	m_pButtonManual = new GWidgetImageButton(this, w - pManualImage->GetWidth() / 2 - 5, 5, pManualImage);

	m_pAlgorithmList = new GWidgetListBox(this, 5, 5, 200, 190);
	new GWidgetListBoxItem(m_pAlgorithmList, L"C3 Decision Tree");
	new GWidgetListBoxItem(m_pAlgorithmList, L"Neural Net (o-8-i)");
	new GWidgetListBoxItem(m_pAlgorithmList, L"Neural Net (o-4-4-i)");
	new GWidgetListBoxItem(m_pAlgorithmList, L"Neural Net (o-10-10-i)");
	new GWidgetListBoxItem(m_pAlgorithmList, L"Naive Bayes");
	new GWidgetListBoxItem(m_pAlgorithmList, L"k-Nearest Neighbor (k=2)");
	new GWidgetListBoxItem(m_pAlgorithmList, L"k-Nearest Neighbor (k=5)");
	new GWidgetListBoxItem(m_pAlgorithmList, L"k-Nearest Neighbor (k=13)");
	new GWidgetListBoxItem(m_pAlgorithmList, L"Axis Aligned Forest (100 trees)");
	new GWidgetListBoxItem(m_pAlgorithmList, L"Arbitrary Arboretum (100 trees)");
	new GWidgetListBoxItem(m_pAlgorithmList, L"PC Forest (100 trees)");
//	new GWidgetListBoxItem(m_pAlgorithmList, L"Pumped Neural Net");
//	new GWidgetListBoxItem(m_pAlgorithmList, L"Pumped KNN");
	m_pAlgorithmList->SetSelection(0);

	m_pTestTechniqueList = new GWidgetListBox(this, 220, 25, 350, 150);
	new GWidgetListBoxItem(m_pTestTechniqueList, L"Train and test on same data set");
	new GWidgetListBoxItem(m_pTestTechniqueList, L"Load training and test set from separate files");
	new GWidgetListBoxItem(m_pTestTechniqueList, L"Split one data set into training and test sets");
	new GWidgetListBoxItem(m_pTestTechniqueList, L"N-fold cross validation");
	m_pTestTechniqueList->SetSelection(0);

	m_pFileSystemBrowser1 = new GWidgetFileSystemBrowser(this, 25, 200, 350, 150, ".arff");

	m_pFileSystemBrowser2 = NULL;
	m_pTitle = NULL;
	m_pRelation = NULL;
	m_pTrainingData = NULL;
	m_pTestData = NULL;
	m_pShuffleButton1 = NULL;
	m_pShuffleButton2 = NULL;
	m_pBeginButton = NULL;
	m_pTextBox = NULL;
	m_pTextBoxLabel = NULL;

	m_nTestTechnique = -1;
	m_dTrainingPercent = 70;
}

/*virtual*/ PredAccDialog::~PredAccDialog()
{
}

/*virtual*/ void PredAccDialog::OnReleaseImageButton(GWidgetImageButton* pButton)
{
	if(pButton == m_pButtonManual)
		OpenAppFile("../doc/Waffles/PredAcc.html");
	else
		GAssert(false, "unrecognized image button");
}

void PredAccDialog::SetTrainingSet(GArffRelation* pRelation, GArffData* pTrainingSet)
{
	if(m_pFileSystemBrowser1)
	{
		delete(m_pFileSystemBrowser1);
		m_pFileSystemBrowser1 = NULL;
	}
	if(!m_pRelation)
		m_pRelation = pRelation;

	// Title String
	GString s;
	s.Copy(L"Title: ");
	s.Add(pRelation->GetName());
	s.Add(L"          Attributes: ");
	s.Add(pRelation->GetAttributeCount());
	if(!m_pTitle)
		m_pTitle = new GWidgetTextLabel(this, 25, 200, 350, 16, &s, 0xff8888ff);
	else
		m_pTitle->SetText(&s);

	// Training Set
	delete(m_pTrainingData);
	m_pTrainingData = new GWidgetRelation(pRelation, this, 25, 225, 350, 300);
	int n, i;
	if(pTrainingSet)
	{
		int nCount = pTrainingSet->GetSize();
		if(nCount > 100)
			nCount = 100;
		for(n = 0; n < nCount; n++)
		{
			m_pTrainingData->AddBlankRow();
			double* pRow = pTrainingSet->GetVector(n);
			GAssert(pRelation->GetAttributeCount() == m_pTrainingData->GetColumnCount(), "column count mismatch");
			for(i = 0; i < m_pTrainingData->GetColumnCount(); i++)
			{
				GArffAttribute* pAttr = pRelation->GetAttribute(i);
				if(pAttr->IsContinuous())
					s.Copy(pRow[i]);
				else
					s.Copy(pAttr->GetValue((int)pRow[i]));
				GWidgetTextLabel* pLabel = new GWidgetTextLabel(m_pTrainingData, 0, 0, 80, 20, &s, 0xff8888ff, 0xff003300);
				m_pTrainingData->SetWidget(i, n, pLabel);
			}
		}
	}

	// Shuffle Buttons
	if(!m_pShuffleButton1)
	{
		s.Copy(L"Shuffle");
		m_pShuffleButton1 = new GWidgetTextButton(this, 25, 530, 80, 20, &s);
	}

	// Begin Button
	if(!m_pBeginButton)
	{
		s.Copy(L"Begin");
		m_pBeginButton = new GWidgetTextButton(this, 300, 560, 80, 20, &s);
	}

	if(m_nTestTechnique < 0)
		m_nTestTechnique = m_pTestTechniqueList->GetSelection();
}

void PredAccDialog::SetTestSet(GArffRelation* pRelation, GArffData* pTestSet)
{
	if(m_pFileSystemBrowser2)
	{
		delete(m_pFileSystemBrowser2);
		m_pFileSystemBrowser2 = NULL;
	}
	if(!m_pRelation)
		m_pRelation = pRelation;

	// Test Set
	delete(m_pTestData);
	m_pTestData = new GWidgetRelation(m_pRelation, this, 400, 225, 350, 300);
	GString s;
	int n, i;
	if(pTestSet)
	{
		int nCount = pTestSet->GetSize();
		if(nCount > 100)
			nCount = 100;
		for(n = 0; n < nCount; n++)
		{
			m_pTestData->AddBlankRow();
			double* pRow = pTestSet->GetVector(n);
			for(i = 0; i < m_pRelation->GetAttributeCount(); i++)
			{
				GArffAttribute* pAttr = m_pRelation->GetAttribute(i);
				if(pAttr->IsContinuous())
					s.Copy(pRow[i]);
				else
					s.Copy(pAttr->GetValue((int)pRow[i]));
				GWidgetTextLabel* pLabel = new GWidgetTextLabel(m_pTestData, 0, 0, 80, 20, &s, 0xff8888ff, 0xff003300);
				m_pTestData->SetWidget(i, n, pLabel);
			}
		}
	}

	// Shuffle Buttons
	if(!m_pShuffleButton2)
	{
		s.Copy(L"Shuffle");
		m_pShuffleButton2 = new GWidgetTextButton(this, 400, 530, 80, 20, &s);
	}

	if(m_nTestTechnique < 0)
		m_nTestTechnique = m_pTestTechniqueList->GetSelection();
}

/*virtual*/ void PredAccDialog::OnSelectFilename(GWidgetFileSystemBrowser* pBrowser, const char* szFilename)
{
	if(pBrowser == m_pFileSystemBrowser1)
	{
		if(m_pTestTechniqueList->GetSelection() == 2)
		{
			GString* pString = m_pTextBox->GetText();
			char* szTmp = (char*)alloca(pString->GetLength() + 1);
			pString->GetAnsi(szTmp);
			m_dTrainingPercent = atof(szTmp);
			if(m_dTrainingPercent <= 0)
				m_dTrainingPercent = 50;
			delete(m_pTextBox);
			m_pTextBox = NULL;
			delete(m_pTextBoxLabel);
			m_pTextBoxLabel = NULL;
			m_pController->LoadAndSplitTrainingSet(szFilename, m_dTrainingPercent);
		}
		else
			m_pController->LoadTrainingSet(szFilename);
	}
	else if(pBrowser == m_pFileSystemBrowser2)
		m_pController->LoadTestSet(szFilename);
}

/*virtual*/ void PredAccDialog::OnChangeListSelection(GWidgetListBox* pListBox)
{
	if(pListBox == m_pAlgorithmList)
	{
	}
	else if(pListBox == m_pTestTechniqueList)
	{
		if(m_nTestTechnique >= 0)
		{
			pListBox->SetSelection(m_nTestTechnique);
			return;
		}
		else
		{
			int nIndex = pListBox->GetSelection();

			// Second file system browser
			if(nIndex == 1)
			{
				delete(m_pFileSystemBrowser2);
				m_pFileSystemBrowser2 = new GWidgetFileSystemBrowser(this, 400, 200, 350, 150, ".arff");
			}
			else
			{
				delete(m_pFileSystemBrowser2);
				m_pFileSystemBrowser2 = NULL;
			}

			// Text box
			if(nIndex < 2)
			{
				delete(m_pTextBox);
				m_pTextBox = NULL;
				delete(m_pTextBoxLabel);
				m_pTextBoxLabel = NULL;
			}
			else
			{
				delete(m_pTextBox);
				m_pTextBox = new GWidgetTextBox(this, 400, 215, 75, 20);
				delete(m_pTextBoxLabel);
				GString s;
				if(nIndex == 2)
				{
					m_pTextBox->SetText("70");
					s.Copy(L"Percent for training set:");
				}
				else if(nIndex == 3)
				{
					m_pTextBox->SetText("10");
					s.Copy(L"Number of folds:");
				}
				else
				{
					GAssert(false, "unexpected selection");
				}
				m_pTextBoxLabel = new GWidgetTextLabel(this, 400, 200, 200, 15, &s, 0xff8888ff);
			}
		}
	}
}

/*virtual*/ void PredAccDialog::OnReleaseTextButton(GWidgetTextButton* pButton)
{
	if(pButton == m_pShuffleButton1)
		m_pController->ShuffleTrainingSet();
	else if(pButton == m_pShuffleButton2)
		m_pController->ShuffleTestSet();
	else if(pButton == m_pBeginButton)
	{
		int nSel = m_pTestTechniqueList->GetSelection();
		if(nSel == 0)
			m_pController->TrainAndTestSingleSet(m_pAlgorithmList->GetSelection());
		else if(nSel == 1)
		{
			if(!m_pTestData)
				return;
			m_pController->TrainAndTest(m_pAlgorithmList->GetSelection());
		}
		else if(nSel == 2)
		{
			GAssert(m_pTestData, "no test data");
			m_pController->TrainAndTest(m_pAlgorithmList->GetSelection());
		}
		else if(nSel == 3)
		{
			GString* pString = m_pTextBox->GetText();
			char* szTmp = (char*)alloca(pString->GetLength() + 1);
			pString->GetAnsi(szTmp);
			int n = atoi(szTmp);
			if(n <= 0)
				n = 10;
			m_pController->DoNFoldCrossValidation(m_pAlgorithmList->GetSelection(), n);
		}
	}
}






// -------------------------------------------------------------------------------



PredAccView::PredAccView(PredAccController* pController)
: ViewBase()
{
	m_pDialog = new PredAccDialog(pController, m_screenRect.w, m_screenRect.h);
}

PredAccView::~PredAccView()
{
	delete(m_pDialog);
}

/*virtual*/ void PredAccView::Draw(SDL_Surface *pScreen)
{
	// Clear the screen
	SDL_FillRect(pScreen, NULL/*&r*/, 0x000000);

	// Draw the dialog
	BlitImage(pScreen, m_screenRect.x, m_screenRect.y, m_pDialog->GetImage());
}

void PredAccView::OnChar(char c)
{
	m_pDialog->HandleChar(c);
}

void PredAccView::OnMouseDown(int nButton, int x, int y)
{
	x -= m_screenRect.x;
	y -= m_screenRect.y;
	GWidgetAtomic* pNewWidget = m_pDialog->FindAtomicWidget(x, y);
	m_pDialog->GrabWidget(pNewWidget, nButton, x, y);
}

void PredAccView::OnMouseUp(int nButton, int x, int y)
{
	m_pDialog->ReleaseWidget(nButton);
}

bool PredAccView::OnMousePos(int x, int y)
{
	return m_pDialog->HandleMousePos(x - m_screenRect.x, y - m_screenRect.y);

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?