@@ -60,7 +60,7 @@ use std::{
60
60
collections:: { BTreeMap , BTreeSet } ,
61
61
} ;
62
62
63
- use itertools:: Itertools ;
63
+ use itertools:: { Either , Itertools } ;
64
64
use thiserror:: Error ;
65
65
66
66
use hugr_core:: {
@@ -203,10 +203,14 @@ impl<'a, R: Resolver> Walker<'a, R> {
203
203
/// walkers, which together cover the same space of possible HUGRs, each
204
204
/// having a different additional node pinned.
205
205
///
206
- /// Return an iterator over all possible [`Walker`]s that can be created by
207
- /// pinning exactly one additional node (or one additonal commit with an
208
- /// empty wire) connected to `wire`. Each returned [`Walker`] represents
209
- /// a different alternative Hugr in the exploration space.
206
+ /// If the wire is not complete yet, return an iterator over all possible
207
+ /// [`Walker`]s that can be created by pinning exactly one additional
208
+ /// node (or one additonal commit with an empty wire) connected to
209
+ /// `wire`. Each returned [`Walker`] represents a different alternative
210
+ /// Hugr in the exploration space.
211
+ ///
212
+ /// If the wire is already complete, return an iterator containing one
213
+ /// walker: the current walker unchanged.
210
214
///
211
215
/// Optionally, the expansion can be restricted to only ports with the given
212
216
/// direction (incoming or outgoing).
@@ -223,6 +227,10 @@ impl<'a, R: Resolver> Walker<'a, R> {
223
227
) -> impl Iterator < Item = Walker < ' a , R > > + ' b {
224
228
let dir = dir. into ( ) ;
225
229
230
+ if self . is_complete ( wire, dir) {
231
+ return Either :: Left ( std:: iter:: once ( self . clone ( ) ) ) ;
232
+ }
233
+
226
234
// Find unpinned ports on the wire (satisfying the direction constraint)
227
235
let unpinned_ports = self . wire_unpinned_ports ( wire, dir) ;
228
236
@@ -233,7 +241,7 @@ impl<'a, R: Resolver> Walker<'a, R> {
233
241
. map ( |( n, _, commits) | ( n, commits) )
234
242
. unique ( ) ;
235
243
236
- pinnable_nodes. filter_map ( |( pinnable_node, new_commits) | {
244
+ let new_walkers = pinnable_nodes. filter_map ( |( pinnable_node, new_commits) | {
237
245
let contains_new_commit = || {
238
246
new_commits
239
247
. iter ( )
@@ -268,7 +276,9 @@ impl<'a, R: Resolver> Walker<'a, R> {
268
276
} ;
269
277
new_walker. try_pin_node ( pinnable_node) . ok ( ) ?;
270
278
Some ( new_walker)
271
- } )
279
+ } ) ;
280
+
281
+ Either :: Right ( new_walkers)
272
282
}
273
283
274
284
/// Create a new commit from a set of complete pinned wires and a
@@ -436,6 +446,28 @@ impl<R: Clone> Walker<'_, R> {
436
446
}
437
447
}
438
448
449
+ #[ cfg( test) ]
450
+ impl < R : Resolver > Walker < ' _ , R > {
451
+ // Check walker equality by comparing pointers to the state space and
452
+ // other fields. Only for testing purposes.
453
+ fn ptr_eq ( & self , other : & Self ) -> bool {
454
+ self . state_space . as_ref ( ) as * const CommitStateSpace < R >
455
+ == other. state_space . as_ref ( ) as * const CommitStateSpace < R >
456
+ && self . pinned_nodes == other. pinned_nodes
457
+ && BTreeSet :: from_iter ( self . selected_commits . all_commit_ids ( ) )
458
+ == BTreeSet :: from_iter ( other. selected_commits . all_commit_ids ( ) )
459
+ }
460
+
461
+ /// Check if the Walker cannot be expanded further, i.e. expanding it
462
+ /// returns the same Walker.
463
+ fn no_more_expansion ( & self , wire : & PersistentWire , dir : impl Into < Option < Direction > > ) -> bool {
464
+ let Some ( [ new_walker] ) = self . expand ( wire, dir) . collect_array ( ) else {
465
+ return false ;
466
+ } ;
467
+ new_walker. ptr_eq ( self )
468
+ }
469
+ }
470
+
439
471
impl < R > CommitStateSpace < R > {
440
472
/// Given a node and port, return all child commits of the current `node`
441
473
/// that delete `node` but keep at least one port linked to `(node, port)`.
@@ -546,7 +578,8 @@ mod tests {
546
578
let in0 = walker. get_wire ( base_and_node, IncomingPort :: from ( 0 ) ) ;
547
579
548
580
// a single incoming port (already pinned) => no more expansion
549
- assert ! ( walker. expand( & in0, Direction :: Incoming ) . next( ) . is_none( ) ) ;
581
+ assert ! ( walker. no_more_expansion( & in0, Direction :: Incoming ) ) ;
582
+
550
583
// commit 2 cannot be applied, because AND is pinned
551
584
// => only base commit, or commit1
552
585
let out_walkers = walker. expand ( & in0, Direction :: Outgoing ) . collect_vec ( ) ;
@@ -555,7 +588,7 @@ mod tests {
555
588
// new wire is complete (and thus cannot be expanded)
556
589
let in0 = new_walker. get_wire ( base_and_node, IncomingPort :: from ( 0 ) ) ;
557
590
assert ! ( new_walker. is_complete( & in0, None ) ) ;
558
- assert ! ( new_walker. expand ( & in0, None ) . next ( ) . is_none ( ) ) ;
591
+ assert ! ( new_walker. no_more_expansion ( & in0, None ) ) ;
559
592
560
593
// all nodes on wire are pinned
561
594
let ( not_node, _) = in0. single_outgoing_port ( new_walker. as_hugr_view ( ) ) . unwrap ( ) ;
@@ -612,9 +645,8 @@ mod tests {
612
645
assert ! ( walker. is_pinned( not4_node) ) ;
613
646
614
647
let not4_out = walker. get_wire ( not4_node, OutgoingPort :: from ( 0 ) ) ;
615
- let expanded_out = walker. expand ( & not4_out, Direction :: Outgoing ) . collect_vec ( ) ;
616
648
// a single outgoing port (already pinned) => no more expansion
617
- assert ! ( expanded_out . is_empty ( ) ) ;
649
+ assert ! ( walker . no_more_expansion ( & not4_out , Direction :: Outgoing ) ) ;
618
650
619
651
// Three options:
620
652
// - AND gate from base
@@ -639,7 +671,7 @@ mod tests {
639
671
// new wire is complete (and thus cannot be expanded)
640
672
let not4_out = new_walker. get_wire ( not4_node, OutgoingPort :: from ( 0 ) ) ;
641
673
assert ! ( new_walker. is_complete( & not4_out, None ) ) ;
642
- assert ! ( new_walker. expand ( & not4_out, None ) . next ( ) . is_none ( ) ) ;
674
+ assert ! ( new_walker. no_more_expansion ( & not4_out, None ) ) ;
643
675
644
676
// all nodes on wire are pinned
645
677
let ( next_node, _) = not4_out
0 commit comments