Skip to content

Commit 84cfb5f

Browse files
authored
test: check envelope roundtrips rather than json in HugrView::verify (#2186)
`verify` was calling an internal serialization test helpers to check that roundtrip encoding was correct. With the latest changes, this wasn't checking the actual user-facing envelope encoding. ~This PR currently fails some checks due to #2185~
1 parent cca66d9 commit 84cfb5f

File tree

11 files changed

+144
-101
lines changed

11 files changed

+144
-101
lines changed

.github/workflows/ci-rs.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ env:
1919
CI: true # insta snapshots behave differently on ci
2020
SCCACHE_GHA_ENABLED: "true"
2121
RUSTC_WRAPPER: "sccache"
22-
HUGR_TEST_SCHEMA: "1"
2322
# different strings for install action and feature name
2423
# adapted from https://github.com/TheDan64/inkwell/blob/master/.github/workflows/test.yml
2524
LLVM_VERSION: "14.0"

hugr-core/src/envelope.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,14 +319,55 @@ fn encode_model<'h>(
319319
}
320320

321321
#[cfg(test)]
322-
mod tests {
322+
pub(crate) mod test {
323323
use super::*;
324324
use cool_asserts::assert_matches;
325325
use rstest::rstest;
326+
use std::borrow::Cow;
326327
use std::io::BufReader;
327328

328329
use crate::builder::test::{multi_module_package, simple_package};
329330
use crate::extension::PRELUDE_REGISTRY;
331+
use crate::hugr::test::check_hugr_equality;
332+
use crate::std_extensions::STD_REG;
333+
use crate::HugrView;
334+
335+
/// Returns an `ExtensionRegistry` with the extensions from both
336+
/// sets. Avoids cloning if the first one already contains all
337+
/// extensions from the second one.
338+
fn join_extensions<'a>(
339+
extensions: &'a ExtensionRegistry,
340+
other: &ExtensionRegistry,
341+
) -> Cow<'a, ExtensionRegistry> {
342+
if other.iter().all(|e| extensions.contains(e.name())) {
343+
Cow::Borrowed(extensions)
344+
} else {
345+
let mut extensions = extensions.clone();
346+
extensions.extend(other);
347+
Cow::Owned(extensions)
348+
}
349+
}
350+
351+
/// Serialize and deserialize a HUGR into an envelope with the given config,
352+
/// and check that the result is the same as the original.
353+
///
354+
/// We do not compare the before and after `Hugr`s for equality directly,
355+
/// because impls of `CustomConst` are not required to implement equality
356+
/// checking.
357+
///
358+
/// Returns the deserialized HUGR.
359+
pub(crate) fn check_hugr_roundtrip(hugr: &Hugr, config: EnvelopeConfig) -> Hugr {
360+
let mut buffer = Vec::new();
361+
hugr.store(&mut buffer, config).unwrap();
362+
363+
let extensions = join_extensions(&STD_REG, hugr.extensions());
364+
365+
let reader = BufReader::new(buffer.as_slice());
366+
let extracted = Hugr::load(reader, Some(&extensions)).unwrap();
367+
368+
check_hugr_equality(&extracted, hugr);
369+
extracted
370+
}
330371

331372
#[rstest]
332373
fn errors() {

hugr-core/src/envelope/header.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ impl EnvelopeConfig {
116116
/// If the `zstd` feature is enabled, this will use zstd compression.
117117
pub const fn binary() -> Self {
118118
Self {
119-
format: EnvelopeFormat::Model,
119+
format: EnvelopeFormat::ModelWithExtensions,
120120
zstd: None,
121121
}
122122
}

hugr-core/src/envelope/package_json.rs

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ use itertools::Itertools;
44
use std::io;
55

66
use crate::extension::resolution::ExtensionResolutionError;
7-
use crate::extension::{ExtensionRegistry, PRELUDE_REGISTRY};
7+
use crate::extension::ExtensionRegistry;
88
use crate::hugr::ExtensionError;
99
use crate::package::Package;
10-
use crate::{Extension, Hugr, HugrView};
10+
use crate::{Extension, Hugr};
1111

1212
/// Read a Package in json format from an io reader.
1313
pub(super) fn from_json_reader(
@@ -22,31 +22,9 @@ pub(super) fn from_json_reader(
2222
} = serde_json::from_value::<PackageDeser>(val.clone())?;
2323
let mut modules = modules.into_iter().map(|h| h.0).collect_vec();
2424

25-
// TODO: We don't currently store transitive extension dependencies in the
26-
// package's extensions. For example, if we use a `collections.list` const
27-
// value but don't use anything in `prelude` we would not include `prelude` in
28-
// the package's extensions. But this would then fail when loading the
29-
// extensions, as we _need_ the prelude to load the `collections.list` op
30-
// definitions here.
31-
//
32-
// The current fix is to always include the prelude when decoding, but this
33-
// only works for transitive `prelude` dependencies.
34-
//
35-
// Chains of custom extensions will cause this to fail.
36-
let extension_registry = if PRELUDE_REGISTRY
37-
.iter()
38-
.any(|e| !extension_registry.contains(&e.name))
39-
{
40-
let mut reg_with_prelude = extension_registry.clone();
41-
reg_with_prelude.extend(PRELUDE_REGISTRY.iter().cloned());
42-
reg_with_prelude
43-
} else {
44-
extension_registry.clone()
45-
};
46-
47-
let mut pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
25+
let pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
4826
pkg_extensions,
49-
&(&extension_registry).into(),
27+
&extension_registry.into(),
5028
)?;
5129

5230
// Resolve the operations in the modules using the defined registries.
@@ -55,7 +33,6 @@ pub(super) fn from_json_reader(
5533

5634
for module in &mut modules {
5735
module.resolve_extension_defs(&combined_registry)?;
58-
pkg_extensions.extend(module.extensions());
5936
}
6037

6138
Ok(Package {

hugr-core/src/extension/resolution/test.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use crate::std_extensions::arithmetic::float_types::{self, float64_type, ConstF6
2323
use crate::std_extensions::arithmetic::int_ops;
2424
use crate::std_extensions::arithmetic::int_types::{self, int_type};
2525
use crate::std_extensions::collections::list::ListValue;
26+
use crate::std_extensions::std_reg;
2627
use crate::types::type_param::TypeParam;
2728
use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound};
2829
use crate::{type_row, Extension, Hugr, HugrView};
@@ -117,8 +118,13 @@ fn make_extension_self_referencing(name: &str, op_name: &str, type_name: &str) -
117118
/// Check that the extensions added during building coincide with read-only collected extensions
118119
/// and that they survive a serialization roundtrip.
119120
fn check_extension_resolution(mut hugr: Hugr) {
121+
// Extensions used by the hugr, used to check that the roundtrip preserves them.
120122
let build_extensions = hugr.extensions().clone();
121123

124+
// Extensions used for resolution.
125+
let mut resolution_extensions = std_reg();
126+
resolution_extensions.extend(&build_extensions);
127+
122128
// Check that the read-only methods collect the same extensions.
123129
let collected_exts = ExtensionRegistry::new(hugr.nodes().flat_map(|node| {
124130
hugr.get_optype(node)
@@ -132,7 +138,7 @@ fn check_extension_resolution(mut hugr: Hugr) {
132138
);
133139

134140
// Check that the mutable methods collect the same extensions.
135-
hugr.resolve_extension_defs(&build_extensions).unwrap();
141+
hugr.resolve_extension_defs(&resolution_extensions).unwrap();
136142
assert_eq!(
137143
hugr.extensions(),
138144
&build_extensions,
@@ -143,7 +149,7 @@ fn check_extension_resolution(mut hugr: Hugr) {
143149
// Roundtrip serialize so all weak references are dropped.
144150
let ser = hugr.store_str(EnvelopeConfig::text()).unwrap();
145151

146-
let deser_hugr = Hugr::load_str(&ser, Some(&build_extensions)).unwrap();
152+
let deser_hugr = Hugr::load_str(&ser, Some(&resolution_extensions)).unwrap();
147153

148154
assert_eq!(
149155
deser_hugr.extensions(),

hugr-core/src/hugr.rs

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,14 +489,63 @@ fn make_module_hugr(root_op: OpType, nodes: usize, ports: usize) -> Option<Hugr>
489489
}
490490

491491
#[cfg(test)]
492-
mod test {
492+
pub(crate) mod test {
493493
use std::{fs::File, io::BufReader};
494494

495-
use super::{Hugr, HugrView};
495+
use super::*;
496496

497497
use crate::envelope::{EnvelopeError, PackageEncodingError};
498+
use crate::ops::OpaqueOp;
498499
use crate::test_file;
499500
use cool_asserts::assert_matches;
501+
use portgraph::LinkView;
502+
503+
/// Check that two HUGRs are equivalent, up to node renumbering.
504+
pub(crate) fn check_hugr_equality(lhs: &Hugr, rhs: &Hugr) {
505+
// Original HUGR, with canonicalized node indices
506+
//
507+
// The internal port indices may still be different.
508+
let mut lhs = lhs.clone();
509+
lhs.canonicalize_nodes(|_, _| {});
510+
let mut rhs = rhs.clone();
511+
rhs.canonicalize_nodes(|_, _| {});
512+
513+
assert_eq!(rhs.module_root(), lhs.module_root());
514+
assert_eq!(rhs.entrypoint(), lhs.entrypoint());
515+
assert_eq!(rhs.hierarchy, lhs.hierarchy);
516+
assert_eq!(rhs.metadata, lhs.metadata);
517+
518+
// Extension operations may have been downgraded to opaque operations.
519+
for node in rhs.nodes() {
520+
let new_op = rhs.get_optype(node);
521+
let old_op = lhs.get_optype(node);
522+
if !new_op.is_const() {
523+
match (new_op, old_op) {
524+
(OpType::ExtensionOp(ext), OpType::OpaqueOp(opaque))
525+
| (OpType::OpaqueOp(opaque), OpType::ExtensionOp(ext)) => {
526+
let ext_opaque: OpaqueOp = ext.clone().into();
527+
assert_eq!(ext_opaque, opaque.clone());
528+
}
529+
_ => assert_eq!(new_op, old_op),
530+
}
531+
}
532+
}
533+
534+
// Check that the graphs are equivalent up to port renumbering.
535+
let new_graph = &rhs.graph;
536+
let old_graph = &lhs.graph;
537+
assert_eq!(new_graph.node_count(), old_graph.node_count());
538+
assert_eq!(new_graph.port_count(), old_graph.port_count());
539+
assert_eq!(new_graph.link_count(), old_graph.link_count());
540+
for n in old_graph.nodes_iter() {
541+
assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n));
542+
assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n));
543+
assert_eq!(
544+
new_graph.output_neighbours(n).collect_vec(),
545+
old_graph.output_neighbours(n).collect_vec()
546+
);
547+
}
548+
}
500549

501550
#[test]
502551
fn impls_send_and_sync() {

hugr-core/src/hugr/serialize/test.rs

Lines changed: 11 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::extension::simple_op::MakeRegisteredOp;
1111
use crate::extension::test::SimpleOpDef;
1212
use crate::extension::ExtensionRegistry;
1313
use crate::hugr::internal::HugrMutInternals;
14+
use crate::hugr::test::check_hugr_equality;
1415
use crate::hugr::validate::ValidationError;
1516
use crate::hugr::views::ExtractionResult;
1617
use crate::ops::custom::{ExtensionOp, OpaqueOp, OpaqueOpError};
@@ -28,7 +29,6 @@ use crate::{type_row, OutgoingPort};
2829
use itertools::Itertools;
2930
use jsonschema::{Draft, Validator};
3031
use lazy_static::lazy_static;
31-
use portgraph::LinkView;
3232
use portgraph::{multiportgraph::MultiPortGraph, Hierarchy, LinkMut, PortMut, UnmanagedDenseMap};
3333
use rstest::rstest;
3434

@@ -154,7 +154,7 @@ impl From<Type> for SerTestingLatest {
154154

155155
#[test]
156156
fn empty_hugr_serialize() {
157-
check_hugr_roundtrip(&Hugr::default(), true);
157+
check_hugr_json_roundtrip(&Hugr::default(), true);
158158
}
159159

160160
fn ser_deserialize_check_schema<T: serde::de::DeserializeOwned>(
@@ -184,71 +184,26 @@ fn ser_roundtrip_check_schema<TSer: Serialize, TDeser: serde::de::DeserializeOwn
184184
/// equality checking.
185185
///
186186
/// Returns the deserialized HUGR.
187-
pub fn check_hugr_roundtrip(hugr: &impl HugrView, check_schema: bool) -> Hugr {
187+
fn check_hugr_json_roundtrip(hugr: &impl HugrView, check_schema: bool) -> Hugr {
188188
// Transform the whole view into a HUGR.
189189
let (mut base, extract_map) = hugr.extract_hugr(hugr.module_root());
190190
base.set_entrypoint(extract_map.extracted_node(hugr.entrypoint()));
191191

192192
let new_hugr: HugrDeser =
193193
ser_roundtrip_check_schema(&HugrSer(&base), get_schemas(check_schema));
194194

195-
check_hugr(&base, &new_hugr.0);
195+
check_hugr_equality(&base, &new_hugr.0);
196196
new_hugr.0
197197
}
198198

199199
/// Deserialize a HUGR json, ensuring that it is valid against the schema.
200200
pub fn check_hugr_deserialize(hugr: &Hugr, value: serde_json::Value, check_schema: bool) -> Hugr {
201201
let new_hugr: HugrDeser = ser_deserialize_check_schema(value, get_schemas(check_schema));
202202

203-
check_hugr(hugr, &new_hugr.0);
203+
check_hugr_equality(hugr, &new_hugr.0);
204204
new_hugr.0
205205
}
206206

207-
/// Check that two HUGRs are equivalent, up to node renumbering.
208-
pub fn check_hugr(lhs: &Hugr, rhs: &Hugr) {
209-
// Original HUGR, with canonicalized node indices
210-
//
211-
// The internal port indices may still be different.
212-
let mut h_canon = lhs.clone();
213-
h_canon.canonicalize_nodes(|_, _| {});
214-
215-
assert_eq!(rhs.module_root(), h_canon.module_root());
216-
assert_eq!(rhs.entrypoint(), h_canon.entrypoint());
217-
assert_eq!(rhs.hierarchy, h_canon.hierarchy);
218-
assert_eq!(rhs.metadata, h_canon.metadata);
219-
220-
// Extension operations may have been downgraded to opaque operations.
221-
for node in rhs.nodes() {
222-
let new_op = rhs.get_optype(node);
223-
let old_op = h_canon.get_optype(node);
224-
if !new_op.is_const() {
225-
match (new_op, old_op) {
226-
(OpType::ExtensionOp(ext), OpType::OpaqueOp(opaque))
227-
| (OpType::OpaqueOp(opaque), OpType::ExtensionOp(ext)) => {
228-
let ext_opaque: OpaqueOp = ext.clone().into();
229-
assert_eq!(ext_opaque, opaque.clone());
230-
}
231-
_ => assert_eq!(new_op, old_op),
232-
}
233-
}
234-
}
235-
236-
// Check that the graphs are equivalent up to port renumbering.
237-
let new_graph = &rhs.graph;
238-
let old_graph = &h_canon.graph;
239-
assert_eq!(new_graph.node_count(), old_graph.node_count());
240-
assert_eq!(new_graph.port_count(), old_graph.port_count());
241-
assert_eq!(new_graph.link_count(), old_graph.link_count());
242-
for n in old_graph.nodes_iter() {
243-
assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n));
244-
assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n));
245-
assert_eq!(
246-
new_graph.output_neighbours(n).collect_vec(),
247-
old_graph.output_neighbours(n).collect_vec()
248-
);
249-
}
250-
}
251-
252207
fn check_testing_roundtrip(t: impl Into<SerTestingLatest>) {
253208
let before = Versioned::new_latest(t.into());
254209
let after = ser_roundtrip_check_schema(&before, get_testing_schemas(true));
@@ -306,7 +261,7 @@ fn simpleser() {
306261
extensions: ExtensionRegistry::default(),
307262
};
308263

309-
check_hugr_roundtrip(&hugr, true);
264+
check_hugr_json_roundtrip(&hugr, true);
310265
}
311266

312267
#[test]
@@ -335,7 +290,7 @@ fn weighted_hugr_ser() {
335290
module_builder.finish_hugr().unwrap()
336291
};
337292

338-
check_hugr_roundtrip(&hugr, true);
293+
check_hugr_json_roundtrip(&hugr, true);
339294
}
340295

341296
#[test]
@@ -351,7 +306,7 @@ fn dfg_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
351306
}
352307
let hugr = dfg.finish_hugr_with_outputs(params)?;
353308

354-
check_hugr_roundtrip(&hugr, true);
309+
check_hugr_json_roundtrip(&hugr, true);
355310
Ok(())
356311
}
357312

@@ -370,7 +325,7 @@ fn extension_ops() -> Result<(), Box<dyn std::error::Error>> {
370325

371326
let hugr = dfg.finish_hugr_with_outputs([wire])?;
372327

373-
check_hugr_roundtrip(&hugr, true);
328+
check_hugr_json_roundtrip(&hugr, true);
374329
Ok(())
375330
}
376331

@@ -412,7 +367,7 @@ fn function_type() -> Result<(), Box<dyn std::error::Error>> {
412367
let op = bldr.add_dataflow_op(Noop(fn_ty), bldr.input_wires())?;
413368
let h = bldr.finish_hugr_with_outputs(op.outputs())?;
414369

415-
check_hugr_roundtrip(&h, true);
370+
check_hugr_json_roundtrip(&h, true);
416371
Ok(())
417372
}
418373

@@ -430,7 +385,7 @@ fn hierarchy_order() -> Result<(), Box<dyn std::error::Error>> {
430385
hugr.remove_node(old_in);
431386
hugr.validate()?;
432387

433-
let rhs: Hugr = check_hugr_roundtrip(&hugr, true);
388+
let rhs: Hugr = check_hugr_json_roundtrip(&hugr, true);
434389
rhs.validate()?;
435390
Ok(())
436391
}

0 commit comments

Comments
 (0)