Skip to content

Commit ebfc98e

Browse files
authored
feat!: Merge TypeParam and TypeArg into one Term type in Rust (#2309)
This PR merges the `TypeParam` and `TypeArg` types into one `Term` type in Rust and adds a `StaticType` variant. The JSON encoding remains unchanged for now by the help of a small compatibility layer, which allows us to leave the Python side alone in this PR. BREAKING CHANGE: `TypeParam` and `TypeArg` are now merged in Rust. Variant names in `TypeParam` have been suffixed with `Type` to avoid conflict. Closes #2295.
1 parent cbaf9bf commit ebfc98e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+953
-1025
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hugr-core/src/builder/dataflow.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ pub(crate) mod test {
627627
FunctionBuilder::new(
628628
"bad_eval",
629629
PolyFuncType::new(
630-
[TypeParam::new_list(TypeBound::Copyable)],
630+
[TypeParam::new_list_type(TypeBound::Copyable)],
631631
Signature::new(
632632
Type::new_function(FuncValueType::new(usize_t(), tv.clone())),
633633
vec![],

hugr-core/src/export.rs

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Exporting HUGR graphs to their `hugr-model` representation.
22
use crate::extension::ExtensionRegistry;
33
use crate::hugr::internal::HugrInternals;
4+
use crate::types::type_param::Term;
45
use crate::{
56
Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port,
67
extension::{ExtensionId, OpDef, SignatureFunc},
@@ -14,9 +15,7 @@ use crate::{
1415
},
1516
types::{
1617
CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType,
17-
TypeArg, TypeBase, TypeBound, TypeEnum,
18-
type_param::{TypeArgVariable, TypeParam},
19-
type_row::TypeRowBase,
18+
TypeBase, TypeBound, TypeEnum, type_param::TermVar, type_row::TypeRowBase,
2019
},
2120
};
2221

@@ -385,7 +384,7 @@ impl<'a> Context<'a> {
385384
let node = self.connected_function(node).unwrap();
386385
let symbol = self.node_to_id[&node];
387386
let mut args = BumpVec::new_in(self.bump);
388-
args.extend(call.type_args.iter().map(|arg| self.export_type_arg(arg)));
387+
args.extend(call.type_args.iter().map(|arg| self.export_term(arg, None)));
389388
let args = args.into_bump_slice();
390389
let func = self.make_term(table::Term::Apply(symbol, args));
391390

@@ -401,7 +400,7 @@ impl<'a> Context<'a> {
401400
let node = self.connected_function(node).unwrap();
402401
let symbol = self.node_to_id[&node];
403402
let mut args = BumpVec::new_in(self.bump);
404-
args.extend(load.type_args.iter().map(|arg| self.export_type_arg(arg)));
403+
args.extend(load.type_args.iter().map(|arg| self.export_term(arg, None)));
405404
let args = args.into_bump_slice();
406405
let func = self.make_term(table::Term::Apply(symbol, args));
407406
let runtime_type = self.make_term(table::Term::Wildcard);
@@ -464,7 +463,7 @@ impl<'a> Context<'a> {
464463
let node = self.export_opdef(op.def());
465464
let params = self
466465
.bump
467-
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg)));
466+
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None)));
468467
let operation = self.make_term(table::Term::Apply(node, params));
469468
table::Operation::Custom(operation)
470469
}
@@ -473,7 +472,7 @@ impl<'a> Context<'a> {
473472
let node = self.make_named_global_ref(op.extension(), op.unqualified_id());
474473
let params = self
475474
.bump
476-
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_type_arg(arg)));
475+
.alloc_slice_fill_iter(op.args().iter().map(|arg| self.export_term(arg, None)));
477476
let operation = self.make_term(table::Term::Apply(node, params));
478477
table::Operation::Custom(operation)
479478
}
@@ -806,7 +805,7 @@ impl<'a> Context<'a> {
806805

807806
for (i, param) in t.params().iter().enumerate() {
808807
let name = self.bump.alloc_str(&i.to_string());
809-
let r#type = self.export_type_param(param, Some((scope, i as _)));
808+
let r#type = self.export_term(param, Some((scope, i as _)));
810809
let param = table::Param { name, r#type };
811810
params.push(param);
812811
}
@@ -854,40 +853,12 @@ impl<'a> Context<'a> {
854853

855854
let args = self
856855
.bump
857-
.alloc_slice_fill_iter(t.args().iter().map(|p| self.export_type_arg(p)));
856+
.alloc_slice_fill_iter(t.args().iter().map(|p| self.export_term(p, None)));
858857
let term = table::Term::Apply(symbol, args);
859858
self.make_term(term)
860859
}
861860

862-
pub fn export_type_arg(&mut self, t: &TypeArg) -> table::TermId {
863-
match t {
864-
TypeArg::Type { ty } => self.export_type(ty),
865-
TypeArg::BoundedNat { n } => self.make_term(model::Literal::Nat(*n).into()),
866-
TypeArg::String { arg } => self.make_term(model::Literal::Str(arg.into()).into()),
867-
TypeArg::Float { value } => self.make_term(model::Literal::Float(*value).into()),
868-
TypeArg::Bytes { value } => self.make_term(model::Literal::Bytes(value.clone()).into()),
869-
TypeArg::List { elems } => {
870-
// For now we assume that the sequence is meant to be a list.
871-
let parts = self.bump.alloc_slice_fill_iter(
872-
elems
873-
.iter()
874-
.map(|elem| table::SeqPart::Item(self.export_type_arg(elem))),
875-
);
876-
self.make_term(table::Term::List(parts))
877-
}
878-
TypeArg::Tuple { elems } => {
879-
let parts = self.bump.alloc_slice_fill_iter(
880-
elems
881-
.iter()
882-
.map(|elem| table::SeqPart::Item(self.export_type_arg(elem))),
883-
);
884-
self.make_term(table::Term::Tuple(parts))
885-
}
886-
TypeArg::Variable { v } => self.export_type_arg_var(v),
887-
}
888-
}
889-
890-
pub fn export_type_arg_var(&mut self, var: &TypeArgVariable) -> table::TermId {
861+
pub fn export_type_arg_var(&mut self, var: &TermVar) -> table::TermId {
891862
let node = self.local_scope.expect("local variable out of scope");
892863
self.make_term(table::Term::Var(table::VarId(node, var.index() as _)))
893864
}
@@ -953,19 +924,19 @@ impl<'a> Context<'a> {
953924
self.make_term(table::Term::List(parts))
954925
}
955926

956-
/// Exports a `TypeParam` to a term.
927+
/// Exports a term.
957928
///
958-
/// The `var` argument is set when the type parameter being exported is the
929+
/// The `var` argument is set when the term being exported is the
959930
/// type of a parameter to a polymorphic definition. In that case we can
960931
/// generate a `nonlinear` constraint for the type of runtime types marked as
961932
/// `TypeBound::Copyable`.
962-
pub fn export_type_param(
933+
pub fn export_term(
963934
&mut self,
964-
t: &TypeParam,
935+
t: &Term,
965936
var: Option<(table::NodeId, table::VarIndex)>,
966937
) -> table::TermId {
967938
match t {
968-
TypeParam::Type { b } => {
939+
Term::RuntimeType(b) => {
969940
if let (Some((node, index)), TypeBound::Copyable) = (var, b) {
970941
let term = self.make_term(table::Term::Var(table::VarId(node, index)));
971942
let non_linear = self.make_term_apply(model::CORE_NON_LINEAR, &[term]);
@@ -974,24 +945,46 @@ impl<'a> Context<'a> {
974945

975946
self.make_term_apply(model::CORE_TYPE, &[])
976947
}
977-
// This ignores the bound on the natural for now.
978-
TypeParam::BoundedNat { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]),
979-
TypeParam::String => self.make_term_apply(model::CORE_STR_TYPE, &[]),
980-
TypeParam::Bytes => self.make_term_apply(model::CORE_BYTES_TYPE, &[]),
981-
TypeParam::Float => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]),
982-
TypeParam::List { param } => {
983-
let item_type = self.export_type_param(param, None);
948+
Term::BoundedNatType { .. } => self.make_term_apply(model::CORE_NAT_TYPE, &[]),
949+
Term::StringType => self.make_term_apply(model::CORE_STR_TYPE, &[]),
950+
Term::BytesType => self.make_term_apply(model::CORE_BYTES_TYPE, &[]),
951+
Term::FloatType => self.make_term_apply(model::CORE_FLOAT_TYPE, &[]),
952+
Term::ListType(item_type) => {
953+
let item_type = self.export_term(item_type, None);
984954
self.make_term_apply(model::CORE_LIST_TYPE, &[item_type])
985955
}
986-
TypeParam::Tuple { params } => {
987-
let parts = self.bump.alloc_slice_fill_iter(
956+
Term::TupleType(params) => {
957+
let item_types = self.bump.alloc_slice_fill_iter(
988958
params
989959
.iter()
990-
.map(|param| table::SeqPart::Item(self.export_type_param(param, None))),
960+
.map(|param| table::SeqPart::Item(self.export_term(param, None))),
991961
);
992-
let types = self.make_term(table::Term::List(parts));
962+
let types = self.make_term(table::Term::List(item_types));
993963
self.make_term_apply(model::CORE_TUPLE_TYPE, &[types])
994964
}
965+
Term::Runtime(ty) => self.export_type(ty),
966+
Term::BoundedNat(value) => self.make_term(model::Literal::Nat(*value).into()),
967+
Term::String(value) => self.make_term(model::Literal::Str(value.into()).into()),
968+
Term::Float(value) => self.make_term(model::Literal::Float(*value).into()),
969+
Term::Bytes(value) => self.make_term(model::Literal::Bytes(value.clone()).into()),
970+
Term::List(elems) => {
971+
let parts = self.bump.alloc_slice_fill_iter(
972+
elems
973+
.iter()
974+
.map(|elem| table::SeqPart::Item(self.export_term(elem, None))),
975+
);
976+
self.make_term(table::Term::List(parts))
977+
}
978+
Term::Tuple(elems) => {
979+
let parts = self.bump.alloc_slice_fill_iter(
980+
elems
981+
.iter()
982+
.map(|elem| table::SeqPart::Item(self.export_term(elem, None))),
983+
);
984+
self.make_term(table::Term::Tuple(parts))
985+
}
986+
Term::Variable(v) => self.export_type_arg_var(v),
987+
Term::StaticType => self.make_term_apply(model::CORE_STATIC, &[]),
995988
}
996989
}
997990

hugr-core/src/extension.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use crate::hugr::IdentList;
2222
use crate::ops::custom::{ExtensionOp, OpaqueOp};
2323
use crate::ops::{OpName, OpNameRef};
2424
use crate::types::RowVariable;
25-
use crate::types::type_param::{TypeArg, TypeArgError, TypeParam};
25+
use crate::types::type_param::{TermTypeError, TypeArg, TypeParam};
2626
use crate::types::{CustomType, TypeBound, TypeName};
2727
use crate::types::{Signature, TypeNameRef};
2828

@@ -387,7 +387,7 @@ pub enum SignatureError {
387387
ExtensionMismatch(ExtensionId, ExtensionId),
388388
/// When the type arguments of the node did not match the params declared by the `OpDef`
389389
#[error("Type arguments of node did not match params declared by definition: {0}")]
390-
TypeArgMismatch(#[from] TypeArgError),
390+
TypeArgMismatch(#[from] TermTypeError),
391391
/// Invalid type arguments
392392
#[error("Invalid type arguments for operation")]
393393
InvalidTypeArgs,

hugr-core/src/extension/declarative/types.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,6 @@ impl TypeParamDeclaration {
129129
_extension: &Extension,
130130
_ctx: DeclarationContext<'_>,
131131
) -> Result<TypeParam, ExtensionDeclarationError> {
132-
Ok(TypeParam::String)
132+
Ok(TypeParam::StringType)
133133
}
134134
}

hugr-core/src/extension/op_def.rs

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use super::{
1414
use crate::Hugr;
1515
use crate::envelope::serde_with::AsStringEnvelope;
1616
use crate::ops::{OpName, OpNameRef};
17-
use crate::types::type_param::{TypeArg, TypeParam, check_type_args};
17+
use crate::types::type_param::{TypeArg, TypeParam, check_term_types};
1818
use crate::types::{FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature};
1919
mod serialize_signature_func;
2020

@@ -239,7 +239,7 @@ impl SignatureFunc {
239239
let static_params = func.static_params();
240240
let (static_args, other_args) = args.split_at(min(static_params.len(), args.len()));
241241

242-
check_type_args(static_args, static_params)?;
242+
check_term_types(static_args, static_params)?;
243243
temp = func.compute_signature(static_args, def)?;
244244
(&temp, other_args)
245245
}
@@ -347,7 +347,7 @@ impl OpDef {
347347
let (static_args, other_args) =
348348
args.split_at(min(custom.static_params().len(), args.len()));
349349
static_args.iter().try_for_each(|ta| ta.validate(&[]))?;
350-
check_type_args(static_args, custom.static_params())?;
350+
check_term_types(static_args, custom.static_params())?;
351351
temp = custom.compute_signature(static_args, self)?;
352352
(&temp, other_args)
353353
}
@@ -357,7 +357,7 @@ impl OpDef {
357357
}
358358
};
359359
args.iter().try_for_each(|ta| ta.validate(var_decls))?;
360-
check_type_args(args, pf.params())?;
360+
check_term_types(args, pf.params())?;
361361
Ok(())
362362
}
363363

@@ -553,7 +553,7 @@ pub(super) mod test {
553553
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
554554
use crate::ops::OpName;
555555
use crate::std_extensions::collections::list;
556-
use crate::types::type_param::{TypeArgError, TypeParam};
556+
use crate::types::type_param::{TermTypeError, TypeParam};
557557
use crate::types::{PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV};
558558
use crate::{Extension, const_extension_ids};
559559

@@ -656,7 +656,7 @@ pub(super) mod test {
656656
const OP_NAME: OpName = OpName::new_inline("Reverse");
657657

658658
let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| {
659-
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
659+
const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Any);
660660
let list_of_var =
661661
Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?);
662662
let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var]));
@@ -678,11 +678,10 @@ pub(super) mod test {
678678
reg.validate()?;
679679
let e = reg.get(&EXT_ID).unwrap();
680680

681-
let list_usize =
682-
Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: usize_t() }])?);
681+
let list_usize = Type::new_extension(list_def.instantiate(vec![usize_t().into()])?);
683682
let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?;
684683
let rev = dfg.add_dataflow_op(
685-
e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }])
684+
e.instantiate_extension_op(&OP_NAME, vec![usize_t().into()])
686685
.unwrap(),
687686
dfg.input_wires(),
688687
)?;
@@ -703,8 +702,8 @@ pub(super) mod test {
703702
&self,
704703
arg_values: &[TypeArg],
705704
) -> Result<PolyFuncTypeRV, SignatureError> {
706-
const TP: TypeParam = TypeParam::Type { b: TypeBound::Any };
707-
let [TypeArg::BoundedNat { n }] = arg_values else {
705+
const TP: TypeParam = TypeParam::RuntimeType(TypeBound::Any);
706+
let [TypeArg::BoundedNat(n)] = arg_values else {
708707
return Err(SignatureError::InvalidTypeArgs);
709708
};
710709
let n = *n as usize;
@@ -718,7 +717,7 @@ pub(super) mod test {
718717
}
719718

720719
fn static_params(&self) -> &[TypeParam] {
721-
const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat()];
720+
const MAX_NAT: &[TypeParam] = &[TypeParam::max_nat_type()];
722721
MAX_NAT
723722
}
724723
}
@@ -727,7 +726,7 @@ pub(super) mod test {
727726
ext.add_op("MyOp".into(), String::new(), SigFun(), extension_ref)?;
728727

729728
// Base case, no type variables:
730-
let args = [TypeArg::BoundedNat { n: 3 }, usize_t().into()];
729+
let args = [TypeArg::BoundedNat(3), usize_t().into()];
731730
assert_eq!(
732731
def.compute_signature(&args),
733732
Ok(Signature::new(
@@ -740,7 +739,7 @@ pub(super) mod test {
740739
// Second arg may be a variable (substitutable)
741740
let tyvar = Type::new_var_use(0, TypeBound::Copyable);
742741
let tyvars: Vec<Type> = vec![tyvar.clone(); 3];
743-
let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()];
742+
let args = [TypeArg::BoundedNat(3), tyvar.clone().into()];
744743
assert_eq!(
745744
def.compute_signature(&args),
746745
Ok(Signature::new(
@@ -761,7 +760,7 @@ pub(super) mod test {
761760
);
762761

763762
// First arg must be concrete, not a variable
764-
let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap());
763+
let kind = TypeParam::bounded_nat_type(NonZeroU64::new(5).unwrap());
765764
let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()];
766765
// We can't prevent this from getting into our compute_signature implementation:
767766
assert_eq!(
@@ -798,7 +797,7 @@ pub(super) mod test {
798797
extension_ref,
799798
)?;
800799
let tv = Type::new_var_use(0, TypeBound::Copyable);
801-
let args = [TypeArg::Type { ty: tv.clone() }];
800+
let args = [tv.clone().into()];
802801
let decls = [TypeBound::Copyable.into()];
803802
def.validate_args(&args, &decls).unwrap();
804803
assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv)));
@@ -807,9 +806,9 @@ pub(super) mod test {
807806
assert_eq!(
808807
def.compute_signature(&[arg.clone()]),
809808
Err(SignatureError::TypeArgMismatch(
810-
TypeArgError::TypeMismatch {
811-
param: TypeBound::Any.into(),
812-
arg
809+
TermTypeError::TypeMismatch {
810+
type_: TypeBound::Any.into(),
811+
term: arg,
813812
}
814813
))
815814
);

0 commit comments

Comments
 (0)