spblas.hpp

来自「Gaussian Mixture Algorithm」· HPP 代码 · 共 616 行 · 第 1/2 页

HPP
616
字号
/* * spBlas.hpp * *      Author: cyril Poulet */#ifndef SPBLAS_HPP_#define SPBLAS_HPP_namespace ebl {#define spidx_checkdims(src0,src1) \	if((src0).order() != (src1).order()) {ylerror("not the same number of dimensions"); return;} \	for(int i = 0; i< (src0).order(); i++) \		if((src0).dim(i) != (src1).dim(i)) {ylerror("not the same size of dimensions"); return;}#define spidx_checkdims_rint(src0,src1) \	if((src0).order() != (src1).order()) {ylerror("not the same number of dimensions"); return 0;} \	for(int i = 0; i< (src0).order(); i++) \		if((src0).dim(i) != (src1).dim(i)) {ylerror("not the same size of dimensions"); return 0;}template<class T1, class T2> void idx_copy(spIdx<T1> &src, spIdx<T2> &dst){	spidx_checkdims(src, dst);	dst.set_nelements(src.nelements());	dst.values()->resize(src.values()->dim(0));	idx_copy(*(src.values()), *(dst.values()));	dst.index()->resize(src.index()->dim(0), src.index()->dim(1));	idx_copy(*(src.index()), *(dst.index()));}template<class T1, class T2> void idx_copy(spIdx<T1> &src, Idx<T2> &dst){	spidx_checkdims(src, dst);	int N = src.order();	idx_fill(dst, (T2)BACKGROUND);	for(intg i = 0; i < src.nelements(); i++){		switch(N){			case 1:				dst.set((T2)src.values()->get(i), src.index()->get(i,0));			break;			case 2:				dst.set((T2)src.values()->get(i), src.index()->get(i, 0), src.index()->get(i, 1));			break;			case 3 :				dst.set((T2)src.values()->get(i), src.index()->get(i, 0), src.index()->get(i, 1), src.index()->get(i, 2));			break;			default :				dst.set((T2)src.values()->get(i),						src.index()->get(i, 0),						src.index()->get(i, 1),						src.index()->get(i, 2),						(N >=4 )? src.index()->get(i, 3) : -1,								(N >=5 )? src.index()->get(i, 4) : -1,										(N >=6 )? src.index()->get(i, 5) : -1,												(N >=7 )? src.index()->get(i, 6) : -1,														(N >=8 )? src.index()->get(i, 7) : -1);				break;		}	}}template<class T1, class T2> void idx_copy(Idx<T1> &src, spIdx<T2> &dst){	spidx_checkdims(src, dst);	int N = src.order();	dst.set_nelements(0);	dst.set_values(new Idx<T2>(1));	dst.set_index(new Idx<intg>(1, N));	T1 *data = src.idx_ptr();	intg mod[8] = {(N >= 1)? src.mod(0) : 0,			(N >= 2)? src.mod(1) : 0,					(N >= 3)? src.mod(2) : 0,							(N >= 4)? src.mod(3) : 0,									(N >= 5)? src.mod(4) : 0,											(N >= 6)? src.mod(5) : 0,													(N >= 7)? src.mod(6) : 0,															(N >= 8)? src.mod(7) : 0};	intg ind[8] = {(N >= 1)? 0 : -1, (N >= 2)? 0 : -1, (N >= 3)? 0 : -1, (N >= 4)? 0 : -1,			(N >= 5)? 0 : -1, (N >= 6)? 0 : -1, (N >= 7)? 0 : -1, (N >= 8)? 0 : -1};	intg ind2[8];	for(int i = 0; i<8; i++) ind2[i] = ind[i];	for(intg i = 0; i< src.nelements(); i++){		if(data[i] != BACKGROUND) dst.set((T2)data[i], ind[0], ind[1],ind[2],ind[3],ind[4],ind[5],ind[6],ind[7]);		ind[7]++;		for(int j = 7; j > 0; j--)			if(ind[j] > ((N >= j+1)? (mod[j-1] - 1) : -1)) { ind[j] = ind2[j]; ind[j-1]++;}	}}////////////////////////////////////////////////////////////////template<class T> void idx_clear(spIdx<T> &inp){	inp.set_nelements(0);	inp.set_values(new Idx<T>(1));	inp.set_index(new Idx<intg>(1, inp.order()));}//! negate all elementstemplate<class T> void idx_minus(spIdx<T> &inp, spIdx<T> &out){	spidx_checkdims(inp, out);	out.set_nelements(inp.nelements());	out.values()->resize(inp.values()->dim(0));	idx_minus(*(inp.values()), *(out.values()));	out.index()->resize(inp.index()->dim(0), inp.index()->dim(1));	idx_copy(*(inp.index()), *(out.index()));	out.clean();}//! invert all elementstemplate<class T> void idx_inv(spIdx<T> &inp, spIdx<T> &out){	spidx_checkdims(inp, out);	out.set_nelements(inp.nelements());	out.values()->resize(inp.values()->dim(0));	idx_inv(*(inp.values()), *(out.values()));	out.index()->resize(inp.index()->dim(0), inp.index()->dim(1));	idx_copy(*(inp.index()), *(out.index()));	out.clean();}//! add two spIdxtemplate<class T> void idx_add(spIdx<T> &i1, spIdx<T> &i2, spIdx<T> &out){	spidx_checkdims(i1, i2);	spidx_checkdims(i1, out);	spIdx<T> out2(0, out.order(), out.dims());	T *i1ptr = i1.values()->idx_ptr(), *i2ptr =  i2.values()->idx_ptr();	intg i1mod = i1.values()->mod(0), i2mod = i2.values()->mod(0);	int bla[i2.nelements()];	for(intg i = 0; i<i2.nelements(); i++) bla[i] = 0;	for(intg i = 0; i<i1.nelements(); i++){		intg s0, s1, s2, s3, s4, s5, s6, s7;		i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);		intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7);		if( pos != -1){			out2.set(*(i1ptr + i * i1mod) + *(i2ptr + pos * i2mod), s0, s1, s2, s3, s4, s5, s6, s7);			bla[pos] = 1;		} else {			out2.set(*(i1ptr + i * i1mod), s0, s1, s2, s3, s4, s5, s6, s7);		}	}	for(intg i = 0; i<i2.nelements(); i++){		if(bla[i] == 0){			intg s0, s1, s2, s3, s4, s5, s6, s7;			i2.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);			out2.set(*(i2ptr + i * i2mod), s0, s1, s2, s3, s4, s5, s6, s7);		}	}	idx_copy(out2, out);}//! subtract two spIdxtemplate<class T> void idx_sub(spIdx<T> &i1, spIdx<T> &i2, spIdx<T> &out){	spidx_checkdims(i1, i2);	spidx_checkdims(i1, out);	spIdx<T> out2(0, out.order(), out.dims());	T *i1ptr = i1.values()->idx_ptr(), *i2ptr =  i2.values()->idx_ptr();	intg i1mod = i1.values()->mod(0), i2mod = i2.values()->mod(0);	int bla[i2.nelements()];	for(intg i = 0; i<i2.nelements(); i++) bla[i] = 0;	for(intg i = 0; i<i1.nelements(); i++){		intg s0, s1, s2, s3, s4, s5, s6, s7;		i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);		intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7);		if( pos != -1){			out2.set(*(i1ptr + i * i1mod)-*(i2ptr + pos * i2mod), s0, s1, s2, s3, s4, s5, s6, s7);			bla[pos] = 1;		} else {			out2.set(*(i1ptr + i * i1mod), s0, s1, s2, s3, s4, s5, s6, s7);		}	}	for(intg i = 0; i<i2.nelements(); i++){		if(bla[i] == 0){			intg s0, s1, s2, s3, s4, s5, s6, s7;			i2.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);			out2.set(- *(i2ptr + i * i2mod), s0, s1, s2, s3, s4, s5, s6, s7);		}	}	idx_copy(out2, out);}//! multiply two spIdxtemplate<class T> void idx_mul(spIdx<T> &i1, spIdx<T> &i2, spIdx<T> &out){	spidx_checkdims(i1, i2);	spidx_checkdims(i1, out);	spIdx<T> out2(0, out.order(), out.dims());	T *i1ptr = i1.values()->idx_ptr(), *i2ptr =  i2.values()->idx_ptr();	intg i1mod = i1.values()->mod(0), i2mod = i2.values()->mod(0);	for(intg i = 0; i<i1.nelements(); i++){		intg s0, s1, s2, s3, s4, s5, s6, s7;		i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);		intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7);		if( pos != -1) out2.set(*(i1ptr + i * i1mod)* *(i2ptr + pos * i2mod), s0, s1, s2, s3, s4, s5, s6, s7);	}	idx_copy(out2, out);}//! add a constant to each element:  o1 <- i1+c;template<class T> void idx_addc(spIdx<T> &inp, T c, spIdx<T> &out){	idx_copy(inp, out);	idx_addc(*(inp.values()), c, *(out.values()));	out.clean();}//! add a constant to each element and accumulate//! result: o1 <- o1 + i1+c;template<class T> void idx_addcacc(spIdx<T> &inp, T c, spIdx<T> &out){	spidx_checkdims(inp, out);	for(intg i = 0; i<inp.nelements(); i++){		intg s0, s1, s2, s3, s4, s5, s6, s7;		inp.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);		intg pos = out.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7);		if( pos != -1){			out.values()->set(out.values()->get(pos) + c + inp.values()->get(i), pos);		} else {			out.set(c + inp.values()->get(i), s0, s1, s2, s3, s4, s5, s6, s7);		}	}}//! multiply all elements by a constant:  o1 <- i1*c;template<class T> void idx_dotc(spIdx<T> &inp, T c, spIdx<T> &out){	idx_copy(inp, out);	idx_dotc(*(inp.values()), c, *(out.values()));	out.clean();}//! multiply all elements by a constant and accumulate//! result: o1 <- o1 + i1*c;template<class T> void idx_dotcacc(spIdx<T> &inp, T c, spIdx<T> &out){	spidx_checkdims(inp, out);	for(intg i = 0; i<inp.nelements(); i++){		intg s0, s1, s2, s3, s4, s5, s6, s7;		inp.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);		intg pos = out.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7);		if( pos != -1){			out.values()->set(out.values()->get(pos) + c*inp.values()->get(i), pos);		} else {			out.set(c * inp.values()->get(i), s0, s1, s2, s3, s4, s5, s6, s7);		}	}}//! square of difference of each term:  out <- (i1-i2)^2template<class T> void idx_subsquare(spIdx<T> &i1, spIdx<T> &i2, spIdx<T> &out){	spidx_checkdims(i1, i2);	spidx_checkdims(i1, out);	spIdx<T> out2(0, out.order(), out.dims());	int bla[i2.nelements()];	for(intg i = 0; i<i2.nelements(); i++) bla[i] = 0;	for(intg i = 0; i<i1.nelements(); i++){		intg s0, s1, s2, s3, s4, s5, s6, s7;		i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);		intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7);		if( pos != -1){			T blabla= i1.values()->get(i)-i2.values()->get(pos);			out2.set(blabla*blabla, s0, s1, s2, s3, s4, s5, s6, s7);			bla[pos] = 1;		} else {			T blabla = i1.values()->get(i);			out2.set(blabla*blabla, s0, s1, s2, s3, s4, s5, s6, s7);		}	}	for(intg i = 0; i<i2.nelements(); i++){		if(bla[i] == 0){			intg s0, s1, s2, s3, s4, s5, s6, s7;			i2.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);			T blabla = i2.values()->get(i);			out2.set(blabla*blabla, s0, s1, s2, s3, s4, s5, s6, s7);		}	}	idx_copy(out2, out);}//! compute linear combination of two Idxtemplate<class T> void idx_lincomb(spIdx<T> &i1, T k1, spIdx<T> &i2, T k2, spIdx<T> &out){	spidx_checkdims(i1, i2);	spidx_checkdims(i1, out);	spIdx<T> out2(0, out.order(), out.dims());	int bla[i2.nelements()];	for(intg i = 0; i<i2.nelements(); i++) bla[i] = 0;	if(k1 != BACKGROUND){		for(intg i = 0; i<i1.nelements(); i++){			intg s0, s1, s2, s3, s4, s5, s6, s7;			i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);			intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7);			if( pos != -1){				out2.set(k1 * i1.values()->get(i) + k2 * i2.values()->get(pos), s0, s1, s2, s3, s4, s5, s6, s7);				bla[pos] = 1;			} else {				out2.set(k1 * i1.values()->get(i), s0, s1, s2, s3, s4, s5, s6, s7);			}		}	}	if(k2 != BACKGROUND){		for(intg i = 0; i<i2.nelements(); i++){			if(bla[i] == 0){				intg s0, s1, s2, s3, s4, s5, s6, s7;				i2.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7);				out2.set(k2 * i2.values()->get(i), s0, s1, s2, s3, s4, s5, s6, s7);			}		}	}

⌨️ 快捷键说明

复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?