Skip to content

Commit 1bf35d8

Browse files
committed
Handle variables by is_cached_static
1 parent 7d08db6 commit 1bf35d8

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

hugr-core/src/types/type_param.rs

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,7 @@ impl Term {
168168
}
169169
(Term::BytesType, Term::BytesType) => true,
170170
(Term::FloatType, Term::FloatType) => true,
171-
// This is definitely ok (invariance), but there might be other cases(?).
172-
// The recursive call checks the variable is actually a type.
173-
(Term::Variable(v1), Term::Variable(v2)) => {
174-
v1 == v2 && v1.cached_decl.is_supertype(&*v1.cached_decl)
175-
}
171+
(Term::Variable(v1), Term::Variable(v2)) => v1 == v2 && cached_is_static(v1),
176172
(
177173
Term::Runtime(_)
178174
| Term::BoundedNat(_)
@@ -191,6 +187,14 @@ impl Term {
191187
}
192188
}
193189

190+
fn cached_is_static(tv: &TermVar) -> bool {
191+
match &*tv.cached_decl {
192+
Term::Variable(tv) => cached_is_static(&*tv),
193+
Term::StaticType => true,
194+
_ => false,
195+
}
196+
}
197+
194198
impl From<TypeBound> for Term {
195199
fn from(bound: TypeBound) -> Self {
196200
Self::RuntimeType(bound)
@@ -308,8 +312,17 @@ impl Term {
308312
| Term::FloatType => Ok(()),
309313
Term::ListType(term) => term.validate_param(),
310314
Term::TupleType(terms) => terms.iter().try_for_each(Term::validate_param),
311-
// Variables are allowed as long as all legal instantiations are valid parameter types
312-
Term::Variable(TermVar { cached_decl, .. }) => cached_decl.validate_param(),
315+
// Variables are allowed as long as they could be a static type;
316+
// since StaticType is itself a StaticType, we must loop through chains
317+
// like `(param &b &a) (param ?c ?b) ...` arbitrarily: these could be
318+
// legal if enough of the first params are instantiated with `StaticType`
319+
Term::Variable(tv) => {
320+
if cached_is_static(tv) {
321+
Ok(())
322+
} else {
323+
Err(SignatureError::InvalidTypeParam(self.clone()))
324+
}
325+
}
313326
// The remainder are not static types
314327
Term::Runtime(_)
315328
| Term::BoundedNat(_)

0 commit comments

Comments
 (0)