📄 distiterativesolvertest.java
字号:
public DistIterativeSolver(int rank, Vector.Norm norm) { this.rank = rank; this.norm = norm; } protected abstract Matrix createMatrix(Communicator comm); protected abstract int[] getRowOwnerships(Matrix A); protected abstract int[] getColumnOwnerships(Matrix A); protected abstract void populateMatrix(Matrix A); protected abstract double getVectorEntry(int i); protected abstract IterativeSolver createSolver(Vector x); public void run() { Communicator comm = coll.createCommunicator(rank); Matrix A = createMatrix(comm); populateMatrix(A); DenseVector bl = new DenseVector(localLength[rank]); DistVector b_dist = new DistVector(x.size(), comm, bl); DistVector x_dist = b_dist.copy(); int[] n = getRowOwnerships(A); for (int i = n[rank]; i < n[rank + 1]; ++i) b_dist.set(i, getVectorEntry(i)); IterativeSolver solver = createSolver(b_dist); IterationMonitor monitor = new DefaultIterationMonitor(1000, 1e-50, 1e-12, 1e+5); monitor.setNormType(norm); solver.setIterationMonitor(monitor); try { solver.solve(A, b_dist, x_dist); } catch (IterativeSolverNotConvergedException e) { // This will just lead to an error later on } for (int i = n[rank]; i < n[rank + 1]; ++i) output[i] = x_dist.get(i); } } private abstract class RowDistIterativeSolver extends DistIterativeSolver { public RowDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected Matrix createMatrix(Communicator comm) { int n = x.size(); Matrix Al = new DenseMatrix(localLength[rank], localLength[rank]); Matrix Bl = new DenseMatrix(localLength[rank], n); return new DistRowMatrix(n, n, comm, Al, Bl); } @Override protected int[] getColumnOwnerships(Matrix A) { return ((DistRowMatrix) A).getColumnOwnerships(); } @Override protected int[] getRowOwnerships(Matrix A) { return ((DistRowMatrix) A).getRowOwnerships(); } } private abstract class ColumnDistIterativeSolver extends DistIterativeSolver { public ColumnDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected Matrix createMatrix(Communicator comm) { int n = x.size(); Matrix Al = new DenseMatrix(localLength[rank], localLength[rank]); Matrix Bl = new DenseMatrix(n, localLength[rank]); return new DistColMatrix(n, n, comm, Al, Bl); } @Override protected int[] getColumnOwnerships(Matrix A) { return ((DistColMatrix) A).getColumnOwnerships(); } @Override protected int[] getRowOwnerships(Matrix A) { return ((DistColMatrix) A).getRowOwnerships(); } } private abstract class SymmRowDistIterativeSolver extends RowDistIterativeSolver { public SymmRowDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected void populateMatrix(Matrix A) { int[] n = getRowOwnerships(A); for (int i = n[rank]; i < n[rank + 1]; ++i) for (int j = 0; j < A.numColumns(); ++j) A.set(i, j, A_symm.get(i, j)); } @Override protected double getVectorEntry(int i) { return b_symm.get(i); } } private abstract class UnSymmRowDistIterativeSolver extends RowDistIterativeSolver { public UnSymmRowDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected void populateMatrix(Matrix A) { int[] n = getRowOwnerships(A); for (int i = n[rank]; i < n[rank + 1]; ++i) for (int j = 0; j < A.numColumns(); ++j) A.set(i, j, A_unsymm.get(i, j)); } @Override protected double getVectorEntry(int i) { return b_unsymm.get(i); } } private abstract class SymmColumnDistIterativeSolver extends ColumnDistIterativeSolver { public SymmColumnDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected void populateMatrix(Matrix A) { int[] m = getColumnOwnerships(A); for (int i = 0; i < A.numRows(); ++i) for (int j = m[rank]; j < m[rank + 1]; ++j) A.set(i, j, A_symm.get(i, j)); } @Override protected double getVectorEntry(int i) { return b_symm.get(i); } } private abstract class UnSymmColumnDistIterativeSolver extends ColumnDistIterativeSolver { public UnSymmColumnDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected void populateMatrix(Matrix A) { int[] m = getColumnOwnerships(A); for (int i = 0; i < A.numRows(); ++i) for (int j = m[rank]; j < m[rank + 1]; ++j) A.set(i, j, A_unsymm.get(i, j)); } @Override protected double getVectorEntry(int i) { return b_unsymm.get(i); } } private class GMRESRowDistIterativeSolver extends UnSymmRowDistIterativeSolver { public GMRESRowDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected IterativeSolver createSolver(Vector x) { return new GMRES(x); } } private class BiCGstabRowDistIterativeSolver extends UnSymmRowDistIterativeSolver { public BiCGstabRowDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected IterativeSolver createSolver(Vector x) { return new BiCGstab(x); } } private class GMRESColumnDistIterativeSolver extends UnSymmColumnDistIterativeSolver { public GMRESColumnDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected IterativeSolver createSolver(Vector x) { return new GMRES(x); } } private class BiCGstabColumnDistIterativeSolver extends UnSymmColumnDistIterativeSolver { public BiCGstabColumnDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected IterativeSolver createSolver(Vector x) { return new BiCGstab(x); } } private class CGRowDistIterativeSolver extends SymmRowDistIterativeSolver { public CGRowDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected IterativeSolver createSolver(Vector x) { return new CG(x); } } private class CGColumnDistIterativeSolver extends SymmColumnDistIterativeSolver { public CGColumnDistIterativeSolver(int rank, Vector.Norm norm) { super(rank, norm); } @Override protected IterativeSolver createSolver(Vector x) { return new CG(x); } }}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -