📄 predacc.cpp
字号:
// ********************************************************
// This is demo code. You may derive from, use, modify, and
// distribute it without limitation for any purpose.
// Obviously you don't get a warranty or an assurance of
// fitness for a particular purpose with this code. Your
// welcome to remove this header and claim original
// authorship. I really don't care.
// ********************************************************
#include "PredAcc.h"
#ifdef WIN32
//#include <windows.h>
#else // WIN32
#include <unistd.h>
#endif // !WIN32
#include "../GClasses/GTime.h"
#include "../GClasses/GMacros.h"
#include "../GClasses/GArff.h"
#include "../GClasses/GFile.h"
#include "../GClasses/GDecisionTree.h"
#include "../GClasses/GNeuralNet.h"
#include "../GClasses/GNaiveBayes.h"
#include "../GClasses/GKNN.h"
#include "../GClasses/GArray.h"
#include "../GClasses/GManifold.h"
#define BACKGROUND_COLOR 0
PredAccDialog::PredAccDialog(PredAccController* pController, int w, int h)
: GWidgetDialog(w, h, BACKGROUND_COLOR)
{
m_pController = pController;
m_pAlgorithmList = new GWidgetListBox(this, 25, 25, 150, 150);
new GWidgetListBoxItem(m_pAlgorithmList, L"Decision Tree");
new GWidgetListBoxItem(m_pAlgorithmList, L"Neural Net");
new GWidgetListBoxItem(m_pAlgorithmList, L"Naive Bayesian");
new GWidgetListBoxItem(m_pAlgorithmList, L"K-Nearest Neighbor");
// new GWidgetListBoxItem(m_pAlgorithmList, L"Pumped Neural Net");
// new GWidgetListBoxItem(m_pAlgorithmList, L"Pumped KNN");
m_pAlgorithmList->SetSelection(0);
m_pTestTechniqueList = new GWidgetListBox(this, 200, 25, 350, 150);
new GWidgetListBoxItem(m_pTestTechniqueList, L"Train and test on same data set");
new GWidgetListBoxItem(m_pTestTechniqueList, L"Load training and test set from separate files");
new GWidgetListBoxItem(m_pTestTechniqueList, L"Split one data set into training and test sets");
new GWidgetListBoxItem(m_pTestTechniqueList, L"N-fold cross validation");
m_pTestTechniqueList->SetSelection(0);
m_pFileSystemBrowser1 = new GWidgetFileSystemBrowser(this, 25, 200, 350, 150, ".arff");
m_pFileSystemBrowser1->Draw(this);
m_pFileSystemBrowser2 = NULL;
m_pTitle = NULL;
m_pRelation = NULL;
m_pTrainingData = NULL;
m_pTestData = NULL;
m_pShuffleButton1 = NULL;
m_pShuffleButton2 = NULL;
m_pBeginButton = NULL;
m_pTextBox = NULL;
m_pTextBoxLabel = NULL;
m_pTrainingRows = new GPointerArray(256);
m_pTestRows = new GPointerArray(256);
m_nTestTechnique = -1;
m_dTrainingPercent = 70;
}
/*virtual*/ PredAccDialog::~PredAccDialog()
{
delete(m_pTrainingRows);
delete(m_pTestRows);
}
void PredAccDialog::SetTrainingSet(GArffRelation* pRelation, GArffData* pTrainingSet)
{
if(m_pFileSystemBrowser1)
{
delete(m_pFileSystemBrowser1);
m_pFileSystemBrowser1 = NULL;
}
if(!m_pRelation)
m_pRelation = pRelation;
// Title String
GString s;
s.Copy(L"Title: ");
s.Add(pRelation->GetName());
s.Add(L" Attributes: ");
s.Add(pRelation->GetAttributeCount());
if(!m_pTitle)
m_pTitle = new GWidgetTextLabel(this, 25, 200, 350, 16, &s, 0xff8888ff);
else
m_pTitle->SetText(&s);
m_pTitle->Draw(this);
// Training Set
delete(m_pTrainingData);
m_pTrainingRows->Clear();
m_pTrainingData = new GWidgetGrid(this, m_pTrainingRows, pRelation->GetAttributeCount(), 25, 225, 350, 300);
int n, i;
for(n = 0; n < pRelation->GetAttributeCount(); n++)
{
GArffAttribute* pAttr = pRelation->GetAttribute(n);
s.Copy(pAttr->GetName());
GWidgetTextButton* pButton = new GWidgetTextButton(m_pTrainingData, 0, 0, 80, 20, &s);
m_pTrainingData->SetColumnHeader(n, pButton);
}
if(pTrainingSet)
{
int nCount = pTrainingSet->GetSize();
if(nCount > 100)
nCount = 100;
for(n = 0; n < nCount; n++)
{
m_pTrainingData->AddBlankRow();
double* pRow = pTrainingSet->GetVector(n);
GAssert(pRelation->GetAttributeCount() == m_pTrainingData->GetColumnCount(), "column count mismatch");
for(i = 0; i < m_pTrainingData->GetColumnCount(); i++)
{
GArffAttribute* pAttr = pRelation->GetAttribute(i);
if(pAttr->IsContinuous())
s.Copy(pRow[i]);
else
s.Copy(pAttr->GetValue((int)pRow[i]));
GWidgetTextLabel* pLabel = new GWidgetTextLabel(m_pTrainingData, 0, 0, 80, 20, &s, 0xff8888ff);
m_pTrainingData->SetWidget(i, n, pLabel);
}
}
}
m_pTrainingData->Draw(this);
// Shuffle Buttons
if(!m_pShuffleButton1)
{
s.Copy(L"Shuffle");
m_pShuffleButton1 = new GWidgetTextButton(this, 25, 530, 80, 20, &s);
m_pShuffleButton1->Draw(this);
}
// Begin Button
if(!m_pBeginButton)
{
s.Copy(L"Begin");
m_pBeginButton = new GWidgetTextButton(this, 300, 560, 80, 20, &s);
m_pBeginButton->Draw(this);
}
if(m_nTestTechnique < 0)
m_nTestTechnique = m_pTestTechniqueList->GetSelection();
}
void PredAccDialog::SetTestSet(GArffRelation* pRelation, GArffData* pTestSet)
{
if(m_pFileSystemBrowser2)
{
delete(m_pFileSystemBrowser2);
m_pFileSystemBrowser2 = NULL;
}
if(!m_pRelation)
m_pRelation = pRelation;
// Test Set
delete(m_pTestData);
m_pTestRows->Clear();
m_pTestData = new GWidgetGrid(this, m_pTestRows, m_pRelation->GetAttributeCount(), 400, 225, 350, 300);
GString s;
int n, i;
for(n = 0; n < m_pRelation->GetAttributeCount(); n++)
{
GArffAttribute* pAttr = m_pRelation->GetAttribute(n);
s.Copy(pAttr->GetName());
GWidgetTextButton* pButton = new GWidgetTextButton(m_pTestData, 0, 0, 80, 20, &s);
m_pTestData->SetColumnHeader(n, pButton);
}
if(pTestSet)
{
int nCount = pTestSet->GetSize();
if(nCount > 100)
nCount = 100;
for(n = 0; n < nCount; n++)
{
m_pTestData->AddBlankRow();
double* pRow = pTestSet->GetVector(n);
for(i = 0; i < m_pRelation->GetAttributeCount(); i++)
{
GArffAttribute* pAttr = m_pRelation->GetAttribute(i);
if(pAttr->IsContinuous())
s.Copy(pRow[i]);
else
s.Copy(pAttr->GetValue((int)pRow[i]));
GWidgetTextLabel* pLabel = new GWidgetTextLabel(m_pTestData, 0, 0, 80, 20, &s, 0xff8888ff);
m_pTestData->SetWidget(i, n, pLabel);
}
}
}
m_pTestData->Draw(this);
// Shuffle Buttons
if(!m_pShuffleButton2)
{
s.Copy(L"Shuffle");
m_pShuffleButton2 = new GWidgetTextButton(this, 400, 530, 80, 20, &s);
m_pShuffleButton2->Draw(this);
}
if(m_nTestTechnique < 0)
m_nTestTechnique = m_pTestTechniqueList->GetSelection();
}
/*virtual*/ void PredAccDialog::OnSelectFilename(GWidgetFileSystemBrowser* pBrowser, const char* szFilename)
{
if(pBrowser == m_pFileSystemBrowser1)
{
if(m_pTestTechniqueList->GetSelection() == 2)
{
GString* pString = m_pTextBox->GetText();
char* szTmp = (char*)alloca(pString->GetLength() + 1);
pString->GetAnsi(szTmp);
m_dTrainingPercent = atof(szTmp);
if(m_dTrainingPercent <= 0)
m_dTrainingPercent = 50;
delete(m_pTextBox);
m_pTextBox = NULL;
delete(m_pTextBoxLabel);
m_pTextBoxLabel = NULL;
m_pController->LoadAndSplitTrainingSet(szFilename, m_dTrainingPercent);
}
else
m_pController->LoadTrainingSet(szFilename);
}
else if(pBrowser == m_pFileSystemBrowser2)
m_pController->LoadTestSet(szFilename);
}
/*virtual*/ void PredAccDialog::OnChangeListSelection(GWidgetListBox* pListBox)
{
if(pListBox == m_pAlgorithmList)
{
}
else if(pListBox == m_pTestTechniqueList)
{
if(m_nTestTechnique >= 0)
{
pListBox->SetSelection(m_nTestTechnique);
pListBox->Draw(this);
return;
}
else
{
int nIndex = pListBox->GetSelection();
// Second file system browser
if(nIndex == 1)
{
delete(m_pFileSystemBrowser2);
m_pFileSystemBrowser2 = new GWidgetFileSystemBrowser(this, 400, 200, 350, 150, ".arff");
m_pFileSystemBrowser2->Draw(this);
}
else
{
delete(m_pFileSystemBrowser2);
m_pFileSystemBrowser2 = NULL;
}
// Text box
if(nIndex < 2)
{
delete(m_pTextBox);
m_pTextBox = NULL;
delete(m_pTextBoxLabel);
m_pTextBoxLabel = NULL;
}
else
{
delete(m_pTextBox);
m_pTextBox = new GWidgetTextBox(this, 400, 215, 75, 20);
m_pTextBox->Draw(this);
delete(m_pTextBoxLabel);
GString s;
if(nIndex == 2)
{
m_pTextBox->SetText("70");
s.Copy(L"Percent for training set:");
}
else if(nIndex == 3)
{
m_pTextBox->SetText("10");
s.Copy(L"Number of sections:");
}
else
{
GAssert(false, "unexpected selection");
}
m_pTextBoxLabel = new GWidgetTextLabel(this, 400, 200, 200, 15, &s, 0xff8888ff);
m_pTextBoxLabel->Draw(this);
}
}
}
}
/*virtual*/ void PredAccDialog::OnReleaseTextButton(GWidgetTextButton* pButton)
{
if(pButton == m_pShuffleButton1)
m_pController->ShuffleTrainingSet();
else if(pButton == m_pShuffleButton2)
m_pController->ShuffleTestSet();
else if(pButton == m_pBeginButton)
{
int nSel = m_pTestTechniqueList->GetSelection();
if(nSel == 0)
m_pController->TrainAndTestSingleSet(m_pAlgorithmList->GetSelection());
else if(nSel == 1)
{
if(!m_pTestData)
return;
m_pController->TrainAndTest(m_pAlgorithmList->GetSelection());
}
else if(nSel == 2)
{
GAssert(m_pTestData, "no test data");
m_pController->TrainAndTest(m_pAlgorithmList->GetSelection());
}
else if(nSel == 3)
{
GString* pString = m_pTextBox->GetText();
char* szTmp = (char*)alloca(pString->GetLength() + 1);
pString->GetAnsi(szTmp);
int n = atoi(szTmp);
if(n <= 0)
n = 10;
m_pController->DoNFoldCrossValidation(m_pAlgorithmList->GetSelection(), n);
}
}
}
// -------------------------------------------------------------------------------
PredAccView::PredAccView(PredAccController* pController)
: ViewBase()
{
SetScreenSize(800, 600);
m_pDialog = new PredAccDialog(pController, 790, 590);
}
PredAccView::~PredAccView()
{
delete(m_pDialog);
}
/*virtual*/ void PredAccView::Draw(SDL_Surface *pScreen)
{
// Clear the screen
SDL_FillRect(pScreen, NULL/*&r*/, 0x000000);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -