⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 spex.cpp

📁 Gaussian Mixture Algorithm
💻 CPP
字号:
/* * spex.cpp * *      Author: cyril Poulet */#include "spex.h"extern intg trainsize;void load(const char *fname, spIdx<double> &xp, Idx<ubyte> &yp){	cout << "Loading " << fname << "." << endl;	igzstream f;	f.open(fname);	if (! f.good()){		cerr << "ERROR: cannot open " << fname << "." << endl;		cerr << "Have you preprocessed the data ? \n";		exit(10);	}	int pcount = 0;	int ncount = 0;	string suffix = fname;	if (suffix.size() >= 7) suffix = suffix.substr(suffix.size() - 7);	if (suffix != ".dat.gz"){		cerr << "ERROR: filename should end with .dat.gz" << endl;		exit(10);	}	intg currentdoc = 0;	intg imax = 0, dim = 0;	intg *xpind = xp.index()->idx_ptr();	double *xpval = xp.values()->idx_ptr();	intg *xpmodind = (intg*)(xp.index()->mods());	intg xpmodval = xp.values()->mod(0);	while (f.good()){		double y;		f >> y;		if(yp.dim(0) <= currentdoc) yp.resize(yp.dim(0)+500);		yp.set((y <= 0)? 0 : 1, currentdoc);		int dim1 = 0;		for(;;){			int c = f.get();			if (!f.good() || (c=='\n' || c=='\r'))				break;			if (::isspace(c))				continue;			int i;			f.unget();			f >> std::skipws >> i >> std::ws;			if (f.get() != ':')			{				f.unget();				f.setstate(std::ios::badbit);				break;			}			double x;			f >> std::skipws >> x;			if (!f.good())				break;			if(xp.dim(0) <= currentdoc) xp.resize(xp.dim(0)+10000, max((intg)i, xp.dim(1)));			if(i >= xp.dim(1)) xp.resize(xp.dim(0), i+500);			if(xp.nelements() + 5 >= xp.values()->nelements()){				xp.index()->resize(xp.nelements()+10000, xp.order());				xp.values()->resize( xp.nelements()+10000);				xpmodind = (intg*)(xp.index()->mods());				xpmodval = xp.values()->mod(0);				xpind = xp.index()->idx_ptr() + xp.nelements()*xpmodind[0];				xpval = xp.values()->idx_ptr()+ xp.nelements()*xpmodval;			}			*xpind = currentdoc;			*(xpind + xpmodind[1]) = i;			*xpval = x;			xpind += xpmodind[0];			xpval += xpmodval;			xp.set_nelements(xp.nelements() + 1);			imax = max(imax, (intg)i);	        dim1++;		}		if (dim1 > dim) dim = dim1;		currentdoc++;		if(y <= 0) ncount++;		else pcount++;		if (trainsize > 0 && currentdoc >= (intg)trainsize) break;		if( currentdoc % 1000 == 0) cout << "loaded doc #" << currentdoc << "\n";	}	cout << "Read " << pcount << "+" << ncount << "=" << pcount + ncount << " examples." << endl;	cout << "Number of features " << dim << "." << endl;	const intg bladims[2] = { ncount + pcount, imax + 1 };	xp.set_dims(2, bladims);}spnet::spnet(parameter *p, Idx<intg>* connection_table, intg in, intg out, double beta, Idx<ubyte> *classes):	mylinmodule(p, connection_table, in, out),	mysoftmodule(beta, classes)	{		inter = new state_spidx(out);	}spnet::~spnet(){	delete inter;}void spnet::fprop(state_spidx *in, state_spidx *out, Idx<ubyte> *label, class_state *output, state_idx *energy){	inter->clear();	mylinmodule.fprop(in, inter);	//cout << "inter :";	//inter->x.printElems();	mysoftmodule.fprop(inter, out);	//cout << "out :";	//out->x.printElems();	//cout << "label :";	//label->printElems();	mysoftmodule.calc_max(out, output);	//cout << "output class : " << (int)output->output_class << "\n";	mysoftmodule.calc_energy(out, label, energy);	//cout << "energy : ";	//energy->x.printElems();}void spnet::bprop(state_spidx *in, state_spidx *out){	inter->clear_dx();	mysoftmodule.bprop(inter, out);	//cout << "inter dx :";	//inter->dx.printElems();	mylinmodule.bprop(in, inter);}void spnet::forget(forget_param_linear forgetparam){	mylinmodule.forget(forgetparam);}/////////////////////////////////////////////////sptrainer::sptrainer(const string fname):	doclabels(trainsize),	docs(0, trainsize, 50){	nclasses = 2;	labels = new Idx<ubyte>(nclasses);	for(int i = 0; i < nclasses; i++){	     labels->set(i, i);	}	//! load training examples	load(fname.c_str(), docs, doclabels);	docs.index()->resize(docs.nelements(), docs.index()->dim(1));	docs.values()->resize(docs.nelements());	mydatasource = new spLabeledDataSource<double, ubyte>(&docs, &doclabels);	//! create elements of the trainer	// probleme de table...	myconnections = new Idx<intg>(docs.dim(1), nclasses);	*myconnections = full_table(docs.dim(1), nclasses);	myparam = new parameter(10000);	gdp = new gd_param(0.05, 0, 0, 0, 0, 0, 0, 0, 0);	mynet = new spnet(myparam, myconnections, docs.dim(1), nclasses, 1, labels);	trainmeter = new classifier_meter();	in = new state_spidx(docs.dim(1));	label = new Idx<ubyte>();	out = new state_spidx(nclasses);	output = new class_state(nclasses);	energy = new state_idx();}sptrainer::~sptrainer(){	delete labels;	delete label;	delete myconnections;	delete mydatasource;	delete myparam;	delete gdp;	delete mynet;	delete trainmeter;	delete in;	delete out;	delete output;	delete energy;}void sptrainer::train(int npass){	forget_param_linear forgetparam(1, 1/2);	init_drand(1);	mynet->forget(forgetparam);	myparam->set_epsilons(1);	age = 0;	for(int i = 0; i < npass; i++){		mydatasource->seek_begin();		int j = 0;		cout << "training :\n";		while(j <= (intg)trainsize){			train_online();			myparam->update(*gdp);			mydatasource->next();			j++;			if( ++age % 1000 == 0) cout << "processed doc #" << age << "\n";		}		age--;		test();	}}void sptrainer::train_online(){	out->clear();	out->clear_dx();	mydatasource->fprop(in, label);	mynet->fprop(in, out, label, output, energy);	myparam->clear_dx();	out->dx.set(1, label->get());	mynet->bprop(in, out);}void sptrainer::test(){	mydatasource->seek_begin();	trainmeter->clear();	cout << "testing :\n";	for (int i = 0; i < mydatasource->size(); ++i) {		mydatasource->fprop(in, label);		mynet->fprop(in, out, label, output, energy);		trainmeter->update(age, output, label->get(), energy);		mydatasource->next();	}	trainmeter->display();}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -