⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 spebm.cpp

📁 Gaussian Mixture Algorithm
💻 CPP
字号:
/* * spEbm.cpp * *      Author: cyril Poulet */#include "spEbm.h"namespace ebl {state_spidx::state_spidx():	x(), dx(), ddx()	{}state_spidx::state_spidx(intg s0, intg s1, intg s2, intg s3, intg s4, intg s5, intg s6, intg s7):	x(0,s0,s1,s2,s3,s4,s5,s6,s7),	dx(0,s0,s1,s2,s3,s4,s5,s6,s7),	ddx(0,s0,s1,s2,s3,s4,s5,s6,s7)	{		clear();		clear_dx();		clear_ddx();	}state_spidx::state_spidx(parameter *st, intg Nelem, intg s0, intg s1, intg s2, intg s3, intg s4, intg s5, intg s6, intg s7) :	x(Nelem, st->x.footprint(), st->x.getstorage(), s0, s1, s2, s3, s4, s5, s6, s7),	dx(Nelem, st->dx.footprint(), st->dx.getstorage(), s0, s1, s2, s3, s4, s5, s6, s7),	ddx(Nelem, st->ddx.footprint(), st->ddx.getstorage(), s0, s1, s2, s3, s4, s5, s6, s7){	st->resize(st->footprint() + nelements());	clear();	clear_dx();	clear_ddx();}state_spidx::~state_spidx(){}//! clear xvoid state_spidx::clear(){	idx_clear(x);}//! clear gradients dxvoid state_spidx::clear_dx(){	idx_clear(dx);}//! clear diag hessians ddxvoid state_spidx::clear_ddx(){	idx_clear(ddx);}//! return number of elementsintg state_spidx::nelements(){	return x.nelements();}//! update with gradient descentvoid state_spidx::update_gd(gd_param &arg){	if (arg.decay_l2 > 0)	{		idx_dotcacc(x, arg.decay_l2, dx);	}	if (arg.decay_l1 > 0)	{		idx_signdotcacc(x, arg.decay_l1, dx);	}	idx_dotcacc(dx, -arg.eta, x);}//! resize. The order cannot be changed with this.void state_spidx::resize(intg s0, intg s1, intg s2, intg s3, intg s4, intg s5, intg s6, intg s7){	x.resize(s0, s1, s2, s3, s4, s5, s6, s7);	dx.resize(s0, s1, s2, s3, s4, s5, s6, s7);	ddx.resize(s0, s1, s2, s3, s4, s5, s6, s7);}void state_spidx::resizeAs(state_spidx &s){	if (x.order() != s.x.order())		ylerror("State_spIdx::resizeAs accepts states with same number of dimensions");	intg dims[8] ={-1,-1,-1,-1,-1,-1,-1,-1};	for (int i=0; i<x.order(); i++){		dims[i] = s.x.dim(i);	}	resize(dims[0],dims[1],dims[2],dims[3],dims[4],dims[5],dims[6],dims[7]);}//! make a new copy of selfstate_spidx state_spidx::make_spcopy(){	intg dims[8] ={-1,-1,-1,-1,-1,-1,-1,-1};	for (int i=0; i<x.order(); i++){		dims[i] = x.dim(i);	}	state_spidx other(dims[0],dims[1],dims[2],dims[3],dims[4],dims[5],dims[6],dims[7]);	idx_copy(x,other.x);	idx_copy(dx,other.dx);	idx_copy(ddx,other.ddx);	return other;}////////////////////////////////////////////////////////////////sp_linear_module::sp_linear_module(parameter *p, Idx<intg>* connection_table, intg in, intg out){	table = connection_table;	w = new state_idx(p, out, in);}sp_linear_module::~sp_linear_module(){	delete w;}void sp_linear_module::fprop(state_spidx *in, state_spidx *out){	out->resize(w->x.dim(0));	idx_m2dotm1(w->x, in->x, out->x);}void sp_linear_module::bprop(state_spidx *in, state_spidx *out){	Idx<double> twx = w->x.transpose(0, 1);	idx_m1extm1(w->dx, out->dx, in->x);	//idx_m2dotm1(twx, out->dx, in->dx);}void sp_linear_module::bbprop(state_spidx *in, state_spidx *out){	Idx<double> twx = w->x.transpose(0, 1);	idx_m1squextm1(out->ddx, in->x, w->ddx);	idx_m2squdotm1(twx, out->ddx, in->ddx);}void sp_linear_module::forget(forget_param_linear &fp){	int N = w->x.dim(1);	int bla[N];	for(int i = 0; i<N; i++) bla[i] = 0;	{ idx_bloop1(ind, *table, intg){		bla[ind.get(1)]++;	}}	double fanin = 0;	for(int i = 0; i<N; i++) if(bla[i]>0) fanin++;	//double fanin = w->x.dim(1);	double z = fp.value / pow(fanin, fp.exponent);	if(!drand_ini) printf("You have not initialized random sequence. Please call init_drand() before using this function !\n");	{ idx_bloop1(ind, *table, intg){		w->x.set(drand(z), ind.get(1), ind.get(0));	}}}void sp_linear_module::normalize(){	norm_columns(w->x);}//////////////////////////////////////////////////////////////////sp_logsoftmax_module::sp_logsoftmax_module(double b, Idx<ubyte> *classes){	intg imax = idx_max(*classes);	intg imin = idx_min(*classes);	if (imin < 0) ylerror("labels must be positive");	if (imax > 100000) printf("warning: [edist-cost] largest label is huuuge\n");	classindex2label = Idx<ubyte>(1 + imax);	{ idx_bloop1(v, classindex2label, ubyte) {		v->set(0);	}}	for (intg i = 0; i < classes->dim(0); ++i)		classindex2label.set(i, classes->get(i));	beta = b;}void sp_logsoftmax_module::fprop(state_spidx *in, state_spidx *out){	double sum1 = 0;	double *in_val = in->x.values()->idx_ptr();	const intg in_valmod = in->x.values()->mod(0);	for(int i = 0; i < in->x.nelements(); i++){		sum1 += exp(beta * (*in_val));		in_val += in_valmod;	}	sum1 = log(sum1);	idx_copy(in->x, out->x);	idx_dotc(out->x, beta, out->x);	idx_minus(out->x, out->x);	idx_addc(out->x, sum1, out->x);}void sp_logsoftmax_module::bprop(state_spidx *in, state_spidx *out){	/*	spIdx<double> wupdate1(0, out->x.dim(0), in->x.dim(0));	idx_m1extm1(wupdate1, out->x, in->x);	idx_minus(wupdate1, wupdate1);	spIdx<double> w2(0, out->x.dim(0), in->x.dim(0));	spIdx<double> wupdate2(0, out->x.dim(0), in->x.dim(0));	double sum = 0;	for(int i = 0; i < out->x.dim(0); i++){		spIdx<double> z(0,out->x.dim(0));		z.set(1, i);		idx_clear(w2);		idx_m1extm1(w2, z, in->x);		double bla = exp(beta * out->x.get(i));		idx_dotc(w2, beta * bla, w2);		idx_add(w2, wupdate2, wupdate2);		sum += bla;	}	idx_dotc(wupdate2, 1/sum, wupdate2);	idx_add(wupdate2, wupdate1, wupdate1);	spIdx<double> w = wupdate1.transpose(0,1);	idx_m2dotm1(w, out->dx, in->dx);	*/	idx_copy(out->dx, in->dx);	{idx_bloop1(i, *(in->dx.values()), double){		i.set(1/(-beta *(1 - exp(- i.get()))));	}}	idx_mul(out->dx, in->dx, in->dx);}void sp_logsoftmax_module::calc_energy(state_spidx *out, Idx<ubyte> *desired, state_idx *energy){	spIdx<double> proto(0, out->x.dim(0));	idx_clear(proto);	for(int i = 0; i < classindex2label.dim(0); i++){		if(classindex2label.get(i) == desired->get()){			proto.set(1, i);			break;		}	}	double en = idx_dot(out->x, proto);	energy->x.set(en);}void sp_logsoftmax_module::calc_max(state_spidx *out, class_state *output){	intg n = out->x.dim(0);	output->resize(n);	idx_copy(out->x, *(output->sorted_scores));	idx_minus(*(output->sorted_scores), *(output->sorted_scores));	idx_copy(classindex2label, *(output->sorted_classes));	idx_sortdown(*(output->sorted_scores), *(output->sorted_classes));	idx_minus(*(output->sorted_scores), *(output->sorted_scores));	output->output_class = output->sorted_classes->get(0);	output->confidence = output->sorted_scores->get(0);}}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -