|
1 | 1 | //! Compat wrappers for interop with other crates.
|
2 | 2 |
|
3 |
| -use std::io::{self, BufRead, Read, Write}; |
| 3 | +use std::{ |
| 4 | + future::Future, |
| 5 | + io::{self, BufRead, Read, Write}, |
| 6 | + pin::Pin, |
| 7 | + task::{Context, Poll}, |
| 8 | +}; |
4 | 9 |
|
5 | 10 | use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut, SetBufInit};
|
| 11 | +use pin_project_lite::pin_project; |
6 | 12 |
|
7 | 13 | use crate::{buffer::Buffer, util::DEFAULT_BUF_SIZE};
|
8 | 14 |
|
@@ -158,3 +164,181 @@ impl<S: crate::AsyncWrite> SyncStream<S> {
|
158 | 164 | Ok(len)
|
159 | 165 | }
|
160 | 166 | }
|
| 167 | + |
| 168 | +type PinBoxFuture<T> = Pin<Box<dyn Future<Output = T>>>; |
| 169 | + |
| 170 | +pin_project! { |
| 171 | + /// A stream wrapper for [`futures_util::io`] traits. |
| 172 | + pub struct AsyncStream<S> { |
| 173 | + #[pin] |
| 174 | + inner: SyncStream<S>, |
| 175 | + read_future: Option<PinBoxFuture<io::Result<usize>>>, |
| 176 | + write_future: Option<PinBoxFuture<io::Result<usize>>>, |
| 177 | + shutdown_future: Option<PinBoxFuture<io::Result<()>>>, |
| 178 | + } |
| 179 | +} |
| 180 | + |
| 181 | +impl<S> AsyncStream<S> { |
| 182 | + /// Create [`AsyncStream`] with the stream and default buffer size. |
| 183 | + pub fn new(stream: S) -> Self { |
| 184 | + Self::new_impl(SyncStream::new(stream)) |
| 185 | + } |
| 186 | + |
| 187 | + /// Create [`AsyncStream`] with the stream and buffer size. |
| 188 | + pub fn with_capacity(cap: usize, stream: S) -> Self { |
| 189 | + Self::new_impl(SyncStream::with_capacity(cap, stream)) |
| 190 | + } |
| 191 | + |
| 192 | + fn new_impl(inner: SyncStream<S>) -> Self { |
| 193 | + Self { |
| 194 | + inner, |
| 195 | + read_future: None, |
| 196 | + write_future: None, |
| 197 | + shutdown_future: None, |
| 198 | + } |
| 199 | + } |
| 200 | + |
| 201 | + /// Get the reference of the inner stream. |
| 202 | + pub fn get_ref(&self) -> &S { |
| 203 | + self.inner.get_ref() |
| 204 | + } |
| 205 | +} |
| 206 | + |
| 207 | +macro_rules! poll_future { |
| 208 | + ($f:expr, $cx:expr, $e:expr) => {{ |
| 209 | + let mut future = match $f.take() { |
| 210 | + Some(f) => f, |
| 211 | + None => Box::pin($e), |
| 212 | + }; |
| 213 | + let f = future.as_mut(); |
| 214 | + match f.poll($cx) { |
| 215 | + Poll::Pending => { |
| 216 | + $f.replace(future); |
| 217 | + return Poll::Pending; |
| 218 | + } |
| 219 | + Poll::Ready(res) => res, |
| 220 | + } |
| 221 | + }}; |
| 222 | +} |
| 223 | + |
| 224 | +macro_rules! poll_future_would_block { |
| 225 | + ($f:expr, $cx:expr, $e:expr, $io:expr) => {{ |
| 226 | + if let Some(mut f) = $f.take() { |
| 227 | + if f.as_mut().poll($cx).is_pending() { |
| 228 | + $f.replace(f); |
| 229 | + return Poll::Pending; |
| 230 | + } |
| 231 | + } |
| 232 | + |
| 233 | + match $io { |
| 234 | + Ok(len) => Poll::Ready(Ok(len)), |
| 235 | + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { |
| 236 | + $f.replace(Box::pin($e)); |
| 237 | + $cx.waker().wake_by_ref(); |
| 238 | + Poll::Pending |
| 239 | + } |
| 240 | + Err(e) => Poll::Ready(Err(e)), |
| 241 | + } |
| 242 | + }}; |
| 243 | +} |
| 244 | + |
| 245 | +impl<S: crate::AsyncRead + 'static> futures_util::AsyncRead for AsyncStream<S> { |
| 246 | + fn poll_read( |
| 247 | + self: Pin<&mut Self>, |
| 248 | + cx: &mut Context<'_>, |
| 249 | + buf: &mut [u8], |
| 250 | + ) -> Poll<io::Result<usize>> { |
| 251 | + let this = self.project(); |
| 252 | + // Safety: |
| 253 | + // - The futures won't live longer than the stream. |
| 254 | + // - `self` is pinned. |
| 255 | + // - The inner stream won't be moved. |
| 256 | + let inner: &'static mut SyncStream<S> = |
| 257 | + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; |
| 258 | + |
| 259 | + poll_future_would_block!( |
| 260 | + this.read_future, |
| 261 | + cx, |
| 262 | + inner.fill_read_buf(), |
| 263 | + io::Read::read(inner, buf) |
| 264 | + ) |
| 265 | + } |
| 266 | +} |
| 267 | + |
| 268 | +impl<S: crate::AsyncRead + 'static> futures_util::AsyncBufRead for AsyncStream<S> { |
| 269 | + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { |
| 270 | + let this = self.project(); |
| 271 | + |
| 272 | + let inner: &'static mut SyncStream<S> = |
| 273 | + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; |
| 274 | + poll_future_would_block!( |
| 275 | + this.read_future, |
| 276 | + cx, |
| 277 | + inner.fill_read_buf(), |
| 278 | + // Safety: anyway the slice won't be used after free. |
| 279 | + io::BufRead::fill_buf(inner).map(|slice| unsafe { &*(slice as *const _) }) |
| 280 | + ) |
| 281 | + } |
| 282 | + |
| 283 | + fn consume(self: Pin<&mut Self>, amt: usize) { |
| 284 | + let this = self.project(); |
| 285 | + |
| 286 | + let inner: &'static mut SyncStream<S> = |
| 287 | + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; |
| 288 | + inner.consume(amt) |
| 289 | + } |
| 290 | +} |
| 291 | + |
| 292 | +impl<S: crate::AsyncWrite + 'static> futures_util::AsyncWrite for AsyncStream<S> { |
| 293 | + fn poll_write( |
| 294 | + self: Pin<&mut Self>, |
| 295 | + cx: &mut Context<'_>, |
| 296 | + buf: &[u8], |
| 297 | + ) -> Poll<io::Result<usize>> { |
| 298 | + let this = self.project(); |
| 299 | + |
| 300 | + if this.shutdown_future.is_some() { |
| 301 | + debug_assert!(this.write_future.is_none()); |
| 302 | + return Poll::Pending; |
| 303 | + } |
| 304 | + |
| 305 | + let inner: &'static mut SyncStream<S> = |
| 306 | + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; |
| 307 | + poll_future_would_block!( |
| 308 | + this.write_future, |
| 309 | + cx, |
| 310 | + inner.flush_write_buf(), |
| 311 | + io::Write::write(inner, buf) |
| 312 | + ) |
| 313 | + } |
| 314 | + |
| 315 | + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 316 | + let this = self.project(); |
| 317 | + |
| 318 | + if this.shutdown_future.is_some() { |
| 319 | + debug_assert!(this.write_future.is_none()); |
| 320 | + return Poll::Pending; |
| 321 | + } |
| 322 | + |
| 323 | + let inner: &'static mut SyncStream<S> = |
| 324 | + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; |
| 325 | + let res = poll_future!(this.write_future, cx, inner.flush_write_buf()); |
| 326 | + Poll::Ready(res.map(|_| ())) |
| 327 | + } |
| 328 | + |
| 329 | + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { |
| 330 | + let this = self.project(); |
| 331 | + |
| 332 | + // Avoid shutdown on flush because the inner buffer might be passed to the |
| 333 | + // driver. |
| 334 | + if this.write_future.is_some() { |
| 335 | + debug_assert!(this.shutdown_future.is_none()); |
| 336 | + return Poll::Pending; |
| 337 | + } |
| 338 | + |
| 339 | + let inner: &'static mut SyncStream<S> = |
| 340 | + unsafe { &mut *(this.inner.get_unchecked_mut() as *mut _) }; |
| 341 | + let res = poll_future!(this.shutdown_future, cx, inner.get_mut().shutdown()); |
| 342 | + Poll::Ready(res) |
| 343 | + } |
| 344 | +} |
0 commit comments