Skip to content

feat: validate Terms used as parameter types are appropriate (into #2309) #2345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4430a3e
`...Type` suffix for `TypeParam` variants.
zrho Jun 6, 2025
e270162
Merge `TypeArg` and `TypeParam`.
zrho Jun 6, 2025
593bb6a
Implement functions.
zrho Jun 6, 2025
a972d2a
Docs.
zrho Jun 6, 2025
b0b8507
Adjust `Arbitrary` instance to avoid infinite recursion.
zrho Jun 6, 2025
e17e918
Import no longer needs to distinguish between type params and args.
zrho Jun 6, 2025
1f62ea4
JSON compatibility layer in Rust.
zrho Jun 6, 2025
d1404b9
Added proptest for `Term::contains`.
zrho Jun 6, 2025
ea05c57
Fix `Term::contains`.
zrho Jun 6, 2025
e4cde0b
Taking the opportunity to have nicer field names.
zrho Jun 6, 2025
377eb27
Formatting.
zrho Jun 6, 2025
ae48d57
Update feature gated code.
zrho Jun 7, 2025
3ff1d5c
Hash change in snapshot test.
zrho Jun 9, 2025
5bfcb55
Another hash in a snapshot.
zrho Jun 9, 2025
af19410
Tuple fields.
zrho Jun 13, 2025
79066ce
Instantiate `Term`s with `From` and `new_*`.
zrho Jun 13, 2025
4a9171a
Lints and link to const issue.
zrho Jun 13, 2025
b2e7b6c
`TermVar`.
zrho Jun 13, 2025
be931d9
Lints.
zrho Jun 14, 2025
ba71b19
Extension resolution on `Term`s.
zrho Jun 14, 2025
7c521f3
Formatting.
zrho Jun 14, 2025
d52193a
Rename `check_type_arg` to `check_term_type`.
zrho Jun 16, 2025
903da0c
Use `.into`
zrho Jun 16, 2025
bba7e72
Formatting
zrho Jun 16, 2025
4008375
Lints.
zrho Jun 16, 2025
c401c30
Typo, `is_supertype`, `Term::Runtime`.
zrho Jun 16, 2025
b0d58b7
Add validate_param
acl-cqc Jun 16, 2025
9d19ba5
is_supertype need only deal with types
acl-cqc Jun 16, 2025
67bcc63
update proptest
acl-cqc Jun 16, 2025
b6dfde2
Allow variables as parameter types, so validate cached_decls against …
acl-cqc Jun 17, 2025
a0996c7
Allow invariant case for variables
acl-cqc Jun 17, 2025
dfc9444
fix corner case of non-type variable
acl-cqc Jun 17, 2025
7d08db6
Missing RuntimeType: StaticType case
acl-cqc Jun 17, 2025
1bf35d8
Handle variables by is_cached_static
acl-cqc Jun 17, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ pub(crate) mod test {
FunctionBuilder::new(
"bad_eval",
PolyFuncType::new(
[TypeParam::new_list(TypeBound::Copyable)],
[TypeParam::new_list_type(TypeBound::Copyable)],
Signature::new(
Type::new_function(FuncValueType::new(usize_t(), tv.clone())),
vec![],
Expand Down
101 changes: 47 additions & 54 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Exporting HUGR graphs to their `hugr-model` representation.
use crate::extension::ExtensionRegistry;
use crate::hugr::internal::HugrInternals;
use crate::types::type_param::Term;
use crate::{
Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port,
extension::{ExtensionId, OpDef, SignatureFunc},
Expand All @@ -14,9 +15,7 @@ use crate::{
},
types::{
CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType,
TypeArg, TypeBase, TypeBound, TypeEnum,
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
TypeBase, TypeBound, TypeEnum, type_param::TermVar, type_row::TypeRowBase,
},
};

Expand Down Expand Up @@ -385,7 +384,7 @@ impl<'a> Context<'a> {
let node = self.connected_function(node).unwrap();
let symbol = self.node_to_id[&node];
let mut args = BumpVec::new_in(self.bump);
args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg)));
args.extend(call.type_args.iter().map(|arg| self.export_term(arg, None)));
let args = args.into_bump_slice();
let func = self.make_term(table::Term::Apply(symbol, args));

Expand All @@ -401,7 +400,7 @@ impl<'a> Context<'a> {
let node = self.connected_function(node).unwrap();
let symbol = self.node_to_id[&node];
let mut args = BumpVec::new_in(self.bump);
args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg)));
args.extend(load.type_args.iter().map(|arg| self.export_term(arg, None)));
let args = args.into_bump_slice();
let func = self.make_term(table::Term::Apply(symbol, args));
let runtime_type = self.make_term(table::Term::Wildcard);
Expand Down Expand Up @@ -464,7 +463,7 @@ impl<'a> Context<'a> {
let node = self.export_opdef(op.def());
let params = self
.bump
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg)));
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None)));
let operation = self.make_term(table::Term::Apply(node, params));
table::Operation::Custom(operation)
}
Expand All @@ -473,7 +472,7 @@ impl<'a> Context<'a> {
let node = self.make_named_global_ref(op.extension(), op.unqualified_id());
let params = self
.bump
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg)));
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None)));
let operation = self.make_term(table::Term::Apply(node, params));
table::Operation::Custom(operation)
}
Expand Down Expand Up @@ -806,7 +805,7 @@ impl<'a> Context<'a> {

for (i, param) in t.params().iter().enumerate() {
let name = self.bump.alloc_str(&i.to_string());
let r#type = self.export_type_param(param, Some((scope, i as _)));
let r#type = self.export_term(param, Some((scope, i as _)));
let param = table::Param { name, r#type };
params.push(param);
}
Expand Down Expand Up @@ -854,40 +853,12 @@ impl<'a> Context<'a> {

let args = self
.bump
.alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p)));
.alloc_slice_fill_iter(t.args().iter().map(|p| self.export_term(p, None)));
let term = table::Term::Apply(symbol, args);
self.make_term(term)
}

pub fn export_type_arg(&mut self, t: &TypeArg) -> table::TermId {
match t {
TypeArg::Type { ty } => self.export_type(ty),
TypeArg::BoundedNat { n } => self.make_term(model::Literal::Nat(*n).into()),
TypeArg::String { arg } => self.make_term(model::Literal::Str(arg.into()).into()),
TypeArg::Float { value } => self.make_term(model::Literal::Float(*value).into()),
TypeArg::Bytes { value } => self.make_term(model::Literal::Bytes(value.clone()).into()),
TypeArg::List { elems } => {
// For now we assume that the sequence is meant to be a list.
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
.map(|elem| table::SeqPart::Item(self.export_type_arg(elem))),
);
self.make_term(table::Term::List(parts))
}
TypeArg::Tuple { elems } => {
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
.map(|elem| table::SeqPart::Item(self.export_type_arg(elem))),
);
self.make_term(table::Term::Tuple(parts))
}
TypeArg::Variable { v } => self.export_type_arg_var(v),
}
}

pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> table::TermId {
pub fn export_type_arg_var(&mut self, var: &TermVar) -> table::TermId {
let node = self.local_scope.expect("local variable out of scope");
self.make_term(table::Term::Var(table::VarId(node, var.index() as _)))
}
Expand Down Expand Up @@ -953,19 +924,19 @@ impl<'a> Context<'a> {
self.make_term(table::Term::List(parts))
}

/// Exports a `TypeParam` to a term.
/// Exports a term.
///
/// The `var` argument is set when the type parameter being exported is the
/// The `var` argument is set when the term being exported is the
/// type of a parameter to a polymorphic definition. In that case we can
/// generate a `nonlinear` constraint for the type of runtime types marked as
/// `TypeBound::Copyable`.
pub fn export_type_param(
pub fn export_term(
&mut self,
t: &TypeParam,
t: &Term,
var: Option<(table::NodeId, table::VarIndex)>,
) -> table::TermId {
match t {
TypeParam::Type { b } => {
Term::RuntimeType(b) => {
if let (Some((node, index)), TypeBound::Copyable) = (var, b) {
let term = self.make_term(table::Term::Var(table::VarId(node, index)));
let non_linear = self.make_term_apply(model::CORE_NON_LINEAR, &[term]);
Expand All @@ -974,24 +945,46 @@ impl<'a> Context<'a> {

self.make_term_apply(model::CORE_TYPE, &[])
}
// This ignores the bound on the natural for now.
TypeParam::BoundedNat { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]),
TypeParam::String => self.make_term_apply(model::CORE_STR_TYPE, &[]),
TypeParam::Bytes => self.make_term_apply(model::CORE_BYTES_TYPE, &[]),
TypeParam::Float => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]),
TypeParam::List { param } => {
let item_type = self.export_type_param(param, None);
Term::BoundedNatType { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]),
Term::StringType => self.make_term_apply(model::CORE_STR_TYPE, &[]),
Term::BytesType => self.make_term_apply(model::CORE_BYTES_TYPE, &[]),
Term::FloatType => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]),
Term::ListType(item_type) => {
let item_type = self.export_term(item_type, None);
self.make_term_apply(model::CORE_LIST_TYPE, &[item_type])
}
TypeParam::Tuple { params } => {
let parts = self.bump.alloc_slice_fill_iter(
Term::TupleType(params) => {
let item_types = self.bump.alloc_slice_fill_iter(
params
.iter()
.map(|param| table::SeqPart::Item(self.export_type_param(param, None))),
.map(|param| table::SeqPart::Item(self.export_term(param, None))),
);
let types = self.make_term(table::Term::List(parts));
let types = self.make_term(table::Term::List(item_types));
self.make_term_apply(model::CORE_TUPLE_TYPE, &[types])
}
Term::Runtime(ty) => self.export_type(ty),
Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()),
Term::String(value) => self.make_term(model::Literal::Str(value.into()).into()),
Term::Float(value) => self.make_term(model::Literal::Float(*value).into()),
Term::Bytes(value) => self.make_term(model::Literal::Bytes(value.clone()).into()),
Term::List(elems) => {
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
.map(|elem| table::SeqPart::Item(self.export_term(elem, None))),
);
self.make_term(table::Term::List(parts))
}
Term::Tuple(elems) => {
let parts = self.bump.alloc_slice_fill_iter(
elems
.iter()
.map(|elem| table::SeqPart::Item(self.export_term(elem, None))),
);
self.make_term(table::Term::Tuple(parts))
}
Term::Variable(v) => self.export_type_arg_var(v),
Term::StaticType => self.make_term_apply(model::CORE_STATIC, &[]),
}
}

Expand Down
10 changes: 6 additions & 4 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ use crate::hugr::IdentList;
use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{OpName, OpNameRef};
use crate::types::RowVariable;
use crate::types::type_param::{TypeArg, TypeArgError, TypeParam};
use crate::types::{CustomType, TypeBound, TypeName};
use crate::types::{Signature, TypeNameRef};
use crate::types::type_param::{Term, TermTypeError, TypeArg, TypeParam};
use crate::types::{CustomType, Signature, TypeBound, TypeName, TypeNameRef};

mod const_fold;
mod op_def;
Expand Down Expand Up @@ -387,7 +386,10 @@ pub enum SignatureError {
ExtensionMismatch(ExtensionId, ExtensionId),
/// When the type arguments of the node did not match the params declared by the `OpDef`
#[error("Type arguments of node did not match params declared by definition: {0}")]
TypeArgMismatch(#[from] TypeArgError),
TypeArgMismatch(#[from] TermTypeError),
/// A [Term] was not a valid type parameter
#[error("Term {0} is not a valid parameter type")]
InvalidTypeParam(Term),
/// Invalid type arguments
#[error("Invalid type arguments for operation")]
InvalidTypeArgs,
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/src/extension/declarative/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,6 @@ impl TypeParamDeclaration {
_extension: &Extension,
_ctx: DeclarationContext<'_>,
) -> Result<TypeParam, ExtensionDeclarationError> {
Ok(TypeParam::String)
Ok(TypeParam::StringType)
}
}
37 changes: 18 additions & 19 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::{
use crate::Hugr;
use crate::envelope::serde_with::AsStringEnvelope;
use crate::ops::{OpName, OpNameRef};
use crate::types::type_param::{TypeArg, TypeParam, check_type_args};
use crate::types::type_param::{TypeArg, TypeParam, check_term_types};
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
mod serialize_signature_func;

Expand Down Expand Up @@ -239,7 +239,7 @@ impl SignatureFunc {
let static_params = func.static_params();
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));

check_type_args(static_args, static_params)?;
check_term_types(static_args, static_params)?;
temp = func.compute_signature(static_args, def)?;
(&temp, other_args)
}
Expand Down Expand Up @@ -347,7 +347,7 @@ impl OpDef {
let (static_args, other_args) =
args.split_at(min(custom.static_params().len(), args.len()));
static_args.iter().try_for_each(|ta| ta.validate(&[]))?;
check_type_args(static_args, custom.static_params())?;
check_term_types(static_args, custom.static_params())?;
temp = custom.compute_signature(static_args, self)?;
(&temp, other_args)
}
Expand All @@ -357,7 +357,7 @@ impl OpDef {
}
};
args.iter().try_for_each(|ta| ta.validate(var_decls))?;
check_type_args(args, pf.params())?;
check_term_types(args, pf.params())?;
Ok(())
}

Expand Down Expand Up @@ -553,7 +553,7 @@ pub(super) mod test {
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
use crate::ops::OpName;
use crate::std_extensions::collections::list;
use crate::types::type_param::{TypeArgError, TypeParam};
use crate::types::type_param::{TermTypeError, TypeParam};
use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
use crate::{Extension, const_extension_ids};

Expand Down Expand Up @@ -656,7 +656,7 @@ pub(super) mod test {
const OP_NAME: OpName = OpName::new_inline("Reverse");

let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Any);
let list_of_var =
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var]));
Expand All @@ -678,11 +678,10 @@ pub(super) mod test {
reg.validate()?;
let e = reg.get(&EXT_ID).unwrap();

let list_usize =
Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: usize_t() }])?);
let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?);
let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
let rev = dfg.add_dataflow_op(
e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }])
e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()])
.unwrap(),
dfg.input_wires(),
)?;
Expand All @@ -703,8 +702,8 @@ pub(super) mod test {
&self,
arg_values: &[TypeArg],
) -> Result<PolyFuncTypeRV, SignatureError> {
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
let [TypeArg::BoundedNat { n }] = arg_values else {
const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Any);
let [TypeArg::BoundedNat(n)] = arg_values else {
return Err(SignatureError::InvalidTypeArgs);
};
let n = *n as usize;
Expand All @@ -718,7 +717,7 @@ pub(super) mod test {
}

fn static_params(&self) -> &[TypeParam] {
const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat()];
const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()];
MAX_NAT
}
}
Expand All @@ -727,7 +726,7 @@ pub(super) mod test {
ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?;

// Base case, no type variables:
let args = [TypeArg::BoundedNat { n: 3 }, usize_t().into()];
let args = [TypeArg::BoundedNat(3), usize_t().into()];
assert_eq!(
def.compute_signature(&args),
Ok(Signature::new(
Expand All @@ -740,7 +739,7 @@ pub(super) mod test {
// Second arg may be a variable (substitutable)
let tyvar = Type::new_var_use(0, TypeBound::Copyable);
let tyvars: Vec<Type> = vec![tyvar.clone(); 3];
let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()];
let args = [TypeArg::BoundedNat(3), tyvar.clone().into()];
assert_eq!(
def.compute_signature(&args),
Ok(Signature::new(
Expand All @@ -761,7 +760,7 @@ pub(super) mod test {
);

// First arg must be concrete, not a variable
let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap());
let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap());
let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()];
// We can't prevent this from getting into our compute_signature implementation:
assert_eq!(
Expand Down Expand Up @@ -798,7 +797,7 @@ pub(super) mod test {
extension_ref,
)?;
let tv = Type::new_var_use(0, TypeBound::Copyable);
let args = [TypeArg::Type { ty: tv.clone() }];
let args = [tv.clone().into()];
let decls = [TypeBound::Copyable.into()];
def.validate_args(&args, &decls).unwrap();
assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv)));
Expand All @@ -807,9 +806,9 @@ pub(super) mod test {
assert_eq!(
def.compute_signature(&[arg.clone()]),
Err(SignatureError::TypeArgMismatch(
TypeArgError::TypeMismatch {
param: TypeBound::Any.into(),
arg
TermTypeError::TypeMismatch {
type_: TypeBound::Any.into(),
term: arg,
}
))
);
Expand Down
Loading
Loading