gnaiveinstance.cpp

来自「一个由Mike Gashler完成的机器学习方面的includes neural」· C++ 代码 · 共 302 行

CPP
302
字号
/*
	Copyright (C) 2006, Mike Gashler

	This library is free software; you can redistribute it and/or
	modify it under the terms of the GNU Lesser General Public
	License as published by the Free Software Foundation; either
	version 2.1 of the License, or (at your option) any later version.

	see http://www.gnu.org/copyleft/lesser.html
*/

#include "GNaiveInstance.h"
#include "GAVLTree.h"
#include "GArff.h"
#include <math.h>

class GNaiveInstanceNodeBase : public GAVLNode
{
protected:
	double m_dInput;

public:
	GNaiveInstanceNodeBase(double dInput)
		: GAVLNode()
	{
		m_dInput = dInput;
	}

	virtual ~GNaiveInstanceNodeBase()
	{
	}

	double GetInputValue()
	{
		return m_dInput;
	}

	virtual int Compare(GAVLNode* pThat)
	{
		GNaiveInstanceNodeBase* pOther = (GNaiveInstanceNodeBase*)pThat;
		if(m_dInput < pOther->m_dInput)
			return -1;
		else if(m_dInput > pOther->m_dInput)
			return 1;
		else
			return 0;
	}
};



class GNaiveInstanceNode : public GNaiveInstanceNodeBase
{
protected:
	double* m_pOutputs;

public:
	// Takes ownership of pOutputVector
	GNaiveInstanceNode(double dInput, double* pOutputVector)
		: GNaiveInstanceNodeBase(dInput)
	{
		m_pOutputs = pOutputVector;
		m_pOutputs[0]++;
	}

	virtual ~GNaiveInstanceNode()
	{
		if(--m_pOutputs[0] <= 0)
			delete[] m_pOutputs;
	}

	double GetOutput(int i)
	{
		return m_pOutputs[i + 1];
	}
};

// -----------------------------------------------------------

class GNaiveInstanceAttr : public GAVLTree
{
protected:
	double* m_pSumOfValues;
	double* m_pSumOfSquaredValues;

public:
	GNaiveInstanceAttr(int nOutputs)
		: GAVLTree()
	{
		m_pSumOfValues = new double[nOutputs * 2];
		m_pSumOfSquaredValues = &m_pSumOfValues[nOutputs];
		int i;
		for(i = 0; i < nOutputs; i++)
		{
			m_pSumOfValues[i] = 0;
			m_pSumOfSquaredValues[i] = 0;
		}
	}

	virtual ~GNaiveInstanceAttr()
	{
		delete[] m_pSumOfValues;
	}

	double GetVariance(int nOutput)
	{
		// Variance = the mean of the squares minus the square of the mean. Then
		// I rearranged the terms a little to minimize the number of divisions.
		int nInstances = GetSize();
		return (m_pSumOfSquaredValues[nOutput] * nInstances - m_pSumOfValues[nOutput] * m_pSumOfValues[nOutput]) / (nInstances * nInstances);
	}

	void AddInstance(double dInput, int nOutputs, double* pOutputs)
	{
		double d;
		int i;
		for(i = 0; i < nOutputs; i++)
		{
			d = pOutputs[i];
			m_pSumOfValues[i] += d;
			m_pSumOfSquaredValues[i] += (d * d);
		}
		Insert(new GNaiveInstanceNode(dInput, pOutputs));
	}
};

// -----------------------------------------------------------

#define VECTOR_MODE_CAP 100

GNaiveInstance::GNaiveInstance(GArffRelation* pRelation, int nNeighbors)
 : GSupervisedLearner(pRelation)
{
	m_nNeighbors = nNeighbors;
	m_nVectorModeInputs = pRelation->CountVectorModeInputs(VECTOR_MODE_CAP);
	m_nVectorModeOutputs = pRelation->CountVectorModeOutputs(VECTOR_MODE_CAP);
	m_pAttrs = new GNaiveInstanceAttr*[m_nVectorModeInputs];
	int i;
	for(i = 0; i < m_nVectorModeInputs; i++)
		m_pAttrs[i] = NULL;
	m_pValueSums = new double[4 * m_nVectorModeOutputs + m_nVectorModeInputs];
	m_pWeightSums = &m_pValueSums[m_nVectorModeOutputs];
	m_pSumBuffer = &m_pValueSums[2 * m_nVectorModeOutputs];
	m_pSumOfSquaresBuffer = &m_pValueSums[3 * m_nVectorModeOutputs];
	m_pInputBuffer = &m_pValueSums[4 * m_nVectorModeOutputs];
	Reset();
}

// virtual
GNaiveInstance::~GNaiveInstance()
{
	delete[] m_pValueSums;
	int i;
	for(i = 0; i < m_nVectorModeInputs; i++)
		delete(m_pAttrs[i]);
	delete[] m_pAttrs;
}

// virtual
void GNaiveInstance::Reset()
{
	int i;
	for(i = 0; i < m_nVectorModeInputs; i++)
	{
		delete(m_pAttrs[i]);
		m_pAttrs[i] = new GNaiveInstanceAttr(m_nVectorModeOutputs);
	}
}

void GNaiveInstance::AddInstance(double* pVector)
{
	double* pOutputs = new double[m_nVectorModeOutputs + 1];
	pOutputs[0] = 0;
	m_pRelation->OutputsToVectorMode(pVector, &pOutputs[1], VECTOR_MODE_CAP);
	m_pRelation->InputsToVectorMode(pVector, m_pInputBuffer, VECTOR_MODE_CAP);
	int i;
	for(i = 0; i < m_nVectorModeInputs; i++)
		m_pAttrs[i]->AddInstance(m_pInputBuffer[i], m_nVectorModeOutputs, pOutputs);
}

// virtual
void GNaiveInstance::Train(GArffData* pData)
{
	pData->Shuffle();
	int nCount = pData->GetSize();
	int i;
	for(i = 0; i < nCount; i++)
		AddInstance(pData->GetVector(i));
}

void GNaiveInstance::EvalInput(int nInputDim, double dInput)
{
	// Init the accumulators
	int j;
	for(j = 0; j < m_nVectorModeOutputs; j++)
	{
		m_pSumBuffer[j] = 0;
		m_pSumOfSquaresBuffer[j] = 0;
	}

	// Find the nodes on either side of dInput
	GNaiveInstanceAttr* pAttr = m_pAttrs[nInputDim];
	GNaiveInstanceNodeBase base(dInput);
	int nLeftIndex, nRightIndex;
	GNaiveInstanceNode* pLeft = NULL;
	GNaiveInstanceNode* pRight = (GNaiveInstanceNode*)pAttr->FindCloseNode(&base, &nRightIndex);
	if(pRight->GetInputValue() >= dInput)
	{
		nLeftIndex = nRightIndex - 1;
		if(nLeftIndex >= 0)
			pLeft = (GNaiveInstanceNode*)pAttr->GetNode(nLeftIndex);
	}
	else
	{
		pLeft = pRight;
		pRight = NULL;
		nLeftIndex = nRightIndex;
		nRightIndex = nLeftIndex + 1;
		if(nRightIndex < pAttr->GetSize() - 1)
			pRight = (GNaiveInstanceNode*)pAttr->GetNode(nRightIndex);
	}

	// Find the k-nearest neighbors
	GNaiveInstanceNode* pNode;
	int nNeighbors = 0;
	bool bAdvanceRight;
	double d;
	while(true)
	{
		// Pick the closer of the two nodes
		GAssert(!pRight || pRight->GetInputValue() >= dInput, "not on the right");
		GAssert(!pLeft || pLeft->GetInputValue() <= dInput, "not on the left");
		if(!pLeft || (pRight && pRight->GetInputValue() - dInput < dInput - pLeft->GetInputValue()))
		{
			pNode = pRight;
			bAdvanceRight = true;
		}
		else
		{
			pNode = pLeft;
			bAdvanceRight = false;
		}

		// Accumulate values
		for(j = 0; j < m_nVectorModeOutputs; j++)
		{
			d = pNode->GetOutput(j);
			m_pSumBuffer[j] += d;
			m_pSumOfSquaresBuffer[j] += (d * d);
		}

		// See if we're done
		if(++nNeighbors >= m_nNeighbors)
			break;
		if(nLeftIndex <= 0 && nRightIndex >= pAttr->GetSize() - 1)
			break;

		// Decide which way to advance
		if(nLeftIndex <= 0)
			bAdvanceRight = true;
		else if(nRightIndex >= pAttr->GetSize() - 1)
			bAdvanceRight = false;

		// Advance
		if(bAdvanceRight)
			pRight = (GNaiveInstanceNode*)pAttr->GetNode(++nRightIndex);
		else
			pLeft = (GNaiveInstanceNode*)pAttr->GetNode(--nLeftIndex);
	}

	// Compute the ratio of local variance to global variance
	for(j = 0; j < m_nVectorModeOutputs; j++)
	{
		// d = local variance
		d = (m_pSumOfSquaresBuffer[j] * nNeighbors - m_pSumBuffer[j] * m_pSumBuffer[j]) / (nNeighbors * nNeighbors);

		// d = confidence weight for these values
		d = 1.0 / (d + 1e-6);

		// Add to the sums
		m_pWeightSums[j] += d;
		m_pValueSums[j] += (d * m_pSumBuffer[j] / nNeighbors); // todo: it would be better to do linear interpolation here
	}
}

// virtual
void GNaiveInstance::Eval(double* pVector)
{
	m_pRelation->InputsToVectorMode(pVector, m_pInputBuffer, VECTOR_MODE_CAP);
	int i;
	for(i = 0; i < m_nVectorModeOutputs; i++)
	{
		m_pValueSums[i] = 0;
		m_pWeightSums[i] = 0;
	}
	for(i = 0; i < m_nVectorModeInputs; i++)
		EvalInput(i, m_pInputBuffer[i]);
	for(i = 0; i < m_nVectorModeOutputs; i++)
		m_pValueSums[i] /= m_pWeightSums[i];
	m_pRelation->VectorModeToOutputs(m_pValueSums, pVector, VECTOR_MODE_CAP);
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?