garff.cpp

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 1,909 行 · 第 1/3 页

CPP
1,909
字号
/*	Copyright (C) 2006, Mike Gashler	This library is free software; you can redistribute it and/or	modify it under the terms of the GNU Lesser General Public	License as published by the Free Software Foundation; either	version 2.1 of the License, or (at your option) any later version.	see http://www.gnu.org/copyleft/lesser.html*/#include "GArff.h"#include "../GClasses/GArray.h"#include "../GClasses/GMacros.h"#include "../GClasses/GMath.h"#include "../GClasses/GVector.h"#include "../GClasses/GFile.h"#include <math.h>#include "GBits.h"#include "GMatrix.h"GArffRelation::GArffRelation(){	m_szName = NULL;	m_pAttributes = new GPointerArray(32);	m_nInputCount = -1;	m_pInputIndexes = NULL;	m_nOutputCount = -1;	m_pOutputIndexes = NULL;}GArffRelation::~GArffRelation(){	int n;	int nCount;	nCount = m_pAttributes->GetSize();	for(n = 0; n < nCount; n++)		delete((GArffAttribute*)m_pAttributes->GetPointer(n));	delete(m_pAttributes);	delete[] m_szName;	delete[] m_pInputIndexes;	delete[] m_pOutputIndexes;}void GArffRelation::AddAttribute(GArffAttribute* pAttr){	m_pAttributes->AddPointer(pAttr);	m_nInputCount = -1;	m_nOutputCount = -1;	delete[] m_pInputIndexes;	m_pInputIndexes = NULL;	delete[] m_pOutputIndexes;	m_pOutputIndexes = NULL;}int GArffRelation::CountContinuousAttributes(){	int n;	int nAttributes = GetAttributeCount();	int nCount = 0;	for(n = 0; n < nAttributes; n++)	{		GArffAttribute* pAttr = GetAttribute(n);		if(pAttr->IsContinuous())			nCount++;	}	return nCount;}void GArffRelation::SaveArffFile(GArffData* pData, const char* szFilename){	// Open the file for writing	FILE* pFile = fopen(szFilename, "w");	FileHolder hFile(pFile);	if(!pFile)		ThrowError("Failed to open file: %s", szFilename);	// Write the relation title	fputs("@RELATION ", pFile);	const char* szName = GetName();	if(!szName)		szName = "Untitled";	fputs(szName, pFile);	fputs("\n\n", pFile);	// Write the attributes	char szTmp[64];	int i, j;	for(i = 0; i < GetAttributeCount(); i++)	{		GArffAttribute* pAttr = GetAttribute(i);		fputs("@ATTRIBUTE ", pFile);		szName = pAttr->GetName();		if(!szName)		{			strcpy(szTmp, "a");			itoa(i, szTmp, 10);			szName = szTmp;		}		fputs(szName, pFile);		fputs("\t", pFile);		if(pAttr->IsContinuous())			fputs("CONTINUOUS", pFile);		else		{			fputs("{", pFile);			for(j = 0; j < pAttr->GetValueCount(); j++)			{				szName = pAttr->GetValue(j);				if(!szName)				{					strcpy(szTmp, "v");					itoa(j, szTmp, 10);					szName = szTmp;				}				fputs(szName, pFile);				fputs(",", pFile);			}			fputs("}", pFile);		}		fputs("\n", pFile);	}	// Write the data	fputs("\n@DATA\n", pFile);	for(i = 0; i < pData->GetSize(); i++)	{		double* pVector = pData->GetVector(i);		for(j = 0; j < GetAttributeCount(); j++)		{			if(j > 0)				fputs(",", pFile);			GArffAttribute* pAttr = GetAttribute(j);			if(pAttr->IsContinuous())			{				GBits::DoubleToString(szTmp, pVector[j]);				fputs(szTmp, pFile);			}			else			{				szName = pAttr->GetValue((int)pVector[j]);				if(!szName)				{					strcpy(szTmp, "v");					itoa(j, szTmp, 10);					szName = szTmp;				}				fputs(szName, pFile);			}		}		fputs("\n", pFile);	}}/*static*/ void GArffRelation::LoadArffFile(GArffRelation** ppOutRelation, GArffData** ppOutData, const char* szFilename){	// Load the ARFF file	int nLen;	char* szFile = GFile::LoadFileToBuffer(szFilename, &nLen);	if(!szFile)		ThrowError("Failed to load the file: %s\n", szFilename);	ArrayHolder<char*> hFile(szFile);	ParseArffFile(ppOutRelation, ppOutData, szFile, nLen);}/*static*/ void GArffRelation::ParseArffFile(GArffRelation** ppOutRelation, GArffData** ppOutData, const char* szFile, int nLen){	// Parse the relation name	int nPos = 0;	int nLine = 1;	GArffRelation* pRelation = new GArffRelation();	Holder<GArffRelation*> hRelation(pRelation);	while(true)	{		// Skip Whitespace		while(nPos < nLen && szFile[nPos] <= ' ')		{			if(szFile[nPos] == '\n')				nLine++;			nPos++;		}		if(nPos >= nLen)			ThrowError("Expected @RELATION at line %d", nLine);		// Check for comments		if(szFile[nPos] == '%')		{			for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++)			{			}			continue;		}		// Parse Relation		if(nLen - nPos < 9 || strnicmp(&szFile[nPos], "@RELATION", 9) != 0)			ThrowError("Expected @RELATION at line %d", nLine);		nPos += 9;		// Skip Whitespace		while(szFile[nPos] <= ' ' && nPos < nLen)		{			if(szFile[nPos] == '\n')				nLine++;			nPos++;		}		if(nPos >= nLen)			ThrowError("Expected relation name at line %d", nLine);		// Parse Name		int nNameStart = nPos;		while(szFile[nPos] > ' ' && nPos < nLen)			nPos++;		pRelation->m_szName = new char[nPos - nNameStart + 1];		memcpy(pRelation->m_szName, &szFile[nNameStart], nPos - nNameStart);		pRelation->m_szName[nPos - nNameStart] = '\0';		break;	}	// Parse the attribute section	int nCommentAttributes = 0;	while(true)	{		// Skip Whitespace		while(nPos < nLen && szFile[nPos] <= ' ')		{			if(szFile[nPos] == '\n')				nLine++;			nPos++;		}		if(nPos >= nLen)			ThrowError("Expected @ATTRIBUTE or @DATA at line %d", nLine);		// Check for comments		if(szFile[nPos] == '%')		{			for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++)			{			}			continue;		}		// Check for @DATA		if(nLen - nPos < 5) // 10 = strlen("@DATA")			ThrowError("Expected @DATA at line %d", nLine);		if(strnicmp(&szFile[nPos], "@DATA", 5) == 0)		{			nPos += 5;			break;		}		// Parse @ATTRIBUTE		if(nLen - nPos < 10) // 10 = strlen("@ATTRIBUTE")			ThrowError("Expected @ATTRIBUTE at line %d", nLine);		if(strnicmp(&szFile[nPos], "@ATTRIBUTE", 10) != 0)			ThrowError("Expected @ATTRIBUTE or @DATA at line %d", nLine);		nPos += 10;		GArffAttribute* pAttr = GArffAttribute::Parse(&szFile[nPos], nLen - nPos, nLine);		if(pAttr->GetValueCount() < 0)			nCommentAttributes++;		pRelation->m_pAttributes->AddPointer(pAttr);		// Move to next line		for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++)		{		}	}	// Parse the data section	Holder<GArffData*> hData(new GArffData(256));	GArffData* pData = hData.Get();	while(true)	{		// Skip Whitespace		while(nPos < nLen && szFile[nPos] <= ' ')		{			if(szFile[nPos] == '\n')				nLine++;			nPos++;		}		if(nPos >= nLen)			break;		// Check for comments		if(szFile[nPos] == '%')		{			for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++)			{			}			continue;		}		// Parse the data line		double* pRow = pRelation->ParseDataRow(&szFile[nPos], nLen - nPos, nLine, nCommentAttributes);		GAssert(pRow, "expected data");		pData->AddVector(pRow);		// Move to next line		for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++)		{		}		continue;	}	// Throw out comment attributes	if(nCommentAttributes > 0)	{		int i;		for(i = pRelation->GetAttributeCount() - 1; i >= 0; i--)		{			GArffAttribute* pAttr = pRelation->GetAttribute(i);			if(pAttr->GetValueCount() < 0) // if it's a comment attribute			{				delete(pAttr);				pRelation->m_pAttributes->DeleteCell(i);			}		}	}	*ppOutRelation = hRelation.Drop();	*ppOutData = hData.Drop();}double* GArffRelation::ParseDataRow(const char* szFile, int nLen, int nLine, int nCommentAttributes){	char szBuf[512];	int nAttributeCount = GetAttributeCount();	Holder<double*> hData(new double[nAttributeCount - nCommentAttributes]);	double* pData = hData.Get();	GArffAttribute* pAttr;	int col = 0;	int n;	for(n = 0; n < nAttributeCount; n++)	{		// Eat whitespace		while(nLen > 0 && *szFile <= ' ')		{			if(*szFile == '\n')				ThrowError("Expected more data at line %d", nLine);			szFile++;			nLen--;		}		if(nLen < 1)			ThrowError("Expected more data at line %d", nLine);		// Parse the next value		pAttr = GetAttribute(n);		int nPos;		bool bQuoting = false;		for(nPos = 0; nPos < nLen; nPos++)		{			if(szFile[nPos] == ',')				break;			if(szFile[nPos] == '\n')				break;			if(szFile[nPos] == '\'' || szFile[nPos] == '"')				bQuoting = !bQuoting;			if(nPos > 0 && !bQuoting && szFile[nPos] > ' ' && szFile[nPos - 1] <= ' ')			{				nPos--;				break;			}		}		if(pAttr->GetValueCount() >= 0) // if it's not a comment attribute		{			int nEnd;			for(nEnd = nPos; nEnd > 0 && szFile[nEnd - 1] <= ' '; nEnd--)			{			}			memcpy(szBuf, szFile, nEnd);			szBuf[nEnd] = '\0';			if(strcmp(szBuf, "?") == 0)				pData[col++] = -1;			else if(pAttr->IsContinuous())			{				// Parse a continuous value				if(szBuf[0] == '.' || szBuf[0] == '-' || (szBuf[0] >= '0' && szBuf[0] <= '9'))					pData[col++] = atof(szBuf);				else					ThrowError("Expected a continuous value at line %d attribute %d", nLine, n + 1);			}			else			{				// Parse an enumerated value				int nVal = pAttr->FindEnumeratedValue(szBuf);				if(nVal < 0)					ThrowError("Unrecognized enumeration value at line %d attribute %d", nLine, n + 1);				pData[col++] = nVal;			}		}		// Advance past the attribute		if(nPos < nLen)			nPos++;		while(nPos > 0)		{			szFile++;			nPos--;			nLen--;		}	}	GAssert(col + nCommentAttributes == nAttributeCount, "something got off count");	return hData.Drop();}int GArffRelation::GetAttributeCount(){	return m_pAttributes->GetSize();}GArffAttribute* GArffRelation::GetAttribute(int n){	return (GArffAttribute*)m_pAttributes->GetPointer(n);}void GArffRelation::CountInputs(){	m_nInputCount = 0;	m_nOutputCount = 0;	int n;	int nCount = GetAttributeCount();	GArffAttribute* pAttr;	for(n = 0; n < nCount; n++)	{		pAttr = GetAttribute(n);		if(pAttr->IsInput())			m_nInputCount++;		else			m_nOutputCount++;	}	delete[] m_pInputIndexes;	delete[] m_pOutputIndexes;	m_pInputIndexes = new int[m_nInputCount];	m_pOutputIndexes = new int[m_nOutputCount];	int nIn = 0;	int nOut = 0;	for(n = 0; n < nCount; n++)	{		pAttr = GetAttribute(n);		if(pAttr->IsInput())			m_pInputIndexes[nIn++] = n;		else			m_pOutputIndexes[nOut++] = n;	}}int GArffRelation::GetInputCount(){	if(m_nInputCount < 0)		CountInputs();	return m_nInputCount;}int GArffRelation::GetOutputCount(){	if(m_nOutputCount < 0)		CountInputs();	return m_nOutputCount;}int GArffRelation::GetInputIndex(int n){	if(!m_pInputIndexes)		CountInputs();	GAssert(n >= 0 && n < m_nInputCount, "out of range");	return m_pInputIndexes[n];}int GArffRelation::GetOutputIndex(int n){	if(!m_pOutputIndexes)		CountInputs();	GAssert(n >= 0 && n < m_nOutputCount, "out of range");	return m_pOutputIndexes[n];}double GArffRelation::MeasureTotalOutputInfo(GArffData* pData){	double dInfo = 0;	int nOutputs = GetOutputCount();	int n, nIndex;	GArffAttribute* pAttr;	for(n = 0; n < nOutputs; n++)	{		nIndex = GetOutputIndex(n);		pAttr = GetAttribute(nIndex);		if(pAttr->IsContinuous())			dInfo += pData->ComputeVariance(pData->ComputeMean(nIndex), nIndex);		else			dInfo += pData->MeasureEntropy(this, nIndex);	}	return dInfo;}double GArffRelation::ComputeInputDistanceSquared(double* pRow1, double* pRow2){	double dSum = 0;	double d;	int n, nIndex;	int nInputs = GetInputCount();	for(n = 0; n < nInputs; n++)	{		nIndex = GetInputIndex(n);		if(GetAttribute(nIndex)->IsContinuous())		{			d = pRow2[nIndex] - pRow1[nIndex];			dSum += (d * d);		}		else		{			if(pRow2[nIndex] != pRow1[nIndex])				dSum += 1;		}	}	return dSum;}double GArffRelation::ComputeOutputDistanceSquared(double* pRow1, double* pRow2){	double dSum = 0;	double d;	int n, nIndex;	int nOutputs = GetOutputCount();	for(n = 0; n < nOutputs; n++)	{		nIndex = GetOutputIndex(n);		if(GetAttribute(nIndex)->IsContinuous())		{			d = pRow2[nIndex] - pRow1[nIndex];			dSum += (d * d);		}		else		{			if(pRow2[nIndex] != pRow1[nIndex])				dSum += 1;		}	}	return dSum;}double GArffRelation::ComputeScaledInputDistanceSquared(double* pRow1, double* pRow2, double* pInputScales){	double dSum = 0;	double d;	int n, nIndex;	int nInputs = GetInputCount();	for(n = 0; n < nInputs; n++)	{		nIndex = GetInputIndex(n);		if(GetAttribute(nIndex)->IsContinuous())		{			d = pRow2[nIndex] * pInputScales[n] - pRow1[nIndex] * pInputScales[n];			dSum += (d * d);		}		else		{			if(pRow2[nIndex] != pRow1[nIndex])				dSum += pInputScales[n];		}	}	return dSum;}int GArffRelation::CountVectorModeInputs(int nCap){	int nVectorModeInputs = 0;	int nInputs = GetInputCount(); // lazily count inputs	int i;	for(i = 0; i < nInputs; i++)	{		GArffAttribute* pAttr = GetAttribute(GetInputIndex(i));		if(pAttr->IsContinuous() || pAttr->GetValueCount() >= nCap)			nVectorModeInputs++;		else			nVectorModeInputs += pAttr->GetValueCount();	}	return nVectorModeInputs;}int GArffRelation::CountVectorModeOutputs(int nCap){	int nVectorModeOutputs = 0;	GetOutputCount(); // lazily count outputs	int i;	for(i = 0; i < m_nOutputCount; i++)	{		GArffAttribute* pAttr = GetAttribute(GetOutputIndex(i));		if(pAttr->IsContinuous() || pAttr->GetValueCount() >= nCap)			nVectorModeOutputs++;		else			nVectorModeOutputs += pAttr->GetValueCount();	}	return nVectorModeOutputs;}void GArffRelation::InputsToVectorMode(double* pIn, double* pOut, int nCap){	int nPos = 0;	int nInputs = GetInputCount(); // lazily count inputs	int i, j, k, index;	for(i = 0; i < nInputs; i++)	{		index = GetInputIndex(i);		GArffAttribute* pAttr = GetAttribute(index);		if(pAttr->IsContinuous() || pAttr->GetValueCount() >= nCap)			pOut[nPos++] = pIn[index];		else		{			k = nPos;			for(j = 0; j < pAttr->GetValueCount(); j++)				pOut[nPos++] = 0;			if(pIn[index] != -1) // -1 = unknown value. Set to all zeros.			{				GAssert(pIn[index] >= 0 && pIn[index] < pAttr->GetValueCount(), "out of range");				pOut[k + (int)pIn[index]] = 1;			}		}	}}void GArffRelation::OutputsToVectorMode(double* pIn, double* pOut, int nCap){	int nPos = 0;	GetOutputCount(); // lazily count outputs	int i, j, k, index;	for(i = 0; i < m_nOutputCount; i++)	{		index = GetOutputIndex(i);		GArffAttribute* pAttr = GetAttribute(index);		if(pAttr->IsContinuous() || pAttr->GetValueCount() >= nCap)			pOut[nPos++] = pIn[index];		else		{			k = nPos;			for(j = 0; j < pAttr->GetValueCount(); j++)				pOut[nPos++] = 0;			if(pIn[index] != -1) // -1 = unknown value. Set to all zeros.			{

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?