predacc.cpp
来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 752 行 · 第 1/2 页
CPP
752 行
// --------------------------------------------------------
// This demo file is dedicated to the Public Domain. See:
// http://creativecommons.org/licenses/publicdomain
// --------------------------------------------------------
#include "PredAcc.h"
#ifdef WIN32
//#include <windows.h>
#else // WIN32
#include <unistd.h>
#endif // !WIN32
#include "../GClasses/GArff.h"
#include "../GClasses/GArray.h"
#include "../GClasses/GBag.h"
#include "../GClasses/GFile.h"
#include "../GClasses/GDecisionTree.h"
#include "../GClasses/GKNN.h"
#include "../GClasses/GMacros.h"
#include "../GClasses/GNeuralNet.h"
#include "../GClasses/GNaiveBayes.h"
#include "../GClasses/GPCTree.h"
#include "../GClasses/GTime.h"
#include "../GClasses/GThread.h"
#include "../GClasses/GManifold.h"
#define BACKGROUND_COLOR 0
PredAccDialog::PredAccDialog(PredAccController* pController, int w, int h)
: GWidgetDialog(w, h, BACKGROUND_COLOR)
{
m_pController = pController;
// Make the manual button
GImage* pManualImage = ControllerBase::GetManualImage();
m_pButtonManual = new GWidgetImageButton(this, w - pManualImage->GetWidth() / 2 - 5, 5, pManualImage);
m_pAlgorithmList = new GWidgetListBox(this, 5, 5, 200, 190);
new GWidgetListBoxItem(m_pAlgorithmList, L"C3 Decision Tree");
new GWidgetListBoxItem(m_pAlgorithmList, L"Neural Net (o-8-i)");
new GWidgetListBoxItem(m_pAlgorithmList, L"Neural Net (o-4-4-i)");
new GWidgetListBoxItem(m_pAlgorithmList, L"Neural Net (o-10-10-i)");
new GWidgetListBoxItem(m_pAlgorithmList, L"Naive Bayes");
new GWidgetListBoxItem(m_pAlgorithmList, L"k-Nearest Neighbor (k=2)");
new GWidgetListBoxItem(m_pAlgorithmList, L"k-Nearest Neighbor (k=5)");
new GWidgetListBoxItem(m_pAlgorithmList, L"k-Nearest Neighbor (k=13)");
new GWidgetListBoxItem(m_pAlgorithmList, L"Axis Aligned Forest (100 trees)");
new GWidgetListBoxItem(m_pAlgorithmList, L"Arbitrary Arboretum (100 trees)");
new GWidgetListBoxItem(m_pAlgorithmList, L"PC Forest (100 trees)");
// new GWidgetListBoxItem(m_pAlgorithmList, L"Pumped Neural Net");
// new GWidgetListBoxItem(m_pAlgorithmList, L"Pumped KNN");
m_pAlgorithmList->SetSelection(0);
m_pTestTechniqueList = new GWidgetListBox(this, 220, 25, 350, 150);
new GWidgetListBoxItem(m_pTestTechniqueList, L"Train and test on same data set");
new GWidgetListBoxItem(m_pTestTechniqueList, L"Load training and test set from separate files");
new GWidgetListBoxItem(m_pTestTechniqueList, L"Split one data set into training and test sets");
new GWidgetListBoxItem(m_pTestTechniqueList, L"N-fold cross validation");
m_pTestTechniqueList->SetSelection(0);
m_pFileSystemBrowser1 = new GWidgetFileSystemBrowser(this, 25, 200, 350, 150, ".arff");
m_pFileSystemBrowser2 = NULL;
m_pTitle = NULL;
m_pRelation = NULL;
m_pTrainingData = NULL;
m_pTestData = NULL;
m_pShuffleButton1 = NULL;
m_pShuffleButton2 = NULL;
m_pBeginButton = NULL;
m_pTextBox = NULL;
m_pTextBoxLabel = NULL;
m_nTestTechnique = -1;
m_dTrainingPercent = 70;
}
/*virtual*/ PredAccDialog::~PredAccDialog()
{
}
/*virtual*/ void PredAccDialog::OnReleaseImageButton(GWidgetImageButton* pButton)
{
if(pButton == m_pButtonManual)
OpenAppFile("../doc/Waffles/PredAcc.html");
else
GAssert(false, "unrecognized image button");
}
void PredAccDialog::SetTrainingSet(GArffRelation* pRelation, GArffData* pTrainingSet)
{
if(m_pFileSystemBrowser1)
{
delete(m_pFileSystemBrowser1);
m_pFileSystemBrowser1 = NULL;
}
if(!m_pRelation)
m_pRelation = pRelation;
// Title String
GString s;
s.Copy(L"Title: ");
s.Add(pRelation->GetName());
s.Add(L" Attributes: ");
s.Add(pRelation->GetAttributeCount());
if(!m_pTitle)
m_pTitle = new GWidgetTextLabel(this, 25, 200, 350, 16, &s, 0xff8888ff);
else
m_pTitle->SetText(&s);
// Training Set
delete(m_pTrainingData);
m_pTrainingData = new GWidgetRelation(pRelation, this, 25, 225, 350, 300);
int n, i;
if(pTrainingSet)
{
int nCount = pTrainingSet->GetSize();
if(nCount > 100)
nCount = 100;
for(n = 0; n < nCount; n++)
{
m_pTrainingData->AddBlankRow();
double* pRow = pTrainingSet->GetVector(n);
GAssert(pRelation->GetAttributeCount() == m_pTrainingData->GetColumnCount(), "column count mismatch");
for(i = 0; i < m_pTrainingData->GetColumnCount(); i++)
{
GArffAttribute* pAttr = pRelation->GetAttribute(i);
if(pAttr->IsContinuous())
s.Copy(pRow[i]);
else
s.Copy(pAttr->GetValue((int)pRow[i]));
GWidgetTextLabel* pLabel = new GWidgetTextLabel(m_pTrainingData, 0, 0, 80, 20, &s, 0xff8888ff, 0xff003300);
m_pTrainingData->SetWidget(i, n, pLabel);
}
}
}
// Shuffle Buttons
if(!m_pShuffleButton1)
{
s.Copy(L"Shuffle");
m_pShuffleButton1 = new GWidgetTextButton(this, 25, 530, 80, 20, &s);
}
// Begin Button
if(!m_pBeginButton)
{
s.Copy(L"Begin");
m_pBeginButton = new GWidgetTextButton(this, 300, 560, 80, 20, &s);
}
if(m_nTestTechnique < 0)
m_nTestTechnique = m_pTestTechniqueList->GetSelection();
}
void PredAccDialog::SetTestSet(GArffRelation* pRelation, GArffData* pTestSet)
{
if(m_pFileSystemBrowser2)
{
delete(m_pFileSystemBrowser2);
m_pFileSystemBrowser2 = NULL;
}
if(!m_pRelation)
m_pRelation = pRelation;
// Test Set
delete(m_pTestData);
m_pTestData = new GWidgetRelation(m_pRelation, this, 400, 225, 350, 300);
GString s;
int n, i;
if(pTestSet)
{
int nCount = pTestSet->GetSize();
if(nCount > 100)
nCount = 100;
for(n = 0; n < nCount; n++)
{
m_pTestData->AddBlankRow();
double* pRow = pTestSet->GetVector(n);
for(i = 0; i < m_pRelation->GetAttributeCount(); i++)
{
GArffAttribute* pAttr = m_pRelation->GetAttribute(i);
if(pAttr->IsContinuous())
s.Copy(pRow[i]);
else
s.Copy(pAttr->GetValue((int)pRow[i]));
GWidgetTextLabel* pLabel = new GWidgetTextLabel(m_pTestData, 0, 0, 80, 20, &s, 0xff8888ff, 0xff003300);
m_pTestData->SetWidget(i, n, pLabel);
}
}
}
// Shuffle Buttons
if(!m_pShuffleButton2)
{
s.Copy(L"Shuffle");
m_pShuffleButton2 = new GWidgetTextButton(this, 400, 530, 80, 20, &s);
}
if(m_nTestTechnique < 0)
m_nTestTechnique = m_pTestTechniqueList->GetSelection();
}
/*virtual*/ void PredAccDialog::OnSelectFilename(GWidgetFileSystemBrowser* pBrowser, const char* szFilename)
{
if(pBrowser == m_pFileSystemBrowser1)
{
if(m_pTestTechniqueList->GetSelection() == 2)
{
GString* pString = m_pTextBox->GetText();
char* szTmp = (char*)alloca(pString->GetLength() + 1);
pString->GetAnsi(szTmp);
m_dTrainingPercent = atof(szTmp);
if(m_dTrainingPercent <= 0)
m_dTrainingPercent = 50;
delete(m_pTextBox);
m_pTextBox = NULL;
delete(m_pTextBoxLabel);
m_pTextBoxLabel = NULL;
m_pController->LoadAndSplitTrainingSet(szFilename, m_dTrainingPercent);
}
else
m_pController->LoadTrainingSet(szFilename);
}
else if(pBrowser == m_pFileSystemBrowser2)
m_pController->LoadTestSet(szFilename);
}
/*virtual*/ void PredAccDialog::OnChangeListSelection(GWidgetListBox* pListBox)
{
if(pListBox == m_pAlgorithmList)
{
}
else if(pListBox == m_pTestTechniqueList)
{
if(m_nTestTechnique >= 0)
{
pListBox->SetSelection(m_nTestTechnique);
return;
}
else
{
int nIndex = pListBox->GetSelection();
// Second file system browser
if(nIndex == 1)
{
delete(m_pFileSystemBrowser2);
m_pFileSystemBrowser2 = new GWidgetFileSystemBrowser(this, 400, 200, 350, 150, ".arff");
}
else
{
delete(m_pFileSystemBrowser2);
m_pFileSystemBrowser2 = NULL;
}
// Text box
if(nIndex < 2)
{
delete(m_pTextBox);
m_pTextBox = NULL;
delete(m_pTextBoxLabel);
m_pTextBoxLabel = NULL;
}
else
{
delete(m_pTextBox);
m_pTextBox = new GWidgetTextBox(this, 400, 215, 75, 20);
delete(m_pTextBoxLabel);
GString s;
if(nIndex == 2)
{
m_pTextBox->SetText("70");
s.Copy(L"Percent for training set:");
}
else if(nIndex == 3)
{
m_pTextBox->SetText("10");
s.Copy(L"Number of folds:");
}
else
{
GAssert(false, "unexpected selection");
}
m_pTextBoxLabel = new GWidgetTextLabel(this, 400, 200, 200, 15, &s, 0xff8888ff);
}
}
}
}
/*virtual*/ void PredAccDialog::OnReleaseTextButton(GWidgetTextButton* pButton)
{
if(pButton == m_pShuffleButton1)
m_pController->ShuffleTrainingSet();
else if(pButton == m_pShuffleButton2)
m_pController->ShuffleTestSet();
else if(pButton == m_pBeginButton)
{
int nSel = m_pTestTechniqueList->GetSelection();
if(nSel == 0)
m_pController->TrainAndTestSingleSet(m_pAlgorithmList->GetSelection());
else if(nSel == 1)
{
if(!m_pTestData)
return;
m_pController->TrainAndTest(m_pAlgorithmList->GetSelection());
}
else if(nSel == 2)
{
GAssert(m_pTestData, "no test data");
m_pController->TrainAndTest(m_pAlgorithmList->GetSelection());
}
else if(nSel == 3)
{
GString* pString = m_pTextBox->GetText();
char* szTmp = (char*)alloca(pString->GetLength() + 1);
pString->GetAnsi(szTmp);
int n = atoi(szTmp);
if(n <= 0)
n = 10;
m_pController->DoNFoldCrossValidation(m_pAlgorithmList->GetSelection(), n);
}
}
}
// -------------------------------------------------------------------------------
PredAccView::PredAccView(PredAccController* pController)
: ViewBase()
{
m_pDialog = new PredAccDialog(pController, m_screenRect.w, m_screenRect.h);
}
PredAccView::~PredAccView()
{
delete(m_pDialog);
}
/*virtual*/ void PredAccView::Draw(SDL_Surface *pScreen)
{
// Clear the screen
SDL_FillRect(pScreen, NULL/*&r*/, 0x000000);
// Draw the dialog
BlitImage(pScreen, m_screenRect.x, m_screenRect.y, m_pDialog->GetImage());
}
void PredAccView::OnChar(char c)
{
m_pDialog->HandleChar(c);
}
void PredAccView::OnMouseDown(int nButton, int x, int y)
{
x -= m_screenRect.x;
y -= m_screenRect.y;
GWidgetAtomic* pNewWidget = m_pDialog->FindAtomicWidget(x, y);
m_pDialog->GrabWidget(pNewWidget, nButton, x, y);
}
void PredAccView::OnMouseUp(int nButton, int x, int y)
{
m_pDialog->ReleaseWidget(nButton);
}
bool PredAccView::OnMousePos(int x, int y)
{
return m_pDialog->HandleMousePos(x - m_screenRect.x, y - m_screenRect.y);
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?