'Rust issue in send-recv rsmpi

I am trying to port my simple C++ code to Rust MPI (rsmpi) but I am having trouble in the send and recv logic. My C++ code is:

#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <cmath>
#include <chrono>
#include <mpi/mpi.h>

#define N_x_total 12
#define N_ghost 1
#define N_t 20

int main(int argc, char** argv){

    // Initial setup
    MPI::Init(argc, argv);
    int rank, size;
    MPI_Comm_size(MPI_COMM_WORLD, &size);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    int dstL   = rank - 1;
    int dstR   = rank + 1;
    if (dstL < 0)     dstL = MPI_PROC_NULL;
    if (dstR >= size) dstR = MPI_PROC_NULL;

    double alpha = 0.05;

    // Setup grid (local)
    double x_start = 0.0;
    double x_end = 1.0;
    double x_global[N_x_total + 2*N_ghost];
    int N_x = N_x_total/size;
    int start = N_ghost;
    int end = start + N_x -1;
    double dx = (x_end-x_start)/(N_x_total+1);
    for (auto i = 0; i < N_x_total + 2*N_ghost; i++)
    {
        x_global[i] = x_start + (i-start+1)*dx;
    }
    double* x_local = x_global + N_x*rank;
    double dt = 0.0059;
    double F = alpha*dt/(dx*dx);

    // Allocate memory to local grids
    double* u_old = (double*) malloc((N_x + 2*N_ghost)*sizeof(double));
    double* u_new = (double*) malloc((N_x + 2*N_ghost)*sizeof(double));
    double* u_global = (double*) malloc((N_x_total + 2*N_ghost)*sizeof(double));

    for (auto i = 0; i < size; i++)
    {
        MPI_Bcast(&i, 1, MPI_INT, 0, MPI_COMM_WORLD);
        if (rank == i)
        {
            std::cout << "Rank is: " << rank << std::endl;
            std::cout << "dstL is: " << dstL << " dstR is: " << dstR << std::endl;
            std::cout << "start is: " << start << " end is: " << end << std::endl;
            std::cout << "x_start is: " << x_local[start] << "x_end is: " << x_local[end] << std::endl;
        }
    }
    MPI_Barrier(MPI_COMM_WORLD);

    // Set initial conditions
    for (auto i = start; i <= end; i++)
    {
        u_old[i] = 10;
    }

    // Simulate timesteps
    // double t = 0.0;
    // double t_end = 0.1;
    for (auto t = 0; t < N_t; t++)
    {
        if (dstL == MPI_PROC_NULL)
        {
            u_old[start-1] = 0.0;
        }
        if (dstR == MPI_PROC_NULL)
        {
            u_old[end+1] = 0.0;
        }

        double* send_buffer = u_old + end - (N_ghost - 1);
        double* recv_buffer = u_old + 0;
        MPI_Sendrecv(send_buffer, N_ghost, MPI_DOUBLE, dstR, 0, recv_buffer, N_ghost, MPI_DOUBLE, dstL, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        send_buffer = u_old + start;
        recv_buffer = u_old + end + 1;
        MPI_Sendrecv(send_buffer, N_ghost, MPI_DOUBLE, dstL, 0, recv_buffer, N_ghost, MPI_DOUBLE, dstR, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);

        for (auto i = start; i <= end; i++)
        {
            u_new[i] = u_old[i] + F*(u_old[i-1] - 2.0*u_old[i] + u_old[i+1]);
            // std::cout << "rank: " << rank << " u: " << u_old[i] << " for u[i-1]:" << u_old[i-1] <<  "for u[i+1]:" << u_old[i+1] <<  "for u[i]:" << u_old[i] << std::endl;
        }
        // t += dt;

        for (auto i = start; i <= end; i++)
        {
            u_old[i] = u_new[i];
        }

    }
    MPI_Barrier(MPI_COMM_WORLD);


    // Print output after gathering data
    MPI_Gather(u_old+start, N_x, MPI_DOUBLE, u_global+start, N_x, MPI_DOUBLE, 0, MPI_COMM_WORLD);
    if (rank == 0)
    {
        for (auto i = start; i < start + N_x_total; i++)
        {
            std::cout << "Solution is: " << u_global[i] << std::endl;
        }
        
    }

    free(u_global);
    free(u_old);
    free(u_new);
    
    
    MPI_Finalize();
    return 0;
}

My Rust code for the same is:

extern crate mpi;
use mpi::request::{CancelGuard, WaitGuard};
use mpi::topology::CommunicatorRelation;
use mpi::point_to_point as p2p;
use mpi::topology::Rank;
use mpi::traits::*;

const N_x_total:usize = 12;
const N_ghost:usize = 1;
const N_t:usize = 5;


fn main() {
    // mpiexec -n 4 ./../target/debug/heat_equation_mpi // This is the command to run rsmpi

    // Initial setup
    let universe = mpi::initialize().unwrap();
    let world = universe.world();
    let size = world.size();
    let rank = world.rank();
    let root_rank = 0;
    let root_process = world.process_at_rank(root_rank);

    let mut distL = rank - 1;
    let mut distR = rank + 1;
    if distL < 0 {
        distL = -2;
    }
    if distR >= size {
        distR = size+1;
    }

    let alpha = 0.05;

    // Setup grid (local)
    let x_start = 0.0;
    let x_end = 1.0;
    let mut x_global = vec!(0.0; N_x_total + 2*N_ghost);
    let N_x = N_x_total/(size as usize);
    let start = N_ghost;
    let end = start + N_x - 1;
    let dx = (x_end-x_start)/((N_x_total+1) as f32);
    for i in 0..(N_x_total + 2*N_ghost)
    {
        x_global[i] = x_start + (i as f32 - start as f32 + 1.0) * dx;
    }
    let mut x_local = vec!(0.0; N_x + 2*N_ghost);
    for i in 0..(N_x + 2*N_ghost)
    {
        x_local[i] = x_global[i + N_x as usize * rank as usize];
    }

    let dt = 0.0059;
    let F = alpha*dt/(dx*dx);

    // Allocate memory to local grid
    let mut u_old = vec!(10.0; N_x + 2*N_ghost);
    let mut u_new = vec!(0.0; N_x + 2*N_ghost);
    let mut u_global = vec!(0.0; N_x_total + 2*N_ghost);

    for i in 0..size {
        if rank == i {
            println!("Rank is: {}", rank);
            println!("distL is: {} and distR is: {}", distL, distR);
            println!("start is: {} and end is: {}", start, end);
            println!("x_start is: {:?} and x_end is: {:?}", x_local[start], x_local[end]);
        }
    }

    world.barrier();

    // Simulate timesteps
    let next_rank = if rank + 1 < size { rank + 1 } else { 0 };
    let previous_rank = if rank > 0 { rank - 1 } else { size - 1 };

    for t in 0..N_t {
        if distL == -2 {
            u_old[start-1] = 0.0;
        }
        if distR == size+1 {
            u_old[end+1] = 0.0;
        }

        // TODO: Send and receive operation

        // // Trial 1
        // let mut msg: f32 = u_old[0];
        // mpi::request::scope(|scope|
        // {
        //     let _sreq = WaitGuard::from(
        //         world.this_process().immediate_ready_send(scope, &u_old[0 + end - (N_ghost - 1)]),
        //     );
        //     let _rreq = WaitGuard::from(
        //         world.any_process().immediate_receive_into(scope, &mut msg),
        //     );
        // });
        // u_old[0] = msg;

        // let mut msg: f32 = u_old[0 + end + 1];
        // mpi::request::scope(|scope|
        // {
        //     let _sreq = WaitGuard::from(
        //         world.this_process().immediate_ready_send(scope, &u_old[0 + start]),
        //     );
        //     let _rreq = WaitGuard::from(
        //         world.any_process().immediate_receive_into(scope, &mut msg),
        //     );
        // });
        // u_old[0 + end + 1] = msg;

        // // Trial2
        // let (msg, status) : (Rank, _) = p2p::send_receive(&u_old[0 + end - (N_ghost-1)], &world.process_at_rank(next_rank), &world.process_at_rank(rank));
        // println!(
        //     "Process {} got message {}.\nStatus is: {:?}",
        //     rank, msg, status
        // );

        // let (msg, status) : (Rank, _) = p2p::send_receive(&u_old[0 + start], &world.process_at_rank(previous_rank), &world.process_at_rank(rank));
        // println!(
        //     "Process {} got message {}.\nStatus is: {:?}",
        //     rank, msg, status
        // );

        // // Trial3:
        // let mut msg: f32 = u_old[0];
        // mpi::request::scope(|scope|
        // {
        //     let _sreq = WaitGuard::from(
        //         world.this_process().immediate_ready_send(scope, &u_old[0 + end - (N_ghost - 1)]),
        //     );
        //     let _rreq = WaitGuard::from(
        //         world.process_at_rank(next_rank).immediate_receive_into(scope, &mut msg),
        //     );
        // });
        // // u_old[0] = msg;

        // let mut msg: f32 = u_old[0 + end + 1];
        // mpi::request::scope(|scope|
        // {
        //     let _sreq = WaitGuard::from(
        //         world.this_process().immediate_ready_send(scope, &u_old[0 + start]),
        //     );
        //     let _rreq = WaitGuard::from(
        //         world.process_at_rank(previous_rank).immediate_receive_into(scope, &mut msg),
        //     );
        // });
        // // u_old[0 + end + 1] = msg;


        // // Trial 4
        let mut msg: f32 = u_old[0];
        mpi::request::scope(|scope|
        {
            let _sreq = WaitGuard::from(
                world.process_at_rank(rank).immediate_ready_send(scope, &u_old[0 + end - (N_ghost - 1)]),
            );
            let _rreq = WaitGuard::from(
                world.process_at_rank(next_rank).immediate_receive_into(scope, &mut msg),
            );
        });
        u_old[0] = msg;

        let mut msg: f32 = u_old[0 + end + 1];
        mpi::request::scope(|scope|
        {
            let _sreq = WaitGuard::from(
                world.process_at_rank(rank).immediate_ready_send(scope, &u_old[0 + start]),
            );
            let _rreq = WaitGuard::from(
                world.process_at_rank(previous_rank).immediate_receive_into(scope, &mut msg),
            );
        });
        u_old[0 + end + 1] = msg;

        println!("Rank: {:?}, Value u_old[0 + end - (N_ghost-1)] is: {:?}", rank, u_old[0 + end - (N_ghost-1)]);
        println!("Rank: {:?}, Value u_old[0 + start] is: {:?}", rank, u_old[0 + start]);

        world.barrier();

        // update values
        for i in start..=end {
            u_new[i] = u_old[i] + F*(u_old[i-1] - 2.0*u_old[i] + u_old[i+1]);
        }

        for i in start..=end {
            u_old[i] = u_new[i];
        }


    }
    world.barrier();

    println!("Solution u_old is: {:?}", u_old);


    // Print output after gathering data
    if world.rank() == root_rank {
        root_process.gather_into_root(&u_old[start..=end], &mut u_global[start..=(start+N_x_total-1)]);
    }
    else {
        root_process.gather_into(&u_old[start..=end]);
    }

    if rank == 0 {
        println!("Solution u_global is: {:?}", u_global);
    }
}

The different trials that I have commented in the code are the various things I have tried but the code doesn't work. It just freezes and I suspect a deadlock happening. But I am unable to point it out. Any help please?



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source