Skip to content

Commit 68415f4

Browse files
authored
fix: Make JSON Schema checks actually work again (#2412)
fixes #2401 * Update generate_schema.py to add an entry to the generated schema describing what members of `#/$defs` the actual value (i.e. root) should conform to. * Add a test that schema validation fails when expected (by adding/removing fields). * Some changes to `check_schemas` and related (crate) methods to expose errors etc. * Fix various failures in proptests stemming from * #2309 (the `Arbitrary` instance can generate "TypeArg"s where "TypeParam"s are expected or vice versa, and these do not confirm to the JSON schema). * #2360 - I'm not clear what *well-typed* ArrayOrTermSer::Term can fit the JSON schema, but our `Arbitrary`s generate plenty that are not well-typed. ...I've "fixed" the proptests by filtering out non-schema-compliant instances using `prop_filter`, this is not great, the new situation is detailed in #2415 for further cleaning up.
1 parent 0789ed4 commit 68415f4

File tree

13 files changed

+241
-33
lines changed

13 files changed

+241
-33
lines changed

hugr-core/src/hugr/serialize/test.rs

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -62,34 +62,37 @@ impl NamedSchema {
6262
Self { name, schema }
6363
}
6464

65-
pub fn check(&self, val: &serde_json::Value) {
65+
pub fn check(&self, val: &serde_json::Value) -> Result<(), String> {
6666
let mut errors = self.schema.iter_errors(val).peekable();
67-
if errors.peek().is_some() {
68-
// errors don't necessarily implement Debug
69-
eprintln!("Schema failed to validate: {}", self.name);
70-
for error in errors {
71-
eprintln!("Validation error: {error}");
72-
eprintln!("Instance path: {}", error.instance_path);
73-
}
74-
panic!("Serialization test failed.");
67+
if errors.peek().is_none() {
68+
return Ok(());
7569
}
70+
71+
// errors don't necessarily implement Debug
72+
let mut strs = vec![format!("Schema failed to validate: {}", self.name)];
73+
strs.extend(errors.flat_map(|error| {
74+
[
75+
format!("Validation error: {error}"),
76+
format!("Instance path: {}", error.instance_path),
77+
]
78+
}));
79+
strs.push("Serialization test failed.".to_string());
80+
Err(strs.join("\n"))
7681
}
7782

7883
pub fn check_schemas(
7984
val: &serde_json::Value,
8085
schemas: impl IntoIterator<Item = &'static Self>,
81-
) {
82-
for schema in schemas {
83-
schema.check(val);
84-
}
86+
) -> Result<(), String> {
87+
schemas.into_iter().try_for_each(|schema| schema.check(val))
8588
}
8689
}
8790

8891
macro_rules! include_schema {
8992
($name:ident, $path:literal) => {
9093
lazy_static! {
9194
static ref $name: NamedSchema =
92-
NamedSchema::new("$name", {
95+
NamedSchema::new(stringify!($name), {
9396
let schema_val: serde_json::Value = serde_json::from_str(include_str!(
9497
concat!("../../../../specification/schema/", $path, "_live.json")
9598
))
@@ -161,7 +164,7 @@ fn ser_deserialize_check_schema<T: serde::de::DeserializeOwned>(
161164
val: serde_json::Value,
162165
schemas: impl IntoIterator<Item = &'static NamedSchema>,
163166
) -> T {
164-
NamedSchema::check_schemas(&val, schemas);
167+
NamedSchema::check_schemas(&val, schemas).unwrap();
165168
serde_json::from_value(val).unwrap()
166169
}
167170

@@ -171,8 +174,10 @@ fn ser_roundtrip_check_schema<TSer: Serialize, TDeser: serde::de::DeserializeOwn
171174
schemas: impl IntoIterator<Item = &'static NamedSchema>,
172175
) -> TDeser {
173176
let val = serde_json::to_value(g).unwrap();
174-
NamedSchema::check_schemas(&val, schemas);
175-
serde_json::from_value(val).unwrap()
177+
match NamedSchema::check_schemas(&val, schemas) {
178+
Ok(()) => serde_json::from_value(val).unwrap(),
179+
Err(msg) => panic!("ser_roundtrip_check_schema failed with {msg}, input was {val}"),
180+
}
176181
}
177182

178183
/// Serialize a Hugr and check that it is valid against the schema.
@@ -187,7 +192,7 @@ pub(crate) fn check_hugr_serialization_schema(hugr: &Hugr) {
187192
let schemas = get_schemas(true);
188193
let hugr_ser = HugrSer(hugr);
189194
let val = serde_json::to_value(hugr_ser).unwrap();
190-
NamedSchema::check_schemas(&val, schemas);
195+
NamedSchema::check_schemas(&val, schemas).unwrap();
191196
}
192197

193198
/// Serialize and deserialize a HUGR, and check that the result is the same as the original.
@@ -225,6 +230,77 @@ fn check_testing_roundtrip(t: impl Into<SerTestingLatest>) {
225230
assert_eq!(before, after);
226231
}
227232

233+
fn test_schema_val() -> serde_json::Value {
234+
serde_json::json!({
235+
"op_def":null,
236+
"optype":{
237+
"name":"polyfunc1",
238+
"op":"FuncDefn",
239+
"parent":0,
240+
"signature":{
241+
"body":{
242+
"input":[],
243+
"output":[]
244+
},
245+
"params":[
246+
{"bound":null,"tp":"BoundedNat"}
247+
]
248+
}
249+
},
250+
"poly_func_type":null,
251+
"sum_type":null,
252+
"typ":null,
253+
"value":null,
254+
"version":"live"
255+
})
256+
}
257+
258+
fn schema_val() -> serde_json::Value {
259+
serde_json::json!({"nodes": [], "edges": [], "version": "live"})
260+
}
261+
262+
#[rstest]
263+
#[case(&TESTING_SCHEMA, &TESTING_SCHEMA_STRICT, test_schema_val(), Some("optype"))]
264+
#[case(&SCHEMA, &SCHEMA_STRICT, schema_val(), None)]
265+
fn wrong_fields(
266+
#[case] lax_schema: &'static NamedSchema,
267+
#[case] strict_schema: &'static NamedSchema,
268+
#[case] mut val: serde_json::Value,
269+
#[case] target_loc: impl IntoIterator<Item = &'static str> + Clone,
270+
) {
271+
use serde_json::Value;
272+
fn get_fields(
273+
val: &mut Value,
274+
mut path: impl Iterator<Item = &'static str>,
275+
) -> &mut serde_json::Map<String, Value> {
276+
let Value::Object(fields) = val else { panic!() };
277+
match path.next() {
278+
Some(n) => get_fields(fields.get_mut(n).unwrap(), path),
279+
None => fields,
280+
}
281+
}
282+
// First, some "known good" JSON
283+
NamedSchema::check_schemas(&val, [lax_schema, strict_schema]).unwrap();
284+
285+
// Now try adding an extra field
286+
let fields = get_fields(&mut val, target_loc.clone().into_iter());
287+
fields.insert(
288+
"extra_field".to_string(),
289+
Value::String("not in schema".to_string()),
290+
);
291+
strict_schema.check(&val).unwrap_err();
292+
lax_schema.check(&val).unwrap();
293+
294+
// And removing one
295+
let fields = get_fields(&mut val, target_loc.into_iter());
296+
fields.remove("extra_field").unwrap();
297+
let key = fields.keys().next().unwrap().clone();
298+
fields.remove(&key).unwrap();
299+
300+
lax_schema.check(&val).unwrap_err();
301+
strict_schema.check(&val).unwrap_err();
302+
}
303+
228304
/// Generate an optype for a node with a matching amount of inputs and outputs.
229305
fn gen_optype(g: &MultiPortGraph, node: portgraph::NodeIndex) -> OpType {
230306
let inputs = g.num_inputs(node);
@@ -544,7 +620,7 @@ fn std_extensions_valid() {
544620
let std_reg = crate::std_extensions::std_reg();
545621
for ext in std_reg {
546622
let val = serde_json::to_value(ext).unwrap();
547-
NamedSchema::check_schemas(&val, get_schemas(true));
623+
NamedSchema::check_schemas(&val, get_schemas(true)).unwrap();
548624
// check deserialises correctly, can't check equality because of custom binaries.
549625
let deser: crate::extension::Extension = serde_json::from_value(val.clone()).unwrap();
550626
assert_eq!(serde_json::to_value(deser).unwrap(), val);

hugr-core/src/ops/custom.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use thiserror::Error;
77
#[cfg(test)]
88
use {
99
crate::extension::test::SimpleOpDef, crate::proptest::any_nonempty_smolstr,
10-
::proptest::prelude::*, ::proptest_derive::Arbitrary,
10+
crate::types::proptest_utils::any_serde_type_arg_vec, ::proptest::prelude::*,
11+
::proptest_derive::Arbitrary,
1112
};
1213

1314
use crate::core::HugrNode;
@@ -35,6 +36,7 @@ pub struct ExtensionOp {
3536
proptest(strategy = "any::<SimpleOpDef>().prop_map(|x| Arc::new(x.into()))")
3637
)]
3738
def: Arc<OpDef>,
39+
#[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
3840
args: Vec<TypeArg>,
3941
signature: Signature, // Cache
4042
}
@@ -235,6 +237,7 @@ pub struct OpaqueOp {
235237
extension: ExtensionId,
236238
#[cfg_attr(test, proptest(strategy = "any_nonempty_smolstr()"))]
237239
name: OpName,
240+
#[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
238241
args: Vec<TypeArg>,
239242
// note that the `signature` field might not include `extension`. Thus this must
240243
// remain private, and should be accessed through

hugr-core/src/ops/dataflow.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeAr
1010
use crate::{IncomingPort, type_row};
1111

1212
#[cfg(test)]
13-
use proptest_derive::Arbitrary;
13+
use {crate::types::proptest_utils::any_serde_type_arg_vec, proptest_derive::Arbitrary};
1414

1515
/// Trait implemented by all dataflow operations.
1616
pub trait DataflowOpTrait: Sized {
@@ -191,6 +191,7 @@ pub struct Call {
191191
/// Signature of function being called.
192192
pub func_sig: PolyFuncType,
193193
/// The type arguments that instantiate `func_sig`.
194+
#[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
194195
pub type_args: Vec<TypeArg>,
195196
/// The instantiation of `func_sig`.
196197
pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature()
@@ -391,6 +392,7 @@ pub struct LoadFunction {
391392
/// Signature of the function
392393
pub func_sig: PolyFuncType,
393394
/// The type arguments that instantiate `func_sig`.
395+
#[cfg_attr(test, proptest(strategy = "any_serde_type_arg_vec()"))]
394396
pub type_args: Vec<TypeArg>,
395397
/// The instantiation of `func_sig`.
396398
pub instantiation: Signature, // Cache, so we can fail in try_new() not in signature()

hugr-core/src/types.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,3 +1107,82 @@ pub(crate) mod test {
11071107
}
11081108
}
11091109
}
1110+
1111+
#[cfg(test)]
1112+
pub(super) mod proptest_utils {
1113+
use proptest::collection::vec;
1114+
use proptest::prelude::{Strategy, any_with};
1115+
1116+
use super::serialize::{TermSer, TypeArgSer, TypeParamSer};
1117+
use super::type_param::Term;
1118+
1119+
use crate::proptest::RecursionDepth;
1120+
use crate::types::serialize::ArrayOrTermSer;
1121+
1122+
fn term_is_serde_type_arg(t: &Term) -> bool {
1123+
let TermSer::TypeArg(arg) = TermSer::from(t.clone()) else {
1124+
return false;
1125+
};
1126+
match arg {
1127+
TypeArgSer::List { elems: terms }
1128+
| TypeArgSer::ListConcat { lists: terms }
1129+
| TypeArgSer::Tuple { elems: terms }
1130+
| TypeArgSer::TupleConcat { tuples: terms } => terms.iter().all(term_is_serde_type_arg),
1131+
TypeArgSer::Variable { v } => term_is_serde_type_param(&v.cached_decl),
1132+
TypeArgSer::Type { ty } => {
1133+
if let Some(cty) = ty.as_extension() {
1134+
cty.args().iter().all(term_is_serde_type_arg)
1135+
} else {
1136+
true
1137+
}
1138+
} // Do we need to inspect inside function types? sum types?
1139+
TypeArgSer::BoundedNat { .. }
1140+
| TypeArgSer::String { .. }
1141+
| TypeArgSer::Bytes { .. }
1142+
| TypeArgSer::Float { .. } => true,
1143+
}
1144+
}
1145+
1146+
fn term_is_serde_type_param(t: &Term) -> bool {
1147+
let TermSer::TypeParam(parm) = TermSer::from(t.clone()) else {
1148+
return false;
1149+
};
1150+
match parm {
1151+
TypeParamSer::Type { .. }
1152+
| TypeParamSer::BoundedNat { .. }
1153+
| TypeParamSer::String
1154+
| TypeParamSer::Bytes
1155+
| TypeParamSer::Float
1156+
| TypeParamSer::StaticType => true,
1157+
TypeParamSer::List { param } => term_is_serde_type_param(&param),
1158+
TypeParamSer::Tuple { params } => {
1159+
match &params {
1160+
ArrayOrTermSer::Array(terms) => terms.iter().all(term_is_serde_type_param),
1161+
ArrayOrTermSer::Term(b) => match &**b {
1162+
Term::List(_) => panic!("Should be represented as ArrayOrTermSer::Array"),
1163+
// This might be well-typed, but does not fit the (TODO: update) JSON schema
1164+
Term::Variable(_) => false,
1165+
// Similarly, but not produced by our `impl Arbitrary`:
1166+
Term::ListConcat(_) => todo!("Update schema"),
1167+
1168+
// The others do not fit the JSON schema, and are not well-typed,
1169+
// but can be produced by our impl of Arbitrary, so we must filter out:
1170+
_ => false,
1171+
},
1172+
}
1173+
}
1174+
}
1175+
}
1176+
1177+
pub fn any_serde_type_arg(depth: RecursionDepth) -> impl Strategy<Value = Term> {
1178+
any_with::<Term>(depth).prop_filter("Term was not a TypeArg", term_is_serde_type_arg)
1179+
}
1180+
1181+
pub fn any_serde_type_arg_vec() -> impl Strategy<Value = Vec<Term>> {
1182+
vec(any_serde_type_arg(RecursionDepth::default()), 1..3)
1183+
}
1184+
1185+
pub fn any_serde_type_param(depth: RecursionDepth) -> impl Strategy<Value = Term> {
1186+
any_with::<Term>(depth).prop_filter("Term was not a TypeParam", term_is_serde_type_param)
1187+
}
1188+
}

hugr-core/src/types/custom.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ mod test {
188188
use crate::extension::ExtensionId;
189189
use crate::proptest::RecursionDepth;
190190
use crate::proptest::any_nonempty_string;
191-
use crate::types::type_param::TypeArg;
191+
use crate::types::proptest_utils::any_serde_type_arg;
192192
use crate::types::{CustomType, TypeBound};
193193
use ::proptest::collection::vec;
194194
use ::proptest::prelude::*;
@@ -224,7 +224,7 @@ mod test {
224224
Just(vec![]).boxed()
225225
} else {
226226
// a TypeArg may contain a CustomType, so we descend here
227-
vec(any_with::<TypeArg>(depth.descend()), 0..3).boxed()
227+
vec(any_serde_type_arg(depth.descend()), 0..3).boxed()
228228
};
229229
(any_nonempty_string(), args, any::<ExtensionId>(), bound)
230230
.prop_map(|(id, args, extension, bound)| {

hugr-core/src/types/poly_func.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use itertools::Itertools;
77
use crate::extension::SignatureError;
88
#[cfg(test)]
99
use {
10+
super::proptest_utils::any_serde_type_param,
1011
crate::proptest::RecursionDepth,
1112
::proptest::{collection::vec, prelude::*},
1213
proptest_derive::Arbitrary,
@@ -31,7 +32,7 @@ pub struct PolyFuncTypeBase<RV: MaybeRV> {
3132
/// The declared type parameters, i.e., these must be instantiated with
3233
/// the same number of [`TypeArg`]s before the function can be called. This
3334
/// defines the indices used by variables inside the body.
34-
#[cfg_attr(test, proptest(strategy = "vec(any_with::<TypeParam>(params), 0..3)"))]
35+
#[cfg_attr(test, proptest(strategy = "vec(any_serde_type_param(params), 0..3)"))]
3536
params: Vec<TypeParam>,
3637
/// Template for the function. May contain variables up to length of [`Self::params`]
3738
#[cfg_attr(test, proptest(strategy = "any_with::<FuncTypeBase<RV>>(params)"))]

hugr-core/src/types/serialize.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ impl From<TermSer> for Term {
187187
#[serde(untagged)]
188188
pub(super) enum ArrayOrTermSer {
189189
Array(Vec<Term>),
190-
Term(Box<Term>),
190+
Term(Box<Term>), // TODO JSON Schema does not really support this yet
191191
}
192192

193193
impl From<ArrayOrTermSer> for Term {

hugr-core/src/types/type_param.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ impl<const N: usize> From<[Term; N]> for Term {
261261
#[display("#{idx}")]
262262
pub struct TermVar {
263263
idx: usize,
264-
cached_decl: Box<Term>,
264+
pub(in crate::types) cached_decl: Box<Term>,
265265
}
266266

267267
impl Term {
@@ -1046,13 +1046,13 @@ mod test {
10461046

10471047
use super::super::{TermVar, UpperBound};
10481048
use crate::proptest::RecursionDepth;
1049-
use crate::types::{Term, Type, TypeBound};
1049+
use crate::types::{Term, Type, TypeBound, proptest_utils::any_serde_type_param};
10501050

10511051
impl Arbitrary for TermVar {
10521052
type Parameters = RecursionDepth;
10531053
type Strategy = BoxedStrategy<Self>;
10541054
fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy {
1055-
(any::<usize>(), any_with::<Term>(depth))
1055+
(any::<usize>(), any_serde_type_param(depth))
10561056
.prop_map(|(idx, cached_decl)| Self {
10571057
idx,
10581058
cached_decl: Box::new(cached_decl),

scripts/generate_schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pathlib import Path
1515

1616
from pydantic import ConfigDict
17-
from pydantic.json_schema import models_json_schema
17+
from pydantic.json_schema import DEFAULT_REF_TEMPLATE, models_json_schema
1818

1919
from hugr._serialization.extension import Extension, Package
2020
from hugr._serialization.serial_hugr import SerialHugr
@@ -38,6 +38,9 @@ def write_schema(
3838
_, top_level_schema = models_json_schema(
3939
[(s, "validation") for s in schemas], title="HUGR schema"
4040
)
41+
top_level_schema["oneOf"] = [
42+
{"$ref": DEFAULT_REF_TEMPLATE.format(model=s.__name__)} for s in schemas
43+
]
4144
with path.open("w") as f:
4245
json.dump(top_level_schema, f, indent=4)
4346

0 commit comments

Comments
 (0)