diff --git a/Cargo.toml b/Cargo.toml index 1ce1eba..f403f86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,3 +14,4 @@ embedded-hal = "0.2.4" nb = "1" usb-device = "0.3" embedded-io = "0.6" +embedded-io-async = "0.6.1" diff --git a/src/io.rs b/src/io.rs index 0b965e4..dc083dd 100644 --- a/src/io.rs +++ b/src/io.rs @@ -1,5 +1,10 @@ use super::SerialPort; -use core::borrow::BorrowMut; +use core::{ + borrow::BorrowMut, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; use usb_device::bus::UsbBus; #[derive(Debug)] @@ -85,3 +90,111 @@ impl, WS: BorrowMut<[u8]>> embedded_io::WriteRe Ok(self.write_buf.available_write() != 0) } } + +impl embedded_io_async::Write for SerialPort<'_, B, RS, WS> +where + B: UsbBus, + RS: BorrowMut<[u8]>, + WS: BorrowMut<[u8]>, +{ + async fn write(&mut self, buffer: &[u8]) -> core::result::Result { + if buffer.is_empty() { + return Ok(0); + } + AsyncWrite { + serial_port: self, + buffer, + } + .await + } + + // async fn flush(&mut self) -> core::result::Result<(), Self::Error> { + // todo!() + // } +} +struct AsyncWrite<'a, 'b, 'c, B, RS, WS> +where + B: UsbBus, + RS: BorrowMut<[u8]>, + WS: BorrowMut<[u8]>, +{ + serial_port: &'a mut SerialPort<'b, B, RS, WS>, + buffer: &'c [u8], +} + +impl<'a, 'b, 'c, B, RS, WS> Future for AsyncWrite<'a, 'b, 'c, B, RS, WS> +where + B: UsbBus, + RS: BorrowMut<[u8]>, + WS: BorrowMut<[u8]>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let s = self.get_mut(); + match s.serial_port.write(&s.buffer) { + Ok(n) => Poll::Ready(Ok(n)), + Err(usb_device::UsbError::WouldBlock) => { + // No need to worry about overriding. + // The ownership is borrowed though the mutable reference, + // so it's impossable to run twice at the same time. + s.serial_port.write_waker = Some(cx.waker().clone()); + Poll::Pending + } + Err(err) => Poll::Ready(Err(Error(err))), + } + } +} + +impl embedded_io_async::Read for SerialPort<'_, B, RS, WS> +where + B: UsbBus, + RS: BorrowMut<[u8]>, + WS: BorrowMut<[u8]>, +{ + async fn read(&mut self, buffer: &mut [u8]) -> Result { + AsyncRead { + serial_port: self, + buffer, + } + .await + } +} + +struct AsyncRead<'a, 'b, 'c, B, RS, WS> +where + B: UsbBus, + RS: BorrowMut<[u8]>, + WS: BorrowMut<[u8]>, +{ + serial_port: &'a mut SerialPort<'b, B, RS, WS>, + buffer: &'c mut [u8], +} + +impl<'a, 'b, 'c, B, RS, WS> Future for AsyncRead<'a, 'b, 'c, B, RS, WS> +where + B: UsbBus, + RS: BorrowMut<[u8]>, + WS: BorrowMut<[u8]>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let s = self.get_mut(); + match s.serial_port.read(&mut s.buffer) { + Ok(n) => Poll::Ready(Ok(n)), + Err(usb_device::UsbError::WouldBlock) => { + if s.buffer.len() == 0 { + Poll::Ready(Ok(0)) + } else { + // No need to worry about overriding. + // The ownership is borrowed though the mutable reference, + // so it's impossable to run twice at the same time. + s.serial_port.read_waker = Some(cx.waker().clone()); + Poll::Pending + } + } + Err(err) => Poll::Ready(Err(Error(err))), + } + } +} diff --git a/src/serial_port.rs b/src/serial_port.rs index 1e05e89..57e8d18 100644 --- a/src/serial_port.rs +++ b/src/serial_port.rs @@ -2,6 +2,7 @@ use crate::buffer::{Buffer, DefaultBufferStore}; use crate::cdc_acm::*; use core::borrow::BorrowMut; use core::slice; +use core::task::Waker; use usb_device::class_prelude::*; use usb_device::descriptor::lang_id::LangID; use usb_device::Result; @@ -20,6 +21,9 @@ where pub(crate) read_buf: Buffer, pub(crate) write_buf: Buffer, write_state: WriteState, + + pub(crate) read_waker: Option, + pub(crate) write_waker: Option, } /// If this many full size packets have been sent in a row, a short packet will be sent so that the @@ -95,6 +99,8 @@ where read_buf: Buffer::new(read_store), write_buf: Buffer::new(write_store), write_state: WriteState::Idle, + read_waker: None, + write_waker: None, } } @@ -151,6 +157,9 @@ where Err(err) => Err(err), } })?; + if let Some(read_waker) = self.read_waker.take() { + read_waker.wake(); + } Ok(()) } @@ -258,6 +267,9 @@ where fn endpoint_in_complete(&mut self, addr: EndpointAddress) { if addr == self.inner.write_ep().address() { self.flush().ok(); + if let Some(write_waker) = self.write_waker.take() { + write_waker.wake(); + } } }