⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 distcolmatrix.java

📁 另一个功能更强大的矩阵运算软件开源代码
💻 JAVA
字号:
/* * Copyright (C) 2003-2006 Bjørn-Ove Heimsund *  * This file is part of MTJ. *  * This library is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as published by the * Free Software Foundation; either version 2.1 of the License, or (at your * option) any later version. *  * This library is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License * for more details. *  * You should have received a copy of the GNU Lesser General Public License * along with this library; if not, write to the Free Software Foundation, * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */package no.uib.cipr.matrix.distributed;import java.util.Collection;import java.util.HashSet;import java.util.Iterator;import no.uib.cipr.matrix.Matrix;import no.uib.cipr.matrix.MatrixEntry;import no.uib.cipr.matrix.Vector;/** * Distributed matrix with column major blocks */public class DistColMatrix extends DistMatrix {    private static final long serialVersionUID = 3618421523053359673L;    /**     * Constructor for DistColMatrix     *      * @param numRows     *            Global number of rows     * @param numColumns     *            Global number of columns     * @param comm     *            Communicator to use     * @param A     *            Block diagonal matrix. The sum of the local row sizes of     *            <code>A</code> must equal the global number, and likewise     *            with the column sizes.     * @param B     *            Off-diagonal matrix part. Its number of rows must equal the     *            global number of rows, and its number of columns must equal     *            that of <code>A</code>     */    public DistColMatrix(int numRows, int numColumns, Communicator comm,            Matrix A, Matrix B) {        super(numRows, numColumns, comm, A, B);        if (A.numColumns() != B.numColumns())            throw new IllegalArgumentException(                    "A.numColumns() != B.numColumns()");        if (B.numRows() != numRows)            throw new IllegalArgumentException("B.numRows() != numRows");    }    @Override    public void add(int row, int column, double value) {        check(row, column);        if (inA(row, column))            A.add(row - n[rank], column - m[rank], value);        else if (local(row, column))            B.add(row, column - m[rank], value);        else            throw new IllegalArgumentException("Column index " + column                    + " is not local");    }    @Override    public void set(int row, int column, double value) {        check(row, column);        if (inA(row, column))            A.set(row - n[rank], column - m[rank], value);        else if (local(row, column))            B.set(row, column - m[rank], value);        else            throw new IllegalArgumentException("Column index " + column                    + " is not local");    }    @Override    public double get(int row, int column) {        check(row, column);        if (inA(row, column))            return A.get(row - n[rank], column - m[rank]);        else if (local(row, column))            return B.get(row, column - m[rank]);        else            throw new IndexOutOfBoundsException("Entry not available locally");    }    @Override    public DistColMatrix copy() {        return new DistColMatrix(numRows, numColumns, comm, A.copy(), B.copy());    }    @Override    public Iterator<MatrixEntry> iterator() {        return new DistMatrixIterator(n[rank], m[rank], 0, m[rank]);    }    @Override    public Vector multAdd(double alpha, Vector x, Vector y) {        if (!(x instanceof DistVector && y instanceof DistVector))            throw new IllegalArgumentException("Vectors must be DistVectors");        checkMultAdd(x, y);        // y = 1/alpha * y        y.scale(1. / alpha);        // y = A*x + y = A*x + 1/alpha * y        DistVector xd = (DistVector) x, yd = (DistVector) y;        // Non-local part        B.mult(xd.getLocal(), locR);        // Send it to the others        scatter.startGather(locR, yd);        // Local part        A.multAdd(xd.getLocal(), yd.getLocal());        // Finish communications, concluding the matrix product        scatter.endAddGather(locR, yd);        // y = alpha*y = alpha * A*x + y        return y.scale(alpha);    }    @Override    public Vector transMultAdd(double alpha, Vector x, Vector y) {        if (!(x instanceof DistVector && y instanceof DistVector))            throw new IllegalArgumentException("Vectors must be DistVectors");        checkTransMultAdd(x, y);        DistVector xd = (DistVector) x, yd = (DistVector) y;        // Recieve the needed components of the global x into the local vector        scatter.startScatter(xd, locC);        // Local part        A.transMultAdd(alpha, xd.getLocal(), yd.getLocal());        // Finish communications        scatter.endSetScatter(xd, locC);        // Non-local part        B.transMultAdd(alpha, locC, yd.getLocal());        return y;    }    @Override    public boolean local(int row, int column) {        return column >= m[rank] && column < m[rank + 1];    }    int getRank(int column) {        int i = 1;        for (; i < m.length; ++i)            if (column < m[i])                break;        return i - 1;    }    @Override    int[] getDelimiter() {        return m;    }    @Override    int[] getCommIndices() {        // Get the unique row indices from B        Collection<Integer> set = new HashSet<Integer>();        for (MatrixEntry e : B)            if (!local(e.row(), e.row()))                set.add(e.row());        // Get an array representation        int[] indices = new int[set.size()];        int j = 0;        for (Integer i : set)            indices[j++] = i;        return indices;    }    @Override    protected double norm1() {        // Compute as much locally as possible        double[] rowSum = new double[numRows];        for (MatrixEntry e : this)            rowSum[e.row()] += Math.abs(e.get());        // Sum in the rest from the other ranks        double[] recv = new double[numRows];        comm.allReduce(rowSum, recv, Reductions.sum());        // The global maximum        return max(recv);    }    @Override    protected double normInf() {        // Compute local norm        double norm = super.normInf();        // Find the maximum        double[] recv = new double[1];        comm.allReduce(new double[] { norm }, recv, Reductions.max());        return recv[0];    }    @Override    public DistColMatrix zero() {        super.zero();        return this;    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -