spblas.hpp
来自「Gaussian Mixture Algorithm」· HPP 代码 · 共 616 行 · 第 1/2 页
HPP
616 行
/* * spBlas.hpp * * Author: cyril Poulet */#ifndef SPBLAS_HPP_#define SPBLAS_HPP_namespace ebl {#define spidx_checkdims(src0,src1) \ if((src0).order() != (src1).order()) {ylerror("not the same number of dimensions"); return;} \ for(int i = 0; i< (src0).order(); i++) \ if((src0).dim(i) != (src1).dim(i)) {ylerror("not the same size of dimensions"); return;}#define spidx_checkdims_rint(src0,src1) \ if((src0).order() != (src1).order()) {ylerror("not the same number of dimensions"); return 0;} \ for(int i = 0; i< (src0).order(); i++) \ if((src0).dim(i) != (src1).dim(i)) {ylerror("not the same size of dimensions"); return 0;}template<class T1, class T2> void idx_copy(spIdx<T1> &src, spIdx<T2> &dst){ spidx_checkdims(src, dst); dst.set_nelements(src.nelements()); dst.values()->resize(src.values()->dim(0)); idx_copy(*(src.values()), *(dst.values())); dst.index()->resize(src.index()->dim(0), src.index()->dim(1)); idx_copy(*(src.index()), *(dst.index()));}template<class T1, class T2> void idx_copy(spIdx<T1> &src, Idx<T2> &dst){ spidx_checkdims(src, dst); int N = src.order(); idx_fill(dst, (T2)BACKGROUND); for(intg i = 0; i < src.nelements(); i++){ switch(N){ case 1: dst.set((T2)src.values()->get(i), src.index()->get(i,0)); break; case 2: dst.set((T2)src.values()->get(i), src.index()->get(i, 0), src.index()->get(i, 1)); break; case 3 : dst.set((T2)src.values()->get(i), src.index()->get(i, 0), src.index()->get(i, 1), src.index()->get(i, 2)); break; default : dst.set((T2)src.values()->get(i), src.index()->get(i, 0), src.index()->get(i, 1), src.index()->get(i, 2), (N >=4 )? src.index()->get(i, 3) : -1, (N >=5 )? src.index()->get(i, 4) : -1, (N >=6 )? src.index()->get(i, 5) : -1, (N >=7 )? src.index()->get(i, 6) : -1, (N >=8 )? src.index()->get(i, 7) : -1); break; } }}template<class T1, class T2> void idx_copy(Idx<T1> &src, spIdx<T2> &dst){ spidx_checkdims(src, dst); int N = src.order(); dst.set_nelements(0); dst.set_values(new Idx<T2>(1)); dst.set_index(new Idx<intg>(1, N)); T1 *data = src.idx_ptr(); intg mod[8] = {(N >= 1)? src.mod(0) : 0, (N >= 2)? src.mod(1) : 0, (N >= 3)? src.mod(2) : 0, (N >= 4)? src.mod(3) : 0, (N >= 5)? src.mod(4) : 0, (N >= 6)? src.mod(5) : 0, (N >= 7)? src.mod(6) : 0, (N >= 8)? src.mod(7) : 0}; intg ind[8] = {(N >= 1)? 0 : -1, (N >= 2)? 0 : -1, (N >= 3)? 0 : -1, (N >= 4)? 0 : -1, (N >= 5)? 0 : -1, (N >= 6)? 0 : -1, (N >= 7)? 0 : -1, (N >= 8)? 0 : -1}; intg ind2[8]; for(int i = 0; i<8; i++) ind2[i] = ind[i]; for(intg i = 0; i< src.nelements(); i++){ if(data[i] != BACKGROUND) dst.set((T2)data[i], ind[0], ind[1],ind[2],ind[3],ind[4],ind[5],ind[6],ind[7]); ind[7]++; for(int j = 7; j > 0; j--) if(ind[j] > ((N >= j+1)? (mod[j-1] - 1) : -1)) { ind[j] = ind2[j]; ind[j-1]++;} }}////////////////////////////////////////////////////////////////template<class T> void idx_clear(spIdx<T> &inp){ inp.set_nelements(0); inp.set_values(new Idx<T>(1)); inp.set_index(new Idx<intg>(1, inp.order()));}//! negate all elementstemplate<class T> void idx_minus(spIdx<T> &inp, spIdx<T> &out){ spidx_checkdims(inp, out); out.set_nelements(inp.nelements()); out.values()->resize(inp.values()->dim(0)); idx_minus(*(inp.values()), *(out.values())); out.index()->resize(inp.index()->dim(0), inp.index()->dim(1)); idx_copy(*(inp.index()), *(out.index())); out.clean();}//! invert all elementstemplate<class T> void idx_inv(spIdx<T> &inp, spIdx<T> &out){ spidx_checkdims(inp, out); out.set_nelements(inp.nelements()); out.values()->resize(inp.values()->dim(0)); idx_inv(*(inp.values()), *(out.values())); out.index()->resize(inp.index()->dim(0), inp.index()->dim(1)); idx_copy(*(inp.index()), *(out.index())); out.clean();}//! add two spIdxtemplate<class T> void idx_add(spIdx<T> &i1, spIdx<T> &i2, spIdx<T> &out){ spidx_checkdims(i1, i2); spidx_checkdims(i1, out); spIdx<T> out2(0, out.order(), out.dims()); T *i1ptr = i1.values()->idx_ptr(), *i2ptr = i2.values()->idx_ptr(); intg i1mod = i1.values()->mod(0), i2mod = i2.values()->mod(0); int bla[i2.nelements()]; for(intg i = 0; i<i2.nelements(); i++) bla[i] = 0; for(intg i = 0; i<i1.nelements(); i++){ intg s0, s1, s2, s3, s4, s5, s6, s7; i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7); if( pos != -1){ out2.set(*(i1ptr + i * i1mod) + *(i2ptr + pos * i2mod), s0, s1, s2, s3, s4, s5, s6, s7); bla[pos] = 1; } else { out2.set(*(i1ptr + i * i1mod), s0, s1, s2, s3, s4, s5, s6, s7); } } for(intg i = 0; i<i2.nelements(); i++){ if(bla[i] == 0){ intg s0, s1, s2, s3, s4, s5, s6, s7; i2.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); out2.set(*(i2ptr + i * i2mod), s0, s1, s2, s3, s4, s5, s6, s7); } } idx_copy(out2, out);}//! subtract two spIdxtemplate<class T> void idx_sub(spIdx<T> &i1, spIdx<T> &i2, spIdx<T> &out){ spidx_checkdims(i1, i2); spidx_checkdims(i1, out); spIdx<T> out2(0, out.order(), out.dims()); T *i1ptr = i1.values()->idx_ptr(), *i2ptr = i2.values()->idx_ptr(); intg i1mod = i1.values()->mod(0), i2mod = i2.values()->mod(0); int bla[i2.nelements()]; for(intg i = 0; i<i2.nelements(); i++) bla[i] = 0; for(intg i = 0; i<i1.nelements(); i++){ intg s0, s1, s2, s3, s4, s5, s6, s7; i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7); if( pos != -1){ out2.set(*(i1ptr + i * i1mod)-*(i2ptr + pos * i2mod), s0, s1, s2, s3, s4, s5, s6, s7); bla[pos] = 1; } else { out2.set(*(i1ptr + i * i1mod), s0, s1, s2, s3, s4, s5, s6, s7); } } for(intg i = 0; i<i2.nelements(); i++){ if(bla[i] == 0){ intg s0, s1, s2, s3, s4, s5, s6, s7; i2.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); out2.set(- *(i2ptr + i * i2mod), s0, s1, s2, s3, s4, s5, s6, s7); } } idx_copy(out2, out);}//! multiply two spIdxtemplate<class T> void idx_mul(spIdx<T> &i1, spIdx<T> &i2, spIdx<T> &out){ spidx_checkdims(i1, i2); spidx_checkdims(i1, out); spIdx<T> out2(0, out.order(), out.dims()); T *i1ptr = i1.values()->idx_ptr(), *i2ptr = i2.values()->idx_ptr(); intg i1mod = i1.values()->mod(0), i2mod = i2.values()->mod(0); for(intg i = 0; i<i1.nelements(); i++){ intg s0, s1, s2, s3, s4, s5, s6, s7; i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7); if( pos != -1) out2.set(*(i1ptr + i * i1mod)* *(i2ptr + pos * i2mod), s0, s1, s2, s3, s4, s5, s6, s7); } idx_copy(out2, out);}//! add a constant to each element: o1 <- i1+c;template<class T> void idx_addc(spIdx<T> &inp, T c, spIdx<T> &out){ idx_copy(inp, out); idx_addc(*(inp.values()), c, *(out.values())); out.clean();}//! add a constant to each element and accumulate//! result: o1 <- o1 + i1+c;template<class T> void idx_addcacc(spIdx<T> &inp, T c, spIdx<T> &out){ spidx_checkdims(inp, out); for(intg i = 0; i<inp.nelements(); i++){ intg s0, s1, s2, s3, s4, s5, s6, s7; inp.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); intg pos = out.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7); if( pos != -1){ out.values()->set(out.values()->get(pos) + c + inp.values()->get(i), pos); } else { out.set(c + inp.values()->get(i), s0, s1, s2, s3, s4, s5, s6, s7); } }}//! multiply all elements by a constant: o1 <- i1*c;template<class T> void idx_dotc(spIdx<T> &inp, T c, spIdx<T> &out){ idx_copy(inp, out); idx_dotc(*(inp.values()), c, *(out.values())); out.clean();}//! multiply all elements by a constant and accumulate//! result: o1 <- o1 + i1*c;template<class T> void idx_dotcacc(spIdx<T> &inp, T c, spIdx<T> &out){ spidx_checkdims(inp, out); for(intg i = 0; i<inp.nelements(); i++){ intg s0, s1, s2, s3, s4, s5, s6, s7; inp.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); intg pos = out.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7); if( pos != -1){ out.values()->set(out.values()->get(pos) + c*inp.values()->get(i), pos); } else { out.set(c * inp.values()->get(i), s0, s1, s2, s3, s4, s5, s6, s7); } }}//! square of difference of each term: out <- (i1-i2)^2template<class T> void idx_subsquare(spIdx<T> &i1, spIdx<T> &i2, spIdx<T> &out){ spidx_checkdims(i1, i2); spidx_checkdims(i1, out); spIdx<T> out2(0, out.order(), out.dims()); int bla[i2.nelements()]; for(intg i = 0; i<i2.nelements(); i++) bla[i] = 0; for(intg i = 0; i<i1.nelements(); i++){ intg s0, s1, s2, s3, s4, s5, s6, s7; i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7); if( pos != -1){ T blabla= i1.values()->get(i)-i2.values()->get(pos); out2.set(blabla*blabla, s0, s1, s2, s3, s4, s5, s6, s7); bla[pos] = 1; } else { T blabla = i1.values()->get(i); out2.set(blabla*blabla, s0, s1, s2, s3, s4, s5, s6, s7); } } for(intg i = 0; i<i2.nelements(); i++){ if(bla[i] == 0){ intg s0, s1, s2, s3, s4, s5, s6, s7; i2.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); T blabla = i2.values()->get(i); out2.set(blabla*blabla, s0, s1, s2, s3, s4, s5, s6, s7); } } idx_copy(out2, out);}//! compute linear combination of two Idxtemplate<class T> void idx_lincomb(spIdx<T> &i1, T k1, spIdx<T> &i2, T k2, spIdx<T> &out){ spidx_checkdims(i1, i2); spidx_checkdims(i1, out); spIdx<T> out2(0, out.order(), out.dims()); int bla[i2.nelements()]; for(intg i = 0; i<i2.nelements(); i++) bla[i] = 0; if(k1 != BACKGROUND){ for(intg i = 0; i<i1.nelements(); i++){ intg s0, s1, s2, s3, s4, s5, s6, s7; i1.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); intg pos = i2.pos_to_index(s0, s1, s2, s3, s4, s5, s6, s7); if( pos != -1){ out2.set(k1 * i1.values()->get(i) + k2 * i2.values()->get(pos), s0, s1, s2, s3, s4, s5, s6, s7); bla[pos] = 1; } else { out2.set(k1 * i1.values()->get(i), s0, s1, s2, s3, s4, s5, s6, s7); } } } if(k2 != BACKGROUND){ for(intg i = 0; i<i2.nelements(); i++){ if(bla[i] == 0){ intg s0, s1, s2, s3, s4, s5, s6, s7; i2.index_to_pos(i, s0, s1, s2, s3, s4, s5, s6, s7); out2.set(k2 * i2.values()->get(i), s0, s1, s2, s3, s4, s5, s6, s7); } } }
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?