neuron.cpp

来自「多层神经网络范例 http://www.codeproject.com/cpp」· C++ 代码 · 共 81 行

CPP
81
字号
#include "StdAfx.h"
#include ".\neuron.h"
#include "synapse.h"
#include <math.h>

double Neuron::momentum = 0.9;
double Neuron::learningRate = 0.05;

Neuron::~Neuron(void)
{
	inlinks.RemoveAll();
	outlinks.RemoveAll();
}

void Neuron::computeOutput()
{
	sum=0.0;
	POSITION pos = inlinks.GetHeadPosition();
	Synapse* synapse = inlinks.GetAt(pos);
	for (int i = 0; i < inlinks.GetCount(); i++) 
	{
		synapse = inlinks.GetNext(pos);
		sum += synapse->from->getOutput()*synapse->getWeight();
	}
	output = 1.0/(1.0 + exp(-sum)); // sigmoid function
}
void Neuron::computeBackpropDelta(double d) // for an output neuron
{
	delta = (d - output) * output * (1.0 - output);
}
void Neuron::computeBackpropDelta() // for a hidden neuron
{
	double errorSum = 0.0;
	POSITION pos = outlinks.GetHeadPosition();
	Synapse* synapse = outlinks.GetAt(pos);
	for (int i = 0; i < outlinks.GetCount(); i++) 
	{
		synapse = outlinks.GetNext(pos);
		errorSum += synapse->to->delta * synapse->getWeight();
	}
	delta = output * ( 1.0 - output) * errorSum;
}
void Neuron::computeWeight()
{
	POSITION pos = inlinks.GetHeadPosition();
	Synapse* synapse = (Synapse*)inlinks.GetAt(pos);
	for (int i = 0; i < inlinks.GetCount(); i++) 
	{
		synapse = (Synapse*)inlinks.GetNext(pos);
		synapse->data = learningRate*delta*synapse->from->getOutput()
			+ momentum*synapse->data;
		synapse->weight += synapse->data;
	}
}
CString Neuron::print()
{
	CString str;
	str.Format(_T("\n%s = %f : \n"),label,output);
	POSITION pos = outlinks.GetHeadPosition();
	Synapse* synapse = NULL;
	for (int i = 0; i < outlinks.GetCount(); i++) 
	{
		synapse = outlinks.GetNext(pos);
		str.AppendFormat(_T("%s(%.4f)  "),synapse->to->label,synapse->weight);
	}
	OutputDebugString(str);
	return str;
}

int Neuron::SetWeights(double* pWeights)
{
	POSITION pos = outlinks.GetHeadPosition();
	Synapse* synapse = NULL;
	for (int i = 0; i < outlinks.GetCount(); i++) 
	{
		synapse = outlinks.GetNext(pos);
		synapse->weight = pWeights[i];
	}
	return 0;
}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?