@@ -126,11 +126,16 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
126
126
/// of `self`.
127
127
///
128
128
/// The returned port will be in `replacement`, unless the wire in the
129
- /// replacement is empty, in which case it will another `host` port.
129
+ /// replacement is empty and `return_invalid` is
130
+ /// [`IncludeReplacementNodes::Valid`] (the default), in which case it
131
+ /// will be another `host` port. If [`IncludeReplacementNodes::All`] is
132
+ /// passed, the returned port will always be in `replacement` even if it
133
+ /// is invalid (i.e. it is an IO node in the replacement).
130
134
pub fn linked_replacement_output (
131
135
& self ,
132
136
port : impl Into < HostPort < HostNode , IncomingPort > > ,
133
137
host : & impl HugrView < Node = HostNode > ,
138
+ return_invalid : IncludeReplacementNodes ,
134
139
) -> Option < BoundaryPort < HostNode , OutgoingPort > > {
135
140
let HostPort ( node, port) = port. into ( ) ;
136
141
let pos = self
@@ -139,7 +144,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
139
144
. iter ( )
140
145
. position ( move |& ( n, p) | host. linked_inputs ( n, p) . contains ( & ( node, port) ) ) ?;
141
146
142
- Some ( self . linked_replacement_output_by_position ( pos, host) )
147
+ Some ( self . linked_replacement_output_by_position ( pos, host, return_invalid ) )
143
148
}
144
149
145
150
/// The outgoing port linked to the i-th output boundary edge of `subgraph`.
@@ -150,6 +155,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
150
155
& self ,
151
156
pos : usize ,
152
157
host : & impl HugrView < Node = HostNode > ,
158
+ return_invalid : IncludeReplacementNodes ,
153
159
) -> BoundaryPort < HostNode , OutgoingPort > {
154
160
debug_assert ! ( pos < self . subgraph( ) . signature( host) . output_count( ) ) ;
155
161
@@ -160,7 +166,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
160
166
. single_linked_output ( repl_out, pos)
161
167
. expect ( "valid dfg wire" ) ;
162
168
163
- if out_node != repl_inp {
169
+ if out_node != repl_inp || return_invalid == IncludeReplacementNodes :: All {
164
170
BoundaryPort :: Replacement ( out_node, out_port)
165
171
} else {
166
172
let ( in_node, in_port) = * self . subgraph . incoming_ports ( ) [ out_port. index ( ) ]
@@ -207,11 +213,17 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
207
213
/// of `self`.
208
214
///
209
215
/// The returned ports will be in `replacement`, unless the wires in the
210
- /// replacement are empty, in which case they are other `host` ports.
216
+ /// replacement are empty and `return_invalid` is
217
+ /// [`IncludeReplacementNodes::Valid`] (the default), in which case they
218
+ /// will be other `host` ports. If [`IncludeReplacementNodes::All`] is
219
+ /// passed, the returned ports will always be in
220
+ /// `replacement` even if they are invalid (i.e. they are an IO node in
221
+ /// the replacement).
211
222
pub fn linked_replacement_inputs < ' a > (
212
223
& ' a self ,
213
224
port : impl Into < HostPort < HostNode , OutgoingPort > > ,
214
225
host : & ' a impl HugrView < Node = HostNode > ,
226
+ return_invalid : IncludeReplacementNodes ,
215
227
) -> impl Iterator < Item = BoundaryPort < HostNode , IncomingPort > > + ' a {
216
228
let HostPort ( node, port) = port. into ( ) ;
217
229
let positions = self
@@ -223,26 +235,25 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
223
235
host. single_linked_output ( n, p) . expect ( "valid dfg wire" ) == ( node, port)
224
236
} ) ;
225
237
226
- positions. flat_map ( |pos| self . linked_replacement_inputs_by_position ( pos, host) )
238
+ positions. flat_map ( move |pos| {
239
+ self . linked_replacement_inputs_by_position ( pos, host, return_invalid)
240
+ } )
227
241
}
228
242
229
243
/// The incoming ports linked to the i-th input boundary edge of `subgraph`.
230
- ///
231
- /// The ports will be in `replacement` for all endpoints of the i-th input
232
- /// wire that are not the output node of `replacement` and be in `host`
233
- /// otherwise.
234
244
fn linked_replacement_inputs_by_position (
235
245
& self ,
236
246
pos : usize ,
237
247
host : & impl HugrView < Node = HostNode > ,
248
+ return_invalid : IncludeReplacementNodes ,
238
249
) -> impl Iterator < Item = BoundaryPort < HostNode , IncomingPort > > {
239
250
debug_assert ! ( pos < self . subgraph( ) . signature( host) . input_count( ) ) ;
240
251
241
252
let [ repl_inp, repl_out] = self . get_replacement_io ( ) ;
242
253
self . replacement
243
254
. linked_inputs ( repl_inp, pos)
244
255
. flat_map ( move |( in_node, in_port) | {
245
- if in_node != repl_out {
256
+ if in_node != repl_out || return_invalid == IncludeReplacementNodes :: All {
246
257
Either :: Left ( std:: iter:: once ( BoundaryPort :: Replacement ( in_node, in_port) ) )
247
258
} else {
248
259
let ( out_node, out_port) = self . subgraph . outgoing_ports ( ) [ in_port. index ( ) ] ;
@@ -316,8 +327,12 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
316
327
subgraph_outgoing_ports
317
328
. enumerate ( )
318
329
. flat_map ( |( pos, subg_np) | {
319
- self . linked_replacement_inputs_by_position ( pos, host)
320
- . filter_map ( move |np| Some ( ( np. as_replacement ( ) ?, subg_np) ) )
330
+ self . linked_replacement_inputs_by_position (
331
+ pos,
332
+ host,
333
+ IncludeReplacementNodes :: Valid ,
334
+ )
335
+ . filter_map ( move |np| Some ( ( np. as_replacement ( ) ?, subg_np) ) )
321
336
} )
322
337
. map ( |( ( repl_node, repl_port) , ( subgraph_node, subgraph_port) ) | {
323
338
(
@@ -359,7 +374,11 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
359
374
. enumerate ( )
360
375
. filter_map ( |( pos, subg_all) | {
361
376
let np = self
362
- . linked_replacement_output_by_position ( pos, host)
377
+ . linked_replacement_output_by_position (
378
+ pos,
379
+ host,
380
+ IncludeReplacementNodes :: Valid ,
381
+ )
363
382
. as_replacement ( ) ?;
364
383
Some ( ( np, subg_all) )
365
384
} )
@@ -406,8 +425,12 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
406
425
. enumerate ( )
407
426
. filter_map ( |( pos, subg_all) | {
408
427
Some ( (
409
- self . linked_replacement_output_by_position ( pos, host)
410
- . as_host ( ) ?,
428
+ self . linked_replacement_output_by_position (
429
+ pos,
430
+ host,
431
+ IncludeReplacementNodes :: Valid ,
432
+ )
433
+ . as_host ( ) ?,
411
434
subg_all,
412
435
) )
413
436
} )
@@ -517,7 +540,8 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
517
540
SimpleReplacement :: try_new ( subgraph, new_host, replacement. clone ( ) )
518
541
}
519
542
520
- /// Allows to get the [Self::invalidated_nodes] without requiring a [HugrView].
543
+ /// Allows to get the [Self::invalidated_nodes] without requiring a
544
+ /// [HugrView].
521
545
pub fn invalidation_set ( & self ) -> impl Iterator < Item = HostNode > {
522
546
self . subgraph . nodes ( ) . iter ( ) . copied ( )
523
547
}
@@ -540,6 +564,20 @@ impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
540
564
}
541
565
}
542
566
567
+ /// In [`SimpleReplacement`], some nodes in the replacement may not be valid
568
+ /// after the replacement is applied.
569
+ ///
570
+ /// This enum allows to filter out such nodes.
571
+ #[ derive( Debug , Clone , Copy , PartialEq , Eq , Hash , Default ) ]
572
+ #[ non_exhaustive]
573
+ pub enum IncludeReplacementNodes {
574
+ /// Only consider nodes that are valid after the replacement is applied.
575
+ #[ default]
576
+ Valid ,
577
+ /// Include all nodes, including potentially invalid ones.
578
+ All ,
579
+ }
580
+
543
581
/// Result of applying a [`SimpleReplacement`].
544
582
pub struct Outcome < HostNode = Node > {
545
583
/// Map from Node in replacement to corresponding Node in the result Hugr
@@ -652,7 +690,7 @@ pub(in crate::hugr::patch) mod test {
652
690
ModuleBuilder , endo_sig, inout_sig,
653
691
} ;
654
692
use crate :: extension:: prelude:: { bool_t, qb_t} ;
655
- use crate :: hugr:: patch:: simple_replace:: Outcome ;
693
+ use crate :: hugr:: patch:: simple_replace:: { IncludeReplacementNodes , Outcome } ;
656
694
use crate :: hugr:: patch:: { BoundaryPort , HostPort , PatchVerification , ReplacementPort } ;
657
695
use crate :: hugr:: views:: { HugrView , SiblingSubgraph } ;
658
696
use crate :: hugr:: { Hugr , HugrMut , Patch } ;
@@ -1145,7 +1183,11 @@ pub(in crate::hugr::patch) mod test {
1145
1183
1146
1184
// Test linked_replacement_inputs with empty replacement
1147
1185
let replacement_inputs: Vec < _ > = repl
1148
- . linked_replacement_inputs ( ( inp, OutgoingPort :: from ( 0 ) ) , & hugr)
1186
+ . linked_replacement_inputs (
1187
+ ( inp, OutgoingPort :: from ( 0 ) ) ,
1188
+ & hugr,
1189
+ IncludeReplacementNodes :: Valid ,
1190
+ )
1149
1191
. collect ( ) ;
1150
1192
1151
1193
assert_eq ! (
@@ -1158,8 +1200,12 @@ pub(in crate::hugr::patch) mod test {
1158
1200
// Test linked_replacement_output with empty replacement
1159
1201
let replacement_output = ( 0 ..4 )
1160
1202
. map ( |i| {
1161
- repl. linked_replacement_output ( ( out, IncomingPort :: from ( i) ) , & hugr)
1162
- . unwrap ( )
1203
+ repl. linked_replacement_output (
1204
+ ( out, IncomingPort :: from ( i) ) ,
1205
+ & hugr,
1206
+ IncludeReplacementNodes :: Valid ,
1207
+ )
1208
+ . unwrap ( )
1163
1209
} )
1164
1210
. collect_vec ( ) ;
1165
1211
@@ -1191,7 +1237,11 @@ pub(in crate::hugr::patch) mod test {
1191
1237
} ;
1192
1238
1193
1239
let replacement_inputs: Vec < _ > = repl
1194
- . linked_replacement_inputs ( ( inp, OutgoingPort :: from ( 0 ) ) , & hugr)
1240
+ . linked_replacement_inputs (
1241
+ ( inp, OutgoingPort :: from ( 0 ) ) ,
1242
+ & hugr,
1243
+ IncludeReplacementNodes :: Valid ,
1244
+ )
1195
1245
. collect ( ) ;
1196
1246
1197
1247
assert_eq ! (
@@ -1203,8 +1253,12 @@ pub(in crate::hugr::patch) mod test {
1203
1253
1204
1254
let replacement_output = ( 0 ..4 )
1205
1255
. map ( |i| {
1206
- repl. linked_replacement_output ( ( out, IncomingPort :: from ( i) ) , & hugr)
1207
- . unwrap ( )
1256
+ repl. linked_replacement_output (
1257
+ ( out, IncomingPort :: from ( i) ) ,
1258
+ & hugr,
1259
+ IncludeReplacementNodes :: Valid ,
1260
+ )
1261
+ . unwrap ( )
1208
1262
} )
1209
1263
. collect_vec ( ) ;
1210
1264
@@ -1241,7 +1295,11 @@ pub(in crate::hugr::patch) mod test {
1241
1295
} ;
1242
1296
1243
1297
let replacement_inputs: Vec < _ > = repl
1244
- . linked_replacement_inputs ( ( inp, OutgoingPort :: from ( 0 ) ) , & hugr)
1298
+ . linked_replacement_inputs (
1299
+ ( inp, OutgoingPort :: from ( 0 ) ) ,
1300
+ & hugr,
1301
+ IncludeReplacementNodes :: Valid ,
1302
+ )
1245
1303
. collect ( ) ;
1246
1304
1247
1305
assert_eq ! (
@@ -1257,8 +1315,12 @@ pub(in crate::hugr::patch) mod test {
1257
1315
1258
1316
let replacement_output = ( 0 ..4 )
1259
1317
. map ( |i| {
1260
- repl. linked_replacement_output ( ( out, IncomingPort :: from ( i) ) , & hugr)
1261
- . unwrap ( )
1318
+ repl. linked_replacement_output (
1319
+ ( out, IncomingPort :: from ( i) ) ,
1320
+ & hugr,
1321
+ IncludeReplacementNodes :: Valid ,
1322
+ )
1323
+ . unwrap ( )
1262
1324
} )
1263
1325
. collect_vec ( ) ;
1264
1326
0 commit comments