Skip to content

Commit aabc8a2

Browse files
author
Aryan Tikarya
committed
address comments
1 parent 6e031d6 commit aabc8a2

File tree

2 files changed

+50
-78
lines changed

2 files changed

+50
-78
lines changed

yamux/src/connection.rs

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
504504
}
505505

506506
fn on_drop_stream(&mut self, stream_id: StreamId) -> Option<Frame<()>> {
507-
let s = self.streams.remove(&stream_id).expect("stream not found");
507+
let mut s = self.streams.remove(&stream_id).expect("stream not found");
508508

509509
log::trace!("{}: removing dropped stream {}", self.id, stream_id);
510510
let frame = s.with_mut(|inner| {
@@ -565,7 +565,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
565565
&& matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate)
566566
{
567567
let id = frame.header().stream_id();
568-
if let Some(shared) = self.streams.get(&id) {
568+
if let Some(shared) = self.streams.get_mut(&id) {
569569
shared.update_state(self.id, id, State::Open { acknowledged: true });
570570
}
571571
if let Some(waker) = self.new_outbound_stream_waker.take() {
@@ -625,16 +625,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
625625
log::error!("{}: maximum number of streams reached", self.id);
626626
return Action::Terminate(Frame::internal_error());
627627
}
628-
let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
629-
{
630-
stream.shared().with_mut(|inner| {
631-
if is_finish {
632-
inner.update_state(self.id, stream_id, State::RecvClosed);
633-
}
634-
inner.consume_receive_window(frame.body_len());
635-
inner.buffer.push(frame.into_body());
636-
})
637-
}
628+
let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
629+
stream.shared_mut().with_mut(|inner| {
630+
if is_finish {
631+
inner.update_state(self.id, stream_id, State::RecvClosed);
632+
}
633+
inner.consume_receive_window(frame.body_len());
634+
inner.buffer.push(frame.into_body());
635+
});
638636
self.streams.insert(stream_id, stream.clone_shared());
639637
return Action::New(stream);
640638
}
@@ -660,7 +658,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
660658
Action::None
661659
}
662660
});
663-
return action;
661+
action
664662
} else {
665663
log::trace!(
666664
"{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}",
@@ -675,9 +673,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
675673
// termination for the remote.
676674
//
677675
// See https://github.com/paritytech/yamux/issues/110 for details.
676+
Action::None
678677
}
679-
680-
Action::None
681678
}
682679

683680
fn on_window_update(&mut self, frame: &Frame<WindowUpdate>) -> Action {
@@ -717,11 +714,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
717714
}
718715

719716
let credit = frame.header().credit() + DEFAULT_CREDIT;
720-
let stream = self.make_new_inbound_stream(stream_id, credit);
717+
let mut stream = self.make_new_inbound_stream(stream_id, credit);
721718

722719
if is_finish {
723720
stream
724-
.shared()
721+
.shared_mut()
725722
.update_state(self.id, stream_id, State::RecvClosed);
726723
}
727724
self.streams.insert(stream_id, stream.clone_shared());
@@ -874,7 +871,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
874871
impl<T> Active<T> {
875872
/// Close and drop all `Stream`s and wake any pending `Waker`s.
876873
fn drop_all_streams(&mut self) {
877-
for (id, shared) in self.streams.drain() {
874+
for (id, mut shared) in self.streams.drain() {
878875
shared.with_mut(|inner| {
879876
inner.update_state(self.id, id, State::Closed);
880877
if let Some(w) = inner.reader.take() {

yamux/src/connection/stream.rs

Lines changed: 35 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
88
// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
99
// at https://opensource.org/licenses/MIT.
10-
11-
use crate::chunks::Chunk;
1210
use crate::connection::rtt::Rtt;
1311
use crate::frame::header::ACK;
1412
use crate::{
@@ -28,7 +26,7 @@ use futures::{
2826
ready, SinkExt,
2927
};
3028

31-
use parking_lot::Mutex;
29+
use parking_lot::{Mutex, MutexGuard};
3230
use std::{
3331
fmt, io,
3432
pin::Pin,
@@ -179,14 +177,12 @@ impl Stream {
179177
matches!(self.shared.state(), State::Closed)
180178
}
181179

182-
/// Whether we are still waiting for the remote to acknowledge this stream.
183180
pub fn is_pending_ack(&self) -> bool {
184181
self.shared.is_pending_ack()
185182
}
186183

187-
/// Returns a reference to the `Shared` concurrency wrapper.
188-
pub(crate) fn shared(&self) -> &Shared {
189-
&self.shared
184+
pub(crate) fn shared_mut(&mut self) -> &mut Shared {
185+
&mut self.shared
190186
}
191187

192188
pub(crate) fn clone_shared(&self) -> Shared {
@@ -265,7 +261,8 @@ impl futures::stream::Stream for Stream {
265261
Poll::Pending => {}
266262
}
267263

268-
if let Some(bytes) = self.shared.pop_buffer() {
264+
let mut shared = self.shared.lock();
265+
if let Some(bytes) = shared.buffer.pop() {
269266
let off = bytes.offset();
270267
let mut vec = bytes.into_vec();
271268
if off != 0 {
@@ -276,23 +273,21 @@ impl futures::stream::Stream for Stream {
276273
log::debug!(
277274
"{}/{}: chunk has been partially consumed",
278275
self.conn,
279-
self.id
276+
self.id,
280277
);
281278
vec = vec.split_off(off)
282279
}
283280
return Poll::Ready(Some(Ok(Packet(vec))));
284281
}
285-
286282
// Buffer is empty, let's check if we can expect to read more data.
287-
if !self.shared.state().can_read() {
283+
if !shared.state.can_read() {
288284
log::debug!("{}/{}: eof", self.conn, self.id);
289285
return Poll::Ready(None); // stream has been reset
290286
}
291287

292288
// Since we have no more data at this point, we want to be woken up
293289
// by the connection when more becomes available for us.
294-
self.shared.set_reader_waker(Some(cx.waker().clone()));
295-
290+
shared.reader = Some(cx.waker().clone());
296291
Poll::Pending
297292
}
298293
}
@@ -318,7 +313,9 @@ impl AsyncRead for Stream {
318313

319314
// Copy data from stream buffer.
320315
let mut n = 0;
321-
let can_read = self.shared.with_mut(|inner| {
316+
let conn_id = self.conn;
317+
let stream_id = self.id;
318+
let poll_state = self.shared.with_mut(|inner| {
322319
while let Some(chunk) = inner.buffer.front_mut() {
323320
if chunk.is_empty() {
324321
inner.buffer.pop();
@@ -334,31 +331,23 @@ impl AsyncRead for Stream {
334331
}
335332

336333
if n > 0 {
337-
return true;
334+
log::trace!("{}/{}: read {} bytes", conn_id, stream_id, n);
335+
return Poll::Ready(Ok(n));
338336
}
339337

340338
// Buffer is empty, let's check if we can expect to read more data.
341339
if !inner.state.can_read() {
342-
return false; // No more data available
340+
log::debug!("{}/{}: eof", conn_id, stream_id);
341+
return Poll::Ready(Ok(0)); // stream has been reset
343342
}
344343

345344
// Since we have no more data at this point, we want to be woken up
346345
// by the connection when more becomes available for us.
347346
inner.reader = Some(cx.waker().clone());
348-
true
347+
Poll::Pending
349348
});
350349

351-
if n > 0 {
352-
log::trace!("{}/{}: read {} bytes", self.conn, self.id, n);
353-
return Poll::Ready(Ok(n));
354-
}
355-
356-
if !can_read {
357-
log::debug!("{}/{}: eof", self.conn, self.id);
358-
return Poll::Ready(Ok(0)); // stream has been reset
359-
}
360-
361-
Poll::Pending
350+
poll_state
362351
}
363352
}
364353

@@ -373,18 +362,19 @@ impl AsyncWrite for Stream {
373362
.poll_ready(cx)
374363
.map_err(|_| self.write_zero_err())?);
375364

376-
let result = self.shared.with_mut(|inner| {
377-
if !inner.state.can_write() {
365+
let body = {
366+
let mut shared = self.shared.lock();
367+
if !shared.state.can_write() {
378368
log::debug!("{}/{}: can no longer write", self.conn, self.id);
379369
// Return an error
380-
return Err(self.write_zero_err());
370+
return Poll::Ready(Err(self.write_zero_err()));
381371
}
382372

383-
let window = inner.send_window();
373+
let window = shared.send_window();
384374
if window == 0 {
385375
log::trace!("{}/{}: no more credit left", self.conn, self.id);
386-
inner.writer = Some(cx.waker().clone());
387-
return Ok(None); // means we are Pending
376+
shared.writer = Some(cx.waker().clone());
377+
return Poll::Pending;
388378
}
389379

390380
let k = std::cmp::min(window, buf.len().try_into().unwrap_or(u32::MAX));
@@ -394,15 +384,8 @@ impl AsyncWrite for Stream {
394384
self.config.split_send_size.try_into().unwrap_or(u32::MAX),
395385
);
396386

397-
inner.consume_send_window(k);
398-
let body = Some(Vec::from(&buf[..k as usize]));
399-
Ok(body)
400-
});
401-
402-
let body = match result {
403-
Err(e) => return Poll::Ready(Err(e)), // can't write
404-
Ok(None) => return Poll::Pending, // no credit => Pending
405-
Ok(Some(b)) => b, // we have a body
387+
shared.consume_send_window(k);
388+
Vec::from(&buf[..k as usize])
406389
};
407390

408391
let n = body.len();
@@ -415,9 +398,8 @@ impl AsyncWrite for Stream {
415398
// a) to be consistent with outbound streams
416399
// b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test.
417400
if frame.header().flags().contains(ACK) {
418-
self.shared.with_mut(|inner| {
419-
inner.update_state(self.conn, self.id, State::Open { acknowledged: true });
420-
});
401+
self.shared
402+
.update_state(self.conn, self.id, State::Open { acknowledged: true });
421403
}
422404

423405
let cmd = StreamCommand::SendFrame(frame);
@@ -452,9 +434,8 @@ impl AsyncWrite for Stream {
452434
self.sender
453435
.start_send(cmd)
454436
.map_err(|_| self.write_zero_err())?;
455-
self.shared.with_mut(|inner| {
456-
inner.update_state(self.conn, self.id, State::SendClosed);
457-
});
437+
self.shared
438+
.update_state(self.conn, self.id, State::SendClosed);
458439
Poll::Ready(Ok(()))
459440
}
460441
}
@@ -487,10 +468,6 @@ impl Shared {
487468
self.inner.lock().state
488469
}
489470

490-
pub fn pop_buffer(&self) -> Option<Chunk> {
491-
self.with_mut(|inner| inner.buffer.pop())
492-
}
493-
494471
pub fn is_pending_ack(&self) -> bool {
495472
self.inner.lock().is_pending_ack()
496473
}
@@ -499,17 +476,15 @@ impl Shared {
499476
self.inner.lock().next_window_update()
500477
}
501478

502-
pub fn set_reader_waker(&self, waker: Option<Waker>) {
503-
self.with_mut(|inner| {
504-
inner.reader = waker;
505-
});
479+
pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State {
480+
self.inner.lock().update_state(cid, sid, next)
506481
}
507482

508-
pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State {
509-
self.with_mut(|inner| inner.update_state(cid, sid, next))
483+
pub fn lock(&self) -> MutexGuard<'_, SharedInner> {
484+
self.inner.lock()
510485
}
511486

512-
pub fn with_mut<F, R>(&self, f: F) -> R
487+
pub fn with_mut<F, R>(&mut self, f: F) -> R
513488
where
514489
F: FnOnce(&mut SharedInner) -> R,
515490
{

0 commit comments

Comments
 (0)