⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 distvector.java

📁 另一个功能更强大的矩阵运算软件开源代码
💻 JAVA
字号:
/* * Copyright (C) 2003-2006 Bjørn-Ove Heimsund *  * This file is part of MTJ. *  * This library is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as published by the * Free Software Foundation; either version 2.1 of the License, or (at your * option) any later version. *  * This library is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License * for more details. *  * You should have received a copy of the GNU Lesser General Public License * along with this library; if not, write to the Free Software Foundation, * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */package no.uib.cipr.matrix.distributed;import java.util.Iterator;import no.uib.cipr.matrix.AbstractVector;import no.uib.cipr.matrix.Vector;import no.uib.cipr.matrix.VectorEntry;/** * Distributed memory vector */public class DistVector extends AbstractVector {    private static final long serialVersionUID = 4048795671438309432L;    /**     * Communicator in use     */    private Communicator comm;    /**     * Local part of the vector     */    Vector x;    /**     * The subdivisions of the global vector     */    int[] n;    int rank;    /**     * Rank and size of the communicator     */    private int commSize;    /**     * Constructor for DistVector     *      * @param size     *            Global vector size     * @param comm     *            Communicator to use     * @param x     *            Local vector, its size cannot exceed the global size, and the     *            sum of the local vector sizes must equal the global vector     *            size (this is checked)     */    public DistVector(int size, Communicator comm, Vector x) {        super(size);        this.comm = comm;        this.x = x;        rank = comm.rank();        commSize = comm.size();        n = new int[commSize + 1];        // Find out the sizes of all the parts of the distributed vector        int[] send = new int[] { x.size() };        int[][] recv = new int[comm.size()][1];        comm.allGather(send, recv);        for (int i = 0; i < commSize; ++i)            n[i + 1] = n[i] + recv[i][0];        if (n[commSize] != size)            throw new IllegalArgumentException("Sum of local vector sizes ("                    + n[commSize] + ") do not match the global vector size ("                    + size + ")");    }    @Override    public void set(int index, double value) {        check(index);        if (local(index))            x.set(index - n[rank], value);        else            throw new IllegalArgumentException("Index " + index                    + " is not local");    }    @Override    public void add(int index, double value) {        check(index);        if (local(index))            x.add(index - n[rank], value);        else            throw new IllegalArgumentException("Index " + index                    + " is not local");    }    @Override    public double get(int index) {        check(index);        if (local(index))            return x.get(index - n[rank]);        else            throw new IllegalArgumentException("Entry not available locally");    }    @Override    public DistVector copy() {        return new DistVector(size, comm, x.copy());    }    @Override    public DistVector zero() {        x.zero();        return this;    }    @Override    public DistVector scale(double alpha) {        x.scale(alpha);        return this;    }    @Override    public DistVector set(double alpha, Vector y) {        if (!(y instanceof DistVector))            throw new IllegalArgumentException("Vector must be DistVector");        checkSize(y);        Vector yb = ((DistVector) y).getLocal();        x.set(alpha, yb);        return this;    }    @Override    public DistVector add(double alpha, Vector y) {        if (!(y instanceof DistVector))            throw new IllegalArgumentException("Vector must be DistVector");        checkSize(y);        Vector yb = ((DistVector) y).getLocal();        x.add(alpha, yb);        return this;    }    @Override    public double dot(Vector y) {        if (!(y instanceof DistVector))            throw new IllegalArgumentException("Vector must be a DistVector");        checkSize(y);        // Compute local part        Vector yb = ((DistVector) y).getLocal();        double ldot = x.dot(yb);        // Sum all local parts        double[] recv = new double[1];        comm.allReduce(new double[] { ldot }, recv, Reductions.sum());        return recv[0];    }    @Override    protected double norm1() {        double norm = x.norm(Norm.One);        double[] recv = new double[1];        comm.allReduce(new double[] { norm }, recv, Reductions.sum());        return norm;    }    @Override    protected double norm2_robust() {        // We'll just call the fast version, as we have to square the norm        // anyways during communications        return norm2();    }    @Override    protected double norm2() {        // Compute local norm squared        double norm = x.norm(Norm.Two);        norm *= norm;        double[] recv = new double[1];        comm.allReduce(new double[] { norm }, recv, Reductions.sum());        return Math.sqrt(recv[0]);    }    @Override    protected double normInf() {        double norm = x.norm(Norm.Infinity);        double[] recv = new double[1];        comm.allReduce(new double[] { norm }, recv, Reductions.max());        return norm;    }    /**     * Returns the local part of the vector     */    public Vector getLocal() {        return x;    }    /**     * Returns which indices are owned by which ranks. The current rank owns the     * indices <code>n[comm.rank()]</code> (inclusive) to     * <code>n[comm.rank()+1]</code> (exclusive)     */    public int[] getOwnerships() {        return n;    }    /**     * Returns true if the insertion index is local to this rank, and no     * communication is needed afterwards.     */    public boolean local(int index) {        return index >= n[rank] && index < n[rank + 1];    }    @Override    public Iterator<VectorEntry> iterator() {        return new DistVectorIterator();    }    /**     * Gets the communicator associated with this vector     */    public Communicator getCommunicator() {        return comm;    }    /**     * Iterator over a distributed memory vector     */    private class DistVectorIterator implements Iterator<VectorEntry> {        /**         * Entry of local iterator         */        private DistVectorEntry entry;        /**         * Iterator of local vector         */        private Iterator<VectorEntry> i;        /**         * Constructor for DistVectorIterator         */        public DistVectorIterator() {            i = x.iterator();            entry = new DistVectorEntry();        }        public void remove() {            i.remove();        }        public boolean hasNext() {            return i.hasNext();        }        public VectorEntry next() {            entry.update(i.next());            return entry;        }    }    /**     * Entry returned by the iterator     */    private class DistVectorEntry implements VectorEntry {        private int offset;        private VectorEntry e;        public DistVectorEntry() {            offset = n[rank];        }        public void update(VectorEntry e) {            this.e = e;        }        public int index() {            return e.index() + offset;        }        public double get() {            return e.get();        }        public void set(double value) {            e.set(value);        }    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -