Skip to content

Commit 511e7a7

Browse files
committed
Add support for skippable frames
1 parent 1bed85b commit 511e7a7

File tree

6 files changed

+453
-1
lines changed

6 files changed

+453
-1
lines changed

src/stream/raw.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ pub struct Status {
127127
pub bytes_written: usize,
128128
}
129129

130+
/// The magic variant encoded in a skippable frame.
131+
pub struct MagicVariant(pub u8);
132+
130133
/// An in-memory decoder for streams of data.
131134
pub struct Decoder<'a> {
132135
context: zstd_safe::DCtx<'a>,
@@ -171,6 +174,29 @@ impl<'a> Decoder<'a> {
171174
.map_err(map_error_code)?;
172175
Ok(())
173176
}
177+
178+
#[cfg(feature = "experimental")]
179+
// TODO: remove self?
180+
/// Read a skippable frame.
181+
pub fn read_skippable_frame(&self, dest: &mut Vec<u8>, input: &[u8]) -> io::Result<(usize, MagicVariant)> {
182+
use zstd_safe::DCtx;
183+
184+
let mut magic_variant = 0;
185+
DCtx::read_skippable_frame(&mut OutBuffer::around(dest), &mut magic_variant, input)
186+
.map(|written| (written, MagicVariant(magic_variant as u8)))
187+
.map_err(map_error_code)
188+
}
189+
190+
#[cfg(feature = "experimental")]
191+
// TODO: remove self?
192+
/// Check if a frame is skippable.
193+
pub fn is_skippable_frame(&self, input: &[u8]) -> io::Result<bool> {
194+
use zstd_safe::DCtx;
195+
196+
DCtx::is_skippable_frame(input)
197+
.map(|is_skippable| is_skippable != 0)
198+
.map_err(map_error_code)
199+
}
174200
}
175201

176202
impl Operation for Decoder<'_> {

src/stream/read/mod.rs

Lines changed: 218 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
11
//! Implement pull-based [`Read`] trait for both compressing and decompressing.
2-
use std::io::{self, BufRead, BufReader, Read};
2+
#[cfg(feature = "experimental")]
3+
use std::cmp::min;
4+
#[cfg(feature = "experimental")]
5+
use std::io::SeekFrom;
6+
use std::io::{self, BufRead, BufReader, Read, Seek};
7+
#[cfg(feature = "experimental")]
8+
use std::mem::size_of;
39

410
use crate::dict::{DecoderDictionary, EncoderDictionary};
511
use crate::stream::{raw, zio};
612
use zstd_safe;
713

14+
#[cfg(feature = "experimental")]
15+
use zstd_safe::{frame_header_size, MAGIC_SKIPPABLE_MASK, MAGIC_SKIPPABLE_START, SKIPPABLEHEADERSIZE};
16+
#[cfg(feature = "experimental")]
17+
use super::raw::MagicVariant;
18+
819
#[cfg(test)]
920
mod tests;
1021

22+
#[cfg(feature = "experimental")]
23+
const U24_SIZE: usize = size_of::<u16>() + size_of::<u8>();
24+
#[cfg(feature = "experimental")]
25+
const U32_SIZE: usize = size_of::<u32>();
26+
1127
/// A decoder that decompress input data from another `Read`.
1228
///
1329
/// This allows to read a stream of compressed data
@@ -45,6 +61,207 @@ impl<R: BufRead> Decoder<'static, R> {
4561
Ok(Decoder { reader })
4662
}
4763
}
64+
65+
/// Read and discard `bytes_count` bytes in the reader.
66+
#[cfg(feature = "experimental")]
67+
fn consume<R: Read + ?Sized>(this: &mut R, mut bytes_count: usize) -> io::Result<()> {
68+
let mut buf = [0; 100];
69+
while bytes_count > 0 {
70+
let end = min(buf.len(), bytes_count);
71+
match this.read(&mut buf[..end]) {
72+
Ok(0) => break,
73+
Ok(n) => bytes_count -= n,
74+
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {},
75+
Err(e) => return Err(e),
76+
}
77+
}
78+
if bytes_count > 0 {
79+
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "failed to fill whole buffer"))
80+
} else {
81+
Ok(())
82+
}
83+
}
84+
85+
/// Like Read::read_exact(), but seek back to the starting position of the reader in case of an
86+
/// error.
87+
#[cfg(feature = "experimental")]
88+
fn read_exact_or_seek_back<R: Read + Seek + ?Sized>(this: &mut R, mut buf: &mut [u8]) -> io::Result<()> {
89+
let mut bytes_read = 0;
90+
while !buf.is_empty() {
91+
match this.read(buf) {
92+
Ok(0) => break,
93+
Ok(n) => {
94+
bytes_read += n as i64;
95+
let tmp = buf;
96+
buf = &mut tmp[n..];
97+
}
98+
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
99+
Err(e) => {
100+
if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) {
101+
panic!("Error while seeking back to the start: {}", error);
102+
}
103+
return Err(e)
104+
},
105+
}
106+
}
107+
if !buf.is_empty() {
108+
if let Err(error) = this.seek(SeekFrom::Current(-bytes_read)) {
109+
panic!("Error while seeking back to the start: {}", error);
110+
}
111+
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "failed to fill whole buffer"))
112+
} else {
113+
Ok(())
114+
}
115+
}
116+
117+
impl<'a, R: Read + Seek> Decoder<'a, BufReader<R>> {
118+
#[cfg(feature = "experimental")]
119+
fn read_skippable_frame_size(&mut self) -> io::Result<usize> {
120+
let mut magic_buffer = [0u8; U32_SIZE];
121+
read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?;
122+
123+
// Read skippable frame size.
124+
let mut buffer = [0u8; U32_SIZE];
125+
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
126+
let content_size = u32::from_le_bytes(buffer) as usize;
127+
128+
self.seek_back(U32_SIZE * 2);
129+
130+
Ok(content_size + SKIPPABLEHEADERSIZE as usize)
131+
}
132+
133+
#[cfg(feature = "experimental")]
134+
fn seek_back(&mut self, bytes_count: usize) {
135+
if let Err(error) = self.reader.reader_mut().seek(SeekFrom::Current(-(bytes_count as i64))) {
136+
panic!("Error while seeking back to the start: {}", error);
137+
}
138+
}
139+
140+
#[cfg(feature = "experimental")]
141+
/// Attempt to read a skippable frame and write its content to `dest`.
142+
/// If it cannot read a skippable frame, the reader will be back to its starting position.
143+
pub fn read_skippable_frame(&mut self, dest: &mut Vec<u8>) -> io::Result<(usize, MagicVariant)> {
144+
let mut bytes_to_seek = 0;
145+
146+
let res = (|| {
147+
let mut magic_buffer = [0u8; U32_SIZE];
148+
read_exact_or_seek_back(self.reader.reader_mut(), &mut magic_buffer)?;
149+
let magic_number = u32::from_le_bytes(magic_buffer);
150+
151+
// Read skippable frame size.
152+
let mut buffer = [0u8; U32_SIZE];
153+
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
154+
let content_size = u32::from_le_bytes(buffer) as usize;
155+
156+
let op = self.reader.operation();
157+
// FIXME: I feel like we should do that check right after reading the magic number, but
158+
// ZSTD does it after reading the content size.
159+
if !op.is_skippable_frame(&magic_buffer)? {
160+
bytes_to_seek = U32_SIZE * 2;
161+
return Err(io::Error::new(io::ErrorKind::Other, "Unsupported frame parameter"));
162+
}
163+
if content_size > dest.capacity() {
164+
bytes_to_seek = U32_SIZE * 2;
165+
return Err(io::Error::new(io::ErrorKind::Other, "Destination buffer is too small"));
166+
}
167+
168+
if content_size > 0 {
169+
dest.resize(content_size, 0);
170+
read_exact_or_seek_back(self.reader.reader_mut(), dest)?;
171+
}
172+
173+
Ok((magic_number, content_size))
174+
})();
175+
176+
let (magic_number, content_size) =
177+
match res {
178+
Ok(data) => data,
179+
Err(err) => {
180+
if bytes_to_seek != 0 {
181+
self.seek_back(bytes_to_seek);
182+
}
183+
return Err(err);
184+
},
185+
};
186+
187+
let magic_variant = magic_number - MAGIC_SKIPPABLE_START;
188+
189+
Ok((content_size, MagicVariant(magic_variant as u8)))
190+
}
191+
192+
#[cfg(feature = "experimental")]
193+
fn get_block_size(&mut self) -> io::Result<(usize, bool)> {
194+
let mut buffer = [0u8; U24_SIZE];
195+
self.reader.reader_mut().read_exact(&mut buffer)?;
196+
let buffer = [buffer[0], buffer[1], buffer[2], 0];
197+
let block_header = u32::from_le_bytes(buffer);
198+
let compressed_size = block_header >> 3;
199+
let last_block = block_header & 1;
200+
self.seek_back(U24_SIZE);
201+
Ok((compressed_size as usize, last_block != 0))
202+
}
203+
204+
#[cfg(feature = "experimental")]
205+
fn find_frame_compressed_size(&mut self) -> io::Result<usize> {
206+
const ZSTD_BLOCK_HEADER_SIZE: usize = 3;
207+
208+
// TODO: should we support legacy format?
209+
let mut magic_buffer = [0u8; U32_SIZE];
210+
self.reader.reader_mut().read_exact(&mut magic_buffer)?;
211+
let magic_number = u32::from_le_bytes(magic_buffer);
212+
self.seek_back(U32_SIZE);
213+
if magic_number & MAGIC_SKIPPABLE_MASK == MAGIC_SKIPPABLE_START {
214+
self.read_skippable_frame_size()
215+
}
216+
else {
217+
let mut bytes_read = 0;
218+
let (header_size, checksum_flag) = self.frame_header_size()?;
219+
bytes_read += header_size;
220+
consume(self.reader.reader_mut(), header_size)?;
221+
222+
loop {
223+
let (compressed_size, last_block) = self.get_block_size()?;
224+
let block_size = ZSTD_BLOCK_HEADER_SIZE + compressed_size;
225+
consume(self.reader.reader_mut(), block_size)?;
226+
bytes_read += block_size;
227+
if last_block {
228+
break;
229+
}
230+
}
231+
232+
self.seek_back(bytes_read);
233+
234+
if checksum_flag {
235+
bytes_read += 4;
236+
}
237+
238+
Ok(bytes_read)
239+
}
240+
}
241+
242+
#[cfg(feature = "experimental")]
243+
fn frame_header_size(&mut self) -> io::Result<(usize, bool)> {
244+
use crate::map_error_code;
245+
const MAX_FRAME_HEADER_SIZE_PREFIX: usize = 5;
246+
let mut buffer = [0u8; MAX_FRAME_HEADER_SIZE_PREFIX];
247+
read_exact_or_seek_back(self.reader.reader_mut(), &mut buffer)?;
248+
let size = frame_header_size(&buffer)
249+
.map_err(map_error_code)?;
250+
let byte = buffer[MAX_FRAME_HEADER_SIZE_PREFIX - 1];
251+
let checksum_flag = (byte >> 2) & 1;
252+
self.seek_back(MAX_FRAME_HEADER_SIZE_PREFIX);
253+
Ok((size, checksum_flag != 0))
254+
}
255+
256+
#[cfg(feature = "experimental")]
257+
/// Skip over a frame, without decompressing it.
258+
pub fn skip_frame(&mut self) -> io::Result<()> {
259+
let size = self.find_frame_compressed_size()?;
260+
consume(self.reader.reader_mut(), size)?;
261+
Ok(())
262+
}
263+
}
264+
48265
impl<'a, R: BufRead> Decoder<'a, R> {
49266
/// Sets this `Decoder` to stop after the first frame.
50267
///

src/stream/write/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,12 @@ impl<W: Write> Encoder<'static, W> {
190190
let writer = zio::Writer::new(writer, encoder);
191191
Ok(Encoder { writer })
192192
}
193+
194+
/// Write a skippable frame.
195+
#[cfg(feature = "experimental")]
196+
pub fn write_skippable_frame(&mut self, buf: &[u8], magic_variant: u32) -> io::Result<()> {
197+
self.writer.write_skippable_frame(buf, magic_variant)
198+
}
193199
}
194200

195201
impl<'a, W: Write> Encoder<'a, W> {
@@ -259,6 +265,12 @@ impl<'a, W: Write> Encoder<'a, W> {
259265
self.try_finish().map_err(|(_, err)| err)
260266
}
261267

268+
/// Useful to get back the writer after calling write_skippable_frame. You don't want to call
269+
/// finish because this will create yet another frame.
270+
pub fn into_inner(self) -> W {
271+
self.writer.into_inner().0
272+
}
273+
262274
/// **Required**: Attempts to finish the stream.
263275
///
264276
/// You *need* to finish the stream when you're done writing, either with

src/stream/zio/reader.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ impl<R, D> Reader<R, D> {
4747
self.single_frame = true;
4848
}
4949

50+
/// Returns a reference to the underlying operation.
51+
pub fn operation(&self) -> &D {
52+
&self.operation
53+
}
54+
5055
/// Returns a mutable reference to the underlying operation.
5156
pub fn operation_mut(&mut self) -> &mut D {
5257
&mut self.operation

0 commit comments

Comments
 (0)