⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 collectivecommunications.java

📁 另一个功能更强大的矩阵运算软件开源代码
💻 JAVA
字号:
/* * Copyright (C) 2003-2006 Bjørn-Ove Heimsund *  * This file is part of MTJ. *  * This library is free software; you can redistribute it and/or modify it * under the terms of the GNU Lesser General Public License as published by the * Free Software Foundation; either version 2.1 of the License, or (at your * option) any later version. *  * This library is distributed in the hope that it will be useful, but WITHOUT * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or * FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License * for more details. *  * You should have received a copy of the GNU Lesser General Public License * along with this library; if not, write to the Free Software Foundation, * Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */package no.uib.cipr.matrix.distributed;import java.lang.reflect.Array;import java.util.ArrayList;import java.util.List;import java.util.concurrent.BrokenBarrierException;import java.util.concurrent.CyclicBarrier;import java.util.concurrent.Exchanger;import no.uib.cipr.matrix.distributed.Communicator.SendRecv;/** * Collective communications between all the participating threads. */public class CollectiveCommunications {    /**     * Size of the whole collective     */    final int size;    /**     * Exchanges pipes between the threads. Only used for setup of the     * point-to-point communications     */    private final List<List<Exchanger<SendRecv>>> ex;    /**     * Sets up a collective of the given size     *      * @param size     *            Number of members of the collective     */    public CollectiveCommunications(int size) {        if (size < 1)            throw new IllegalArgumentException("size < 1");        this.size = size;        barrier = new CyclicBarrier(size);        broadcast = new Broadcast();        gather = new Gather();        scatter = new Scatter();        allGather = new AllGather();        allToAll = new AllToAll();        reduce = new Reduce();        ex = new ArrayList<List<Exchanger<SendRecv>>>();        for (int i = 0; i < size; ++i) {            List<Exchanger<SendRecv>> iex = new ArrayList<Exchanger<SendRecv>>();            // Get from symmetry            for (int j = 0; j < i; ++j)                iex.add(ex.get(j).get(i));            // Main diagonal (ignored)            iex.add(null);            // New entries            for (int j = i + 1; j < size; ++j)                iex.add(new Exchanger<SendRecv>());            ex.add(iex);        }    }    /**     * Gets the size of the collective     */    public int size() {        return size;    }    /**     * Creates a communicator for point-to-point data-exchange. This method     * should be called by all the threads in the collective     *      * @param rank     *            Rank of the communicator     */    public Communicator createCommunicator(int rank) {        if (rank < 0 || rank >= size)            throw new IllegalArgumentException("rank < 0 || rank >= size");        return new Communicator(rank, ex.get(rank), this);    }    static void await(CyclicBarrier barrier) {        try {            barrier.await();        } catch (InterruptedException e) {            throw new RuntimeException(e);        } catch (BrokenBarrierException e) {            throw new RuntimeException(e);        }    }    /**     * Blocks the caller until all threads have called     */    void barrier() {        await(barrier);    }    private final CyclicBarrier barrier;    void broadcast(Object buffer, int root, int rank) {        broadcast.buffer[rank] = buffer;        if (rank == root)            broadcast.root = root;        await(broadcast.barrier);    }    private final Broadcast broadcast;    private class Broadcast implements Runnable {        final CyclicBarrier barrier = new CyclicBarrier(size, this);        int root;        final Object[] buffer = new Object[size];        public void run() {            int length = Array.getLength(buffer[root]);            for (int i = 0; i < size; ++i)                System.arraycopy(buffer[root], 0, buffer[i], 0, length);        }    }    void gather(Object sendbuf, Object[] recvbuf, int root, int rank) {        gather.setSend(sendbuf, rank);        if (rank == root)            gather.recvbuf = recvbuf;        await(gather.barrier);    }    private final Gather gather;    private class Gather implements Runnable {        CyclicBarrier barrier = new CyclicBarrier(size, this);        Object[] recvbuf = new Object[size], sendbuf = new Object[size];        private final int[] length = new int[size];        public void setSend(Object sendbuf, int rank) {            this.sendbuf[rank] = sendbuf;            length[rank] = Array.getLength(sendbuf);        }        public void run() {            for (int i = 0; i < size; ++i)                System.arraycopy(sendbuf[i], 0, recvbuf[i], 0, length[i]);        }    }    void scatter(Object[] sendbuf, Object recvbuf, int root, int rank) {        scatter.setRecv(recvbuf, rank);        if (rank == root)            scatter.sendbuf = sendbuf;        await(scatter.barrier);    }    private final Scatter scatter;    private class Scatter implements Runnable {        final CyclicBarrier barrier = new CyclicBarrier(size, this);        Object[] sendbuf = new Object[size], recvbuf = new Object[size];        private final int[] length = new int[size];        public void setRecv(Object recvbuf, int rank) {            this.recvbuf[rank] = recvbuf;            length[rank] = Array.getLength(recvbuf);        }        public void run() {            for (int i = 0; i < size; ++i)                System.arraycopy(sendbuf[i], 0, recvbuf[i], 0, length[i]);        }    }    void allGather(Object sendbuf, Object[] recvbuf, int rank) {        allGather.setSendRecv(sendbuf, recvbuf, rank);        await(allGather.barrier);    }    private final AllGather allGather;    private class AllGather implements Runnable {        final CyclicBarrier barrier = new CyclicBarrier(size, this);        private final Object sendbuf[] = new Object[size],                recvbuf[][] = new Object[size][size];        private final int[] length = new int[size];        public void setSendRecv(Object send, Object[] recv, int rank) {            sendbuf[rank] = send;            recvbuf[rank] = recv;            length[rank] = Array.getLength(send);        }        public void run() {            for (int i = 0; i < size; ++i)                for (int j = 0; j < size; ++j)                    System                            .arraycopy(sendbuf[i], 0, recvbuf[j][i], 0,                                    length[i]);        }    }    void allToAll(Object[] sendbuf, Object[] recvbuf, int rank) {        allToAll.setSendRecv(sendbuf, recvbuf, rank);        await(allToAll.barrier);    }    private final AllToAll allToAll;    private class AllToAll implements Runnable {        final CyclicBarrier barrier = new CyclicBarrier(size, this);        private final Object[][] sendbuf = new Object[size][size],                recvbuf = new Object[size][size];        private final int[][] length = new int[size][size];        public void setSendRecv(Object[] send, Object[] recv, int rank) {            sendbuf[rank] = send;            recvbuf[rank] = recv;            for (int i = 0; i < size; ++i)                length[rank][i] = Array.getLength(send[i]);        }        public void run() {            for (int i = 0; i < size; ++i)                for (int j = 0; j < size; ++j)                    System.arraycopy(sendbuf[i][j], 0, recvbuf[j][i], 0,                            length[i][j]);        }    }    void reduce(Object sendbuf, Object recvbuf, Reduction op, int root, int rank) {        reduce.sendbuf[rank] = sendbuf;        if (rank == root) {            reduce.op = op;            reduce.recvbuf = recvbuf;        }        await(reduce.barrier);    }    private final Reduce reduce;    private class Reduce implements Runnable {        CyclicBarrier barrier = new CyclicBarrier(size, this);        Reduction op;        Object sendbuf[] = new Object[size], recvbuf;        public void run() {            op.init(recvbuf);            for (int i = 0; i < size; ++i)                op.op(recvbuf, sendbuf[i]);        }    }    void allReduce(Object sendbuf, Object recvbuf, Reduction op, int rank) {        reduce(sendbuf, recvbuf, op, 0, rank);        broadcast(recvbuf, 0, rank);    }}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -