garff.cpp
来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 1,909 行 · 第 1/3 页
CPP
1,909 行
/* Copyright (C) 2006, Mike Gashler This library is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. see http://www.gnu.org/copyleft/lesser.html*/#include "GArff.h"#include "../GClasses/GArray.h"#include "../GClasses/GMacros.h"#include "../GClasses/GMath.h"#include "../GClasses/GVector.h"#include "../GClasses/GFile.h"#include <math.h>#include "GBits.h"#include "GMatrix.h"GArffRelation::GArffRelation(){ m_szName = NULL; m_pAttributes = new GPointerArray(32); m_nInputCount = -1; m_pInputIndexes = NULL; m_nOutputCount = -1; m_pOutputIndexes = NULL;}GArffRelation::~GArffRelation(){ int n; int nCount; nCount = m_pAttributes->GetSize(); for(n = 0; n < nCount; n++) delete((GArffAttribute*)m_pAttributes->GetPointer(n)); delete(m_pAttributes); delete[] m_szName; delete[] m_pInputIndexes; delete[] m_pOutputIndexes;}void GArffRelation::AddAttribute(GArffAttribute* pAttr){ m_pAttributes->AddPointer(pAttr); m_nInputCount = -1; m_nOutputCount = -1; delete[] m_pInputIndexes; m_pInputIndexes = NULL; delete[] m_pOutputIndexes; m_pOutputIndexes = NULL;}int GArffRelation::CountContinuousAttributes(){ int n; int nAttributes = GetAttributeCount(); int nCount = 0; for(n = 0; n < nAttributes; n++) { GArffAttribute* pAttr = GetAttribute(n); if(pAttr->IsContinuous()) nCount++; } return nCount;}void GArffRelation::SaveArffFile(GArffData* pData, const char* szFilename){ // Open the file for writing FILE* pFile = fopen(szFilename, "w"); FileHolder hFile(pFile); if(!pFile) ThrowError("Failed to open file: %s", szFilename); // Write the relation title fputs("@RELATION ", pFile); const char* szName = GetName(); if(!szName) szName = "Untitled"; fputs(szName, pFile); fputs("\n\n", pFile); // Write the attributes char szTmp[64]; int i, j; for(i = 0; i < GetAttributeCount(); i++) { GArffAttribute* pAttr = GetAttribute(i); fputs("@ATTRIBUTE ", pFile); szName = pAttr->GetName(); if(!szName) { strcpy(szTmp, "a"); itoa(i, szTmp, 10); szName = szTmp; } fputs(szName, pFile); fputs("\t", pFile); if(pAttr->IsContinuous()) fputs("CONTINUOUS", pFile); else { fputs("{", pFile); for(j = 0; j < pAttr->GetValueCount(); j++) { szName = pAttr->GetValue(j); if(!szName) { strcpy(szTmp, "v"); itoa(j, szTmp, 10); szName = szTmp; } fputs(szName, pFile); fputs(",", pFile); } fputs("}", pFile); } fputs("\n", pFile); } // Write the data fputs("\n@DATA\n", pFile); for(i = 0; i < pData->GetSize(); i++) { double* pVector = pData->GetVector(i); for(j = 0; j < GetAttributeCount(); j++) { if(j > 0) fputs(",", pFile); GArffAttribute* pAttr = GetAttribute(j); if(pAttr->IsContinuous()) { GBits::DoubleToString(szTmp, pVector[j]); fputs(szTmp, pFile); } else { szName = pAttr->GetValue((int)pVector[j]); if(!szName) { strcpy(szTmp, "v"); itoa(j, szTmp, 10); szName = szTmp; } fputs(szName, pFile); } } fputs("\n", pFile); }}/*static*/ void GArffRelation::LoadArffFile(GArffRelation** ppOutRelation, GArffData** ppOutData, const char* szFilename){ // Load the ARFF file int nLen; char* szFile = GFile::LoadFileToBuffer(szFilename, &nLen); if(!szFile) ThrowError("Failed to load the file: %s\n", szFilename); ArrayHolder<char*> hFile(szFile); ParseArffFile(ppOutRelation, ppOutData, szFile, nLen);}/*static*/ void GArffRelation::ParseArffFile(GArffRelation** ppOutRelation, GArffData** ppOutData, const char* szFile, int nLen){ // Parse the relation name int nPos = 0; int nLine = 1; GArffRelation* pRelation = new GArffRelation(); Holder<GArffRelation*> hRelation(pRelation); while(true) { // Skip Whitespace while(nPos < nLen && szFile[nPos] <= ' ') { if(szFile[nPos] == '\n') nLine++; nPos++; } if(nPos >= nLen) ThrowError("Expected @RELATION at line %d", nLine); // Check for comments if(szFile[nPos] == '%') { for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++) { } continue; } // Parse Relation if(nLen - nPos < 9 || strnicmp(&szFile[nPos], "@RELATION", 9) != 0) ThrowError("Expected @RELATION at line %d", nLine); nPos += 9; // Skip Whitespace while(szFile[nPos] <= ' ' && nPos < nLen) { if(szFile[nPos] == '\n') nLine++; nPos++; } if(nPos >= nLen) ThrowError("Expected relation name at line %d", nLine); // Parse Name int nNameStart = nPos; while(szFile[nPos] > ' ' && nPos < nLen) nPos++; pRelation->m_szName = new char[nPos - nNameStart + 1]; memcpy(pRelation->m_szName, &szFile[nNameStart], nPos - nNameStart); pRelation->m_szName[nPos - nNameStart] = '\0'; break; } // Parse the attribute section int nCommentAttributes = 0; while(true) { // Skip Whitespace while(nPos < nLen && szFile[nPos] <= ' ') { if(szFile[nPos] == '\n') nLine++; nPos++; } if(nPos >= nLen) ThrowError("Expected @ATTRIBUTE or @DATA at line %d", nLine); // Check for comments if(szFile[nPos] == '%') { for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++) { } continue; } // Check for @DATA if(nLen - nPos < 5) // 10 = strlen("@DATA") ThrowError("Expected @DATA at line %d", nLine); if(strnicmp(&szFile[nPos], "@DATA", 5) == 0) { nPos += 5; break; } // Parse @ATTRIBUTE if(nLen - nPos < 10) // 10 = strlen("@ATTRIBUTE") ThrowError("Expected @ATTRIBUTE at line %d", nLine); if(strnicmp(&szFile[nPos], "@ATTRIBUTE", 10) != 0) ThrowError("Expected @ATTRIBUTE or @DATA at line %d", nLine); nPos += 10; GArffAttribute* pAttr = GArffAttribute::Parse(&szFile[nPos], nLen - nPos, nLine); if(pAttr->GetValueCount() < 0) nCommentAttributes++; pRelation->m_pAttributes->AddPointer(pAttr); // Move to next line for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++) { } } // Parse the data section Holder<GArffData*> hData(new GArffData(256)); GArffData* pData = hData.Get(); while(true) { // Skip Whitespace while(nPos < nLen && szFile[nPos] <= ' ') { if(szFile[nPos] == '\n') nLine++; nPos++; } if(nPos >= nLen) break; // Check for comments if(szFile[nPos] == '%') { for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++) { } continue; } // Parse the data line double* pRow = pRelation->ParseDataRow(&szFile[nPos], nLen - nPos, nLine, nCommentAttributes); GAssert(pRow, "expected data"); pData->AddVector(pRow); // Move to next line for(nPos++; szFile[nPos] != '\n' && nPos < nLen; nPos++) { } continue; } // Throw out comment attributes if(nCommentAttributes > 0) { int i; for(i = pRelation->GetAttributeCount() - 1; i >= 0; i--) { GArffAttribute* pAttr = pRelation->GetAttribute(i); if(pAttr->GetValueCount() < 0) // if it's a comment attribute { delete(pAttr); pRelation->m_pAttributes->DeleteCell(i); } } } *ppOutRelation = hRelation.Drop(); *ppOutData = hData.Drop();}double* GArffRelation::ParseDataRow(const char* szFile, int nLen, int nLine, int nCommentAttributes){ char szBuf[512]; int nAttributeCount = GetAttributeCount(); Holder<double*> hData(new double[nAttributeCount - nCommentAttributes]); double* pData = hData.Get(); GArffAttribute* pAttr; int col = 0; int n; for(n = 0; n < nAttributeCount; n++) { // Eat whitespace while(nLen > 0 && *szFile <= ' ') { if(*szFile == '\n') ThrowError("Expected more data at line %d", nLine); szFile++; nLen--; } if(nLen < 1) ThrowError("Expected more data at line %d", nLine); // Parse the next value pAttr = GetAttribute(n); int nPos; bool bQuoting = false; for(nPos = 0; nPos < nLen; nPos++) { if(szFile[nPos] == ',') break; if(szFile[nPos] == '\n') break; if(szFile[nPos] == '\'' || szFile[nPos] == '"') bQuoting = !bQuoting; if(nPos > 0 && !bQuoting && szFile[nPos] > ' ' && szFile[nPos - 1] <= ' ') { nPos--; break; } } if(pAttr->GetValueCount() >= 0) // if it's not a comment attribute { int nEnd; for(nEnd = nPos; nEnd > 0 && szFile[nEnd - 1] <= ' '; nEnd--) { } memcpy(szBuf, szFile, nEnd); szBuf[nEnd] = '\0'; if(strcmp(szBuf, "?") == 0) pData[col++] = -1; else if(pAttr->IsContinuous()) { // Parse a continuous value if(szBuf[0] == '.' || szBuf[0] == '-' || (szBuf[0] >= '0' && szBuf[0] <= '9')) pData[col++] = atof(szBuf); else ThrowError("Expected a continuous value at line %d attribute %d", nLine, n + 1); } else { // Parse an enumerated value int nVal = pAttr->FindEnumeratedValue(szBuf); if(nVal < 0) ThrowError("Unrecognized enumeration value at line %d attribute %d", nLine, n + 1); pData[col++] = nVal; } } // Advance past the attribute if(nPos < nLen) nPos++; while(nPos > 0) { szFile++; nPos--; nLen--; } } GAssert(col + nCommentAttributes == nAttributeCount, "something got off count"); return hData.Drop();}int GArffRelation::GetAttributeCount(){ return m_pAttributes->GetSize();}GArffAttribute* GArffRelation::GetAttribute(int n){ return (GArffAttribute*)m_pAttributes->GetPointer(n);}void GArffRelation::CountInputs(){ m_nInputCount = 0; m_nOutputCount = 0; int n; int nCount = GetAttributeCount(); GArffAttribute* pAttr; for(n = 0; n < nCount; n++) { pAttr = GetAttribute(n); if(pAttr->IsInput()) m_nInputCount++; else m_nOutputCount++; } delete[] m_pInputIndexes; delete[] m_pOutputIndexes; m_pInputIndexes = new int[m_nInputCount]; m_pOutputIndexes = new int[m_nOutputCount]; int nIn = 0; int nOut = 0; for(n = 0; n < nCount; n++) { pAttr = GetAttribute(n); if(pAttr->IsInput()) m_pInputIndexes[nIn++] = n; else m_pOutputIndexes[nOut++] = n; }}int GArffRelation::GetInputCount(){ if(m_nInputCount < 0) CountInputs(); return m_nInputCount;}int GArffRelation::GetOutputCount(){ if(m_nOutputCount < 0) CountInputs(); return m_nOutputCount;}int GArffRelation::GetInputIndex(int n){ if(!m_pInputIndexes) CountInputs(); GAssert(n >= 0 && n < m_nInputCount, "out of range"); return m_pInputIndexes[n];}int GArffRelation::GetOutputIndex(int n){ if(!m_pOutputIndexes) CountInputs(); GAssert(n >= 0 && n < m_nOutputCount, "out of range"); return m_pOutputIndexes[n];}double GArffRelation::MeasureTotalOutputInfo(GArffData* pData){ double dInfo = 0; int nOutputs = GetOutputCount(); int n, nIndex; GArffAttribute* pAttr; for(n = 0; n < nOutputs; n++) { nIndex = GetOutputIndex(n); pAttr = GetAttribute(nIndex); if(pAttr->IsContinuous()) dInfo += pData->ComputeVariance(pData->ComputeMean(nIndex), nIndex); else dInfo += pData->MeasureEntropy(this, nIndex); } return dInfo;}double GArffRelation::ComputeInputDistanceSquared(double* pRow1, double* pRow2){ double dSum = 0; double d; int n, nIndex; int nInputs = GetInputCount(); for(n = 0; n < nInputs; n++) { nIndex = GetInputIndex(n); if(GetAttribute(nIndex)->IsContinuous()) { d = pRow2[nIndex] - pRow1[nIndex]; dSum += (d * d); } else { if(pRow2[nIndex] != pRow1[nIndex]) dSum += 1; } } return dSum;}double GArffRelation::ComputeOutputDistanceSquared(double* pRow1, double* pRow2){ double dSum = 0; double d; int n, nIndex; int nOutputs = GetOutputCount(); for(n = 0; n < nOutputs; n++) { nIndex = GetOutputIndex(n); if(GetAttribute(nIndex)->IsContinuous()) { d = pRow2[nIndex] - pRow1[nIndex]; dSum += (d * d); } else { if(pRow2[nIndex] != pRow1[nIndex]) dSum += 1; } } return dSum;}double GArffRelation::ComputeScaledInputDistanceSquared(double* pRow1, double* pRow2, double* pInputScales){ double dSum = 0; double d; int n, nIndex; int nInputs = GetInputCount(); for(n = 0; n < nInputs; n++) { nIndex = GetInputIndex(n); if(GetAttribute(nIndex)->IsContinuous()) { d = pRow2[nIndex] * pInputScales[n] - pRow1[nIndex] * pInputScales[n]; dSum += (d * d); } else { if(pRow2[nIndex] != pRow1[nIndex]) dSum += pInputScales[n]; } } return dSum;}int GArffRelation::CountVectorModeInputs(int nCap){ int nVectorModeInputs = 0; int nInputs = GetInputCount(); // lazily count inputs int i; for(i = 0; i < nInputs; i++) { GArffAttribute* pAttr = GetAttribute(GetInputIndex(i)); if(pAttr->IsContinuous() || pAttr->GetValueCount() >= nCap) nVectorModeInputs++; else nVectorModeInputs += pAttr->GetValueCount(); } return nVectorModeInputs;}int GArffRelation::CountVectorModeOutputs(int nCap){ int nVectorModeOutputs = 0; GetOutputCount(); // lazily count outputs int i; for(i = 0; i < m_nOutputCount; i++) { GArffAttribute* pAttr = GetAttribute(GetOutputIndex(i)); if(pAttr->IsContinuous() || pAttr->GetValueCount() >= nCap) nVectorModeOutputs++; else nVectorModeOutputs += pAttr->GetValueCount(); } return nVectorModeOutputs;}void GArffRelation::InputsToVectorMode(double* pIn, double* pOut, int nCap){ int nPos = 0; int nInputs = GetInputCount(); // lazily count inputs int i, j, k, index; for(i = 0; i < nInputs; i++) { index = GetInputIndex(i); GArffAttribute* pAttr = GetAttribute(index); if(pAttr->IsContinuous() || pAttr->GetValueCount() >= nCap) pOut[nPos++] = pIn[index]; else { k = nPos; for(j = 0; j < pAttr->GetValueCount(); j++) pOut[nPos++] = 0; if(pIn[index] != -1) // -1 = unknown value. Set to all zeros. { GAssert(pIn[index] >= 0 && pIn[index] < pAttr->GetValueCount(), "out of range"); pOut[k + (int)pIn[index]] = 1; } } }}void GArffRelation::OutputsToVectorMode(double* pIn, double* pOut, int nCap){ int nPos = 0; GetOutputCount(); // lazily count outputs int i, j, k, index; for(i = 0; i < m_nOutputCount; i++) { index = GetOutputIndex(i); GArffAttribute* pAttr = GetAttribute(index); if(pAttr->IsContinuous() || pAttr->GetValueCount() >= nCap) pOut[nPos++] = pIn[index]; else { k = nPos; for(j = 0; j < pAttr->GetValueCount(); j++) pOut[nPos++] = 0; if(pIn[index] != -1) // -1 = unknown value. Set to all zeros. {
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?