Skip to content

Commit f75bcf8

Browse files
committed
Do not seek back on error
1 parent e3e0431 commit f75bcf8

File tree

3 files changed

+77
-75
lines changed

3 files changed

+77
-75
lines changed

src/stream/read/mod.rs

Lines changed: 28 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -82,47 +82,15 @@ fn consume<R: Read + ?Sized>(this: &mut R, mut bytes_count: usize) -> io::Result
8282
}
8383
}
8484

85-
/// Like Read::read_exact(), but seek back to the starting position of the reader in case of an
86-
/// error.
87-
#[cfg(feature = "experimental")]
88-
fn read_exact_or_seek_back<R: Read + Seek + ?Sized>(this: &mut R, mut buf: &mut [u8]) -> io::Result<()> {
89-
let mut bytes_read = 0;
90-
while !buf.is_empty() {
91-
match this.read(buf) {
92-
Ok(0) => break,
93-
Ok(n) => {
94-
bytes_read += n as i64;
95-
let tmp = buf;
96-
buf = &mut tmp[n..];
97-
}
98-
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
99-
Err(e) => {
100-
if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) {
101-
panic!("Error while seeking back to the start: {}", error);
102-
}
103-
return Err(e)
104-
},
105-
}
106-
}
107-
if !buf.is_empty() {
108-
if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) {
109-
panic!("Error while seeking back to the start: {}", error);
110-
}
111-
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "failed to fill whole buffer"))
112-
} else {
113-
Ok(())
114-
}
115-
}
116-
11785
#[cfg(feature = "experimental")]
11886
impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {
11987
fn read_skippable_frame_size(&mut self) -> io::Result<usize> {
12088
let mut magic_buffer = [0u8; U32_SIZE];
121-
read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?;
89+
self.reader.reader_mut().read_exact(&mut magic_buffer)?;
12290

12391
// Read skippable frame size.
12492
let mut buffer = [0u8; U32_SIZE];
125-
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
93+
self.reader.reader_mut().read_exact(&mut buffer)?;
12694
let content_size = u32::from_le_bytes(buffer) as usize;
12795

12896
self.seek_back(U32_SIZE * 2);
@@ -139,47 +107,27 @@ impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {
139107
/// Attempt to read a skippable frame and write its content to `dest`.
140108
/// If it cannot read a skippable frame, the reader will be back to its starting position.
141109
pub fn read_skippable_frame(&mut self, dest: &mut [u8]) -> io::Result<(usize, MagicVariant)> {
142-
let mut bytes_to_seek = 0;
143-
144-
let res = (|| {
145-
let mut magic_buffer = [0u8; U32_SIZE];
146-
read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?;
147-
let magic_number = u32::from_le_bytes(magic_buffer);
148-
149-
// Read skippable frame size.
150-
let mut buffer = [0u8; U32_SIZE];
151-
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
152-
let content_size = u32::from_le_bytes(buffer) as usize;
153-
154-
let op = self.reader.operation();
155-
// FIXME: I feel like we should do that check right after reading the magic number, but
156-
// ZSTD does it after reading the content size.
157-
if !op.is_skippable_frame(&magic_buffer) {
158-
bytes_to_seek = U32_SIZE * 2;
159-
return Err(io::Error::new(io::ErrorKind::Other, "Unsupported frame parameter"));
160-
}
161-
if content_size > dest.len() {
162-
bytes_to_seek = U32_SIZE * 2;
163-
return Err(io::Error::new(io::ErrorKind::Other, "Destination buffer is too small"));
164-
}
110+
let magic_buffer = self.reader.peek_4bytes()?;
111+
let op = self.reader.operation();
112+
if !op.is_skippable_frame(&magic_buffer) {
113+
return Err(io::Error::new(io::ErrorKind::Other, "Unsupported frame parameter"));
114+
}
115+
self.reader.clear_peeked_data();
165116

166-
if content_size > 0 {
167-
read_exact_or_seek_back(self.reader.reader_mut(), &mut dest[..content_size])?;
168-
}
117+
let magic_number = u32::from_le_bytes(magic_buffer);
118+
119+
// Read skippable frame size.
120+
let mut buffer = [0u8; U32_SIZE];
121+
self.reader.reader_mut().read_exact(&mut buffer)?;
122+
let content_size = u32::from_le_bytes(buffer) as usize;
169123

170-
Ok((magic_number, content_size))
171-
})();
124+
if content_size > dest.len() {
125+
return Err(io::Error::new(io::ErrorKind::Other, "Destination buffer is too small"));
126+
}
172127

173-
let (magic_number, content_size) =
174-
match res {
175-
Ok(data) => data,
176-
Err(err) => {
177-
if bytes_to_seek != 0 {
178-
self.seek_back(bytes_to_seek);
179-
}
180-
return Err(err);
181-
},
182-
};
128+
if content_size > 0 {
129+
self.reader.reader_mut().read_exact(&mut dest[..content_size])?;
130+
}
183131

184132
let magic_variant = magic_number - MAGIC_SKIPPABLE_START;
185133

@@ -202,7 +150,13 @@ impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {
202150

203151
// TODO: should we support legacy format?
204152
let mut magic_buffer = [0u8; U32_SIZE];
205-
self.reader.reader_mut().read_exact(&mut magic_buffer)?;
153+
if self.reader.peeking() {
154+
magic_buffer = self.reader.peeked_data();
155+
self.reader.clear_peeked_data();
156+
}
157+
else {
158+
self.reader.reader_mut().read_exact(&mut magic_buffer)?;
159+
}
206160
let magic_number = u32::from_le_bytes(magic_buffer);
207161
self.seek_back(U32_SIZE);
208162
if magic_number & MAGIC_SKIPPABLE_MASK == MAGIC_SKIPPABLE_START {
@@ -240,7 +194,7 @@ impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {
240194
use crate::map_error_code;
241195
const MAX_FRAME_HEADER_SIZE_PREFIX: usize = 5;
242196
let mut buffer = [0u8; MAX_FRAME_HEADER_SIZE_PREFIX];
243-
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
197+
self.reader.reader_mut().read_exact(&mut buffer)?;
244198
let size = frame_header_size(&buffer)
245199
.map_err(map_error_code)?;
246200
let byte = buffer[MAX_FRAME_HEADER_SIZE_PREFIX - 1];

src/stream/zio/reader.rs

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ pub struct Reader<R, D> {
1717

1818
single_frame: bool,
1919
finished_frame: bool,
20+
21+
peeking: bool,
22+
peeked_data: [u8; 4],
2023
}
2124

2225
enum State {
@@ -39,6 +42,8 @@ impl<R, D> Reader<R, D> {
3942
state: State::Reading,
4043
single_frame: false,
4144
finished_frame: false,
45+
peeking: false,
46+
peeked_data: [0; 4],
4247
}
4348
}
4449

@@ -81,7 +86,37 @@ impl<R, D> Reader<R, D> {
8186
{
8287
self.operation.flush(&mut OutBuffer::around(output))
8388
}
89+
90+
/// Read some data, but do not consume it.
91+
pub fn peek_4bytes(&mut self) -> io::Result<[u8; 4]>
92+
where
93+
R: BufRead,
94+
D: Operation,
95+
{
96+
if !self.peeking {
97+
self.reader.read_exact(&mut self.peeked_data)?;
98+
self.peeking = true;
99+
}
100+
101+
Ok(self.peeked_data)
102+
}
103+
104+
/// Clear the peeked data.
105+
pub fn clear_peeked_data(&mut self) {
106+
self.peeking = false;
107+
}
108+
109+
/// Check if there is currently any peeked data.
110+
pub fn peeking(&self) -> bool {
111+
self.peeking
112+
}
113+
114+
/// Get the peeked data.
115+
pub fn peeked_data(&self) -> [u8; 4] {
116+
self.peeked_data
117+
}
84118
}
119+
85120
// Read and retry on Interrupted errors.
86121
fn fill_buf<R>(reader: &mut R) -> io::Result<&[u8]>
87122
where
@@ -118,12 +153,17 @@ where
118153
loop {
119154
match self.state {
120155
State::Reading => {
156+
let is_peeking = self.peeking;
157+
121158
let (bytes_read, bytes_written) = {
122159
// Start with a fresh pool of un-processed data.
123160
// This is the only line that can return an interruption error.
124161
let input = if first {
125162
// eprintln!("First run, no input coming.");
126163
b""
164+
} else if self.peeking {
165+
self.clear_peeked_data();
166+
&self.peeked_data
127167
} else {
128168
fill_buf(&mut self.reader)?
129169
};
@@ -170,7 +210,9 @@ where
170210
(src.pos(), dst.pos())
171211
};
172212

173-
self.reader.consume(bytes_read);
213+
if !is_peeking {
214+
self.reader.consume(bytes_read);
215+
}
174216

175217
if bytes_written > 0 {
176218
return Ok(bytes_written);

src/stream/zio/writer.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ mod tests {
358358
};
359359

360360
let mut target = vec![];
361+
assert!(decoder.read_skippable_frame(&mut frame).is_err());
362+
assert!(decoder.read_skippable_frame(&mut frame).is_err());
361363
io::copy(&mut decoder, &mut target).unwrap();
362364
assert_eq!("compressed frame 1", String::from_utf8(target).unwrap());
363365

@@ -371,6 +373,8 @@ mod tests {
371373
let (size, _) = decoder.read_skippable_frame(&mut frame).unwrap();
372374
assert_eq!("SKIP", String::from_utf8_lossy(&frame[..size]));
373375

376+
assert!(decoder.read_skippable_frame(&mut frame).is_err());
377+
assert!(decoder.read_skippable_frame(&mut frame).is_err());
374378
decoder.skip_frame().unwrap();
375379

376380
let inner = decoder.finish();
@@ -391,6 +395,8 @@ mod tests {
391395
};
392396

393397
let mut target = vec![];
398+
assert!(decoder.read_skippable_frame(&mut frame).is_err());
399+
assert!(decoder.read_skippable_frame(&mut frame).is_err());
394400
io::copy(&mut decoder, &mut target).unwrap();
395401
assert_eq!("compressed frame 3", String::from_utf8(target).unwrap());
396402
}

0 commit comments

Comments
 (0)