Skip to content

Commit f7ad5e9

Browse files
author
Aryan Tikarya
committed
address comments
1 parent 6e031d6 commit f7ad5e9

File tree

2 files changed

+58
-91
lines changed

2 files changed

+58
-91
lines changed

yamux/src/connection.rs

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
504504
}
505505

506506
fn on_drop_stream(&mut self, stream_id: StreamId) -> Option<Frame<()>> {
507-
let s = self.streams.remove(&stream_id).expect("stream not found");
507+
let mut s = self.streams.remove(&stream_id).expect("stream not found");
508508

509509
log::trace!("{}: removing dropped stream {}", self.id, stream_id);
510510
let frame = s.with_mut(|inner| {
@@ -565,7 +565,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
565565
&& matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate)
566566
{
567567
let id = frame.header().stream_id();
568-
if let Some(shared) = self.streams.get(&id) {
568+
if let Some(shared) = self.streams.get_mut(&id) {
569569
shared.update_state(self.id, id, State::Open { acknowledged: true });
570570
}
571571
if let Some(waker) = self.new_outbound_stream_waker.take() {
@@ -625,16 +625,14 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
625625
log::error!("{}: maximum number of streams reached", self.id);
626626
return Action::Terminate(Frame::internal_error());
627627
}
628-
let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
629-
{
630-
stream.shared().with_mut(|inner| {
631-
if is_finish {
632-
inner.update_state(self.id, stream_id, State::RecvClosed);
633-
}
634-
inner.consume_receive_window(frame.body_len());
635-
inner.buffer.push(frame.into_body());
636-
})
637-
}
628+
let mut stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
629+
stream.shared_mut().with_mut(|inner| {
630+
if is_finish {
631+
inner.update_state(self.id, stream_id, State::RecvClosed);
632+
}
633+
inner.consume_receive_window(frame.body_len());
634+
inner.buffer.push(frame.into_body());
635+
});
638636
self.streams.insert(stream_id, stream.clone_shared());
639637
return Action::New(stream);
640638
}
@@ -660,7 +658,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
660658
Action::None
661659
}
662660
});
663-
return action;
661+
action
664662
} else {
665663
log::trace!(
666664
"{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}",
@@ -675,9 +673,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
675673
// termination for the remote.
676674
//
677675
// See https://github.com/paritytech/yamux/issues/110 for details.
676+
Action::None
678677
}
679-
680-
Action::None
681678
}
682679

683680
fn on_window_update(&mut self, frame: &Frame<WindowUpdate>) -> Action {
@@ -717,11 +714,11 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
717714
}
718715

719716
let credit = frame.header().credit() + DEFAULT_CREDIT;
720-
let stream = self.make_new_inbound_stream(stream_id, credit);
717+
let mut stream = self.make_new_inbound_stream(stream_id, credit);
721718

722719
if is_finish {
723720
stream
724-
.shared()
721+
.shared_mut()
725722
.update_state(self.id, stream_id, State::RecvClosed);
726723
}
727724
self.streams.insert(stream_id, stream.clone_shared());
@@ -874,7 +871,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
874871
impl<T> Active<T> {
875872
/// Close and drop all `Stream`s and wake any pending `Waker`s.
876873
fn drop_all_streams(&mut self) {
877-
for (id, shared) in self.streams.drain() {
874+
for (id, mut shared) in self.streams.drain() {
878875
shared.with_mut(|inner| {
879876
inner.update_state(self.id, id, State::Closed);
880877
if let Some(w) = inner.reader.take() {

yamux/src/connection/stream.rs

Lines changed: 43 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
// as LICENSE-MIT. You may also obtain a copy of the Apache License, Version 2.0
88
// at https://www.apache.org/licenses/LICENSE-2.0 and a copy of the MIT license
99
// at https://opensource.org/licenses/MIT.
10-
11-
use crate::chunks::Chunk;
1210
use crate::connection::rtt::Rtt;
1311
use crate::frame::header::ACK;
1412
use crate::{
@@ -28,7 +26,7 @@ use futures::{
2826
ready, SinkExt,
2927
};
3028

31-
use parking_lot::Mutex;
29+
use parking_lot::{Mutex, MutexGuard};
3230
use std::{
3331
fmt, io,
3432
pin::Pin,
@@ -179,14 +177,12 @@ impl Stream {
179177
matches!(self.shared.state(), State::Closed)
180178
}
181179

182-
/// Whether we are still waiting for the remote to acknowledge this stream.
183180
pub fn is_pending_ack(&self) -> bool {
184181
self.shared.is_pending_ack()
185182
}
186183

187-
/// Returns a reference to the `Shared` concurrency wrapper.
188-
pub(crate) fn shared(&self) -> &Shared {
189-
&self.shared
184+
pub(crate) fn shared_mut(&mut self) -> &mut Shared {
185+
&mut self.shared
190186
}
191187

192188
pub(crate) fn clone_shared(&self) -> Shared {
@@ -265,7 +261,8 @@ impl futures::stream::Stream for Stream {
265261
Poll::Pending => {}
266262
}
267263

268-
if let Some(bytes) = self.shared.pop_buffer() {
264+
let mut shared = self.shared.lock();
265+
if let Some(bytes) = shared.buffer.pop() {
269266
let off = bytes.offset();
270267
let mut vec = bytes.into_vec();
271268
if off != 0 {
@@ -276,23 +273,21 @@ impl futures::stream::Stream for Stream {
276273
log::debug!(
277274
"{}/{}: chunk has been partially consumed",
278275
self.conn,
279-
self.id
276+
self.id,
280277
);
281278
vec = vec.split_off(off)
282279
}
283280
return Poll::Ready(Some(Ok(Packet(vec))));
284281
}
285-
286282
// Buffer is empty, let's check if we can expect to read more data.
287-
if !self.shared.state().can_read() {
283+
if !shared.state.can_read() {
288284
log::debug!("{}/{}: eof", self.conn, self.id);
289285
return Poll::Ready(None); // stream has been reset
290286
}
291287

292288
// Since we have no more data at this point, we want to be woken up
293289
// by the connection when more becomes available for us.
294-
self.shared.set_reader_waker(Some(cx.waker().clone()));
295-
290+
shared.reader = Some(cx.waker().clone());
296291
Poll::Pending
297292
}
298293
}
@@ -317,47 +312,36 @@ impl AsyncRead for Stream {
317312
}
318313

319314
// Copy data from stream buffer.
315+
let mut shared = self.shared.lock();
320316
let mut n = 0;
321-
let can_read = self.shared.with_mut(|inner| {
322-
while let Some(chunk) = inner.buffer.front_mut() {
323-
if chunk.is_empty() {
324-
inner.buffer.pop();
325-
continue;
326-
}
327-
let k = std::cmp::min(chunk.len(), buf.len() - n);
328-
buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
329-
n += k;
330-
chunk.advance(k);
331-
if n == buf.len() {
332-
break;
333-
}
317+
while let Some(chunk) = shared.buffer.front_mut() {
318+
if chunk.is_empty() {
319+
shared.buffer.pop();
320+
continue;
334321
}
335-
336-
if n > 0 {
337-
return true;
338-
}
339-
340-
// Buffer is empty, let's check if we can expect to read more data.
341-
if !inner.state.can_read() {
342-
return false; // No more data available
322+
let k = std::cmp::min(chunk.len(), buf.len() - n);
323+
buf[n..n + k].copy_from_slice(&chunk.as_ref()[..k]);
324+
n += k;
325+
chunk.advance(k);
326+
if n == buf.len() {
327+
break;
343328
}
344-
345-
// Since we have no more data at this point, we want to be woken up
346-
// by the connection when more becomes available for us.
347-
inner.reader = Some(cx.waker().clone());
348-
true
349-
});
329+
}
350330

351331
if n > 0 {
352332
log::trace!("{}/{}: read {} bytes", self.conn, self.id, n);
353333
return Poll::Ready(Ok(n));
354334
}
355335

356-
if !can_read {
336+
// Buffer is empty, let's check if we can expect to read more data.
337+
if !shared.state.can_read() {
357338
log::debug!("{}/{}: eof", self.conn, self.id);
358339
return Poll::Ready(Ok(0)); // stream has been reset
359340
}
360341

342+
// Since we have no more data at this point, we want to be woken up
343+
// by the connection when more becomes available for us.
344+
shared.reader = Some(cx.waker().clone());
361345
Poll::Pending
362346
}
363347
}
@@ -373,18 +357,19 @@ impl AsyncWrite for Stream {
373357
.poll_ready(cx)
374358
.map_err(|_| self.write_zero_err())?);
375359

376-
let result = self.shared.with_mut(|inner| {
377-
if !inner.state.can_write() {
360+
let body = {
361+
let mut shared = self.shared.lock();
362+
if !shared.state.can_write() {
378363
log::debug!("{}/{}: can no longer write", self.conn, self.id);
379364
// Return an error
380-
return Err(self.write_zero_err());
365+
return Poll::Ready(Err(self.write_zero_err()));
381366
}
382367

383-
let window = inner.send_window();
368+
let window = shared.send_window();
384369
if window == 0 {
385370
log::trace!("{}/{}: no more credit left", self.conn, self.id);
386-
inner.writer = Some(cx.waker().clone());
387-
return Ok(None); // means we are Pending
371+
shared.writer = Some(cx.waker().clone());
372+
return Poll::Pending;
388373
}
389374

390375
let k = std::cmp::min(window, buf.len().try_into().unwrap_or(u32::MAX));
@@ -394,15 +379,8 @@ impl AsyncWrite for Stream {
394379
self.config.split_send_size.try_into().unwrap_or(u32::MAX),
395380
);
396381

397-
inner.consume_send_window(k);
398-
let body = Some(Vec::from(&buf[..k as usize]));
399-
Ok(body)
400-
});
401-
402-
let body = match result {
403-
Err(e) => return Poll::Ready(Err(e)), // can't write
404-
Ok(None) => return Poll::Pending, // no credit => Pending
405-
Ok(Some(b)) => b, // we have a body
382+
shared.consume_send_window(k);
383+
Vec::from(&buf[..k as usize])
406384
};
407385

408386
let n = body.len();
@@ -415,9 +393,8 @@ impl AsyncWrite for Stream {
415393
// a) to be consistent with outbound streams
416394
// b) to correctly test our behaviour around timing of when ACKs are sent. See `ack_timing.rs` test.
417395
if frame.header().flags().contains(ACK) {
418-
self.shared.with_mut(|inner| {
419-
inner.update_state(self.conn, self.id, State::Open { acknowledged: true });
420-
});
396+
self.shared
397+
.update_state(self.conn, self.id, State::Open { acknowledged: true });
421398
}
422399

423400
let cmd = StreamCommand::SendFrame(frame);
@@ -452,9 +429,8 @@ impl AsyncWrite for Stream {
452429
self.sender
453430
.start_send(cmd)
454431
.map_err(|_| self.write_zero_err())?;
455-
self.shared.with_mut(|inner| {
456-
inner.update_state(self.conn, self.id, State::SendClosed);
457-
});
432+
self.shared
433+
.update_state(self.conn, self.id, State::SendClosed);
458434
Poll::Ready(Ok(()))
459435
}
460436
}
@@ -487,10 +463,6 @@ impl Shared {
487463
self.inner.lock().state
488464
}
489465

490-
pub fn pop_buffer(&self) -> Option<Chunk> {
491-
self.with_mut(|inner| inner.buffer.pop())
492-
}
493-
494466
pub fn is_pending_ack(&self) -> bool {
495467
self.inner.lock().is_pending_ack()
496468
}
@@ -499,17 +471,15 @@ impl Shared {
499471
self.inner.lock().next_window_update()
500472
}
501473

502-
pub fn set_reader_waker(&self, waker: Option<Waker>) {
503-
self.with_mut(|inner| {
504-
inner.reader = waker;
505-
});
474+
pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State {
475+
self.inner.lock().update_state(cid, sid, next)
506476
}
507477

508-
pub fn update_state(&self, cid: connection::Id, sid: StreamId, next: State) -> State {
509-
self.with_mut(|inner| inner.update_state(cid, sid, next))
478+
pub fn lock(&self) -> MutexGuard<'_, SharedInner> {
479+
self.inner.lock()
510480
}
511481

512-
pub fn with_mut<F, R>(&self, f: F) -> R
482+
pub fn with_mut<F, R>(&mut self, f: F) -> R
513483
where
514484
F: FnOnce(&mut SharedInner) -> R,
515485
{

0 commit comments

Comments
 (0)