Skip to content

Commit 6e031d6

Browse files
author
Aryan Tikarya
committed
refactor: Shared to use internal mutability
1 parent 3f9e8c3 commit 6e031d6

File tree

2 files changed

+223
-136
lines changed

2 files changed

+223
-136
lines changed

yamux/src/connection.rs

Lines changed: 80 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ struct Active<T> {
282282
socket: Fuse<frame::Io<T>>,
283283
next_id: u32,
284284

285-
streams: IntMap<StreamId, Arc<Mutex<stream::Shared>>>,
285+
streams: IntMap<StreamId, stream::Shared>,
286286
stream_receivers: SelectAll<TaggedStream<StreamId, mpsc::Receiver<StreamCommand>>>,
287287
no_streams_waker: Option<Waker>,
288288

@@ -507,9 +507,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
507507
let s = self.streams.remove(&stream_id).expect("stream not found");
508508

509509
log::trace!("{}: removing dropped stream {}", self.id, stream_id);
510-
let frame = {
511-
let mut shared = s.lock();
512-
let frame = match shared.update_state(self.id, stream_id, State::Closed) {
510+
let frame = s.with_mut(|inner| {
511+
let frame = match inner.update_state(self.id, stream_id, State::Closed) {
513512
// The stream was dropped without calling `poll_close`.
514513
// We reset the stream to inform the remote of the closure.
515514
State::Open { .. } => {
@@ -541,14 +540,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
541540
// remote end has already done so in the past.
542541
State::Closed => None,
543542
};
544-
if let Some(w) = shared.reader.take() {
543+
if let Some(w) = inner.reader.take() {
545544
w.wake()
546545
}
547-
if let Some(w) = shared.writer.take() {
546+
if let Some(w) = inner.writer.take() {
548547
w.wake()
549548
}
549+
550550
frame
551-
};
551+
});
552552
frame.map(Into::into)
553553
}
554554

@@ -565,10 +565,8 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
565565
&& matches!(frame.header().tag(), Tag::Data | Tag::WindowUpdate)
566566
{
567567
let id = frame.header().stream_id();
568-
if let Some(stream) = self.streams.get(&id) {
569-
stream
570-
.lock()
571-
.update_state(self.id, id, State::Open { acknowledged: true });
568+
if let Some(shared) = self.streams.get(&id) {
569+
shared.update_state(self.id, id, State::Open { acknowledged: true });
572570
}
573571
if let Some(waker) = self.new_outbound_stream_waker.take() {
574572
waker.wake();
@@ -590,14 +588,15 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
590588
if frame.header().flags().contains(header::RST) {
591589
// stream reset
592590
if let Some(s) = self.streams.get_mut(&stream_id) {
593-
let mut shared = s.lock();
594-
shared.update_state(self.id, stream_id, State::Closed);
595-
if let Some(w) = shared.reader.take() {
596-
w.wake()
597-
}
598-
if let Some(w) = shared.writer.take() {
599-
w.wake()
600-
}
591+
s.with_mut(|inner| {
592+
inner.update_state(self.id, stream_id, State::Closed);
593+
if let Some(w) = inner.reader.take() {
594+
w.wake()
595+
}
596+
if let Some(w) = inner.writer.take() {
597+
w.wake()
598+
}
599+
});
601600
}
602601
return Action::None;
603602
}
@@ -628,35 +627,40 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
628627
}
629628
let stream = self.make_new_inbound_stream(stream_id, DEFAULT_CREDIT);
630629
{
631-
let mut shared = stream.shared();
632-
if is_finish {
633-
shared.update_state(self.id, stream_id, State::RecvClosed);
634-
}
635-
shared.consume_receive_window(frame.body_len());
636-
shared.buffer.push(frame.into_body());
630+
stream.shared().with_mut(|inner| {
631+
if is_finish {
632+
inner.update_state(self.id, stream_id, State::RecvClosed);
633+
}
634+
inner.consume_receive_window(frame.body_len());
635+
inner.buffer.push(frame.into_body());
636+
})
637637
}
638638
self.streams.insert(stream_id, stream.clone_shared());
639639
return Action::New(stream);
640640
}
641641

642-
if let Some(s) = self.streams.get_mut(&stream_id) {
643-
let mut shared = s.lock();
644-
if frame.body_len() > shared.receive_window() {
645-
log::error!(
646-
"{}/{}: frame body larger than window of stream",
647-
self.id,
648-
stream_id
649-
);
650-
return Action::Terminate(Frame::protocol_error());
651-
}
652-
if is_finish {
653-
shared.update_state(self.id, stream_id, State::RecvClosed);
654-
}
655-
shared.consume_receive_window(frame.body_len());
656-
shared.buffer.push(frame.into_body());
657-
if let Some(w) = shared.reader.take() {
658-
w.wake()
659-
}
642+
if let Some(shared) = self.streams.get_mut(&stream_id) {
643+
let action = shared.with_mut(|inner| {
644+
if frame.body_len() > inner.receive_window() {
645+
log::error!(
646+
"{}/{}: frame body larger than window of stream",
647+
self.id,
648+
stream_id
649+
);
650+
Action::Terminate(Frame::protocol_error())
651+
} else {
652+
if is_finish {
653+
inner.update_state(self.id, stream_id, State::RecvClosed);
654+
}
655+
inner.consume_receive_window(frame.body_len());
656+
inner.buffer.push(frame.into_body());
657+
if let Some(w) = inner.reader.take() {
658+
w.wake()
659+
}
660+
Action::None
661+
}
662+
});
663+
return action;
660664
} else {
661665
log::trace!(
662666
"{}/{}: data frame for unknown stream, possibly dropped earlier: {:?}",
@@ -681,15 +685,16 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
681685

682686
if frame.header().flags().contains(header::RST) {
683687
// stream reset
684-
if let Some(s) = self.streams.get_mut(&stream_id) {
685-
let mut shared = s.lock();
686-
shared.update_state(self.id, stream_id, State::Closed);
687-
if let Some(w) = shared.reader.take() {
688-
w.wake()
689-
}
690-
if let Some(w) = shared.writer.take() {
691-
w.wake()
692-
}
688+
if let Some(shared) = self.streams.get_mut(&stream_id) {
689+
shared.with_mut(|inner| {
690+
inner.update_state(self.id, stream_id, State::Closed);
691+
if let Some(w) = inner.reader.take() {
692+
w.wake()
693+
}
694+
if let Some(w) = inner.writer.take() {
695+
w.wake()
696+
}
697+
});
693698
}
694699
return Action::None;
695700
}
@@ -723,19 +728,21 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
723728
return Action::New(stream);
724729
}
725730

726-
if let Some(s) = self.streams.get_mut(&stream_id) {
727-
let mut shared = s.lock();
728-
shared.increase_send_window_by(frame.header().credit());
729-
if is_finish {
730-
shared.update_state(self.id, stream_id, State::RecvClosed);
731+
if let Some(shared) = self.streams.get_mut(&stream_id) {
732+
shared.with_mut(|inner| {
733+
inner.increase_send_window_by(frame.header().credit());
734+
if is_finish {
735+
inner.update_state(self.id, stream_id, State::RecvClosed);
736+
737+
if let Some(w) = inner.reader.take() {
738+
w.wake()
739+
}
740+
}
731741

732-
if let Some(w) = shared.reader.take() {
742+
if let Some(w) = inner.writer.take() {
733743
w.wake()
734744
}
735-
}
736-
if let Some(w) = shared.writer.take() {
737-
w.wake()
738-
}
745+
});
739746
} else {
740747
log::trace!(
741748
"{}/{}: window update for unknown stream, possibly dropped earlier: {:?}",
@@ -848,7 +855,7 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
848855
Mode::Client => id.is_client(),
849856
Mode::Server => id.is_server(),
850857
})
851-
.filter(|(_, s)| s.lock().is_pending_ack())
858+
.filter(|(_, s)| s.is_pending_ack())
852859
.count()
853860
}
854861

@@ -867,15 +874,16 @@ impl<T: AsyncRead + AsyncWrite + Unpin> Active<T> {
867874
impl<T> Active<T> {
868875
/// Close and drop all `Stream`s and wake any pending `Waker`s.
869876
fn drop_all_streams(&mut self) {
870-
for (id, s) in self.streams.drain() {
871-
let mut shared = s.lock();
872-
shared.update_state(self.id, id, State::Closed);
873-
if let Some(w) = shared.reader.take() {
874-
w.wake()
875-
}
876-
if let Some(w) = shared.writer.take() {
877-
w.wake()
878-
}
877+
for (id, shared) in self.streams.drain() {
878+
shared.with_mut(|inner| {
879+
inner.update_state(self.id, id, State::Closed);
880+
if let Some(w) = inner.reader.take() {
881+
w.wake()
882+
}
883+
if let Some(w) = inner.writer.take() {
884+
w.wake()
885+
}
886+
});
879887
}
880888
}
881889
}

0 commit comments

Comments
 (0)