📄 cannon.c~
字号:
#include <stdio.h>#include <mpi.h>int multiply(int *a, int *b, int *c, int n){ int i,j,k; for(i=0;i<n;i++) for(j=0;j<n;j++) for(k=0;k<n;k++) c[i*n+j] += a[i*n+k]*b[k*n+j]; }void cannon(int *a, int *b, int *c, int n){ int i; int size, rank, t_rank; MPI_Status status; MPI_Comm com_topology; int dims[2], periods[2], coords[2]; int rrank, lrank, urank, drank; int source, dest; MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_rank(MPI_COMM_WORLD, &rank); printf("process %d's info:\n", rank); dims[0] = dims[1] = 4; periods[0] = periods[1] = 1; MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods,\ 1, &com_topology); //printf("2"); MPI_Comm_rank(com_topology, &t_rank); MPI_Cart_coords(com_topology, t_rank, 2, coords); //printf("\n(%d, %d)-->%d\n", coords[0], coords[1],t_rank); MPI_Cart_shift(com_topology, 1, -1, &rrank, &lrank); MPI_Cart_shift(com_topology, 0, -1, &drank, &urank); MPI_Cart_shift(com_topology, 1, -coords[0], &source, &dest); //printf("(%d, %d), %d-->%d-->%d\n", coords[0], coords[1],source, t_rank, dest); printf("(%d, %d), %d-->%d-->%d\n", coords[0], coords[1],source, t_rank, dest); printf("(%d, %d), %d-->%d-->%d\n", coords[0], coords[1],source, t_rank, dest); MPI_Sendrecv_replace(a, n*n, MPI_INT, dest, 1, source, 1,\ com_topology, &status); MPI_Cart_shift(com_topology, 0, -coords[1], &source, &dest); MPI_Sendrecv_replace(a, n*n, MPI_INT, dest, 1, source, 1,\ com_topology, &status); MPI_Barrier(com_topology); // printf("3"); for(i =0;i<dims[0];i++) { multiply(a, b, c, n); printf("(%d, %d), %d-->%d-->%d\n", coords[0], coords[1],rrank, t_rank, lrank); printf("(%d, %d), %d-->%d-->%d\n", coords[0], coords[1],drank, t_rank, urank); MPI_Sendrecv_replace(a, n*n, MPI_INT, lrank, 1, rrank, 1,\ com_topology, &status); MPI_Sendrecv_replace(b, n*n, MPI_INT, urank, 1, drank, 1,\ com_topology, &status); } //printf("%d %d %d",a[10],b[10],c[10]); MPI_Comm_free(&com_topology);}void print(int *a, int *b, int *c, int n){ int i,j,k,base; printf("A's block is -----------------\n"); for(i =0;i<4;i++) for(j=0;j<4;j++) { printf("\nA[%d][%d] is:", i,j); base = (i*4+j)*16; for(k=0;k<16;k++) { printf("%d ", a[base+k]); if(k%4 ==3) printf("\n "); } } printf("\n-----------------------------"); printf("B's block is -----------------\n"); for(i =0;i<4;i++) for(j=0;j<4;j++) { printf("\nB[%d][%d] is:", i,j); base = (i*4+j)*16; for(k=0;k<16;k++) { printf("%d ", b[base+k]); if(k%4 ==3) printf("\n "); } } printf("\n-----------------------------"); printf("C's block is -----------------\n"); for(i =0;i<4;i++) for(j=0;j<4;j++) { printf("\nC[%d][%d] is:", i,j); base = (i*4+j)*16; for(k=0;k<16;k++) { printf("%d ", c[base+k]); if(k%4 ==3) printf("\n "); } } printf("\n-----------------------------"); }int main( int argc, char **argv){ int process_num = 16; int matrix_dim = 16; int every_matrix = 4; int i; int root =0; int *a, *b, *c; int *ea, *eb, *ec; int rank, size; int err; err = MPI_Init(&argc, &argv); err =MPI_Comm_rank(MPI_COMM_WORLD, &rank); err = MPI_Comm_size(MPI_COMM_WORLD, &size); printf("%d ",rank); if(rank == root){\ /* initialize matrix */ a = (int *)malloc(matrix_dim*matrix_dim*sizeof(int)); b = (int *)malloc(matrix_dim*matrix_dim*sizeof(int)); c = (int *)calloc(matrix_dim*matrix_dim,sizeof(int)); if( a <=0 || b<=0 || c<=0) printf("hello"); for(i =0;i<matrix_dim*matrix_dim;i++) { a[i] = i; b[i] = i+2; } // printf("i am here"); } ea = (int *)calloc(every_matrix*every_matrix,sizeof(int)); eb = (int *)calloc(every_matrix*every_matrix,sizeof(int)); ec = (int *)calloc(every_matrix*every_matrix,sizeof(int)); // printf("i am here"); //err=MPI_Barrier(MPI_COMM_WORLD); /* scatter matrix every row*/ /* for(i =0;i<matrix_dim;i++){ */ /* MPI_Scatter(a+i*matrix_dim, every_matrix,\ */ /* MPI_INT, ea+i*every_matrix, every_matrix,\ */ /* MPI_INT, root, MPI_COMM_WORLD); */ /* } */ MPI_Scatter(a, matrix_dim,\ MPI_INT, ea, matrix_dim,\ MPI_INT, root, MPI_COMM_WORLD); MPI_Scatter(b, matrix_dim,\ MPI_INT, eb, matrix_dim,\ MPI_INT, root, MPI_COMM_WORLD); err = MPI_Barrier(MPI_COMM_WORLD); /* cannon calculation*/ cannon(ea, eb, ec, every_matrix); err = MPI_Barrier(MPI_COMM_WORLD); MPI_Gather(ec, matrix_dim, MPI_INT,\ c, matrix_dim, MPI_INT,\ root, MPI_COMM_WORLD); err = MPI_Barrier(MPI_COMM_WORLD); /* print out result */ if( rank == root) print(a, b,c, matrix_dim); MPI_Finalize(); // return 0;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -