📄 distiterativesolvertest.java
字号:
package no.uib.cipr.matrix.distributed;import java.util.Arrays;import java.util.concurrent.Executor;import java.util.concurrent.ExecutorService;import java.util.concurrent.Executors;import java.util.concurrent.TimeUnit;import junit.framework.TestCase;import no.uib.cipr.matrix.DenseMatrix;import no.uib.cipr.matrix.DenseVector;import no.uib.cipr.matrix.Matrix;import no.uib.cipr.matrix.UpperSymmDenseMatrix;import no.uib.cipr.matrix.Vector;import no.uib.cipr.matrix.distributed.CollectiveCommunications;import no.uib.cipr.matrix.distributed.Communicator;import no.uib.cipr.matrix.distributed.DistColMatrix;import no.uib.cipr.matrix.distributed.DistRowMatrix;import no.uib.cipr.matrix.distributed.DistVector;import no.uib.cipr.matrix.sparse.BiCGstab;import no.uib.cipr.matrix.sparse.CG;import no.uib.cipr.matrix.sparse.DefaultIterationMonitor;import no.uib.cipr.matrix.sparse.GMRES;import no.uib.cipr.matrix.sparse.IterationMonitor;import no.uib.cipr.matrix.sparse.IterativeSolver;import no.uib.cipr.matrix.sparse.IterativeSolverNotConvergedException;import no.uib.cipr.matrix.Utilities;public class DistIterativeSolverTest extends TestCase { volatile CollectiveCommunications coll; volatile DenseMatrix A_unsymm; volatile UpperSymmDenseMatrix A_symm; volatile DenseVector x, b_unsymm, b_symm; /** * Partitioning */ volatile int[] localLength; volatile double[] output; @Override protected void setUp() throws Exception { int size = Utilities.getInt(1, 8); coll = new CollectiveCommunications(size); int n = Utilities.getInt(size, 250); A_unsymm = new DenseMatrix(n, n); A_symm = new UpperSymmDenseMatrix(n); Utilities.populate(A_unsymm); Utilities.upperPopulate(A_unsymm); double shift = 10; do { Utilities.addDiagonal(A_unsymm, shift); } while (Utilities.singular(A_unsymm)); do { Utilities.addDiagonal(A_symm, shift); } while (!Utilities.spd(A_symm)); x = new DenseVector(n); b_unsymm = x.copy(); b_symm = x.copy(); Utilities.populate(x); A_unsymm.mult(x, b_unsymm); A_symm.mult(x, b_symm); output = new double[n]; // Set local lengths localLength = new int[size]; Arrays.fill(localLength, n / size); // Adjust the last length to ensure the whole vector is covered int sum = n; for (int l : localLength) sum -= l; localLength[size - 1] += sum; } public void testRowGMRES_1() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new GMRESRowDistIterativeSolver(i, Vector.Norm.One)); compare(t); } public void testRowGMRES_2() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new GMRESRowDistIterativeSolver(i, Vector.Norm.Two)); compare(t); } public void testRowGMRES_inf() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new GMRESRowDistIterativeSolver(i, Vector.Norm.Infinity)); compare(t); } public void testRowBiCGstab_1() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new BiCGstabRowDistIterativeSolver(i, Vector.Norm.One)); compare(t); } public void testRowBiCGstab_2() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new BiCGstabRowDistIterativeSolver(i, Vector.Norm.Two)); compare(t); } public void testRowBiCGstab_inf() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new BiCGstabRowDistIterativeSolver(i, Vector.Norm.Infinity)); compare(t); } public void testColumnGMRES_1() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new GMRESColumnDistIterativeSolver(i, Vector.Norm.One)); compare(t); } public void testColumnGMRES_2() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new GMRESColumnDistIterativeSolver(i, Vector.Norm.Two)); compare(t); } public void testColumnGMRES_inf() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new GMRESColumnDistIterativeSolver(i, Vector.Norm.Infinity)); compare(t); } public void testColumnBiCGstab_1() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new BiCGstabColumnDistIterativeSolver(i, Vector.Norm.One)); compare(t); } public void testColumnBiCGstab_2() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new BiCGstabColumnDistIterativeSolver(i, Vector.Norm.Two)); compare(t); } public void testColumnBiCGstab_inf() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new BiCGstabColumnDistIterativeSolver(i, Vector.Norm.Infinity)); compare(t); } public void testRowCG_1() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new CGRowDistIterativeSolver(i, Vector.Norm.One)); compare(t); } public void testRowCG_2() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new CGRowDistIterativeSolver(i, Vector.Norm.Two)); compare(t); } public void testRowCG_inf() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new CGRowDistIterativeSolver(i, Vector.Norm.Infinity)); compare(t); } public void testColumnCG_1() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new CGColumnDistIterativeSolver(i, Vector.Norm.One)); compare(t); } public void testColumnCG_2() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new CGColumnDistIterativeSolver(i, Vector.Norm.Two)); compare(t); } public void testColumnCG_inf() throws InterruptedException { Thread[] t = new Thread[coll.size()]; for (int i = 0; i < t.length; ++i) t[i] = new Thread(new CGColumnDistIterativeSolver(i, Vector.Norm.Infinity)); compare(t); } private void compare(Thread[] t) throws InterruptedException { ExecutorService pool = Executors.newFixedThreadPool(t.length); for (Thread ti : t) pool.execute(ti); pool.shutdown(); pool.awaitTermination(20, TimeUnit.SECONDS); for (int i = 0; i < x.size(); ++i) assertEquals(x.get(i), output[i], 1e-10); } private abstract class DistIterativeSolver implements Runnable { final protected int rank; final protected Vector.Norm norm;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -