⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 simpleneuron.cpp

📁 基于VC开发的神经网络工具箱
💻 CPP
字号:
#include "../include/Exception.h"
#include "../include/SimpleNeuron.h"

#include <cstdio>
using namespace std;

namespace annie
{

SimpleNeuron::SimpleNeuron(int label, bool hasBias) : Neuron(label)
{
	_activationFunction = sigmoid;
	_dActivationFunction = dsigmoid;
	_hasBias = hasBias;
	_bias = _deltaBias = 0.0;
	_classHeirarchy.push_back(_SIMPLE_NEURON_STRING);
}

void
SimpleNeuron::connect(Neuron *from, real weight)
{
	Link *link = new Link(this,from,weight);
	disconnect(from);
	_inputLinks.push_back(link);
	from->_outputLinks.push_back(link);
	invalidateOutputCache();
}

void
SimpleNeuron::connect(Neuron *from)
{
	connect(from,random());
}

void
SimpleNeuron::setActivationFunction(ActivationFunction f, ActivationFunction df)
{
	_activationFunction = f;
	_dActivationFunction = df;
	invalidateOutputCache();
	invalidateErrorCache();
}

void
SimpleNeuron::setDesiredOutput(real desired)
{
	if (getOutputCount()!=0)
	{
		string error(getClassName());
		error = error + "::setDesiredOutput() - Called for a non-output neuron";
		throw Exception(error);
	}
	//real output = getOutput();
	//real PI = 3.1415926535897932384626433832795;
	//if (output > 2*PI)
	//	output -= 2*PI;
	//real temp = (desired - output);
	//if (temp > PI)
	//	temp = 2*PI - temp;
	invalidateErrorCache();
	_errorCache = _dActivationFunction(getActivation())*(desired-getOutput());
	_recacheError();
}

void
SimpleNeuron::_recacheError()
{
	if (getOutputCount()==0) //ie an output neuron
		_errorCacheValid = true;
	if (_errorCacheValid)
		return;
	//Note that this work will be done only if this is a NON-OUTPUT neuron
	//ie, it's output is taken as input by other neurons
	LINKS::iterator it;
	_errorCache = 0;
	for (it = _outputLinks.begin(); it!=_outputLinks.end(); it++)
	{
		Link *l = (Link *)(*it);
		_errorCache+=(l->getWeight()*l->getDestination()->getError());
	}
	_errorCache = _errorCache * _dActivationFunction(getActivation());
	_errorCacheValid = true;
}

void
SimpleNeuron::_recacheOutput()
{
	if (_outputCacheValid)
		return;
	LINKS::iterator it;

	_activationCache = 0;
	for (it=_inputLinks.begin();it!=_inputLinks.end();it++)
	{
		Link *l = (Link *)(*it);
		_activationCache+=l->getSource()->getOutput()*l->getWeight();
	}
	if (_hasBias)
		_activationCache+=_bias;
	_outputCache = _activationFunction(_activationCache);
	_outputCacheValid = true;
}
	
void
SimpleNeuron::calculateNewWeights(real learningRate, real momentum)
{
	if (getInputCount()==0)
	{
		string error(getClassName());
		error = error + "::calculateNewWeights() - Called for a neuron with no inputs";
		throw Exception(error);
	}
	LINKS::iterator it;
	for (it=_inputLinks.begin();it!=_inputLinks.end();it++)
	{
		Link *link = (Link *)(*it);
		real input = link->getSource()->getOutput();
		link->setDeltaWeight(getError()*input*learningRate+momentum*link->getLastDeltaWeight());
	}
	if (_hasBias)
		_deltaBias = getError()*learningRate;
}

void
SimpleNeuron::update()
{
	LINKS::iterator it;
	for (it=_inputLinks.begin();it!=_inputLinks.end();it++)
	{
		Link *link = (Link*)(*it);
		link->updateWeight();
	}
	if (_hasBias)
	{
		_bias+=_deltaBias;
		_deltaBias=0.0;
	}
	invalidateOutputCache();
	invalidateErrorCache();
}

void
SimpleNeuron::removeBias()
{
	_hasBias = false;
	_bias = 0.0;
	invalidateOutputCache();
	invalidateErrorCache();
}

void
SimpleNeuron::setBias(real bias)
{
	if (!_hasBias)
		throw Exception("SimpleNeuron::setBias() - This neuron isn't supposed to have a bias");
	_bias = bias;
	invalidateOutputCache();
}

real
SimpleNeuron::getBias()
{	return _bias;	}

bool
SimpleNeuron::hasBias()
{	return _hasBias;	}

string
SimpleNeuron::toString()
{
	string ans = Neuron::toString();
	if (_hasBias)
	{
		char buffer[1000];
		sprintf(buffer,"- Bias = %f\n",getBias());
		ans = ans + buffer;
	}
	return ans;
}

const char *
SimpleNeuron::getClassName()
{	
	return _SIMPLE_NEURON_STRING;
}

real
SimpleNeuron::getWeight(Neuron *from)
{
	LINKS::iterator i;
	for (i=_inputLinks.begin(); i!=_inputLinks.end(); i++)
	{
		Link *l = (Link*)(*i);
		if (l->getSource() == from)
			return l->getWeight();
	}
	return 0.0;
}
} //namespace annie


⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -