⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 predacc.cpp

📁 一个非常有用的开源代码
💻 CPP
📖 第 1 页 / 共 2 页
字号:

	// Draw the dialog
	GRect r;
	BlitImage(pScreen, m_screenRect.x, m_screenRect.y, m_pDialog->GetImage(&r));
}

void PredAccView::OnChar(char c)
{
	m_pDialog->HandleChar(c);
}

void PredAccView::OnMouseDown(int x, int y)
{
	x -= m_screenRect.x;
	y -= m_screenRect.y;
	GWidgetAtomic* pNewWidget = m_pDialog->FindAtomicWidget(x, y);
	m_pDialog->GrabWidget(pNewWidget, x, y);
}

void PredAccView::OnMouseUp(int x, int y)
{
	m_pDialog->ReleaseWidget();
}

bool PredAccView::OnMousePos(int x, int y)
{
	return m_pDialog->HandleMousePos(x - m_screenRect.x, y - m_screenRect.y);
}

void PredAccView::SetTrainingSet(GArffRelation* pRelation, GArffData* pTrainingSet)
{
	m_pDialog->SetTrainingSet(pRelation, pTrainingSet);
}

void PredAccView::SetTestSet(GArffRelation* pRelation, GArffData* pTestSet)
{
	m_pDialog->SetTestSet(pRelation, pTestSet);
}





// -------------------------------------------------------------------------------



PredAccController::PredAccController()
: ControllerBase()
{
	m_pPredAccView = new PredAccView(this);
	m_pView = m_pPredAccView;

	m_pRelation = NULL;
	m_pTrainingSet = NULL;
	m_pTestSet = NULL;

	m_pLearner = NULL;
}

PredAccController::~PredAccController()
{
	delete(m_pLearner);
	delete(m_pRelation);
	delete(m_pTrainingSet);
	delete(m_pTestSet);
	delete(m_pPredAccView);
}

void PredAccController::RunModal()
{
	double timeOld = GTime::GetTime();
	double time;
	m_pView->Update();
	while(m_bKeepRunning)
	{
		time = GTime::GetTime();
		if(HandleEvents(time - timeOld)) // HandleEvents returns true if it thinks the view needs to be updated
		{
			m_pView->Update();
		}
		else
#ifdef WIN32
			Sleep(10);
#else // WIN32
			usleep(10);
#endif // !WIN32
		timeOld = time;
	}
}

void PredAccController::LoadTrainingSet(const char* szFilename)
{
	delete(m_pTrainingSet);
	m_pTrainingSet = NULL;
	int nLen;
	Holder<char*> hFile = GFile::LoadFileToBuffer(szFilename, &nLen);
	if(!hFile.Get())
		throw "File not found";
	GArffRelation* pRelation = GArffRelation::ParseFile(&m_pTrainingSet, hFile.Get(), nLen);
//m_pTrainingSet->RandomlyReplaceMissingData(pRelation);
	if(m_pRelation)
	{
		if(pRelation->GetAttributeCount() != m_pRelation->GetAttributeCount())
			throw "mismatch relations";
		delete(pRelation);
	}
	else
		m_pRelation = pRelation;
	m_pPredAccView->SetTrainingSet(m_pRelation, m_pTrainingSet);
}

void PredAccController::LoadTestSet(const char* szFilename)
{
	delete(m_pTestSet);
	m_pTestSet = NULL;
	int nLen;
	Holder<char*> hFile = GFile::LoadFileToBuffer(szFilename, &nLen);
	if(!hFile.Get())
		throw "File not found";
	GArffRelation* pRelation = GArffRelation::ParseFile(&m_pTestSet, hFile.Get(), nLen);
//m_pTestSet->RandomlyReplaceMissingData(pRelation);
	if(m_pRelation)
	{
		if(pRelation->GetAttributeCount() != m_pRelation->GetAttributeCount())
			throw "mismatch relations";
		delete(pRelation);
	}
	else
		m_pRelation = pRelation;
	m_pPredAccView->SetTestSet(m_pRelation, m_pTestSet);
}

void PredAccController::LoadAndSplitTrainingSet(const char* szFilename, double dTestPercent)
{
	delete(m_pTrainingSet);
	m_pTrainingSet = NULL;
	delete(m_pTestSet);
	m_pTestSet = NULL;
	int nLen;
	Holder<char*> hFile = GFile::LoadFileToBuffer(szFilename, &nLen);
	if(!hFile.Get())
		throw "File not found";
	if(m_pRelation)
		throw "data already loaded";
	m_pRelation = GArffRelation::ParseFile(&m_pTrainingSet, hFile.Get(), nLen);
	m_pTrainingSet->Shuffle();
//m_pTrainingSet->RandomlyReplaceMissingData(m_pRelation);
	m_pTestSet = m_pTrainingSet->SplitBySize((int)(dTestPercent * m_pTrainingSet->GetSize() / 100));
	m_pPredAccView->SetTrainingSet(m_pRelation, m_pTrainingSet);
	m_pPredAccView->SetTestSet(m_pRelation, m_pTestSet);
}

void PredAccController::ShuffleTrainingSet()
{
	if(m_pTrainingSet)
		m_pTrainingSet->Shuffle();
	m_pPredAccView->SetTrainingSet(m_pRelation, m_pTrainingSet);
}

void PredAccController::ShuffleTestSet()
{
	if(m_pTestSet)
		m_pTestSet->Shuffle();
	m_pPredAccView->SetTestSet(m_pRelation, m_pTestSet);
}

void PredAccController::TrainAndTestSingleSet(int nAlgorithm)
{
	GAssert(m_pTrainingSet, "no training set loaded");
	Train(nAlgorithm, m_pRelation, m_pTrainingSet);
	Test(nAlgorithm, m_pTrainingSet);
}

void PredAccController::TrainAndTest(int nAlgorithm)
{
	GAssert(m_pTrainingSet, "no training set loaded");
	GAssert(m_pTestSet, "no test set loaded");
	Train(nAlgorithm, m_pRelation, m_pTrainingSet);
	Test(nAlgorithm, m_pTestSet);
}

void PredAccController::DoNFoldCrossValidation(int nAlgorithm, int nParts)
{
	// Make the last attribute an output attribute--todo: design a better way
	m_pRelation->GetAttribute(m_pRelation->GetAttributeCount() - 1)->SetIsInput(false);
//	m_pRelation->GetAttribute(0)->SetIsInput(false);

	// Determine if it's a regression or classification problem
	bool bRegression = true;
	int i;
	for(i = 0; i < m_pRelation->GetOutputCount(); i++)
	{
		if(!m_pRelation->GetAttribute(m_pRelation->GetOutputIndex(i))->IsContinuous())
		{
			bRegression = false;
			break;
		}
	}

	// Split the data into parts
	GArffData** pSets = (GArffData**)alloca(sizeof(GArffData*) * nParts);
	int nSize = m_pTrainingSet->GetSize() / nParts + nParts;
	int n, j;
	for(n = 0; n < nParts; n++)
		pSets[n] = new GArffData(nSize);
	int nRowCount = m_pTrainingSet->GetSize();
	double* pRow;
	for(n = 0; n < nRowCount; n++)
	{
		pRow = m_pTrainingSet->GetVector(n);
		pSets[n % nParts]->AddVector(pRow);
	}

	// Do the training and testing
	double d;
	double dScore = 0;
	int nCorrect = 0;
	for(n = 0; n < nParts; n++)
	{
		// Merge all sets but one
		GArffData* pTrainer = new GArffData(m_pTrainingSet->GetSize());
		for(i = 0; i < nParts; i++)
		{
			if(i == n)
				continue;
			int nCount = pSets[i]->GetSize();
			for(j = 0; j < nCount; j++)
			{
				pRow = pSets[i]->GetVector(j);
				pTrainer->AddVector(pRow);
			}
		}

		// Make the learner ant train it
		GSupervisedLearner* pLearner = MakeLearner(nAlgorithm, m_pRelation);
		pLearner->Train(pTrainer);

		// Test it
		if(bRegression)
			d = pLearner->MeasureMeanSquaredError(pSets[n]);
		else
			d = pLearner->MeasurePredictiveAccuracy(pSets[n]);
		printf("Cross Validation Set %d/%d = %f\n", n, nParts, d);
		dScore += d;

		// Clean up
		delete(pLearner);
		pTrainer->DropAllVectors();
		delete(pTrainer);
	}
	dScore /= nParts;

	// Show results
	printf("\n\nFinal Cross Validation Results...\n");
	if(bRegression)
		printf("Average Mean Squared Error: %f\n", dScore);
	else
		printf("Average Predictive Accuracy: %f\n", dScore);

	// Clean up
	for(n = 0; n < nParts; n++)
	{
		pSets[n]->DropAllVectors();
		delete(pSets[n]);
	}
}

GSupervisedLearner* PredAccController::MakeLearner(int nAlgorithm, GArffRelation* pRelation)
{
	if(nAlgorithm == 0)
	{
		printf("Decision Tree...\n");
		GDecisionTree* pDecisionTree = new GDecisionTree(pRelation);
		return pDecisionTree;
	}
	else if(nAlgorithm == 1)
	{
		printf("Neural Net...\n");
		GNeuralNet* pNN = new GNeuralNet(pRelation);
		pNN->AddLayer(8);
		pNN->SetRunEpochs(1000);
		pNN->SetMaximumEpochs(5000);
		pNN->SetLearningRate(.3);
		pNN->SetMomentum(.9);
		return pNN;
	}
	else if(nAlgorithm == 2)
	{
		printf("Naive Bayes...\n");
		GNaiveBayes* pNaiveBayes = new GNaiveBayes(pRelation);
		return pNaiveBayes;
	}
	else if(nAlgorithm == 3)
	{
		printf("K-Nearest Neighbor...\n");
		GKNN* pKNN = new GKNN(pRelation, 5);
		return pKNN;
	}
	else if(nAlgorithm == 4)
	{
		printf("Pumped Neural Net...\n");
		GManifoldPumper* pPumper = new GManifoldPumper(pRelation, 1, 6, 18);
		GNeuralNet* pNN = new GNeuralNet(pPumper->GetRelation());
		pNN->AddLayer(8);
		pNN->SetRunEpochs(1000);
		pNN->SetMaximumEpochs(5000);
		pNN->SetLearningRate(.3);
		pNN->SetMomentum(.9);
		pPumper->SetLearner(pNN, true);
		return pPumper;
	}
	else if(nAlgorithm == 5)
	{
		printf("Pumped KNN...\n");
		GManifoldPumper* pPumper = new GManifoldPumper(pRelation, 1, 6, 18);
		GKNN* pKNN = new GKNN(pPumper->GetRelation(), 5);
		pPumper->SetLearner(pKNN, true);
		return pPumper;
	}
	else
	{
		GAssert(false, "unexpected algorithm");
		return NULL;
	}
}

void PredAccController::Train(int nAlgorithm, GArffRelation* pRelation, GArffData* pTrainingSet)
{
	// Make the last attribute an output attribute--todo: design a better way
	pRelation->GetAttribute(pRelation->GetAttributeCount() - 1)->SetIsInput(false);
//	pRelation->GetAttribute(0)->SetIsInput(false);

	// Make the learner
	m_pLearner = MakeLearner(nAlgorithm, pRelation);

	// Train it
	printf("Dataset name: %s\n", pRelation->GetName());
	printf("Training set size: %d\n", pTrainingSet->GetSize());
	printf("Training...\n");
	double dTimeStart = GTime::GetTime();
	m_pLearner->Train(pTrainingSet);
	printf("training time=%lf seconds\n", GTime::GetTime() - dTimeStart);
}

void PredAccController::Test(int nAlgorithm, GArffData* pTestSet)
{
	printf("\nTesting...\n");
	bool bGotDiscrete = false;
	bool bGotContinuous = false;
	int i;
	for(i = 0; i < m_pRelation->GetOutputCount(); i++)
	{
		if(m_pRelation->GetAttribute(m_pRelation->GetOutputIndex(i))->IsContinuous())
			bGotContinuous = true;
		else
			bGotDiscrete = true;
	}
	if(bGotDiscrete)
		printf("Predictive Accuracy = %f\n", m_pLearner->MeasurePredictiveAccuracy(pTestSet));
	if(bGotContinuous)
		printf("Mean Squared Error = %f\n", m_pLearner->MeasureMeanSquaredError(pTestSet));
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -