@@ -8,138 +8,149 @@ use thiserror::Error;
8
8
use crate :: {
9
9
IncomingPort , OutgoingPort , PortIndex ,
10
10
hugr:: HugrMut ,
11
- ops:: { DFG , FuncDefn , Input , OpTrait , OpType , Output , dataflow:: IOTrait , handle:: DfgID } ,
11
+ ops:: {
12
+ DFG , FuncDefn , Input , OpTrait , OpType , Output ,
13
+ dataflow:: IOTrait ,
14
+ handle:: { DataflowParentID , DfgID } ,
15
+ } ,
12
16
types:: { NoRV , Signature , TypeBase } ,
13
17
} ;
14
18
15
19
use super :: RootChecked ;
16
20
17
- impl < H : HugrMut > RootChecked < H , DfgID < H :: Node > > {
18
- /// Get the input and output nodes of the DFG at the entrypoint node.
19
- pub fn get_io ( & self ) -> [ H :: Node ; 2 ] {
20
- self . hugr ( )
21
- . get_io ( self . hugr ( ) . entrypoint ( ) )
22
- . expect ( "valid DFG graph" )
23
- }
24
-
25
- /// Rewire the inputs and outputs of the DFG to modify its signature.
26
- ///
27
- /// Reorder the outgoing resp. incoming wires at the input resp. output
28
- /// node of the DFG to modify the signature of the DFG HUGR. This will
29
- /// recursively update the signatures of all ancestors of the entrypoint.
30
- ///
31
- /// ### Arguments
32
- ///
33
- /// * `new_inputs`: The new input signature. After the map, the i-th input
34
- /// wire will be connected to the ports connected to the
35
- /// `new_inputs[i]`-th input of the old DFG.
36
- /// * `new_outputs`: The new output signature. After the map, the i-th
37
- /// output wire will be connected to the ports connected to the
38
- /// `new_outputs[i]`-th output of the old DFG.
39
- ///
40
- /// Returns an `InvalidSignature` error if the new_inputs and new_outputs
41
- /// map are not valid signatures.
42
- ///
43
- /// ### Panics
44
- ///
45
- /// Panics if the DFG is not trivially nested, i.e. if there is an ancestor
46
- /// DFG of the entrypoint that has more than one inner DFG.
47
- pub fn map_function_type (
48
- & mut self ,
49
- new_inputs : & [ usize ] ,
50
- new_outputs : & [ usize ] ,
51
- ) -> Result < ( ) , InvalidSignature > {
52
- let [ inp, out] = self . get_io ( ) ;
53
- let Self ( hugr, _) = self ;
54
-
55
- // Record the old connections from and to the input and output nodes
56
- let old_inputs_incoming = hugr
57
- . node_outputs ( inp)
58
- . map ( |p| hugr. linked_inputs ( inp, p) . collect_vec ( ) )
59
- . collect_vec ( ) ;
60
- let old_outputs_outgoing = hugr
61
- . node_inputs ( out)
62
- . map ( |p| hugr. linked_outputs ( out, p) . collect_vec ( ) )
63
- . collect_vec ( ) ;
64
-
65
- // The old signature types
66
- let old_inp_sig = hugr
67
- . get_optype ( inp)
68
- . dataflow_signature ( )
69
- . expect ( "input has signature" ) ;
70
- let old_inp_sig = old_inp_sig. output_types ( ) ;
71
- let old_out_sig = hugr
72
- . get_optype ( out)
73
- . dataflow_signature ( )
74
- . expect ( "output has signature" ) ;
75
- let old_out_sig = old_out_sig. input_types ( ) ;
76
-
77
- // Check if the signature map is valid
78
- check_valid_inputs ( & old_inputs_incoming, old_inp_sig, new_inputs) ?;
79
- check_valid_outputs ( old_out_sig, new_outputs) ?;
80
-
81
- // The new signature types
82
- let new_inp_sig = new_inputs
83
- . iter ( )
84
- . map ( |& i| old_inp_sig[ i] . clone ( ) )
85
- . collect_vec ( ) ;
86
- let new_out_sig = new_outputs
87
- . iter ( )
88
- . map ( |& i| old_out_sig[ i] . clone ( ) )
89
- . collect_vec ( ) ;
90
- let new_sig = Signature :: new ( new_inp_sig, new_out_sig) ;
91
-
92
- // Remove all edges of the input and output nodes
93
- disconnect_all ( hugr, inp) ;
94
- disconnect_all ( hugr, out) ;
95
-
96
- // Update the signatures of the IO and their ancestors
97
- let mut is_ancestor = false ;
98
- let mut node = hugr. entrypoint ( ) ;
99
- while matches ! ( hugr. get_optype( node) , OpType :: FuncDefn ( _) | OpType :: DFG ( _) ) {
100
- let [ inner_inp, inner_out] = hugr. get_io ( node) . expect ( "valid DFG graph" ) ;
101
- for node in [ node, inner_inp, inner_out] {
102
- update_signature ( hugr, node, & new_sig) ;
21
+ macro_rules! impl_dataflow_parent_methods {
22
+ ( $handle_type: ident) => {
23
+ impl <H : HugrMut > RootChecked <H , $handle_type<H :: Node >> {
24
+ /// Get the input and output nodes of the DFG at the entrypoint node.
25
+ pub fn get_io( & self ) -> [ H :: Node ; 2 ] {
26
+ self . hugr( )
27
+ . get_io( self . hugr( ) . entrypoint( ) )
28
+ . expect( "valid DFG graph" )
103
29
}
104
- if is_ancestor {
105
- update_inner_dfg_links ( hugr, node) ;
106
- }
107
- if let Some ( parent) = hugr. get_parent ( node) {
108
- node = parent;
109
- is_ancestor = true ;
110
- } else {
111
- break ;
112
- }
113
- }
114
30
115
- // Insert the new edges at the input
116
- let mut old_output_to_new_input = BTreeMap :: < IncomingPort , OutgoingPort > :: new ( ) ;
117
- for ( inp_pos, & old_pos) in new_inputs. iter ( ) . enumerate ( ) {
118
- for & ( node, port) in & old_inputs_incoming[ old_pos] {
119
- if node != out {
120
- hugr. connect ( inp, inp_pos, node, port) ;
121
- } else {
122
- old_output_to_new_input. insert ( port, inp_pos. into ( ) ) ;
31
+ /// Rewire the inputs and outputs of the DFG to modify its signature.
32
+ ///
33
+ /// Reorder the outgoing resp. incoming wires at the input resp. output
34
+ /// node of the DFG to modify the signature of the DFG HUGR. This will
35
+ /// recursively update the signatures of all ancestors of the entrypoint.
36
+ ///
37
+ /// ### Arguments
38
+ ///
39
+ /// * `new_inputs`: The new input signature. After the map, the i-th input
40
+ /// wire will be connected to the ports connected to the
41
+ /// `new_inputs[i]`-th input of the old DFG.
42
+ /// * `new_outputs`: The new output signature. After the map, the i-th
43
+ /// output wire will be connected to the ports connected to the
44
+ /// `new_outputs[i]`-th output of the old DFG.
45
+ ///
46
+ /// Returns an `InvalidSignature` error if the new_inputs and new_outputs
47
+ /// map are not valid signatures.
48
+ ///
49
+ /// ### Panics
50
+ ///
51
+ /// Panics if the DFG is not trivially nested, i.e. if there is an ancestor
52
+ /// DFG of the entrypoint that has more than one inner DFG.
53
+ pub fn map_function_type(
54
+ & mut self ,
55
+ new_inputs: & [ usize ] ,
56
+ new_outputs: & [ usize ] ,
57
+ ) -> Result <( ) , InvalidSignature > {
58
+ let [ inp, out] = self . get_io( ) ;
59
+ let Self ( hugr, _) = self ;
60
+
61
+ // Record the old connections from and to the input and output nodes
62
+ let old_inputs_incoming = hugr
63
+ . node_outputs( inp)
64
+ . map( |p| hugr. linked_inputs( inp, p) . collect_vec( ) )
65
+ . collect_vec( ) ;
66
+ let old_outputs_outgoing = hugr
67
+ . node_inputs( out)
68
+ . map( |p| hugr. linked_outputs( out, p) . collect_vec( ) )
69
+ . collect_vec( ) ;
70
+
71
+ // The old signature types
72
+ let old_inp_sig = hugr
73
+ . get_optype( inp)
74
+ . dataflow_signature( )
75
+ . expect( "input has signature" ) ;
76
+ let old_inp_sig = old_inp_sig. output_types( ) ;
77
+ let old_out_sig = hugr
78
+ . get_optype( out)
79
+ . dataflow_signature( )
80
+ . expect( "output has signature" ) ;
81
+ let old_out_sig = old_out_sig. input_types( ) ;
82
+
83
+ // Check if the signature map is valid
84
+ check_valid_inputs( & old_inputs_incoming, old_inp_sig, new_inputs) ?;
85
+ check_valid_outputs( old_out_sig, new_outputs) ?;
86
+
87
+ // The new signature types
88
+ let new_inp_sig = new_inputs
89
+ . iter( )
90
+ . map( |& i| old_inp_sig[ i] . clone( ) )
91
+ . collect_vec( ) ;
92
+ let new_out_sig = new_outputs
93
+ . iter( )
94
+ . map( |& i| old_out_sig[ i] . clone( ) )
95
+ . collect_vec( ) ;
96
+ let new_sig = Signature :: new( new_inp_sig, new_out_sig) ;
97
+
98
+ // Remove all edges of the input and output nodes
99
+ disconnect_all( hugr, inp) ;
100
+ disconnect_all( hugr, out) ;
101
+
102
+ // Update the signatures of the IO and their ancestors
103
+ let mut is_ancestor = false ;
104
+ let mut node = hugr. entrypoint( ) ;
105
+ while matches!( hugr. get_optype( node) , OpType :: FuncDefn ( _) | OpType :: DFG ( _) ) {
106
+ let [ inner_inp, inner_out] = hugr. get_io( node) . expect( "valid DFG graph" ) ;
107
+ for node in [ node, inner_inp, inner_out] {
108
+ update_signature( hugr, node, & new_sig) ;
109
+ }
110
+ if is_ancestor {
111
+ update_inner_dfg_links( hugr, node) ;
112
+ }
113
+ if let Some ( parent) = hugr. get_parent( node) {
114
+ node = parent;
115
+ is_ancestor = true ;
116
+ } else {
117
+ break ;
118
+ }
119
+ }
120
+
121
+ // Insert the new edges at the input
122
+ let mut old_output_to_new_input = BTreeMap :: <IncomingPort , OutgoingPort >:: new( ) ;
123
+ for ( inp_pos, & old_pos) in new_inputs. iter( ) . enumerate( ) {
124
+ for & ( node, port) in & old_inputs_incoming[ old_pos] {
125
+ if node != out {
126
+ hugr. connect( inp, inp_pos, node, port) ;
127
+ } else {
128
+ old_output_to_new_input. insert( port, inp_pos. into( ) ) ;
129
+ }
130
+ }
123
131
}
124
- }
125
- }
126
132
127
- // Insert the new edges at the output
128
- for ( out_pos, & old_pos) in new_outputs. iter ( ) . enumerate ( ) {
129
- for & ( node, port) in & old_outputs_outgoing[ old_pos] {
130
- if node != inp {
131
- hugr. connect ( node, port, out, out_pos) ;
132
- } else {
133
- let & inp_pos = old_output_to_new_input. get ( & old_pos. into ( ) ) . unwrap ( ) ;
134
- hugr. connect ( inp, inp_pos, out, out_pos) ;
133
+ // Insert the new edges at the output
134
+ for ( out_pos, & old_pos) in new_outputs. iter( ) . enumerate( ) {
135
+ for & ( node, port) in & old_outputs_outgoing[ old_pos] {
136
+ if node != inp {
137
+ hugr. connect( node, port, out, out_pos) ;
138
+ } else {
139
+ let & inp_pos = old_output_to_new_input. get( & old_pos. into( ) ) . unwrap( ) ;
140
+ hugr. connect( inp, inp_pos, out, out_pos) ;
141
+ }
142
+ }
135
143
}
144
+
145
+ Ok ( ( ) )
136
146
}
137
147
}
138
-
139
- Ok ( ( ) )
140
- }
148
+ } ;
141
149
}
142
150
151
+ impl_dataflow_parent_methods ! ( DataflowParentID ) ;
152
+ impl_dataflow_parent_methods ! ( DfgID ) ;
153
+
143
154
/// Panics if the DFG within `node` is not a single inner DFG.
144
155
fn update_inner_dfg_links < H : HugrMut > ( hugr : & mut H , node : H :: Node ) {
145
156
// connect all edges of the inner DFG to the input and output nodes
@@ -272,7 +283,7 @@ mod test {
272
283
} ;
273
284
use crate :: extension:: prelude:: { bool_t, qb_t} ;
274
285
use crate :: hugr:: views:: root_checked:: RootChecked ;
275
- use crate :: ops:: handle:: { DfgID , NodeHandle } ;
286
+ use crate :: ops:: handle:: NodeHandle ;
276
287
use crate :: ops:: { NamedOp , OpParent } ;
277
288
use crate :: types:: Signature ;
278
289
use crate :: utils:: test_quantum_extension:: cx_gate;
@@ -290,6 +301,51 @@ mod test {
290
301
let sig = Signature :: new_endo ( vec ! [ qb_t( ) , qb_t( ) ] ) ;
291
302
let mut hugr = new_empty_dfg ( sig) ;
292
303
304
+ // Wrap in RootChecked
305
+ let mut dfg_view = RootChecked :: < & mut Hugr , DataflowParentID > :: try_new ( & mut hugr) . unwrap ( ) ;
306
+
307
+ // Test mapping inputs: [0,1] -> [1,0]
308
+ let input_map = vec ! [ 1 , 0 ] ;
309
+ let output_map = vec ! [ 0 , 1 ] ;
310
+
311
+ // Map the I/O
312
+ dfg_view. map_function_type ( & input_map, & output_map) . unwrap ( ) ;
313
+
314
+ // Verify the new signature
315
+ let dfg_hugr = dfg_view. hugr ( ) ;
316
+ let new_sig = dfg_hugr
317
+ . get_optype ( dfg_hugr. entrypoint ( ) )
318
+ . dataflow_signature ( )
319
+ . unwrap ( ) ;
320
+ assert_eq ! ( new_sig. input_count( ) , 2 ) ;
321
+ assert_eq ! ( new_sig. output_count( ) , 2 ) ;
322
+
323
+ // Test invalid mapping - missing input
324
+ let invalid_input_map = vec ! [ 0 , 0 ] ;
325
+ let err = dfg_view. map_function_type ( & invalid_input_map, & output_map) ;
326
+ assert ! ( matches!( err, Err ( InvalidSignature :: MissingIO ( 1 , "input" ) ) ) ) ;
327
+
328
+ // Test invalid mapping - duplicate input
329
+ let invalid_input_map = vec ! [ 0 , 0 , 1 ] ;
330
+ assert ! ( matches!(
331
+ dfg_view. map_function_type( & invalid_input_map, & output_map) ,
332
+ Err ( InvalidSignature :: DuplicateInput ( 0 ) )
333
+ ) ) ;
334
+
335
+ // Test invalid mapping - unknown output
336
+ let invalid_output_map = vec ! [ 0 , 2 ] ;
337
+ assert ! ( matches!(
338
+ dfg_view. map_function_type( & input_map, & invalid_output_map) ,
339
+ Err ( InvalidSignature :: UnknownIO ( 2 , "output" ) )
340
+ ) ) ;
341
+ }
342
+
343
+ #[ test]
344
+ fn test_map_io_dfg_id ( ) {
345
+ // Create a DFG with 2 inputs and 2 outputs
346
+ let sig = Signature :: new_endo ( vec ! [ qb_t( ) , qb_t( ) ] ) ;
347
+ let mut hugr = new_empty_dfg ( sig) ;
348
+
293
349
// Wrap in RootChecked
294
350
let mut dfg_view = RootChecked :: < & mut Hugr , DfgID > :: try_new ( & mut hugr) . unwrap ( ) ;
295
351
@@ -337,7 +393,7 @@ mod test {
337
393
let mut hugr = new_empty_dfg ( sig) ;
338
394
339
395
// Wrap in RootChecked
340
- let mut dfg_view = RootChecked :: < & mut Hugr , DfgID > :: try_new ( & mut hugr) . unwrap ( ) ;
396
+ let mut dfg_view = RootChecked :: < & mut Hugr , DataflowParentID > :: try_new ( & mut hugr) . unwrap ( ) ;
341
397
342
398
// Test mapping outputs: [0] -> [0,0] (duplicating the output)
343
399
let input_map = vec ! [ 0 ] ;
@@ -377,7 +433,7 @@ mod test {
377
433
. unwrap ( ) ;
378
434
379
435
// Wrap in RootChecked
380
- let mut dfg_view = RootChecked :: < & mut Hugr , DfgID > :: try_new ( & mut hugr) . unwrap ( ) ;
436
+ let mut dfg_view = RootChecked :: < & mut Hugr , DataflowParentID > :: try_new ( & mut hugr) . unwrap ( ) ;
381
437
382
438
// Test mapping inputs: [0,1] -> [1,0] (swapping inputs)
383
439
let input_map = vec ! [ 1 , 0 ] ;
0 commit comments