⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 lugd.cpp

📁 一个简单灵活的数据挖掘实验平台
💻 CPP
字号:
#pragma warning(disable: 4786)
 
#include "LUGD.h"
#include "..\\Core\\Exception.h"
#include "..\\Utils\\xs.h"
#include "..\\Input\\File\\ArffStream.h"
#include <cmath>
#include <iostream>
#include <iomanip>
#include <fstream>
//: for test
#include <ctime>

LUGD::LUGD(vector<Data>* train_set,LUGD_PARAM params)
{
	_train_set = train_set;
	_ni = params.input;
	_no = params.output;
	_range = params.random_range;
	_max_epoch = params.max_epoch;
	_enta = params.learning_rate;

	this->initialize();
}
LUGD::~LUGD()
{
	delete[] _wio;
	delete[] _dwio;
	delete[] _o;
}
//: saves the network to a specified file, for data format see the content
void LUGD::save(string file)
{
	ofstream stream(file.c_str());
	if(!stream.is_open())
		throw new Exception("LUGD::save file: "+file + " cannot open to write");
	//: file save format as following, input,output,weight sections
	stream<<"INPUT="<<_ni<<endl;
	stream<<"OUTPUT="<<_no<<endl;
	stream<<"WEIGHT=";
	int len = _ni*_no;
	for(int i=0;i<len;i++)
	{
		if(i != len-1)	stream<<_wio[i]<<",";
		else stream<<_wio[i]<<endl;
	}
}
//: load the network from a specified file to build a LUGD algorithm
LUGD* LUGD::load(string file)
{
	ifstream stream(file.c_str());
	if(!stream.is_open())
		throw new Exception("Cannot open file: "+file);

	//:reads input,output and weight section strings
	string input,output,weight;
	if(!getline(stream,input))
		throw new Exception("File missing input-node number line,check for no blank line!");
	if(!getline(stream,output))
		throw new Exception("File missing output-number data line,check for no blank line!");
	if(!getline(stream,weight))
		throw new Exception("File missing weight line,check for no blank line!");

	//: parses the three sections
	vector<string> vi = xs::split(input," =;");
	if(vi.size() != 2)
		throw new Exception("input section format error!");
	int ni = atoi(vi[1].c_str());
	vector<string> vo = xs::split(output," =;");
	if(vo.size() != 2)
		throw new Exception("output section format error!");
	int no = atoi(vo[1].c_str());
	vector<string> vw = xs::split(weight," =,;");
	if(vw.size() != ni*no+1 )
		throw new Exception("weight section format error!");

	LUGD_PARAM params;
	params.input = ni;
	params.output = no;
	LUGD* lugd = new LUGD(NULL,params);
	for(int i=1;i<vw.size();i++)
		lugd->_wio[i-1] = atof(vw[i].c_str());

	return lugd;
}

void LUGD::train()
{
	if(_train_set == NULL)
		throw new Exception("LUGD::train, current is working mode, training is forbidden!");

	int i = 0;
	while(i++<_max_epoch)
	{
		vector<Data>::iterator it;
		for(it=_train_set->begin();it!=_train_set->end();it++)
		{
			forward(*it);
		}
		update_weight();
	}
}

void LUGD::work(Data& data)
{
	this->forward(data);
	for(int k=0;k<_no;k++)
		data[_ni+k] = _o[k];
}
void LUGD::check()
{
	if(_ni < 0)
		throw new Exception("LUGD::check,input layer count < 0");
	cout<<"LUGD: using input layer count: "<<_ni<<endl;
	if(_no < 0)
		throw new Exception("LUGD::check,output layer count < 0");
	cout<<"LUGD: using output layer count: "<<_no<<endl;
	if(_enta < 0)
		throw new Exception("LUGD::check,learning rate < 0");

	if(_enta >= 1)
		cout<<"WARNING: learning rate>=1.0"<<endl;
	cout<<"LUGD: using learning rate: "<<_enta<<endl;
	if(abs(_range)>0.5)
		cout<<"WARNING: random range>0.5"<<endl;
	cout<<"LUGD: using random range: "<<_range<<endl;
	if(_max_epoch < 100)
		cout<<"WARNING: maximum epoch count < 100"<<endl;
	if(_max_epoch > 10000)
		cout<<"WARNING: maximum epoch count > 10000"<<endl;
	cout<<"LUGD: using maximum epoch: "<<_max_epoch<<endl;
}
double LUGD::random()
{
	return _range*(rand()-RAND_MAX/2)/RAND_MAX*2;
}
void LUGD::initialize()
{
	this->check();
	int len = _ni*_no;
	_wio  = new double[len];
	_dwio = new double[len];
	_o    = new double[_no];
	
	for(int i=0;i<len;i++) _wio[i] = random();
	memset(_dwio,0,len*sizeof(double));
}

//: forward calculates the result of the input data and delat weights
void LUGD::forward(Data data)
{
	int i,k;
	for(k=0;k<_no;k++)
	{
		_o[k] = 0;
		for(i=0;i<_ni;i++)
			_o [k] += _wio[i*_no+k] * data[i];
	}
	for(k=0;k<_no;k++)
	{
		for(i=0;i<_ni;i++)
			_dwio[i*_no+k] += _enta*(data[_ni+k] -_o[k])*data[i];
	}
}
//: updates the wights between input and output layers
void LUGD::update_weight()
{
	int len = _ni*_no;

	for(int i=0;i<len;i++)
		_wio[i] += _dwio[i];
	memset(_dwio,0,len*sizeof(double));
}
/************************************************************************
void main()
{
	clock_t start,end;
	
	start = clock();
	LUGD_PARAM params;
	ArffStream arff("arff//838.arff");
	try
	{
		arff.open();
		vector<Data> data;
		while(arff.next())
		{
			data.push_back(Data(arff.data()));
		}
		params.input = 8;
		params.output = 8;

		LUGD lugd(&data,params);

	
		lugd.train();

		lugd.work(data[1]);
		cout<<data[1].to_string()<<endl;
//		lugd.save("arff//lugd.ini");
//		
//		LUGD* l = LUGD::load("arff//lugd.ini");
//		cout<<data[0].to_string()<<endl;
//		l->work(data[0]);
//		cout<<data[0].to_string()<<endl;
	}
	catch(Exception* e)
	{
		cout<<e->message()<<endl;
		delete e;
	}
	end = clock();
	cout<<"Time eclipsed[s]:"<<(double)(end-start)/CLOCKS_PER_SEC<<endl;
}
/************************************************************************/

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -