Skip to content

Commit dd003d8

Browse files
authored
feat: Define text-model envelope formats (#2188)
Adds experimental model s-expressions to the supported envelope formats
1 parent ed2fde5 commit dd003d8

File tree

3 files changed

+143
-43
lines changed

3 files changed

+143
-43
lines changed

hugr-core/src/envelope.rs

Lines changed: 103 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,13 @@ use crate::{extension::ExtensionRegistry, package::Package};
5252
use header::EnvelopeHeader;
5353
use std::io::BufRead;
5454
use std::io::Write;
55+
use std::str::FromStr;
5556

5657
#[allow(unused_imports)]
5758
use itertools::Itertools as _;
5859

5960
use crate::import::ImportError;
61+
use crate::{import::import_package, Extension};
6062

6163
/// Read a HUGR envelope from a reader.
6264
///
@@ -219,6 +221,16 @@ pub enum EnvelopeError {
219221
/// The source error.
220222
source: hugr_model::v0::binary::WriteError,
221223
},
224+
/// Error reading a HUGR model payload.
225+
ModelTextRead {
226+
/// The source error.
227+
source: hugr_model::v0::ast::ParseError,
228+
},
229+
/// Error reading a HUGR model payload.
230+
ModelTextResolve {
231+
/// The source error.
232+
source: hugr_model::v0::ast::ResolveError,
233+
},
222234
}
223235

224236
/// Internal implementation of [`read_envelope`] to call with/without the zstd decompression wrapper.
@@ -233,6 +245,9 @@ fn read_impl(
233245
EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
234246
decode_model(payload, registry, header.format)
235247
}
248+
EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
249+
decode_model_ast(payload, registry, header.format)
250+
}
236251
}
237252
}
238253

@@ -248,7 +263,6 @@ fn decode_model(
248263
extension_registry: &ExtensionRegistry,
249264
format: EnvelopeFormat,
250265
) -> Result<Package, EnvelopeError> {
251-
use crate::{import::import_package, Extension};
252266
use hugr_model::v0::bumpalo::Bump;
253267

254268
if format.model_version() != Some(0) {
@@ -262,7 +276,7 @@ fn decode_model(
262276
let model_package = hugr_model::v0::binary::read_from_reader(&mut stream, &bump)?;
263277

264278
let mut extension_registry = extension_registry.clone();
265-
if format.append_extensions() {
279+
if format == EnvelopeFormat::ModelWithExtensions {
266280
let extra_extensions: Vec<Extension> =
267281
serde_json::from_reader::<_, Vec<Extension>>(stream)?;
268282
for ext in extra_extensions {
@@ -273,6 +287,54 @@ fn decode_model(
273287
Ok(import_package(&model_package, &extension_registry)?)
274288
}
275289

290+
/// Read a HUGR model text payload from a reader.
291+
///
292+
/// Parameters:
293+
/// - `stream`: The reader to read the envelope from.
294+
/// - `extension_registry`: An extension registry with additional extensions to use when
295+
/// decoding the HUGR, if they are not already included in the package.
296+
/// - `format`: The format of the payload.
297+
fn decode_model_ast(
298+
mut stream: impl BufRead,
299+
extension_registry: &ExtensionRegistry,
300+
format: EnvelopeFormat,
301+
) -> Result<Package, EnvelopeError> {
302+
use crate::import::import_package;
303+
use hugr_model::v0::bumpalo::Bump;
304+
305+
if format.model_version() != Some(0) {
306+
return Err(EnvelopeError::FormatUnsupported {
307+
format,
308+
feature: None,
309+
});
310+
}
311+
312+
let mut extension_registry = extension_registry.clone();
313+
if format == EnvelopeFormat::ModelTextWithExtensions {
314+
let deserializer = serde_json::Deserializer::from_reader(&mut stream);
315+
// Deserialize the first json object, leaving the rest of the reader unconsumed.
316+
let extra_extensions = deserializer
317+
.into_iter::<Vec<Extension>>()
318+
.next()
319+
.unwrap_or(Ok(vec![]))?;
320+
for ext in extra_extensions {
321+
extension_registry.register_updated(ext);
322+
}
323+
}
324+
325+
// Read the package into a string, then parse it.
326+
//
327+
// Due to how `to_string` works, we cannot append extensions after the package.
328+
let mut buffer = String::new();
329+
stream.read_to_string(&mut buffer)?;
330+
let ast_package = hugr_model::v0::ast::Package::from_str(&buffer)?;
331+
332+
let bump = Bump::default();
333+
let model_package = ast_package.resolve(&bump)?;
334+
335+
Ok(import_package(&model_package, &extension_registry)?)
336+
}
337+
276338
/// Internal implementation of [`write_envelope`] to call with/without the zstd compression wrapper.
277339
fn write_impl<'h>(
278340
writer: impl Write,
@@ -283,7 +345,10 @@ fn write_impl<'h>(
283345
match config.format {
284346
#[allow(deprecated)]
285347
EnvelopeFormat::PackageJson => package_json::to_json_writer(hugrs, extensions, writer)?,
286-
EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
348+
EnvelopeFormat::Model
349+
| EnvelopeFormat::ModelWithExtensions
350+
| EnvelopeFormat::ModelText
351+
| EnvelopeFormat::ModelTextWithExtensions => {
287352
encode_model(writer, hugrs, extensions, config.format)?
288353
}
289354
}
@@ -307,11 +372,27 @@ fn encode_model<'h>(
307372
});
308373
}
309374

375+
// Prepend extensions for binary model.
376+
if format == EnvelopeFormat::ModelTextWithExtensions {
377+
serde_json::to_writer(&mut writer, &extensions.iter().collect_vec())?;
378+
}
379+
310380
let bump = Bump::default();
311381
let model_package = export_package(hugrs, extensions, &bump);
312-
write_to_writer(&model_package, &mut writer)?;
313382

314-
if format.append_extensions() {
383+
match format {
384+
EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
385+
write_to_writer(&model_package, &mut writer)?;
386+
}
387+
EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
388+
let model_package = model_package.as_ast().unwrap();
389+
writeln!(writer, "{model_package}")?;
390+
}
391+
_ => unreachable!(),
392+
}
393+
394+
// Apend extensions for binary model.
395+
if format == EnvelopeFormat::ModelWithExtensions {
315396
serde_json::to_writer(writer, &extensions.iter().collect_vec())?;
316397
}
317398

@@ -418,34 +499,24 @@ pub(crate) mod test {
418499
}
419500

420501
#[rstest]
421-
//#[case::empty(Package::default())] // Not currently supported
422-
#[case::simple(simple_package())]
423-
//#[case::multi(multi_module_package())] // Not currently supported
424-
fn module_exts_roundtrip(#[case] package: Package) {
502+
// Empty packages
503+
#[case::empty_model(Package::default(), EnvelopeFormat::Model)]
504+
#[case::empty_model_exts(Package::default(), EnvelopeFormat::ModelWithExtensions)]
505+
#[case::empty_text(Package::default(), EnvelopeFormat::ModelText)]
506+
#[case::empty_text_exts(Package::default(), EnvelopeFormat::ModelTextWithExtensions)]
507+
// Single hugrs
508+
#[case::simple_bin(simple_package(), EnvelopeFormat::Model)]
509+
#[case::simple_bin_exts(simple_package(), EnvelopeFormat::ModelWithExtensions)]
510+
#[case::simple_text(simple_package(), EnvelopeFormat::ModelText)]
511+
#[case::simple_text_exts(simple_package(), EnvelopeFormat::ModelTextWithExtensions)]
512+
// Multiple hugrs
513+
#[case::multi_bin(multi_module_package(), EnvelopeFormat::Model)]
514+
#[case::multi_bin_exts(multi_module_package(), EnvelopeFormat::ModelWithExtensions)]
515+
#[case::multi_text(multi_module_package(), EnvelopeFormat::ModelText)]
516+
#[case::multi_text_exts(multi_module_package(), EnvelopeFormat::ModelTextWithExtensions)]
517+
fn model_roundtrip(#[case] package: Package, #[case] format: EnvelopeFormat) {
425518
let mut buffer = Vec::new();
426-
let config = EnvelopeConfig {
427-
format: EnvelopeFormat::ModelWithExtensions,
428-
zstd: None,
429-
};
430-
package.store(&mut buffer, config).unwrap();
431-
let (decoded_config, new_package) =
432-
read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap();
433-
434-
assert_eq!(config.format, decoded_config.format);
435-
assert_eq!(config.zstd.is_some(), decoded_config.zstd.is_some());
436-
assert_eq!(package, new_package);
437-
}
438-
439-
#[rstest]
440-
//#[case::empty(Package::default())] // Not currently supported
441-
#[case::simple(simple_package())]
442-
//#[case::multi(multi_module_package())] // Not currently supported
443-
fn module_roundtrip(#[case] package: Package) {
444-
let mut buffer = Vec::new();
445-
let config = EnvelopeConfig {
446-
format: EnvelopeFormat::Model,
447-
zstd: None,
448-
};
519+
let config = EnvelopeConfig { format, zstd: None };
449520
package.store(&mut buffer, config).unwrap();
450521

451522
let (decoded_config, new_package) =

hugr-core/src/envelope/header.rs

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,30 @@ pub(super) struct EnvelopeHeader {
3333
pub enum EnvelopeFormat {
3434
/// `hugr-model` v0 binary capnproto message.
3535
Model = 1,
36-
/// `hugr-model` v0 binary capnproto message followed by a json-encoded [crate::extension::ExtensionRegistry].
37-
//
38-
// This is a temporary format required until the model adds support for extensions.
36+
/// `hugr-model` v0 binary capnproto message followed by a json-encoded
37+
/// [crate::extension::ExtensionRegistry].
38+
///
39+
/// This is a temporary format required until the model adds support for
40+
/// extensions.
3941
ModelWithExtensions = 2,
42+
/// Human-readable S-expression encoding using [`hugr_model::v0`].
43+
///
44+
/// Uses a printable ascii value as the discriminant so the envelope can be
45+
/// read as text.
46+
///
47+
/// :caution: This format does not yet support extension encoding, so it should
48+
/// be avoided.
49+
//
50+
// TODO: Update comment once extension encoding is supported.
51+
ModelText = 40, // '(' in ascii
52+
/// Human-readable S-expression encoding using [`hugr_model::v0`].
53+
///
54+
/// Uses a printable ascii value as the discriminant so the envelope can be
55+
/// read as text.
56+
///
57+
/// This is a temporary format required until the model adds support for
58+
/// extensions.
59+
ModelTextWithExtensions = 41, // ')' in ascii
4060
/// Json-encoded [crate::package::Package]
4161
///
4262
/// Uses a printable ascii value as the discriminant so the envelope can be
@@ -50,15 +70,13 @@ pub enum EnvelopeFormat {
5070
static_assertions::assert_eq_size!(EnvelopeFormat, u8);
5171

5272
impl EnvelopeFormat {
53-
/// Returns whether to encode the extensions as json after the hugr payload.
54-
pub fn append_extensions(self) -> bool {
55-
matches!(self, Self::ModelWithExtensions)
56-
}
57-
5873
/// If the format is a model format, returns its version number.
5974
pub fn model_version(self) -> Option<u32> {
6075
match self {
61-
Self::Model | Self::ModelWithExtensions => Some(0),
76+
Self::Model
77+
| Self::ModelWithExtensions
78+
| Self::ModelText
79+
| Self::ModelTextWithExtensions => Some(0),
6280
_ => None,
6381
}
6482
}
@@ -67,7 +85,10 @@ impl EnvelopeFormat {
6785
///
6886
/// If true, the encoded envelope can be read as text.
6987
pub fn ascii_printable(self) -> bool {
70-
matches!(self, Self::PackageJson)
88+
matches!(
89+
self,
90+
Self::PackageJson | Self::ModelText | Self::ModelTextWithExtensions
91+
)
7192
}
7293
}
7394

@@ -117,7 +138,7 @@ impl EnvelopeConfig {
117138
pub const fn binary() -> Self {
118139
Self {
119140
format: EnvelopeFormat::ModelWithExtensions,
120-
zstd: None,
141+
zstd: Some(ZstdConfig::default_level()),
121142
}
122143
}
123144
}
@@ -137,6 +158,11 @@ pub struct ZstdConfig {
137158
}
138159

139160
impl ZstdConfig {
161+
/// Create a new zstd configuration with default compression level.
162+
pub const fn default_level() -> Self {
163+
Self { level: None }
164+
}
165+
140166
/// Returns the zstd compression level to pass to the zstd library.
141167
///
142168
/// Uses [zstd::DEFAULT_COMPRESSION_LEVEL] if the level is not set.
@@ -224,6 +250,8 @@ mod tests {
224250
#[rstest]
225251
#[case(EnvelopeFormat::Model)]
226252
#[case(EnvelopeFormat::ModelWithExtensions)]
253+
#[case(EnvelopeFormat::ModelText)]
254+
#[case(EnvelopeFormat::ModelTextWithExtensions)]
227255
#[case(EnvelopeFormat::PackageJson)]
228256
fn header_round_trip(#[case] format: EnvelopeFormat) {
229257
// With zstd compression

hugr-model/src/v0/ast/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ mod python;
3434
mod resolve;
3535
mod view;
3636

37+
pub use parse::ParseError;
3738
pub use resolve::ResolveError;
3839

3940
/// A package in the hugr AST.

0 commit comments

Comments
 (0)