📄 lugd.cpp
字号:
#pragma warning(disable: 4786)
#include "LUGD.h"
#include "..\\Core\\Exception.h"
#include "..\\Utils\\xs.h"
#include "..\\Input\\File\\ArffStream.h"
#include <cmath>
#include <iostream>
#include <iomanip>
#include <fstream>
//: for test
#include <ctime>
LUGD::LUGD(vector<Data>* train_set,LUGD_PARAM params)
{
_train_set = train_set;
_ni = params.input;
_no = params.output;
_range = params.random_range;
_max_epoch = params.max_epoch;
_enta = params.learning_rate;
this->initialize();
}
LUGD::~LUGD()
{
delete[] _wio;
delete[] _dwio;
delete[] _o;
}
//: saves the network to a specified file, for data format see the content
void LUGD::save(string file)
{
ofstream stream(file.c_str());
if(!stream.is_open())
throw new Exception("LUGD::save file: "+file + " cannot open to write");
//: file save format as following, input,output,weight sections
stream<<"INPUT="<<_ni<<endl;
stream<<"OUTPUT="<<_no<<endl;
stream<<"WEIGHT=";
int len = _ni*_no;
for(int i=0;i<len;i++)
{
if(i != len-1) stream<<_wio[i]<<",";
else stream<<_wio[i]<<endl;
}
}
//: load the network from a specified file to build a LUGD algorithm
LUGD* LUGD::load(string file)
{
ifstream stream(file.c_str());
if(!stream.is_open())
throw new Exception("Cannot open file: "+file);
//:reads input,output and weight section strings
string input,output,weight;
if(!getline(stream,input))
throw new Exception("File missing input-node number line,check for no blank line!");
if(!getline(stream,output))
throw new Exception("File missing output-number data line,check for no blank line!");
if(!getline(stream,weight))
throw new Exception("File missing weight line,check for no blank line!");
//: parses the three sections
vector<string> vi = xs::split(input," =;");
if(vi.size() != 2)
throw new Exception("input section format error!");
int ni = atoi(vi[1].c_str());
vector<string> vo = xs::split(output," =;");
if(vo.size() != 2)
throw new Exception("output section format error!");
int no = atoi(vo[1].c_str());
vector<string> vw = xs::split(weight," =,;");
if(vw.size() != ni*no+1 )
throw new Exception("weight section format error!");
LUGD_PARAM params;
params.input = ni;
params.output = no;
LUGD* lugd = new LUGD(NULL,params);
for(int i=1;i<vw.size();i++)
lugd->_wio[i-1] = atof(vw[i].c_str());
return lugd;
}
void LUGD::train()
{
if(_train_set == NULL)
throw new Exception("LUGD::train, current is working mode, training is forbidden!");
int i = 0;
while(i++<_max_epoch)
{
vector<Data>::iterator it;
for(it=_train_set->begin();it!=_train_set->end();it++)
{
forward(*it);
}
update_weight();
}
}
void LUGD::work(Data& data)
{
this->forward(data);
for(int k=0;k<_no;k++)
data[_ni+k] = _o[k];
}
void LUGD::check()
{
if(_ni < 0)
throw new Exception("LUGD::check,input layer count < 0");
cout<<"LUGD: using input layer count: "<<_ni<<endl;
if(_no < 0)
throw new Exception("LUGD::check,output layer count < 0");
cout<<"LUGD: using output layer count: "<<_no<<endl;
if(_enta < 0)
throw new Exception("LUGD::check,learning rate < 0");
if(_enta >= 1)
cout<<"WARNING: learning rate>=1.0"<<endl;
cout<<"LUGD: using learning rate: "<<_enta<<endl;
if(abs(_range)>0.5)
cout<<"WARNING: random range>0.5"<<endl;
cout<<"LUGD: using random range: "<<_range<<endl;
if(_max_epoch < 100)
cout<<"WARNING: maximum epoch count < 100"<<endl;
if(_max_epoch > 10000)
cout<<"WARNING: maximum epoch count > 10000"<<endl;
cout<<"LUGD: using maximum epoch: "<<_max_epoch<<endl;
}
double LUGD::random()
{
return _range*(rand()-RAND_MAX/2)/RAND_MAX*2;
}
void LUGD::initialize()
{
this->check();
int len = _ni*_no;
_wio = new double[len];
_dwio = new double[len];
_o = new double[_no];
for(int i=0;i<len;i++) _wio[i] = random();
memset(_dwio,0,len*sizeof(double));
}
//: forward calculates the result of the input data and delat weights
void LUGD::forward(Data data)
{
int i,k;
for(k=0;k<_no;k++)
{
_o[k] = 0;
for(i=0;i<_ni;i++)
_o [k] += _wio[i*_no+k] * data[i];
}
for(k=0;k<_no;k++)
{
for(i=0;i<_ni;i++)
_dwio[i*_no+k] += _enta*(data[_ni+k] -_o[k])*data[i];
}
}
//: updates the wights between input and output layers
void LUGD::update_weight()
{
int len = _ni*_no;
for(int i=0;i<len;i++)
_wio[i] += _dwio[i];
memset(_dwio,0,len*sizeof(double));
}
/************************************************************************
void main()
{
clock_t start,end;
start = clock();
LUGD_PARAM params;
ArffStream arff("arff//838.arff");
try
{
arff.open();
vector<Data> data;
while(arff.next())
{
data.push_back(Data(arff.data()));
}
params.input = 8;
params.output = 8;
LUGD lugd(&data,params);
lugd.train();
lugd.work(data[1]);
cout<<data[1].to_string()<<endl;
// lugd.save("arff//lugd.ini");
//
// LUGD* l = LUGD::load("arff//lugd.ini");
// cout<<data[0].to_string()<<endl;
// l->work(data[0]);
// cout<<data[0].to_string()<<endl;
}
catch(Exception* e)
{
cout<<e->message()<<endl;
delete e;
}
end = clock();
cout<<"Time eclipsed[s]:"<<(double)(end-start)/CLOCKS_PER_SEC<<endl;
}
/************************************************************************/
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -