📄 distrect.cc
字号:
delete_vecform(); a->delete_vecform();}voidDistSCMatrix::create_vecform(Form f, int nvectors){ // determine with rows/cols go on this node form = f; int nproc = messagegrp()->n(); int me = messagegrp()->me(); int n1=0, n2=0; if (form == Row) { n1 = nrow(); n2 = ncol(); } if (form == Col) { n1 = ncol(); n2 = nrow(); } if (nvectors == -1) { nvec = n1/nproc; vecoff = nvec*me; int nremain = n1%nproc; if (me < nremain) { vecoff += me; nvec++; } else { vecoff += nremain; } } else { nvec = nvectors; vecoff = 0; } // allocate storage vec = new double*[nvec]; vec[0] = new double[nvec*n2]; int i; for (i=1; i<nvec; i++) { vec[i] = &vec[0][i*n2]; }}voidDistSCMatrix::vecform_zero(){ int n=0; if (form == Row) { n = nvec * ncol(); } if (form == Col) { n = nvec * nrow(); } double *dat = vec[0]; for (int i=0; i<n; i++) dat[i] = 0.0;}voidDistSCMatrix::delete_vecform(){ delete[] vec[0]; delete[] vec; vec = 0; nvec = 0;}voidDistSCMatrix::vecform_op(VecOp op, int *ivec){ Ref<SCMatrixSubblockIter> i; if (op == CopyToVec || op == AccumToVec) { i = all_blocks(SCMatrixSubblockIter::Read); } else { if (op == CopyFromVec) assign(0.0); i = all_blocks(SCMatrixSubblockIter::Accum); } for (i->begin(); i->ready(); i->next()) { Ref<SCMatrixRectBlock> b = dynamic_cast<SCMatrixRectBlock*>(i->block()); if (DEBUG) ExEnv::outn() << messagegrp()->me() << ": " << "got block " << b->blocki << ' ' << b->blockj << endl; int b1start, b2start, b1end, b2end; if (form == Row) { b1start = b->istart; b2start = b->jstart; b1end = b->iend; b2end = b->jend; } else { b1start = b->jstart; b2start = b->istart; b1end = b->jend; b2end = b->iend; } int nbj = b->jend - b->jstart; int start, end; if (ivec) { start = b1start; end = b1end; } else { start = b1start > vecoff ? b1start : vecoff; end = b1end > vecoff+nvec ? vecoff+nvec : b1end; } double *dat = b->data; int off = -b1start; for (int j=start; j<end; j++) { double *vecj; if (ivec) { vecj = 0; for (int ii=0; ii<nvec; ii++) { if (ivec[ii] == j) { vecj = vec[ii]; break; } } if (!vecj) continue; if (DEBUG) ExEnv::outn() << messagegrp()->me() << ": getting [" << j << "," << b2start << "-" << b2end << ")" << endl; } else { vecj = vec[j-vecoff]; } for (int k=b2start; k<b2end; k++) { int blockoffset; if (DEBUG) ExEnv::outn() << messagegrp()->me() << ": " << "using vec[" << j-vecoff << "]" << "[" << k << "]" << endl; if (form == Row) { blockoffset = (j+off)*nbj+k - b2start; if (DEBUG) ExEnv::outn() << messagegrp()->me() << ": " << "Row datum offset is " << "(" << j << "+" << off << ")*" << nbj << "+" << k << "-" << b2start << " = " << blockoffset << "(" << b->ndat() << ") " << " -> " << dat[blockoffset] << endl; } else { blockoffset = (k-b2start)*nbj+j+off; } if (blockoffset >= b->ndat()) { fail("bad offset"); } double *datum = &dat[blockoffset]; if (op == CopyToVec) { if (DEBUG) ExEnv::outn() << messagegrp()->me() << ": " << "copying " << *datum << " " << "to " << j << " " << k << endl; vecj[k] = *datum; } else if (op == CopyFromVec) { *datum = vecj[k]; } else if (op == AccumToVec) { vecj[k] += *datum; } else if (op == AccumFromVec) { *datum += vecj[k]; } } } }}// does the outer product a x b. this must have rowdim() == a->dim() and// coldim() == b->dim()voidDistSCMatrix::accumulate_outer_product(SCVector*a,SCVector*b){ const char* name = "DistSCMatrix::accumulate_outer_product"; // make sure that the arguments are of the correct type DistSCVector* la = require_dynamic_cast<DistSCVector*>(a,name); DistSCVector* lb = require_dynamic_cast<DistSCVector*>(b,name); // make sure that the dimensions match if (!rowdim()->equiv(la->dim()) || !coldim()->equiv(lb->dim())) { ExEnv::errn() << indent << "DistSCMatrix::accumulate_outer_product(SCVector*a,SCVector*b): " << "dimensions don't match\n"; abort(); } Ref<SCMatrixSubblockIter> I = a->all_blocks(SCMatrixSubblockIter::Read); Ref<SCMatrixSubblockIter> J = b->all_blocks(SCMatrixSubblockIter::Read); for (I->begin(); I->ready(); I->next()) { Ref<SCVectorSimpleBlock> vi = dynamic_cast<SCVectorSimpleBlock*>(I->block()); int ni = vi->iend - vi->istart; for (J->begin(); J->ready(); J->next()) { Ref<SCVectorSimpleBlock> vj = dynamic_cast<SCVectorSimpleBlock*>(J->block()); Ref<SCMatrixRectBlock> rij = dynamic_cast<SCMatrixRectBlock*>(block_to_block(vi->blocki, vj->blocki).pointer()); // if the block is held locally sum in the outer prod contrib if (rij.nonnull()) { int nj = vj->iend - vj->istart; double *dat = rij->data; for (int i=0; i<ni; i++) { for (int j=0; j<nj; j++) { *dat += vi->data[i]*vj->data[j]; } } } } }}voidDistSCMatrix::transpose_this(){ RefSCDimension tmp = d1; d1 = d2; d2 = tmp; Ref<SCMatrixBlockList> oldlist = blocklist; init_blocklist(); assign(0.0); Ref<SCMatrixSubblockIter> I = new DistSCMatrixListSubblockIter(SCMatrixSubblockIter::Read, oldlist, messagegrp()); for (I->begin(); I->ready(); I->next()) { Ref<SCMatrixRectBlock> remote = dynamic_cast<SCMatrixRectBlock*>(I->block()); Ref<SCMatrixRectBlock> local = dynamic_cast<SCMatrixRectBlock*>(block_to_block(remote->blockj, remote->blocki).pointer()); if (local.nonnull()) { int ni = local->iend - local->istart; int nj = local->jend - local->jstart; for (int i=0; i<ni; i++) { for (int j=0; j<nj; j++) { local->data[i*nj+j] = remote->data[j*ni+i]; } } } } }doubleDistSCMatrix::invert_this(){ if (nrow() != ncol()) { ExEnv::errn() << indent << "DistSCMatrix::invert_this: matrix is not square\n"; abort(); } RefSymmSCMatrix refs = kit()->symmmatrix(d1); refs->assign(0.0); refs->accumulate_symmetric_product(this); double determ2 = refs->invert_this(); transpose_this(); RefSCMatrix reft = copy(); assign(0.0); ((SCMatrix*)this)->accumulate_product(reft.pointer(), refs.pointer()); return sqrt(fabs(determ2));}voidDistSCMatrix::gen_invert_this(){ invert_this();}doubleDistSCMatrix::determ_this(){ if (nrow() != ncol()) { ExEnv::errn() << indent << "DistSCMatrix::determ_this: matrix is not square\n"; abort(); } return invert_this();}doubleDistSCMatrix::trace(){ if (nrow() != ncol()) { ExEnv::errn() << indent << "DistSCMatrix::trace: matrix is not square\n"; abort(); } double ret=0.0; Ref<SCMatrixSubblockIter> I = local_blocks(SCMatrixSubblockIter::Read); for (I->begin(); I->ready(); I->next()) { Ref<SCMatrixRectBlock> b = dynamic_cast<SCMatrixRectBlock*>(I->block()); if (b->blocki == b->blockj) { int ni = b->iend-b->istart; for (int i=0; i<ni; i++) { ret += b->data[i*ni+i]; } } } messagegrp()->sum(ret); return ret;}doubleDistSCMatrix::solve_this(SCVector*v){ error("no solve_this"); // make sure that the dimensions match if (!rowdim()->equiv(v->dim())) { ExEnv::errn() << indent << "DistSCMatrix::solve_this(SCVector*v): " << "dimensions don't match\n"; abort(); } return 0.0;}voidDistSCMatrix::schmidt_orthog(SymmSCMatrix *S, int nc){ error("no schmidt_orthog");}intDistSCMatrix::schmidt_orthog_tol(SymmSCMatrix *S, double tol, double *res){ error("no schmidt_orthog_tol"); return 0;}voidDistSCMatrix::element_op(const Ref<SCElementOp>& op){ SCMatrixBlockListIter i; for (i = blocklist->begin(); i != blocklist->end(); i++) {// ExEnv::outn() << "rect elemop processing a block of type "// << i.block()->class_name() << endl; op->process_base(i.block()); } if (op->has_collect()) op->collect(messagegrp());}voidDistSCMatrix::element_op(const Ref<SCElementOp2>& op, SCMatrix* m){ DistSCMatrix *lm = require_dynamic_cast<DistSCMatrix*>(m,"DistSCMatrix::element_op"); if (!rowdim()->equiv(lm->rowdim()) || !coldim()->equiv(lm->coldim())) { ExEnv::errn() << indent << "DistSCMatrix: bad element_op\n"; abort(); } SCMatrixBlockListIter i, j; for (i = blocklist->begin(), j = lm->blocklist->begin(); i != blocklist->end(); i++, j++) { op->process_base(i.block(), j.block()); } if (op->has_collect()) op->collect(messagegrp());}voidDistSCMatrix::element_op(const Ref<SCElementOp3>& op, SCMatrix* m,SCMatrix* n){ DistSCMatrix *lm = require_dynamic_cast<DistSCMatrix*>(m,"DistSCMatrix::element_op"); DistSCMatrix *ln = require_dynamic_cast<DistSCMatrix*>(n,"DistSCMatrix::element_op"); if (!rowdim()->equiv(lm->rowdim()) || !coldim()->equiv(lm->coldim()) || !rowdim()->equiv(ln->rowdim()) || !coldim()->equiv(ln->coldim())) { ExEnv::errn() << indent << "DistSCMatrix: bad element_op\n"; abort(); } SCMatrixBlockListIter i, j, k; for (i = blocklist->begin(), j = lm->blocklist->begin(), k = ln->blocklist->begin(); i != blocklist->end(); i++, j++, k++) { op->process_base(i.block(), j.block(), k.block()); } if (op->has_collect()) op->collect(messagegrp());}voidDistSCMatrix::vprint(const char *title, ostream& os, int prec) const{ // cast so the non const vprint member can be called ((DistSCMatrix*)this)->vprint(title,os,prec);}voidDistSCMatrix::vprint(const char *title, ostream& os, int prec){ int i,j; int lwidth; double max=this->maxabs(); int me = messagegrp()->me(); max = (max==0.0) ? 1.0 : log10(max); if (max < 0.0) max=1.0; lwidth = prec+5+(int) max; os.setf(ios::fixed,ios::floatfield); os.precision(prec); os.setf(ios::right,ios::adjustfield); if (messagegrp()->me() == 0) { if (title) os << endl << indent << title << endl; else os << endl; } if (nrow()==0 || ncol()==0) { if (me == 0) os << indent << "empty matrix\n"; return; } create_vecform(Row); vecform_op(CopyToVec); int nc = ncol(); int tmp = 0; if (me != 0) { messagegrp()->recv(me-1, tmp); } else { os << indent; for (i=0; i<nc; i++) os << setw(lwidth) << i; os << endl; } for (i=0; i<nvec; i++) { os << indent << setw(5) << i+vecoff; for (j=0; j<nc; j++) os << setw(lwidth) << vec[i][j]; os << endl; } if (messagegrp()->n() > 1) { // send the go ahead to the next node int dest = me+1; if (dest == messagegrp()->n()) dest = 0; messagegrp()->send(dest, tmp); // make node zero wait on the last node if (me == 0) messagegrp()->recv(messagegrp()->n()-1, tmp); } delete_vecform();}Ref<SCMatrixSubblockIter>DistSCMatrix::local_blocks(SCMatrixSubblockIter::Access access){ return new SCMatrixListSubblockIter(access, blocklist);}Ref<SCMatrixSubblockIter>DistSCMatrix::all_blocks(SCMatrixSubblockIter::Access access){ return new DistSCMatrixListSubblockIter(access, blocklist, messagegrp());}voidDistSCMatrix::error(const char *msg){ ExEnv::errn() << "DistSCMatrix: error: " << msg << endl;}Ref<DistSCMatrixKit>DistSCMatrix::skit(){ return dynamic_cast<DistSCMatrixKit*>(kit().pointer());}/////////////////////////////////////////////////////////////////////////////// Local Variables:// mode: c++// c-file-style: "CLJ"// End:
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -