Skip to content

Commit 6bd094f

Browse files
authored
feat: Emulate TypeBounds on parameters via constraints. (#1624)
This PR translates some `TypeBound`s in `hugr-core` to the `nonlinear` constraint in `hugr-model`. This translation only occurs on parameters that take a runtime type directly. As a driveby change before the model stabilises, this PR also moves constraints out of the parameter lists into their own list. In its previous form this could have led to confusions about which parameter a local variable index refers to when a constraint is situated between two parameters in the list. We also remove constraints from aliases for now. Closes #1637.
1 parent 9a43956 commit 6bd094f

File tree

13 files changed

+349
-196
lines changed

13 files changed

+349
-196
lines changed

hugr-core/src/export.rs

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
type_param::{TypeArgVariable, TypeParam},
88
type_row::TypeRowBase,
99
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
10-
TypeBase, TypeEnum,
10+
TypeBase, TypeBound, TypeEnum,
1111
},
1212
Direction, Hugr, HugrView, IncomingPort, Node, Port,
1313
};
@@ -44,8 +44,21 @@ struct Context<'a> {
4444
bump: &'a Bump,
4545
/// Stores the terms that we have already seen to avoid duplicates.
4646
term_map: FxHashMap<model::Term<'a>, model::TermId>,
47+
4748
/// The current scope for local variables.
49+
///
50+
/// This is set to the id of the smallest enclosing node that defines a polymorphic type.
51+
/// We use this when exporting local variables in terms.
4852
local_scope: Option<model::NodeId>,
53+
54+
/// Constraints to be added to the local scope.
55+
///
56+
/// When exporting a node that defines a polymorphic type, we use this field
57+
/// to collect the constraints that need to be added to that polymorphic
58+
/// type. Currently this is used to record `nonlinear` constraints on uses
59+
/// of `TypeParam::Type` with a `TypeBound::Copyable` bound.
60+
local_constraints: Vec<model::TermId>,
61+
4962
/// Mapping from extension operations to their declarations.
5063
decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>,
5164
}
@@ -63,6 +76,7 @@ impl<'a> Context<'a> {
6376
term_map: FxHashMap::default(),
6477
local_scope: None,
6578
decl_operations: FxHashMap::default(),
79+
local_constraints: Vec::new(),
6680
}
6781
}
6882

@@ -173,9 +187,11 @@ impl<'a> Context<'a> {
173187
}
174188

175189
fn with_local_scope<T>(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
176-
let old_scope = self.local_scope.replace(node);
190+
let prev_local_scope = self.local_scope.replace(node);
191+
let prev_local_constraints = std::mem::take(&mut self.local_constraints);
177192
let result = f(self);
178-
self.local_scope = old_scope;
193+
self.local_scope = prev_local_scope;
194+
self.local_constraints = prev_local_constraints;
179195
result
180196
}
181197

@@ -232,10 +248,11 @@ impl<'a> Context<'a> {
232248

233249
OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| {
234250
let name = this.get_func_name(node).unwrap();
235-
let (params, signature) = this.export_poly_func_type(&func.signature);
251+
let (params, constraints, signature) = this.export_poly_func_type(&func.signature);
236252
let decl = this.bump.alloc(model::FuncDecl {
237253
name,
238254
params,
255+
constraints,
239256
signature,
240257
});
241258
let extensions = this.export_ext_set(&func.signature.body().extension_reqs);
@@ -247,10 +264,11 @@ impl<'a> Context<'a> {
247264

248265
OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| {
249266
let name = this.get_func_name(node).unwrap();
250-
let (params, func) = this.export_poly_func_type(&func.signature);
267+
let (params, constraints, func) = this.export_poly_func_type(&func.signature);
251268
let decl = this.bump.alloc(model::FuncDecl {
252269
name,
253270
params,
271+
constraints,
254272
signature: func,
255273
});
256274
model::Operation::DeclareFunc { decl }
@@ -450,10 +468,11 @@ impl<'a> Context<'a> {
450468

451469
let decl = self.with_local_scope(node, |this| {
452470
let name = this.make_qualified_name(opdef.extension(), opdef.name());
453-
let (params, r#type) = this.export_poly_func_type(poly_func_type);
471+
let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type);
454472
let decl = this.bump.alloc(model::OperationDecl {
455473
name,
456474
params,
475+
constraints,
457476
r#type,
458477
});
459478
decl
@@ -671,22 +690,36 @@ impl<'a> Context<'a> {
671690
regions.into_bump_slice()
672691
}
673692

693+
/// Exports a polymorphic function type.
694+
///
695+
/// The returned triple consists of:
696+
/// - The static parameters of the polymorphic function type.
697+
/// - The constraints of the polymorphic function type.
698+
/// - The function type itself.
674699
pub fn export_poly_func_type<RV: MaybeRV>(
675700
&mut self,
676701
t: &PolyFuncTypeBase<RV>,
677-
) -> (&'a [model::Param<'a>], model::TermId) {
702+
) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) {
678703
let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump);
704+
let scope = self
705+
.local_scope
706+
.expect("exporting poly func type outside of local scope");
679707

680708
for (i, param) in t.params().iter().enumerate() {
681709
let name = self.bump.alloc_str(&i.to_string());
682-
let r#type = self.export_type_param(param);
683-
let param = model::Param::Implicit { name, r#type };
710+
let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _)));
711+
let param = model::Param {
712+
name,
713+
r#type,
714+
sort: model::ParamSort::Implicit,
715+
};
684716
params.push(param)
685717
}
686718

719+
let constraints = self.bump.alloc_slice_copy(&self.local_constraints);
687720
let body = self.export_func_type(t.body());
688721

689-
(params.into_bump_slice(), body)
722+
(params.into_bump_slice(), constraints, body)
690723
}
691724

692725
pub fn export_type<RV: MaybeRV>(&mut self, t: &TypeBase<RV>) -> model::TermId {
@@ -703,7 +736,6 @@ impl<'a> Context<'a> {
703736
}
704737
TypeEnum::Function(func) => self.export_func_type(func),
705738
TypeEnum::Variable(index, _) => {
706-
// This ignores the type bound for now
707739
let node = self.local_scope.expect("local variable out of scope");
708740
self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _)))
709741
}
@@ -794,20 +826,39 @@ impl<'a> Context<'a> {
794826
self.make_term(model::Term::List { items, tail: None })
795827
}
796828

797-
pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId {
829+
/// Exports a `TypeParam` to a term.
830+
///
831+
/// The `var` argument is set when the type parameter being exported is the
832+
/// type of a parameter to a polymorphic definition. In that case we can
833+
/// generate a `nonlinear` constraint for the type of runtime types marked as
834+
/// `TypeBound::Copyable`.
835+
pub fn export_type_param(
836+
&mut self,
837+
t: &TypeParam,
838+
var: Option<model::LocalRef<'static>>,
839+
) -> model::TermId {
798840
match t {
799-
// This ignores the type bound for now.
800-
TypeParam::Type { .. } => self.make_term(model::Term::Type),
801-
// This ignores the type bound for now.
841+
TypeParam::Type { b } => {
842+
if let (Some(var), TypeBound::Copyable) = (var, b) {
843+
let term = self.make_term(model::Term::Var(var));
844+
let non_linear = self.make_term(model::Term::NonLinearConstraint { term });
845+
self.local_constraints.push(non_linear);
846+
}
847+
848+
self.make_term(model::Term::Type)
849+
}
850+
// This ignores the bound on the natural for now.
802851
TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType),
803852
TypeParam::String => self.make_term(model::Term::StrType),
804853
TypeParam::List { param } => {
805-
let item_type = self.export_type_param(param);
854+
let item_type = self.export_type_param(param, None);
806855
self.make_term(model::Term::ListType { item_type })
807856
}
808857
TypeParam::Tuple { params } => {
809858
let items = self.bump.alloc_slice_fill_iter(
810-
params.iter().map(|param| self.export_type_param(param)),
859+
params
860+
.iter()
861+
.map(|param| self.export_type_param(param, None)),
811862
);
812863
let types = self.make_term(model::Term::List { items, tail: None });
813864
self.make_term(model::Term::ApplyFull {

hugr-core/src/import.rs

Lines changed: 72 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ struct Context<'a> {
115115
/// A map from `NodeId` to the imported `Node`.
116116
nodes: FxHashMap<model::NodeId, Node>,
117117

118-
/// The types of the local variables that are currently in scope.
119-
local_variables: FxIndexMap<&'a str, model::TermId>,
118+
/// The local variables that are currently in scope.
119+
local_variables: FxIndexMap<&'a str, LocalVar>,
120120

121121
custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>,
122122
}
@@ -155,20 +155,20 @@ impl<'a> Context<'a> {
155155
.ok_or_else(|| model::ModelError::RegionNotFound(region_id).into())
156156
}
157157

158-
/// Looks up a [`LocalRef`] within the current scope and returns its index and type.
158+
/// Looks up a [`LocalRef`] within the current scope.
159159
fn resolve_local_ref(
160160
&self,
161161
local_ref: &model::LocalRef,
162-
) -> Result<(usize, model::TermId), ImportError> {
162+
) -> Result<(usize, LocalVar), ImportError> {
163163
let term = match local_ref {
164164
model::LocalRef::Index(_, index) => self
165165
.local_variables
166166
.get_index(*index as usize)
167-
.map(|(_, term)| (*index as usize, *term)),
167+
.map(|(_, v)| (*index as usize, *v)),
168168
model::LocalRef::Named(name) => self
169169
.local_variables
170170
.get_full(name)
171-
.map(|(index, _, term)| (index, *term)),
171+
.map(|(index, _, v)| (index, *v)),
172172
};
173173

174174
term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into())
@@ -898,41 +898,49 @@ impl<'a> Context<'a> {
898898
self.with_local_socpe(|ctx| {
899899
let mut imported_params = Vec::with_capacity(decl.params.len());
900900

901-
for param in decl.params {
902-
// TODO: `PolyFuncType` should be able to handle constraints
903-
// and distinguish between implicit and explicit parameters.
904-
match param {
905-
model::Param::Implicit { name, r#type } => {
906-
imported_params.push(ctx.import_type_param(*r#type)?);
907-
ctx.local_variables.insert(name, *r#type);
908-
}
909-
model::Param::Explicit { name, r#type } => {
910-
imported_params.push(ctx.import_type_param(*r#type)?);
911-
ctx.local_variables.insert(name, *r#type);
912-
}
913-
model::Param::Constraint { constraint: _ } => {
914-
return Err(error_unsupported!("constraints"));
901+
ctx.local_variables.extend(
902+
decl.params
903+
.iter()
904+
.map(|param| (param.name, LocalVar::new(param.r#type))),
905+
);
906+
907+
for constraint in decl.constraints {
908+
match ctx.get_term(*constraint)? {
909+
model::Term::NonLinearConstraint { term } => {
910+
let model::Term::Var(var) = ctx.get_term(*term)? else {
911+
return Err(error_unsupported!(
912+
"constraint on term that is not a variable"
913+
));
914+
};
915+
916+
let var = ctx.resolve_local_ref(var)?.0;
917+
ctx.local_variables[var].bound = TypeBound::Copyable;
915918
}
919+
_ => return Err(error_unsupported!("constraint other than copy or discard")),
916920
}
917921
}
918922

923+
for (index, param) in decl.params.iter().enumerate() {
924+
// NOTE: `PolyFuncType` only has explicit type parameters at present.
925+
let bound = ctx.local_variables[index].bound;
926+
imported_params.push(ctx.import_type_param(param.r#type, bound)?);
927+
}
928+
919929
let body = ctx.import_func_type::<RV>(decl.signature)?;
920930
in_scope(ctx, PolyFuncTypeBase::new(imported_params, body))
921931
})
922932
}
923933

924934
/// Import a [`TypeParam`] from a term that represents a static type.
925-
fn import_type_param(&mut self, term_id: model::TermId) -> Result<TypeParam, ImportError> {
935+
fn import_type_param(
936+
&mut self,
937+
term_id: model::TermId,
938+
bound: TypeBound,
939+
) -> Result<TypeParam, ImportError> {
926940
match self.get_term(term_id)? {
927941
model::Term::Wildcard => Err(error_uninferred!("wildcard")),
928942

929-
model::Term::Type => {
930-
// As part of the migration from `TypeBound`s to constraints, we pretend that all
931-
// `TypeBound`s are copyable.
932-
Ok(TypeParam::Type {
933-
b: TypeBound::Copyable,
934-
})
935-
}
943+
model::Term::Type => Ok(TypeParam::Type { b: bound }),
936944

937945
model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")),
938946
model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")),
@@ -944,7 +952,9 @@ impl<'a> Context<'a> {
944952
model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")),
945953

946954
model::Term::ListType { item_type } => {
947-
let param = Box::new(self.import_type_param(*item_type)?);
955+
// At present `hugr-model` has no way to express that the item
956+
// type of a list must be copyable. Therefore we import it as `Any`.
957+
let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?);
948958
Ok(TypeParam::List { param })
949959
}
950960

@@ -958,15 +968,18 @@ impl<'a> Context<'a> {
958968
| model::Term::List { .. }
959969
| model::Term::ExtSet { .. }
960970
| model::Term::Adt { .. }
961-
| model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()),
971+
| model::Term::Control { .. }
972+
| model::Term::NonLinearConstraint { .. } => {
973+
Err(model::ModelError::TypeError(term_id).into())
974+
}
962975

963976
model::Term::ControlType => {
964977
Err(error_unsupported!("type of control types as `TypeParam`"))
965978
}
966979
}
967980
}
968981

969-
/// Import a `TypeArg` froma term that represents a static type or value.
982+
/// Import a `TypeArg` from a term that represents a static type or value.
970983
fn import_type_arg(&mut self, term_id: model::TermId) -> Result<TypeArg, ImportError> {
971984
match self.get_term(term_id)? {
972985
model::Term::Wildcard => Err(error_uninferred!("wildcard")),
@@ -975,8 +988,8 @@ impl<'a> Context<'a> {
975988
}
976989

977990
model::Term::Var(var) => {
978-
let (index, var_type) = self.resolve_local_ref(var)?;
979-
let decl = self.import_type_param(var_type)?;
991+
let (index, var) = self.resolve_local_ref(var)?;
992+
let decl = self.import_type_param(var.r#type, var.bound)?;
980993
Ok(TypeArg::new_var_use(index, decl))
981994
}
982995

@@ -1014,7 +1027,10 @@ impl<'a> Context<'a> {
10141027

10151028
model::Term::FuncType { .. }
10161029
| model::Term::Adt { .. }
1017-
| model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()),
1030+
| model::Term::Control { .. }
1031+
| model::Term::NonLinearConstraint { .. } => {
1032+
Err(model::ModelError::TypeError(term_id).into())
1033+
}
10181034
}
10191035
}
10201036

@@ -1115,7 +1131,10 @@ impl<'a> Context<'a> {
11151131
| model::Term::List { .. }
11161132
| model::Term::Control { .. }
11171133
| model::Term::ControlType
1118-
| model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()),
1134+
| model::Term::Nat(_)
1135+
| model::Term::NonLinearConstraint { .. } => {
1136+
Err(model::ModelError::TypeError(term_id).into())
1137+
}
11191138
}
11201139
}
11211140

@@ -1291,3 +1310,21 @@ impl<'a> Names<'a> {
12911310
Ok(Self { items })
12921311
}
12931312
}
1313+
1314+
/// Information about a local variable.
1315+
#[derive(Debug, Clone, Copy)]
1316+
struct LocalVar {
1317+
/// The type of the variable.
1318+
r#type: model::TermId,
1319+
/// The type bound of the variable.
1320+
bound: TypeBound,
1321+
}
1322+
1323+
impl LocalVar {
1324+
pub fn new(r#type: model::TermId) -> Self {
1325+
Self {
1326+
r#type,
1327+
bound: TypeBound::Any,
1328+
}
1329+
}
1330+
}

hugr-core/tests/model.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,10 @@ pub fn test_roundtrip_params() {
5858
"../../hugr-model/tests/fixtures/model-params.edn"
5959
)));
6060
}
61+
62+
#[test]
63+
pub fn test_roundtrip_constraints() {
64+
insta::assert_snapshot!(roundtrip(include_str!(
65+
"../../hugr-model/tests/fixtures/model-constraints.edn"
66+
)));
67+
}

0 commit comments

Comments
 (0)