📄 predacc.cpp
字号:
// Draw the dialog
GRect r;
BlitImage(pScreen, m_screenRect.x, m_screenRect.y, m_pDialog->GetImage(&r));
}
void PredAccView::OnChar(char c)
{
m_pDialog->HandleChar(c);
}
void PredAccView::OnMouseDown(int x, int y)
{
x -= m_screenRect.x;
y -= m_screenRect.y;
GWidgetAtomic* pNewWidget = m_pDialog->FindAtomicWidget(x, y);
m_pDialog->GrabWidget(pNewWidget, x, y);
}
void PredAccView::OnMouseUp(int x, int y)
{
m_pDialog->ReleaseWidget();
}
bool PredAccView::OnMousePos(int x, int y)
{
return m_pDialog->HandleMousePos(x - m_screenRect.x, y - m_screenRect.y);
}
void PredAccView::SetTrainingSet(GArffRelation* pRelation, GArffData* pTrainingSet)
{
m_pDialog->SetTrainingSet(pRelation, pTrainingSet);
}
void PredAccView::SetTestSet(GArffRelation* pRelation, GArffData* pTestSet)
{
m_pDialog->SetTestSet(pRelation, pTestSet);
}
// -------------------------------------------------------------------------------
PredAccController::PredAccController()
: ControllerBase()
{
m_pPredAccView = new PredAccView(this);
m_pView = m_pPredAccView;
m_pRelation = NULL;
m_pTrainingSet = NULL;
m_pTestSet = NULL;
m_pLearner = NULL;
}
PredAccController::~PredAccController()
{
delete(m_pLearner);
delete(m_pRelation);
delete(m_pTrainingSet);
delete(m_pTestSet);
delete(m_pPredAccView);
}
void PredAccController::RunModal()
{
double timeOld = GTime::GetTime();
double time;
m_pView->Update();
while(m_bKeepRunning)
{
time = GTime::GetTime();
if(HandleEvents(time - timeOld)) // HandleEvents returns true if it thinks the view needs to be updated
{
m_pView->Update();
}
else
#ifdef WIN32
Sleep(10);
#else // WIN32
usleep(10);
#endif // !WIN32
timeOld = time;
}
}
void PredAccController::LoadTrainingSet(const char* szFilename)
{
delete(m_pTrainingSet);
m_pTrainingSet = NULL;
int nLen;
Holder<char*> hFile = GFile::LoadFileToBuffer(szFilename, &nLen);
if(!hFile.Get())
throw "File not found";
GArffRelation* pRelation = GArffRelation::ParseFile(&m_pTrainingSet, hFile.Get(), nLen);
//m_pTrainingSet->RandomlyReplaceMissingData(pRelation);
if(m_pRelation)
{
if(pRelation->GetAttributeCount() != m_pRelation->GetAttributeCount())
throw "mismatch relations";
delete(pRelation);
}
else
m_pRelation = pRelation;
m_pPredAccView->SetTrainingSet(m_pRelation, m_pTrainingSet);
}
void PredAccController::LoadTestSet(const char* szFilename)
{
delete(m_pTestSet);
m_pTestSet = NULL;
int nLen;
Holder<char*> hFile = GFile::LoadFileToBuffer(szFilename, &nLen);
if(!hFile.Get())
throw "File not found";
GArffRelation* pRelation = GArffRelation::ParseFile(&m_pTestSet, hFile.Get(), nLen);
//m_pTestSet->RandomlyReplaceMissingData(pRelation);
if(m_pRelation)
{
if(pRelation->GetAttributeCount() != m_pRelation->GetAttributeCount())
throw "mismatch relations";
delete(pRelation);
}
else
m_pRelation = pRelation;
m_pPredAccView->SetTestSet(m_pRelation, m_pTestSet);
}
void PredAccController::LoadAndSplitTrainingSet(const char* szFilename, double dTestPercent)
{
delete(m_pTrainingSet);
m_pTrainingSet = NULL;
delete(m_pTestSet);
m_pTestSet = NULL;
int nLen;
Holder<char*> hFile = GFile::LoadFileToBuffer(szFilename, &nLen);
if(!hFile.Get())
throw "File not found";
if(m_pRelation)
throw "data already loaded";
m_pRelation = GArffRelation::ParseFile(&m_pTrainingSet, hFile.Get(), nLen);
m_pTrainingSet->Shuffle();
//m_pTrainingSet->RandomlyReplaceMissingData(m_pRelation);
m_pTestSet = m_pTrainingSet->SplitBySize((int)(dTestPercent * m_pTrainingSet->GetSize() / 100));
m_pPredAccView->SetTrainingSet(m_pRelation, m_pTrainingSet);
m_pPredAccView->SetTestSet(m_pRelation, m_pTestSet);
}
void PredAccController::ShuffleTrainingSet()
{
if(m_pTrainingSet)
m_pTrainingSet->Shuffle();
m_pPredAccView->SetTrainingSet(m_pRelation, m_pTrainingSet);
}
void PredAccController::ShuffleTestSet()
{
if(m_pTestSet)
m_pTestSet->Shuffle();
m_pPredAccView->SetTestSet(m_pRelation, m_pTestSet);
}
void PredAccController::TrainAndTestSingleSet(int nAlgorithm)
{
GAssert(m_pTrainingSet, "no training set loaded");
Train(nAlgorithm, m_pRelation, m_pTrainingSet);
Test(nAlgorithm, m_pTrainingSet);
}
void PredAccController::TrainAndTest(int nAlgorithm)
{
GAssert(m_pTrainingSet, "no training set loaded");
GAssert(m_pTestSet, "no test set loaded");
Train(nAlgorithm, m_pRelation, m_pTrainingSet);
Test(nAlgorithm, m_pTestSet);
}
void PredAccController::DoNFoldCrossValidation(int nAlgorithm, int nParts)
{
// Make the last attribute an output attribute--todo: design a better way
m_pRelation->GetAttribute(m_pRelation->GetAttributeCount() - 1)->SetIsInput(false);
// m_pRelation->GetAttribute(0)->SetIsInput(false);
// Determine if it's a regression or classification problem
bool bRegression = true;
int i;
for(i = 0; i < m_pRelation->GetOutputCount(); i++)
{
if(!m_pRelation->GetAttribute(m_pRelation->GetOutputIndex(i))->IsContinuous())
{
bRegression = false;
break;
}
}
// Split the data into parts
GArffData** pSets = (GArffData**)alloca(sizeof(GArffData*) * nParts);
int nSize = m_pTrainingSet->GetSize() / nParts + nParts;
int n, j;
for(n = 0; n < nParts; n++)
pSets[n] = new GArffData(nSize);
int nRowCount = m_pTrainingSet->GetSize();
double* pRow;
for(n = 0; n < nRowCount; n++)
{
pRow = m_pTrainingSet->GetVector(n);
pSets[n % nParts]->AddVector(pRow);
}
// Do the training and testing
double d;
double dScore = 0;
int nCorrect = 0;
for(n = 0; n < nParts; n++)
{
// Merge all sets but one
GArffData* pTrainer = new GArffData(m_pTrainingSet->GetSize());
for(i = 0; i < nParts; i++)
{
if(i == n)
continue;
int nCount = pSets[i]->GetSize();
for(j = 0; j < nCount; j++)
{
pRow = pSets[i]->GetVector(j);
pTrainer->AddVector(pRow);
}
}
// Make the learner ant train it
GSupervisedLearner* pLearner = MakeLearner(nAlgorithm, m_pRelation);
pLearner->Train(pTrainer);
// Test it
if(bRegression)
d = pLearner->MeasureMeanSquaredError(pSets[n]);
else
d = pLearner->MeasurePredictiveAccuracy(pSets[n]);
printf("Cross Validation Set %d/%d = %f\n", n, nParts, d);
dScore += d;
// Clean up
delete(pLearner);
pTrainer->DropAllVectors();
delete(pTrainer);
}
dScore /= nParts;
// Show results
printf("\n\nFinal Cross Validation Results...\n");
if(bRegression)
printf("Average Mean Squared Error: %f\n", dScore);
else
printf("Average Predictive Accuracy: %f\n", dScore);
// Clean up
for(n = 0; n < nParts; n++)
{
pSets[n]->DropAllVectors();
delete(pSets[n]);
}
}
GSupervisedLearner* PredAccController::MakeLearner(int nAlgorithm, GArffRelation* pRelation)
{
if(nAlgorithm == 0)
{
printf("Decision Tree...\n");
GDecisionTree* pDecisionTree = new GDecisionTree(pRelation);
return pDecisionTree;
}
else if(nAlgorithm == 1)
{
printf("Neural Net...\n");
GNeuralNet* pNN = new GNeuralNet(pRelation);
pNN->AddLayer(8);
pNN->SetRunEpochs(1000);
pNN->SetMaximumEpochs(5000);
pNN->SetLearningRate(.3);
pNN->SetMomentum(.9);
return pNN;
}
else if(nAlgorithm == 2)
{
printf("Naive Bayes...\n");
GNaiveBayes* pNaiveBayes = new GNaiveBayes(pRelation);
return pNaiveBayes;
}
else if(nAlgorithm == 3)
{
printf("K-Nearest Neighbor...\n");
GKNN* pKNN = new GKNN(pRelation, 5);
return pKNN;
}
else if(nAlgorithm == 4)
{
printf("Pumped Neural Net...\n");
GManifoldPumper* pPumper = new GManifoldPumper(pRelation, 1, 6, 18);
GNeuralNet* pNN = new GNeuralNet(pPumper->GetRelation());
pNN->AddLayer(8);
pNN->SetRunEpochs(1000);
pNN->SetMaximumEpochs(5000);
pNN->SetLearningRate(.3);
pNN->SetMomentum(.9);
pPumper->SetLearner(pNN, true);
return pPumper;
}
else if(nAlgorithm == 5)
{
printf("Pumped KNN...\n");
GManifoldPumper* pPumper = new GManifoldPumper(pRelation, 1, 6, 18);
GKNN* pKNN = new GKNN(pPumper->GetRelation(), 5);
pPumper->SetLearner(pKNN, true);
return pPumper;
}
else
{
GAssert(false, "unexpected algorithm");
return NULL;
}
}
void PredAccController::Train(int nAlgorithm, GArffRelation* pRelation, GArffData* pTrainingSet)
{
// Make the last attribute an output attribute--todo: design a better way
pRelation->GetAttribute(pRelation->GetAttributeCount() - 1)->SetIsInput(false);
// pRelation->GetAttribute(0)->SetIsInput(false);
// Make the learner
m_pLearner = MakeLearner(nAlgorithm, pRelation);
// Train it
printf("Dataset name: %s\n", pRelation->GetName());
printf("Training set size: %d\n", pTrainingSet->GetSize());
printf("Training...\n");
double dTimeStart = GTime::GetTime();
m_pLearner->Train(pTrainingSet);
printf("training time=%lf seconds\n", GTime::GetTime() - dTimeStart);
}
void PredAccController::Test(int nAlgorithm, GArffData* pTestSet)
{
printf("\nTesting...\n");
bool bGotDiscrete = false;
bool bGotContinuous = false;
int i;
for(i = 0; i < m_pRelation->GetOutputCount(); i++)
{
if(m_pRelation->GetAttribute(m_pRelation->GetOutputIndex(i))->IsContinuous())
bGotContinuous = true;
else
bGotDiscrete = true;
}
if(bGotDiscrete)
printf("Predictive Accuracy = %f\n", m_pLearner->MeasurePredictiveAccuracy(pTestSet));
if(bGotContinuous)
printf("Mean Squared Error = %f\n", m_pLearner->MeasureMeanSquaredError(pTestSet));
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -