Skip to content

Commit 3adb23c

Browse files
committed
Walker::expand returns self on complete wires
1 parent 3b92fdc commit 3adb23c

File tree

1 file changed

+44
-12
lines changed

1 file changed

+44
-12
lines changed

hugr-persistent/src/walker.rs

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ use std::{
6060
collections::{BTreeMap, BTreeSet},
6161
};
6262

63-
use itertools::Itertools;
63+
use itertools::{Either, Itertools};
6464
use thiserror::Error;
6565

6666
use hugr_core::{
@@ -203,10 +203,14 @@ impl<'a, R: Resolver> Walker<'a, R> {
203203
/// walkers, which together cover the same space of possible HUGRs, each
204204
/// having a different additional node pinned.
205205
///
206-
/// Return an iterator over all possible [`Walker`]s that can be created by
207-
/// pinning exactly one additional node (or one additonal commit with an
208-
/// empty wire) connected to `wire`. Each returned [`Walker`] represents
209-
/// a different alternative Hugr in the exploration space.
206+
/// If the wire is not complete yet, return an iterator over all possible
207+
/// [`Walker`]s that can be created by pinning exactly one additional
208+
/// node (or one additonal commit with an empty wire) connected to
209+
/// `wire`. Each returned [`Walker`] represents a different alternative
210+
/// Hugr in the exploration space.
211+
///
212+
/// If the wire is already complete, return an iterator containing one
213+
/// walker: the current walker unchanged.
210214
///
211215
/// Optionally, the expansion can be restricted to only ports with the given
212216
/// direction (incoming or outgoing).
@@ -223,6 +227,10 @@ impl<'a, R: Resolver> Walker<'a, R> {
223227
) -> impl Iterator<Item = Walker<'a, R>> + 'b {
224228
let dir = dir.into();
225229

230+
if self.is_complete(wire, dir) {
231+
return Either::Left(std::iter::once(self.clone()));
232+
}
233+
226234
// Find unpinned ports on the wire (satisfying the direction constraint)
227235
let unpinned_ports = self.wire_unpinned_ports(wire, dir);
228236

@@ -233,7 +241,7 @@ impl<'a, R: Resolver> Walker<'a, R> {
233241
.map(|(n, _, commits)| (n, commits))
234242
.unique();
235243

236-
pinnable_nodes.filter_map(|(pinnable_node, new_commits)| {
244+
let new_walkers = pinnable_nodes.filter_map(|(pinnable_node, new_commits)| {
237245
let contains_new_commit = || {
238246
new_commits
239247
.iter()
@@ -268,7 +276,9 @@ impl<'a, R: Resolver> Walker<'a, R> {
268276
};
269277
new_walker.try_pin_node(pinnable_node).ok()?;
270278
Some(new_walker)
271-
})
279+
});
280+
281+
Either::Right(new_walkers)
272282
}
273283

274284
/// Create a new commit from a set of complete pinned wires and a
@@ -436,6 +446,28 @@ impl<R: Clone> Walker<'_, R> {
436446
}
437447
}
438448

449+
#[cfg(test)]
450+
impl<R: Resolver> Walker<'_, R> {
451+
// Check walker equality by comparing pointers to the state space and
452+
// other fields. Only for testing purposes.
453+
fn ptr_eq(&self, other: &Self) -> bool {
454+
self.state_space.as_ref() as *const CommitStateSpace<R>
455+
== other.state_space.as_ref() as *const CommitStateSpace<R>
456+
&& self.pinned_nodes == other.pinned_nodes
457+
&& BTreeSet::from_iter(self.selected_commits.all_commit_ids())
458+
== BTreeSet::from_iter(other.selected_commits.all_commit_ids())
459+
}
460+
461+
/// Check if the Walker cannot be expanded further, i.e. expanding it
462+
/// returns the same Walker.
463+
fn no_more_expansion(&self, wire: &PersistentWire, dir: impl Into<Option<Direction>>) -> bool {
464+
let Some([new_walker]) = self.expand(wire, dir).collect_array() else {
465+
return false;
466+
};
467+
new_walker.ptr_eq(self)
468+
}
469+
}
470+
439471
impl<R> CommitStateSpace<R> {
440472
/// Given a node and port, return all child commits of the current `node`
441473
/// that delete `node` but keep at least one port linked to `(node, port)`.
@@ -546,7 +578,8 @@ mod tests {
546578
let in0 = walker.get_wire(base_and_node, IncomingPort::from(0));
547579

548580
// a single incoming port (already pinned) => no more expansion
549-
assert!(walker.expand(&in0, Direction::Incoming).next().is_none());
581+
assert!(walker.no_more_expansion(&in0, Direction::Incoming));
582+
550583
// commit 2 cannot be applied, because AND is pinned
551584
// => only base commit, or commit1
552585
let out_walkers = walker.expand(&in0, Direction::Outgoing).collect_vec();
@@ -555,7 +588,7 @@ mod tests {
555588
// new wire is complete (and thus cannot be expanded)
556589
let in0 = new_walker.get_wire(base_and_node, IncomingPort::from(0));
557590
assert!(new_walker.is_complete(&in0, None));
558-
assert!(new_walker.expand(&in0, None).next().is_none());
591+
assert!(new_walker.no_more_expansion(&in0, None));
559592

560593
// all nodes on wire are pinned
561594
let (not_node, _) = in0.single_outgoing_port(new_walker.as_hugr_view()).unwrap();
@@ -612,9 +645,8 @@ mod tests {
612645
assert!(walker.is_pinned(not4_node));
613646

614647
let not4_out = walker.get_wire(not4_node, OutgoingPort::from(0));
615-
let expanded_out = walker.expand(&not4_out, Direction::Outgoing).collect_vec();
616648
// a single outgoing port (already pinned) => no more expansion
617-
assert!(expanded_out.is_empty());
649+
assert!(walker.no_more_expansion(&not4_out, Direction::Outgoing));
618650

619651
// Three options:
620652
// - AND gate from base
@@ -639,7 +671,7 @@ mod tests {
639671
// new wire is complete (and thus cannot be expanded)
640672
let not4_out = new_walker.get_wire(not4_node, OutgoingPort::from(0));
641673
assert!(new_walker.is_complete(&not4_out, None));
642-
assert!(new_walker.expand(&not4_out, None).next().is_none());
674+
assert!(new_walker.no_more_expansion(&not4_out, None));
643675

644676
// all nodes on wire are pinned
645677
let (next_node, _) = not4_out

0 commit comments

Comments
 (0)