'MPI_Sendrecv_replace gets blocked
I am working on a code that implements Cannon matrix multiplication algorithm. Cannon's algorithm is described in the following fragment in pseudocode:
- Executed in parallel:
- circular movement with i positions to the left ofsub matrices Ai,x
- circular movement with j positions upwards of submatrices Bx,j
- for k = 0 to n/p-1 Executed in parallel:
- Ci,j = Ci,j + Ai,j * Bi,j
- circular movement with 1 position to the left of sub matrices Ai,x
- circular movement with 1 position upwards of sub matrices Bx,j
However my code seems to get blocked in the for loop after sending the submatrix B.
int main(int argc, char* argv[])
{
read_input_files(argc, argv);
int rank, size, i, j, shift;
//print_matrix(N, A, 0);
//print_matrix(N, B, 0);
//print_matrix(N, AB);
MPI_Init(&argc, &argv);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
MPI_Comm_size(MPI_COMM_WORLD, &size);
MPI_Comm comm;
MPI_Status status;
int left, right, up, down;
int shiftsource, shiftdest;
int dims[2] = { 0, 0 }, periods[2] = { 1, 1 }, coords[2];
MPI_Dims_create(size, 2, dims);
MPI_Cart_create(MPI_COMM_WORLD, 2, dims, periods, 1, &comm);
MPI_Cart_coords(comm, rank, 2, coords);
MPI_Cart_shift(comm, 1, -1, &right, &left);
MPI_Cart_shift(comm, 0, -1, &down, &up);
//printf("%d --- %d %d %d %d.\n", rank, right, left, up, down);
if (dims[0] != dims[1]) {
printf("The number of processors must be a perfect square.\n");
if (rank == 0)
printf("The number of processors must be a perfect square.\n");
MPI_Finalize();
return 0;
}
int block_size = N / sqrt(size);
cout << rank << " : dims " << dims[0] << "---------------------------------------" << endl;
int* A_sub = make_sub(A, rank, block_size, size);
int* B_sub = make_sub(B, rank, block_size, size);
int* AB_sub = (int*)calloc(block_size * block_size, sizeof(int));
//print_submatrix(block_size, A_sub, rank);
cout << rank << " : coords " << coords[0] << " * " << coords[1] << "---------------------------------------" << endl;
MPI_Cart_shift(comm, 0, -coords[0], &shiftsource, &shiftdest);
MPI_Sendrecv_replace(A_sub, block_size * block_size, MPI_INT, shiftdest, 1, shiftsource, 1, comm, &status);
cout << rank << " : MPI_Sendrecv_replace A_sub " << endl;
//print_submatrix(block_size, A_sub, rank);
MPI_Cart_shift(comm, 1, -coords[1], &shiftsource, &shiftdest);
MPI_Sendrecv_replace(B_sub, block_size * block_size, MPI_INT, shiftdest, 1, shiftsource, 1, comm, &status);
cout << rank << " : MPI_Sendrecv_replace B_sub " << endl;
for (shift = 0;shift < dims[0];shift++) {
for (i = 0;i < block_size;i++) {
for (j = 0;j < block_size;j++)
{
for (k = 0;k < block_size;k++) {
AB_sub[i * block_size + j] += A_sub[i * block_size + k] * B_sub[k * block_size + j];
}
}
}
if(shift == dims[0]-1) print_submatrix(block_size, AB_sub, rank);
MPI_Cart_shift(comm, 1, 1, &left, &right);
MPI_Sendrecv_replace(A_sub, block_size * block_size, MPI_INT, left, 1, right, 1, comm, MPI_STATUS_IGNORE);
cout << rank << " : MPI_Sendrecv_replace A " << endl;
MPI_Cart_shift(comm, 0, 1, &up, &down);
MPI_Sendrecv_replace(B_sub, block_size * block_size, MPI_INT, up, 1, down, 1, comm, MPI_STATUS_IGNORE);
cout << rank << " : MPI_Sendrecv_replace B " <<endl;
}
//print_matrix(N, AB, rank);
//cout << rank << " : coords " << coords[0] << " * " << coords[1] << "---------------------------------------" << endl;
MPI_Gather(&AB_sub, block_size*block_size, MPI_INT, AB, N*N, MPI_INT, 0, MPI_COMM_WORLD);
MPI_Finalize();// MPI_Comm_free(&comm); Free up communicator
//print_matrix(N, AB, 0);
return 0;
}
read_input_files() is a function that reads the 2 matrices in files give as cmd line args.
A matrix after reading file:
0 1 2 3 4 5
6 7 8 9 10 11
12 13 14 15 16 17
18 19 20 21 22 23
24 25 26 27 28 29
30 31 32 33 34 35
B matrix after reading file:
0 1 2 3 4 5
6 7 8 9 10 11
12 13 14 15 16 17
18 19 20 21 22 23
24 25 26 27 28 29
30 31 32 33 34 35
N is the size of matrix, N is 6 in this case.
Solution 1:[1]
Your call MPI_Cart_shift(comm, 1, -coords[1], has a strange shift parameter: you're shifting by something depending on the coordinate. That should probably be MPI_Cart_shift(comm,1,-1.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Victor Eijkhout |
