'Timeout doesn't time out in AsyncRead

I'm trying to implement an async read wrapper that will add read timeout functionality. The objective is that the API is plain AsyncRead. In other words, I don't want to add io.read(buf).timeout(t) everywehere in the code. Instead, the read instance itself should return the appropriate io::ErrorKind::TimedOut after the given timeout expires.

I can't poll the delay to Ready though. It's always Pending. I've tried with async-std, futures, smol-timeout - the same result. While the timeout does trigger when awaited, it just doesn't when polled. I know timeouts aren't easy. Something needs to wake it up. What am I doing wrong? How to pull this through?

use async_std::{
    future::Future,
    io,
    pin::Pin,
    task::{sleep, Context, Poll},
};
use std::time::Duration;

pub struct PrudentIo<IO> {
    expired: Option<Pin<Box<dyn Future<Output = ()> + Sync + Send>>>,
    timeout: Duration,
    io: IO,
}

impl<IO> PrudentIo<IO> {
    pub fn new(timeout: Duration, io: IO) -> Self {
        PrudentIo {
            expired: None,
            timeout,
            io,
        }
    }
}

fn delay(t: Duration) -> Option<Pin<Box<dyn Future<Output = ()> + Sync + Send + 'static>>> {
    if t.is_zero() {
        return None;
    }
    Some(Box::pin(sleep(t)))
}

impl<IO: io::Read + Unpin> io::Read for PrudentIo<IO> {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        if let Some(ref mut expired) = self.expired {
            match expired.as_mut().poll(cx) {
                Poll::Ready(_) => {
                    println!("expired ready");
                    // too much time passed since last read/write
                    return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
                }
                Poll::Pending => {
                    println!("expired pending");
                    // in good time
                }
            }
        }

        let res = Pin::new(&mut self.io).poll_read(cx, buf);
        println!("read {:?}", res);

        match res {
            Poll::Pending => {
                if self.expired.is_none() {
                    // No data, start checking for a timeout
                    self.expired = delay(self.timeout);
                }
            }
            Poll::Ready(_) => self.expired = None,
        }

        res
    }
}
impl<IO: io::Write + Unpin> io::Write for PrudentIo<IO> {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        Pin::new(&mut self.io).poll_write(cx, buf)
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.io).poll_flush(cx)
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.io).poll_close(cx)
    }
}

#[cfg(test)]
mod io_tests {
    use super::*;
    use async_std::io::ReadExt;
    use async_std::prelude::FutureExt;
    use async_std::{
        io::{copy, Cursor},
        net::TcpStream,
    };
    use std::time::Duration;

    #[async_std::test]
    async fn fail_read_after_timeout() -> io::Result<()> {
        let mut output = b"______".to_vec();
        let io = PendIo;
        let mut io = PrudentIo::new(Duration::from_millis(5), io);
        let mut io = Pin::new(&mut io);
        insta::assert_debug_snapshot!(io.read(&mut output[..]).timeout(Duration::from_secs(1)).await,@"Ok(io::Err(timeou))");
        Ok(())
    }
    #[async_std::test]
    async fn timeout_expires() {
        let later = delay(Duration::from_millis(1)).expect("some").await;
        insta::assert_debug_snapshot!(later,@r"()");
    }
    /// Mock IO always pending
    struct PendIo;
    impl io::Read for PendIo {
        fn poll_read(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
            _buf: &mut [u8],
        ) -> Poll<futures_io::Result<usize>> {
            Poll::Pending
        }
    }
    impl io::Write for PendIo {
        fn poll_write(
            self: Pin<&mut Self>,
            _cx: &mut Context<'_>,
            _buf: &[u8],
        ) -> Poll<futures_io::Result<usize>> {
            Poll::Pending
        }

        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
            Poll::Pending
        }

        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures_io::Result<()>> {
            Poll::Pending
        }
    }
}



Solution 1:[1]

// This is another solution. I think it is better.

impl<IO: io::AsyncRead + Unpin> io::AsyncRead for PrudentIo<IO> {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        let this = self.get_mut();

        let io = Pin::new(&mut this.io);
        if let Poll::Ready(res) = io.poll_read(cx, buf) {
            return Poll::Ready(res);
        }

        loop {
            if let Some(expired) = this.expired.as_mut() {
                ready!(expired.poll(cx));
                this.expired.take();
                return Poll::Ready(Err(io::ErrorKind::TimedOut.into()));
            }

            let timeout = Timer::after(this.timeout);
            this.expired = Some(timeout);
        }
    }
}

Solution 2:[2]

// 1. smol used, not async_std.
// 2. IO should be 'static.
// 3. when timeout, read_poll return Poll::Ready::Err(io::ErrorKind::Timeout)

use {
    smol::{future::FutureExt, io, ready, Timer},
    std::{
        future::Future,
        pin::Pin,
        task::{Context, Poll},
        time::Duration,
    },
};

// --

pub struct PrudentIo<IO> {
    expired: Option<Pin<Box<dyn Future<Output = io::Result<usize>>>>>,
    timeout: Duration,
    io: IO,
}

impl<IO> PrudentIo<IO> {
    pub fn new(timeout: Duration, io: IO) -> Self {
        PrudentIo {
            expired: None,
            timeout,
            io,
        }
    }
}

impl<IO: io::AsyncRead + Unpin + 'static> io::AsyncRead for PrudentIo<IO> {
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut [u8],
    ) -> Poll<io::Result<usize>> {
        let this = self.get_mut();
        loop {
            if let Some(expired) = this.expired.as_mut() {
                let res = ready!(expired.poll(cx))?;
                this.expired.take();
                return Ok(res).into();
            }
            let timeout = this.timeout.clone();
            let (io, read_buf) = unsafe {
                // Safety: ONLY used in poll_read method.
                (&mut *(&mut this.io as *mut IO), &mut *(buf as *mut [u8]))
            };
            let fut = async move {
                let timeout_fut = async {
                    Timer::after(timeout).await;
                    io::Result::<usize>::Err(io::ErrorKind::TimedOut.into())
                };
                let read_fut = io::AsyncReadExt::read(io, read_buf);
                let res = read_fut.or(timeout_fut).await;
                res
            }
            .boxed_local();
            this.expired = Some(fut);
        }
    }
}
impl<IO: io::AsyncWrite + Unpin> io::AsyncWrite for PrudentIo<IO> {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<io::Result<usize>> {
        Pin::new(&mut self.io).poll_write(cx, buf)
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.io).poll_flush(cx)
    }

    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
        Pin::new(&mut self.io).poll_close(cx)
    }
}

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 moto
Solution 2