@@ -7,7 +7,7 @@ use crate::{
7
7
type_param:: { TypeArgVariable , TypeParam } ,
8
8
type_row:: TypeRowBase ,
9
9
CustomType , FuncTypeBase , MaybeRV , PolyFuncTypeBase , RowVariable , SumType , TypeArg ,
10
- TypeBase , TypeEnum ,
10
+ TypeBase , TypeBound , TypeEnum ,
11
11
} ,
12
12
Direction , Hugr , HugrView , IncomingPort , Node , Port ,
13
13
} ;
@@ -44,8 +44,21 @@ struct Context<'a> {
44
44
bump : & ' a Bump ,
45
45
/// Stores the terms that we have already seen to avoid duplicates.
46
46
term_map : FxHashMap < model:: Term < ' a > , model:: TermId > ,
47
+
47
48
/// The current scope for local variables.
49
+ ///
50
+ /// This is set to the id of the smallest enclosing node that defines a polymorphic type.
51
+ /// We use this when exporting local variables in terms.
48
52
local_scope : Option < model:: NodeId > ,
53
+
54
+ /// Constraints to be added to the local scope.
55
+ ///
56
+ /// When exporting a node that defines a polymorphic type, we use this field
57
+ /// to collect the constraints that need to be added to that polymorphic
58
+ /// type. Currently this is used to record `nonlinear` constraints on uses
59
+ /// of `TypeParam::Type` with a `TypeBound::Copyable` bound.
60
+ local_constraints : Vec < model:: TermId > ,
61
+
49
62
/// Mapping from extension operations to their declarations.
50
63
decl_operations : FxHashMap < ( ExtensionId , OpName ) , model:: NodeId > ,
51
64
}
@@ -63,6 +76,7 @@ impl<'a> Context<'a> {
63
76
term_map : FxHashMap :: default ( ) ,
64
77
local_scope : None ,
65
78
decl_operations : FxHashMap :: default ( ) ,
79
+ local_constraints : Vec :: new ( ) ,
66
80
}
67
81
}
68
82
@@ -173,9 +187,11 @@ impl<'a> Context<'a> {
173
187
}
174
188
175
189
fn with_local_scope < T > ( & mut self , node : model:: NodeId , f : impl FnOnce ( & mut Self ) -> T ) -> T {
176
- let old_scope = self . local_scope . replace ( node) ;
190
+ let prev_local_scope = self . local_scope . replace ( node) ;
191
+ let prev_local_constraints = std:: mem:: take ( & mut self . local_constraints ) ;
177
192
let result = f ( self ) ;
178
- self . local_scope = old_scope;
193
+ self . local_scope = prev_local_scope;
194
+ self . local_constraints = prev_local_constraints;
179
195
result
180
196
}
181
197
@@ -232,10 +248,11 @@ impl<'a> Context<'a> {
232
248
233
249
OpType :: FuncDefn ( func) => self . with_local_scope ( node_id, |this| {
234
250
let name = this. get_func_name ( node) . unwrap ( ) ;
235
- let ( params, signature) = this. export_poly_func_type ( & func. signature ) ;
251
+ let ( params, constraints , signature) = this. export_poly_func_type ( & func. signature ) ;
236
252
let decl = this. bump . alloc ( model:: FuncDecl {
237
253
name,
238
254
params,
255
+ constraints,
239
256
signature,
240
257
} ) ;
241
258
let extensions = this. export_ext_set ( & func. signature . body ( ) . extension_reqs ) ;
@@ -247,10 +264,11 @@ impl<'a> Context<'a> {
247
264
248
265
OpType :: FuncDecl ( func) => self . with_local_scope ( node_id, |this| {
249
266
let name = this. get_func_name ( node) . unwrap ( ) ;
250
- let ( params, func) = this. export_poly_func_type ( & func. signature ) ;
267
+ let ( params, constraints , func) = this. export_poly_func_type ( & func. signature ) ;
251
268
let decl = this. bump . alloc ( model:: FuncDecl {
252
269
name,
253
270
params,
271
+ constraints,
254
272
signature : func,
255
273
} ) ;
256
274
model:: Operation :: DeclareFunc { decl }
@@ -450,10 +468,11 @@ impl<'a> Context<'a> {
450
468
451
469
let decl = self . with_local_scope ( node, |this| {
452
470
let name = this. make_qualified_name ( opdef. extension ( ) , opdef. name ( ) ) ;
453
- let ( params, r#type) = this. export_poly_func_type ( poly_func_type) ;
471
+ let ( params, constraints , r#type) = this. export_poly_func_type ( poly_func_type) ;
454
472
let decl = this. bump . alloc ( model:: OperationDecl {
455
473
name,
456
474
params,
475
+ constraints,
457
476
r#type,
458
477
} ) ;
459
478
decl
@@ -671,22 +690,36 @@ impl<'a> Context<'a> {
671
690
regions. into_bump_slice ( )
672
691
}
673
692
693
+ /// Exports a polymorphic function type.
694
+ ///
695
+ /// The returned triple consists of:
696
+ /// - The static parameters of the polymorphic function type.
697
+ /// - The constraints of the polymorphic function type.
698
+ /// - The function type itself.
674
699
pub fn export_poly_func_type < RV : MaybeRV > (
675
700
& mut self ,
676
701
t : & PolyFuncTypeBase < RV > ,
677
- ) -> ( & ' a [ model:: Param < ' a > ] , model:: TermId ) {
702
+ ) -> ( & ' a [ model:: Param < ' a > ] , & ' a [ model :: TermId ] , model:: TermId ) {
678
703
let mut params = BumpVec :: with_capacity_in ( t. params ( ) . len ( ) , self . bump ) ;
704
+ let scope = self
705
+ . local_scope
706
+ . expect ( "exporting poly func type outside of local scope" ) ;
679
707
680
708
for ( i, param) in t. params ( ) . iter ( ) . enumerate ( ) {
681
709
let name = self . bump . alloc_str ( & i. to_string ( ) ) ;
682
- let r#type = self . export_type_param ( param) ;
683
- let param = model:: Param :: Implicit { name, r#type } ;
710
+ let r#type = self . export_type_param ( param, Some ( model:: LocalRef :: Index ( scope, i as _ ) ) ) ;
711
+ let param = model:: Param {
712
+ name,
713
+ r#type,
714
+ sort : model:: ParamSort :: Implicit ,
715
+ } ;
684
716
params. push ( param)
685
717
}
686
718
719
+ let constraints = self . bump . alloc_slice_copy ( & self . local_constraints ) ;
687
720
let body = self . export_func_type ( t. body ( ) ) ;
688
721
689
- ( params. into_bump_slice ( ) , body)
722
+ ( params. into_bump_slice ( ) , constraints , body)
690
723
}
691
724
692
725
pub fn export_type < RV : MaybeRV > ( & mut self , t : & TypeBase < RV > ) -> model:: TermId {
@@ -703,7 +736,6 @@ impl<'a> Context<'a> {
703
736
}
704
737
TypeEnum :: Function ( func) => self . export_func_type ( func) ,
705
738
TypeEnum :: Variable ( index, _) => {
706
- // This ignores the type bound for now
707
739
let node = self . local_scope . expect ( "local variable out of scope" ) ;
708
740
self . make_term ( model:: Term :: Var ( model:: LocalRef :: Index ( node, * index as _ ) ) )
709
741
}
@@ -794,20 +826,39 @@ impl<'a> Context<'a> {
794
826
self . make_term ( model:: Term :: List { items, tail : None } )
795
827
}
796
828
797
- pub fn export_type_param ( & mut self , t : & TypeParam ) -> model:: TermId {
829
+ /// Exports a `TypeParam` to a term.
830
+ ///
831
+ /// The `var` argument is set when the type parameter being exported is the
832
+ /// type of a parameter to a polymorphic definition. In that case we can
833
+ /// generate a `nonlinear` constraint for the type of runtime types marked as
834
+ /// `TypeBound::Copyable`.
835
+ pub fn export_type_param (
836
+ & mut self ,
837
+ t : & TypeParam ,
838
+ var : Option < model:: LocalRef < ' static > > ,
839
+ ) -> model:: TermId {
798
840
match t {
799
- // This ignores the type bound for now.
800
- TypeParam :: Type { .. } => self . make_term ( model:: Term :: Type ) ,
801
- // This ignores the type bound for now.
841
+ TypeParam :: Type { b } => {
842
+ if let ( Some ( var) , TypeBound :: Copyable ) = ( var, b) {
843
+ let term = self . make_term ( model:: Term :: Var ( var) ) ;
844
+ let non_linear = self . make_term ( model:: Term :: NonLinearConstraint { term } ) ;
845
+ self . local_constraints . push ( non_linear) ;
846
+ }
847
+
848
+ self . make_term ( model:: Term :: Type )
849
+ }
850
+ // This ignores the bound on the natural for now.
802
851
TypeParam :: BoundedNat { .. } => self . make_term ( model:: Term :: NatType ) ,
803
852
TypeParam :: String => self . make_term ( model:: Term :: StrType ) ,
804
853
TypeParam :: List { param } => {
805
- let item_type = self . export_type_param ( param) ;
854
+ let item_type = self . export_type_param ( param, None ) ;
806
855
self . make_term ( model:: Term :: ListType { item_type } )
807
856
}
808
857
TypeParam :: Tuple { params } => {
809
858
let items = self . bump . alloc_slice_fill_iter (
810
- params. iter ( ) . map ( |param| self . export_type_param ( param) ) ,
859
+ params
860
+ . iter ( )
861
+ . map ( |param| self . export_type_param ( param, None ) ) ,
811
862
) ;
812
863
let types = self . make_term ( model:: Term :: List { items, tail : None } ) ;
813
864
self . make_term ( model:: Term :: ApplyFull {
0 commit comments