rank.cpp

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 320 行

CPP
320
字号
// --------------------------------------------------------
// This demo file is dedicated to the Public Domain. See:
// http://creativecommons.org/licenses/publicdomain
// --------------------------------------------------------

#include "Rank.h"
#include "../GClasses/GLearner.h"
#include "../GClasses/GArff.h"
#include "../GClasses/GArray.h"
#include "../GClasses/GBag.h"
#include "../GClasses/GDirList.h"
#include "../GClasses/GFile.h"
#include "../GClasses/GDecisionTree.h"
#include "../GClasses/GKNN.h"
#include "../GClasses/GMacros.h"
#include "../GClasses/GNeuralNet.h"
#include "../GClasses/GNaiveInstance.h"
#include "../GClasses/GPCTree.h"
#include "../GClasses/GTime.h"
#include "../GClasses/GManifold.h"
#include "../GClasses/GTime.h"
#ifdef WIN32
#include <direct.h>
#endif // WIN32

class RankView : public ViewBase
{
protected:
	GImage* m_pImage;

public:
	RankView(RankController* pController);
	virtual ~RankView();

protected:
	virtual void Draw(SDL_Surface *pScreen);
};



RankView::RankView(RankController* pController)
: ViewBase()
{
	m_pImage = new GImage();
	m_pImage->SetSize(400, 200);
	GRect r;
	r.Set(10, 85, 380, 30);
	m_pImage->DrawHardText(&r, "See console window.", 0xff0000ff, 1);
}

RankView::~RankView()
{
	delete(m_pImage);
}

/*virtual*/ void RankView::Draw(SDL_Surface *pScreen)
{
	// Clear the screen
	SDL_FillRect(pScreen, NULL/*&r*/, 0x000000);

	// Draw the dialog
	BlitImage(pScreen, m_screenRect.x + 200, m_screenRect.y + 200, m_pImage);
}


// -------------------------------------------------------------------------------

typedef GSupervisedLearner* (*LearningAlgMaker)(GArffRelation* pRelation);

GSupervisedLearner* MakeBaseline(GArffRelation* pRelation)
{
	return new GBaselineLearner(pRelation);
}

GSupervisedLearner* MakeKNN5(GArffRelation* pRelation)
{
	return new GKNN(pRelation, 5, false);
}

GSupervisedLearner* MakeKNN16(GArffRelation* pRelation)
{
	return new GKNN(pRelation, 16, false);
}

GSupervisedLearner* MakeAxisAlignedForest100(GArffRelation* pRelation)
{
	GBag* pBag = new GBag(pRelation, 100);
	int i;
	for(i = 0; i < 100; i++)
		pBag->AddLearner(new GArbitraryTree(pRelation, true));
	return pBag;
}

GSupervisedLearner* MakeArbitraryArboretum100(GArffRelation* pRelation)
{
	GBag* pBag = new GBag(pRelation, 100);
	int i;
	for(i = 0; i < 100; i++)
		pBag->AddLearner(new GArbitraryTree(pRelation, false));
	return pBag;
}

GSupervisedLearner* MakePCForest100(GArffRelation* pRelation)
{
	GBag* pBag = new GBag(pRelation, 100);
	int i;
	for(i = 0; i < 100; i++)
		pBag->AddLearner(new GPCTree(pRelation));
	return pBag;
}

GSupervisedLearner* MakeEntForest100(GArffRelation* pRelation)
{
	return new GSmartForest(pRelation, 100);
}

GSupervisedLearner* MakeEntForest10(GArffRelation* pRelation)
{
	return new GSmartForest(pRelation, 10);
}

GSupervisedLearner* MakeNaiveInstance10(GArffRelation* pRelation)
{
	return new GNaiveInstance(pRelation, 10);
}

GSupervisedLearner* MakeNeuralNet_O_5_I(GArffRelation* pRelation)
{
	GNeuralNet* pNN = new GNeuralNet(pRelation);
	pNN->AddLayer(5);
	return pNN;
}

GSupervisedLearner* MakeNeuralNet_O_6_10_I(GArffRelation* pRelation)
{
	GNeuralNet* pNN = new GNeuralNet(pRelation);
	pNN->AddLayer(6);
	pNN->AddLayer(10);
	return pNN;
}

// -------------------------------------------------------------------------------

class LearnerEntry
{
public:
	LearningAlgMaker m_pMakerFunc;
	const char* m_szTitle;

	GSupervisedLearner* Make(GArffRelation* pRelation) const
	{
		return m_pMakerFunc(pRelation);
	}
};

const LearnerEntry g_pAlgorithmTable[] =
{
	{MakeBaseline, "Baseline"},
	{MakeKNN5, "KNN-5"},
//	{MakeKNN16, "KNN-16"},
//	{MakeAxisAlignedForest100, "Axis Aligned Forest 100"},
//	{MakeArbitraryArboretum100, "Arbitrary Arboretum 100"},
//	{MakePCForest100, "PC Forest 100"},
//	{MakeEntForest100, "Ent Forest 100"},
	{MakeEntForest10, "Ent Forest 10"},
//	{MakeNaiveInstance10, "Naive Instance"},
//	{MakeNeuralNet_O_5_I, "Neural Net o-5-i"}, // warning: really slow
//	{MakeNeuralNet_O_6_10_I, "Neural Net o-6-10-i"}, // warning: really really slow
};

#define LEARNER_COUNT sizeof(g_pAlgorithmTable) / sizeof(LearnerEntry)

RankController::RankController()
: ControllerBase()
{
	m_pView = new RankView(this);
}

RankController::~RankController()
{
	delete(m_pView);
}

void RankController::RankLearners(bool bRegression, const char* szFilename, GArffRelation* pRelation, GArffData* pData, double* pScores)
{
/*
	// Print some stats to stderr
	fprintf(stderr, "Dataset: %s, %d Instances, %d Inputs, ", szFilename, pRelation->GetInputCount(), pData->GetSize());
	if(bRegression)
		fprintf(stderr, "Regression, ");
	else
		fprintf(stderr, "Classification (%d values), ", pRelation->GetAttribute(pRelation->GetOutputIndex(0))->GetValueCount());
	char szBuf[64];
	printf("Start time=%s\n", GTime::GetAsciiTime(szBuf, 64));
*/
	// Print the title
	int i;
	char szTitle[32];
	PathData pd;
	GFile::ParsePath(szFilename, &pd);
	int len = MIN(16, pd.extStart - pd.fileStart);
	memcpy(szTitle, &szFilename[pd.fileStart], len);
	for(i = len; i < 16; i++)
		szTitle[i] = ' ';
	szTitle[16] = '\0';
	printf("%s,", szTitle);

	// Measure and Print the scores
	double dScore;
	for(i = 0; i < LEARNER_COUNT; i++)
	{
		HandleEvents(0);
		if(!m_bKeepRunning)
		{
			printf("### Interrupted!\n");
			break;
		}
		const LearnerEntry* pEntry = &g_pAlgorithmTable[i];
		GSupervisedLearner* pLearner = pEntry->Make(pRelation);
		Holder<GSupervisedLearner*> hLearner(pLearner);
		dScore = pLearner->CrossValidate(pData, 10, bRegression);
		pScores[2 * i + (bRegression ? 0 : 1)] += dScore;
		if(i > 0)
			printf(",");
		printf("%.5f", dScore);
	}
	printf("\n");
}

void RankController::RankLearnersAgainstAllDatasets()
{
	chdir(ControllerBase::GetAppPath());
	if(chdir("arff") != 0)
		throw "Expected a folder named 'arff' containing a bunch of arff files";
	double* pScores = new double[LEARNER_COUNT * 2];
	int i, j;
	for(i = 0; i < LEARNER_COUNT; i++)
	{
		pScores[2 * i] = 0;
		pScores[2 * i + 1] = 0;
	}
	Holder<double*> hScores(pScores);
	int nDataSetCount = 0;
	int nRegressions = 0;
	int nClassifications = 0;
	GDirList dl(/*bRecurseSubDirs*/true, /*bReportFiles*/true, /*bReportDirs*/false, /*bReportPaths*/true);
	char szBuf[256];
	printf("%% Comparison of Learning Algorithms (Start Time = %s)\n", GTime::GetAsciiTime(szBuf, 256));
	printf("\n@RELATION untitiled\n\n");
	printf("@ATTRIBUTE database\tCOMMENT\n");
	for(i = 0; i < LEARNER_COUNT; i++)
	{
		strcpy(szBuf, g_pAlgorithmTable[i].m_szTitle);
		for(j = 0; szBuf[j] != '\0'; j++)
		{
			if(szBuf[j] <= ' ')
				szBuf[j] = '_';
		}
		printf("@ATTRIBUTE %s\tREAL\n", szBuf);
	}
	printf("\n@DATA\n");
	while(true)
	{
		const char* szFilename = dl.GetNext();
		if(!szFilename)
			break;
		PathData pd;
		GFile::ParsePath(szFilename, &pd);

		// Skip non-ARFF files
		if(stricmp(&szFilename[pd.extStart], ".arff") != 0)
			continue;

		GArffRelation* pRelation = NULL;
		GArffData* pData = NULL;
		try
		{
			GArffRelation::LoadArffFile(&pRelation, &pData, szFilename);
		}
		catch(const char* szError)
		{
			printf("### Error: %s", szError);
		}
		Holder<GArffData*> hData(pData);
		Holder<GArffRelation*> hRelation(pRelation);
		if(!pRelation)
			continue;

		// Use the last attribute for the output
		pRelation->GetAttribute(pRelation->GetAttributeCount() - 1)->SetIsInput(false);

		// Run tests
		bool bRegression = pRelation->GetAttribute(pRelation->GetOutputIndex(0))->IsContinuous();
//pData->AddGaussianNoiseDimensions(pRelation, 100); // todo: remove this line
		RankLearners(bRegression, &szFilename[pd.fileStart], pRelation, pData, pScores);

		if(bRegression)
			nRegressions++;
		else
			nClassifications++;
		nDataSetCount++;
		if(!m_bKeepRunning)
			break;
	}

	// Show summary
	printf("%% Summary with %d datasets: (Finish Time = %s)\n", nDataSetCount, GTime::GetAsciiTime(szBuf, 256));
	printf("%% -----\n");
	for(i = 0; i < LEARNER_COUNT; i++)
		printf("%% %s: Regression Sum Squared Error=%f, Classification Accuracy=%f\n", g_pAlgorithmTable[i].m_szTitle, (nRegressions > 0 ? pScores[2 * i] / nRegressions : 0) , (nClassifications > 0 ? pScores[2 * i + 1] / nClassifications : 0));
}

void RankController::RunModal()
{
	OpenAppFile("../doc/Waffles/Rank.html");
	m_pView->Update();
	RankLearnersAgainstAllDatasets();
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?