1
1
#![ allow( clippy:: type_complexity) ]
2
2
#![ warn( missing_docs) ]
3
- //! Replace types with other types across the Hugr.
3
+ //! Replace types with other types across the Hugr. See [ReplaceTypes] and [Linearizer].
4
4
//!
5
- //! Parametrized types and ops will be reparametrized taking into account the replacements,
6
- //! but any ops taking/returning the replaced types *not* as a result of parametrization,
7
- //! will also need to be replaced - see [ReplaceTypes::replace_op]. (Similarly [Const]s.)
5
+ use std:: borrow:: Cow ;
8
6
use std:: collections:: HashMap ;
9
7
use std:: sync:: Arc ;
10
8
11
9
use thiserror:: Error ;
12
10
11
+ use hugr_core:: builder:: { BuildError , BuildHandle , Dataflow } ;
13
12
use hugr_core:: extension:: { ExtensionId , OpDef , SignatureError , TypeDef } ;
14
13
use hugr_core:: hugr:: hugrmut:: HugrMut ;
15
14
use hugr_core:: ops:: constant:: { OpaqueValue , Sum } ;
15
+ use hugr_core:: ops:: handle:: DataflowOpID ;
16
16
use hugr_core:: ops:: {
17
17
AliasDefn , Call , CallIndirect , Case , Conditional , Const , DataflowBlock , ExitBlock , ExtensionOp ,
18
- FuncDecl , FuncDefn , Input , LoadConstant , LoadFunction , OpType , Output , Tag , TailLoop , Value ,
19
- CFG , DFG ,
18
+ FuncDecl , FuncDefn , Input , LoadConstant , LoadFunction , OpTrait , OpType , Output , Tag , TailLoop ,
19
+ Value , CFG , DFG ,
20
20
} ;
21
- use hugr_core:: types:: { CustomType , Transformable , Type , TypeArg , TypeEnum , TypeTransformer } ;
22
- use hugr_core:: { Hugr , Node } ;
21
+ use hugr_core:: types:: {
22
+ CustomType , Signature , Transformable , Type , TypeArg , TypeEnum , TypeTransformer ,
23
+ } ;
24
+ use hugr_core:: { Hugr , HugrView , Node , Wire } ;
23
25
24
26
use crate :: validation:: { ValidatePassError , ValidationLevel } ;
25
27
26
- /// A thing with which an Op (i.e. node) can be replaced
28
+ mod linearize;
29
+ pub use linearize:: { CallbackHandler , DelegatingLinearizer , LinearizeError , Linearizer } ;
30
+
31
+ /// A recipe for creating a dataflow Node - as a new child of a [DataflowParent]
32
+ /// or in order to replace an existing node.
33
+ ///
34
+ /// [DataflowParent]: hugr_core::ops::OpTag::DataflowParent
27
35
#[ derive( Clone , Debug , PartialEq ) ]
28
- pub enum OpReplacement {
29
- /// Keep the same node, change only the op (updating types of inputs/outputs)
36
+ pub enum NodeTemplate {
37
+ /// A single node - so if replacing an existing node, change only the op
30
38
SingleOp ( OpType ) ,
31
- /// Defines a sub-Hugr to splice in place of the op - a [CFG], [Conditional], [DFG]
32
- /// or [TailLoop], which must have the same inputs and outputs as the original op,
33
- /// modulo replacement.
39
+ /// Defines a sub-Hugr to insert, whose root becomes (or replaces) the desired Node.
40
+ /// The root must be a [CFG], [Conditional], [DFG] or [TailLoop].
34
41
// Not a FuncDefn, nor Case/DataflowBlock
35
- /// Note this will be of limited use before [monomorphization](super::monomorphize()) because
36
- /// the sub-Hugr will not be able to use type variables present in the op.
42
+ /// Note this will be of limited use before [monomorphization](super::monomorphize())
43
+ /// because the new subtree will not be able to use type variables present in the
44
+ /// parent Hugr or previous op.
37
45
// TODO: store also a vec<TypeParam>, and update Hugr::validate to take &[TypeParam]s
38
46
// (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709
39
47
CompoundOp ( Box < Hugr > ) ,
@@ -42,12 +50,33 @@ pub enum OpReplacement {
42
50
// So client should add the functions before replacement, then remove unused ones afterwards.)
43
51
}
44
52
45
- impl OpReplacement {
53
+ impl NodeTemplate {
54
+ /// Adds this instance to the specified [HugrMut] as a new node or subtree under a
55
+ /// given parent, returning the unique new child (of that parent) thus created
56
+ pub fn add_hugr ( self , hugr : & mut impl HugrMut , parent : Node ) -> Node {
57
+ match self {
58
+ NodeTemplate :: SingleOp ( op_type) => hugr. add_node_with_parent ( parent, op_type) ,
59
+ NodeTemplate :: CompoundOp ( new_h) => hugr. insert_hugr ( parent, * new_h) . new_root ,
60
+ }
61
+ }
62
+
63
+ /// Adds this instance to the specified [Dataflow] builder as a new node or subtree
64
+ pub fn add (
65
+ self ,
66
+ dfb : & mut impl Dataflow ,
67
+ inputs : impl IntoIterator < Item = Wire > ,
68
+ ) -> Result < BuildHandle < DataflowOpID > , BuildError > {
69
+ match self {
70
+ NodeTemplate :: SingleOp ( opty) => dfb. add_dataflow_op ( opty, inputs) ,
71
+ NodeTemplate :: CompoundOp ( h) => dfb. add_hugr_with_wires ( * h, inputs) ,
72
+ }
73
+ }
74
+
46
75
fn replace ( & self , hugr : & mut impl HugrMut , n : Node ) {
47
76
assert_eq ! ( hugr. children( n) . count( ) , 0 ) ;
48
77
let new_optype = match self . clone ( ) {
49
- OpReplacement :: SingleOp ( op_type) => op_type,
50
- OpReplacement :: CompoundOp ( new_h) => {
78
+ NodeTemplate :: SingleOp ( op_type) => op_type,
79
+ NodeTemplate :: CompoundOp ( new_h) => {
51
80
let new_root = hugr. insert_hugr ( n, * new_h) . new_root ;
52
81
let children = hugr. children ( new_root) . collect :: < Vec < _ > > ( ) ;
53
82
let root_opty = hugr. remove_node ( new_root) ;
@@ -59,16 +88,50 @@ impl OpReplacement {
59
88
} ;
60
89
* hugr. optype_mut ( n) = new_optype;
61
90
}
91
+
92
+ fn signature ( & self ) -> Option < Cow < ' _ , Signature > > {
93
+ match self {
94
+ NodeTemplate :: SingleOp ( op_type) => op_type,
95
+ NodeTemplate :: CompoundOp ( hugr) => hugr. root_type ( ) ,
96
+ }
97
+ . dataflow_signature ( )
98
+ }
62
99
}
63
100
64
101
/// A configuration of what types, ops, and constants should be replaced with what.
65
102
/// May be applied to a Hugr via [Self::run].
103
+ ///
104
+ /// Parametrized types and ops will be reparametrized taking into account the
105
+ /// replacements, but any ops taking/returning the replaced types *not* as a result of
106
+ /// parametrization, will also need to be replaced - see [Self::replace_op].
107
+ /// Similarly [Const]s.
108
+ ///
109
+ /// Types that are [Copyable](hugr_core::types::TypeBound::Copyable) may also be replaced
110
+ /// with types that are not, see [Linearizer].
111
+ ///
112
+ /// Note that although this pass may be used before [monomorphization], there are some
113
+ /// limitations (that do not apply if done after [monomorphization]):
114
+ /// * [NodeTemplate::CompoundOp] only works for operations that do not use type variables
115
+ /// * "Overrides" of specific instantiations of polymorphic types will not be detected if
116
+ /// the instantiations are created inside polymorphic functions. For example, suppose
117
+ /// we [Self::replace_type] type `A` with `X`, [Self::replace_parametrized_type]
118
+ /// container `MyList` with `List`, and [Self::replace_type] `MyList<A>` with
119
+ /// `SpecialListOfXs`. If a function `foo` polymorphic over a type variable `T` dealing
120
+ /// with `MyList<T>`s, that is called with type argument `A`, then `foo<T>` will be
121
+ /// updated to deal with `List<T>`s and the call `foo<A>` updated to `foo<X>`, but this
122
+ /// will still result in using `List<X>` rather than `SpecialListOfXs`. (However this
123
+ /// would be fine *after* [monomorphization]: the monomorphic definition of `foo_A`
124
+ /// would use `SpecialListOfXs`.)
125
+ /// * See also limitations noted for [Linearizer].
126
+ ///
127
+ /// [monomorphization]: super::monomorphize()
66
128
#[ derive( Clone , Default ) ]
67
129
pub struct ReplaceTypes {
68
130
type_map : HashMap < CustomType , Type > ,
69
131
param_types : HashMap < ParametricType , Arc < dyn Fn ( & [ TypeArg ] ) -> Option < Type > > > ,
70
- op_map : HashMap < OpHashWrapper , OpReplacement > ,
71
- param_ops : HashMap < ParametricOp , Arc < dyn Fn ( & [ TypeArg ] ) -> Option < OpReplacement > > > ,
132
+ linearize : DelegatingLinearizer ,
133
+ op_map : HashMap < OpHashWrapper , NodeTemplate > ,
134
+ param_ops : HashMap < ParametricOp , Arc < dyn Fn ( & [ TypeArg ] ) -> Option < NodeTemplate > > > ,
72
135
consts : HashMap <
73
136
CustomType ,
74
137
Arc < dyn Fn ( & OpaqueValue , & ReplaceTypes ) -> Result < Value , ReplaceTypesError > > ,
@@ -109,6 +172,8 @@ pub enum ReplaceTypesError {
109
172
SignatureError ( #[ from] SignatureError ) ,
110
173
#[ error( transparent) ]
111
174
ValidationError ( #[ from] ValidatePassError ) ,
175
+ #[ error( transparent) ]
176
+ LinearizeError ( #[ from] LinearizeError ) ,
112
177
}
113
178
114
179
impl ReplaceTypes {
@@ -157,16 +222,33 @@ impl ReplaceTypes {
157
222
// (depending on arguments - i.e. if src's TypeDefBound is anything other than
158
223
// `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying
159
224
// overapproximation. Moreover, these depend upon the *return type* of the Fn.
225
+ // It would be too awkward to require:
226
+ // dest_fn: impl Fn(&TypeArg) -> (Type,
227
+ // Fn(&Linearizer) -> NodeTemplate, // copy
228
+ // Fn(&Linearizer) -> NodeTemplate)` // discard
160
229
self . param_types . insert ( src. into ( ) , Arc :: new ( dest_fn) ) ;
161
230
}
162
231
232
+ /// Allows to configure how to deal with types/wires that were [Copyable]
233
+ /// but have become linear as a result of type-changing. Specifically,
234
+ /// the [Linearizer] is used whenever lowering produces an outport which both
235
+ /// * has a non-[Copyable] type - perhaps a direct substitution, or perhaps e.g.
236
+ /// as a result of changing the element type of a collection such as an [`array`]
237
+ /// * has other than one connected inport,
238
+ ///
239
+ /// [Copyable]: hugr_core::types::TypeBound::Copyable
240
+ /// [`array`]: hugr_core::std_extensions::collections::array::array_type
241
+ pub fn linearizer ( & mut self ) -> & mut DelegatingLinearizer {
242
+ & mut self . linearize
243
+ }
244
+
163
245
/// Configures this instance to change occurrences of `src` to `dest`.
164
246
/// Note that if `src` is an instance of a *parametrized* [OpDef], this takes
165
247
/// precedence over [Self::replace_parametrized_op] where the `src`s overlap. Thus,
166
248
/// this should only be used on already-*[monomorphize](super::monomorphize())d*
167
249
/// Hugrs, as substitution (parametric polymorphism) happening later will not respect
168
250
/// this replacement.
169
- pub fn replace_op ( & mut self , src : & ExtensionOp , dest : OpReplacement ) {
251
+ pub fn replace_op ( & mut self , src : & ExtensionOp , dest : NodeTemplate ) {
170
252
self . op_map . insert ( OpHashWrapper :: from ( src) , dest) ;
171
253
}
172
254
@@ -179,7 +261,7 @@ impl ReplaceTypes {
179
261
pub fn replace_parametrized_op (
180
262
& mut self ,
181
263
src : & OpDef ,
182
- dest_fn : impl Fn ( & [ TypeArg ] ) -> Option < OpReplacement > + ' static ,
264
+ dest_fn : impl Fn ( & [ TypeArg ] ) -> Option < NodeTemplate > + ' static ,
183
265
) {
184
266
self . param_ops . insert ( src. into ( ) , Arc :: new ( dest_fn) ) ;
185
267
}
@@ -221,6 +303,22 @@ impl ReplaceTypes {
221
303
let mut changed = false ;
222
304
for n in hugr. nodes ( ) . collect :: < Vec < _ > > ( ) {
223
305
changed |= self . change_node ( hugr, n) ?;
306
+ let new_dfsig = hugr. get_optype ( n) . dataflow_signature ( ) ;
307
+ if let Some ( new_sig) = new_dfsig
308
+ . filter ( |_| changed && n != hugr. root ( ) )
309
+ . map ( Cow :: into_owned)
310
+ {
311
+ for outp in new_sig. output_ports ( ) {
312
+ if !new_sig. out_port_type ( outp) . unwrap ( ) . copyable ( ) {
313
+ let targets = hugr. linked_inputs ( n, outp) . collect :: < Vec < _ > > ( ) ;
314
+ if targets. len ( ) != 1 {
315
+ hugr. disconnect ( n, outp) ;
316
+ let src = Wire :: new ( n, outp) ;
317
+ self . linearize . insert_copy_discard ( hugr, src, & targets) ?;
318
+ }
319
+ }
320
+ }
321
+ }
224
322
}
225
323
Ok ( changed)
226
324
}
@@ -452,7 +550,7 @@ mod test {
452
550
use hugr_core:: { hugr:: IdentList , type_row, Extension , HugrView } ;
453
551
use itertools:: Itertools ;
454
552
455
- use super :: { handlers:: list_const, OpReplacement , ReplaceTypes } ;
553
+ use super :: { handlers:: list_const, NodeTemplate , ReplaceTypes } ;
456
554
457
555
const PACKED_VEC : & str = "PackedVec" ;
458
556
const READ : & str = "read" ;
@@ -513,7 +611,7 @@ mod test {
513
611
}
514
612
515
613
fn lowerer ( ext : & Arc < Extension > ) -> ReplaceTypes {
516
- fn lowered_read ( args : & [ TypeArg ] ) -> Option < OpReplacement > {
614
+ fn lowered_read ( args : & [ TypeArg ] ) -> Option < NodeTemplate > {
517
615
let ty = just_elem_type ( args) ;
518
616
let mut dfb = DFGBuilder :: new ( inout_sig (
519
617
vec ! [ array_type( 64 , ty. clone( ) ) , i64_t( ) ] ,
@@ -532,7 +630,7 @@ mod test {
532
630
let [ res] = dfb
533
631
. build_unwrap_sum ( 1 , option_type ( Type :: from ( ty. clone ( ) ) ) , opt)
534
632
. unwrap ( ) ;
535
- Some ( OpReplacement :: CompoundOp ( Box :: new (
633
+ Some ( NodeTemplate :: CompoundOp ( Box :: new (
536
634
dfb. finish_hugr_with_outputs ( [ res] ) . unwrap ( ) ,
537
635
) ) )
538
636
}
@@ -545,7 +643,7 @@ mod test {
545
643
) ;
546
644
lw. replace_op (
547
645
& read_op ( ext, bool_t ( ) ) ,
548
- OpReplacement :: SingleOp (
646
+ NodeTemplate :: SingleOp (
549
647
ExtensionOp :: new ( ext. get_op ( "lowered_read_bool" ) . unwrap ( ) . clone ( ) , [ ] )
550
648
. unwrap ( )
551
649
. into ( ) ,
@@ -824,7 +922,7 @@ mod test {
824
922
e. get_op ( READ ) . unwrap ( ) . as_ref ( ) ,
825
923
Box :: new ( |args : & [ TypeArg ] | {
826
924
option_contents ( just_elem_type ( args) ) . map ( |elem| {
827
- OpReplacement :: SingleOp (
925
+ NodeTemplate :: SingleOp (
828
926
ListOp :: get
829
927
. with_type ( elem)
830
928
. to_extension_op ( )
0 commit comments