Skip to content

Commit 54d8b67

Browse files
Diggseyjbr
authored andcommitted
Fix race condition in test-utils CloseableCursor implementation
1 parent 4d8bb97 commit 54d8b67

File tree

1 file changed

+44
-34
lines changed

1 file changed

+44
-34
lines changed

tests/test_utils.rs

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -102,35 +102,47 @@ pub struct TestIO {
102102
}
103103

104104
#[derive(Default)]
105-
pub struct CloseableCursor {
106-
data: RwLock<Vec<u8>>,
107-
cursor: RwLock<usize>,
108-
waker: RwLock<Option<Waker>>,
109-
closed: RwLock<bool>,
105+
struct CloseableCursorInner {
106+
data: Vec<u8>,
107+
cursor: usize,
108+
waker: Option<Waker>,
109+
closed: bool,
110110
}
111111

112+
#[derive(Default)]
113+
pub struct CloseableCursor(RwLock<CloseableCursorInner>);
114+
112115
impl CloseableCursor {
113-
fn len(&self) -> usize {
114-
self.data.read().unwrap().len()
116+
pub fn len(&self) -> usize {
117+
self.0.read().unwrap().data.len()
118+
}
119+
120+
pub fn cursor(&self) -> usize {
121+
self.0.read().unwrap().cursor
115122
}
116123

117-
fn cursor(&self) -> usize {
118-
*self.cursor.read().unwrap()
124+
pub fn is_empty(&self) -> bool {
125+
self.len() == 0
119126
}
120127

121-
fn current(&self) -> bool {
122-
self.len() == self.cursor()
128+
pub fn current(&self) -> bool {
129+
let inner = self.0.read().unwrap();
130+
inner.data.len() == inner.cursor
123131
}
124132

125-
fn close(&self) {
126-
*self.closed.write().unwrap() = true;
133+
pub fn close(&self) {
134+
let mut inner = self.0.write().unwrap();
135+
inner.closed = true;
136+
if let Some(waker) = inner.waker.take() {
137+
waker.wake();
138+
}
127139
}
128140
}
129141

130142
impl Display for CloseableCursor {
131143
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132-
let data = &*self.data.read().unwrap();
133-
let s = std::str::from_utf8(data).unwrap_or("not utf8");
144+
let inner = self.0.read().unwrap();
145+
let s = std::str::from_utf8(&inner.data).unwrap_or("not utf8");
134146
write!(f, "{}", s)
135147
}
136148
}
@@ -163,13 +175,14 @@ impl TestIO {
163175

164176
impl Debug for CloseableCursor {
165177
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178+
let inner = self.0.read().unwrap();
166179
f.debug_struct("CloseableCursor")
167180
.field(
168181
"data",
169-
&std::str::from_utf8(&self.data.read().unwrap()).unwrap_or("not utf8"),
182+
&std::str::from_utf8(&inner.data).unwrap_or("not utf8"),
170183
)
171-
.field("closed", &*self.closed.read().unwrap())
172-
.field("cursor", &*self.cursor.read().unwrap())
184+
.field("closed", &inner.closed)
185+
.field("cursor", &inner.cursor)
173186
.finish()
174187
}
175188
}
@@ -180,18 +193,17 @@ impl Read for &CloseableCursor {
180193
cx: &mut Context<'_>,
181194
buf: &mut [u8],
182195
) -> Poll<io::Result<usize>> {
183-
let len = self.len();
184-
let cursor = self.cursor();
185-
if cursor < len {
186-
let data = &*self.data.read().unwrap();
187-
let bytes_to_copy = buf.len().min(len - cursor);
188-
buf[..bytes_to_copy].copy_from_slice(&data[cursor..cursor + bytes_to_copy]);
189-
*self.cursor.write().unwrap() += bytes_to_copy;
196+
let mut inner = self.0.write().unwrap();
197+
if inner.cursor < inner.data.len() {
198+
let bytes_to_copy = buf.len().min(inner.data.len() - inner.cursor);
199+
buf[..bytes_to_copy]
200+
.copy_from_slice(&inner.data[inner.cursor..inner.cursor + bytes_to_copy]);
201+
inner.cursor += bytes_to_copy;
190202
Poll::Ready(Ok(bytes_to_copy))
191-
} else if *self.closed.read().unwrap() {
203+
} else if inner.closed {
192204
Poll::Ready(Ok(0))
193205
} else {
194-
*self.waker.write().unwrap() = Some(cx.waker().clone());
206+
inner.waker = Some(cx.waker().clone());
195207
Poll::Pending
196208
}
197209
}
@@ -203,11 +215,12 @@ impl Write for &CloseableCursor {
203215
_cx: &mut Context<'_>,
204216
buf: &[u8],
205217
) -> Poll<io::Result<usize>> {
206-
if *self.closed.read().unwrap() {
218+
let mut inner = self.0.write().unwrap();
219+
if inner.closed {
207220
Poll::Ready(Ok(0))
208221
} else {
209-
self.data.write().unwrap().extend_from_slice(buf);
210-
if let Some(waker) = self.waker.write().unwrap().take() {
222+
inner.data.extend_from_slice(buf);
223+
if let Some(waker) = inner.waker.take() {
211224
waker.wake();
212225
}
213226
Poll::Ready(Ok(buf.len()))
@@ -219,10 +232,7 @@ impl Write for &CloseableCursor {
219232
}
220233

221234
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
222-
if let Some(waker) = self.waker.write().unwrap().take() {
223-
waker.wake();
224-
}
225-
*self.closed.write().unwrap() = true;
235+
self.close();
226236
Poll::Ready(Ok(()))
227237
}
228238
}

0 commit comments

Comments
 (0)