Skip to content

Commit c37fe20

Browse files
authored
feat: Add PersistentWire type (#2361)
Upon @acl-cqc's comment in #2349, I realised that that PR would be simpler if we tracked the wires traversed as we walk the data structure. For this reason, this PR introduces a `PersistentWire` type. As drive-bys, some method names are also homogenised. The code looks a lot clearer now imo.
1 parent b71b3f4 commit c37fe20

File tree

10 files changed

+446
-356
lines changed

10 files changed

+446
-356
lines changed

hugr-core/src/core.rs

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pub use itertools::Either;
77
use derive_more::From;
88
use itertools::Either::{Left, Right};
99

10-
use crate::hugr::HugrError;
10+
use crate::{HugrView, hugr::HugrError};
1111

1212
/// A handle to a node in the HUGR.
1313
#[derive(
@@ -219,17 +219,55 @@ impl<N: HugrNode> Wire<N> {
219219
Self(node, port.into())
220220
}
221221

222-
/// The node that this wire is connected to.
222+
/// Create a new wire from a node and a port that is connected to the wire.
223+
///
224+
/// If `port` is an incoming port, the wire is traversed to find the unique
225+
/// outgoing port that is connected to the wire. Otherwise, this is
226+
/// equivalent to constructing a wire using [`Wire::new`].
227+
///
228+
/// ## Panics
229+
///
230+
/// This will panic if the wire is not connected to a unique outgoing port.
231+
#[inline]
232+
pub fn from_connected_port(
233+
node: N,
234+
port: impl Into<Port>,
235+
hugr: &impl HugrView<Node = N>,
236+
) -> Self {
237+
let (node, outgoing) = match port.into().as_directed() {
238+
Either::Left(incoming) => hugr
239+
.single_linked_output(node, incoming)
240+
.expect("invalid dfg port"),
241+
Either::Right(outgoing) => (node, outgoing),
242+
};
243+
Self::new(node, outgoing)
244+
}
245+
246+
/// The node of the unique outgoing port that the wire is connected to.
223247
#[inline]
224248
pub fn node(&self) -> N {
225249
self.0
226250
}
227251

228-
/// The output port that this wire is connected to.
252+
/// The unique outgoing port that the wire is connected to.
229253
#[inline]
230254
pub fn source(&self) -> OutgoingPort {
231255
self.1
232256
}
257+
258+
/// Get all ports connected to the wire.
259+
///
260+
/// Return a chained iterator of the unique outgoing port, followed by all
261+
/// incoming ports connected to the wire.
262+
pub fn all_connected_ports<'h, H: HugrView<Node = N>>(
263+
&self,
264+
hugr: &'h H,
265+
) -> impl Iterator<Item = (N, Port)> + use<'h, N, H> {
266+
let node = self.node();
267+
let out_port = self.source();
268+
269+
std::iter::once((node, out_port.into())).chain(hugr.linked_ports(node, out_port))
270+
}
233271
}
234272

235273
impl<N: HugrNode> std::fmt::Display for Wire<N> {

hugr-core/src/hugr/views/root_checked/dfg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ macro_rules! impl_dataflow_parent_methods {
2828
.expect("valid DFG graph")
2929
}
3030

31-
/// Rewire the inputs and outputs of the DFG to modify its signature.
31+
/// Rewire the inputs and outputs of the nested DFG to modify its signature.
3232
///
3333
/// Reorder the outgoing resp. incoming wires at the input resp. output
3434
/// node of the DFG to modify the signature of the DFG HUGR. This will

hugr-persistent/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,13 @@ mod resolver;
7272
pub mod state_space;
7373
mod trait_impls;
7474
pub mod walker;
75+
mod wire;
7576

7677
pub use persistent_hugr::{Commit, PersistentHugr};
7778
pub use resolver::{PointerEqResolver, Resolver, SerdeHashResolver};
7879
pub use state_space::{CommitId, CommitStateSpace, InvalidCommit, PatchNode};
79-
pub use walker::{PinnedWire, Walker};
80+
pub use walker::Walker;
81+
pub use wire::PersistentWire;
8082

8183
/// A replacement operation that can be applied to a [`PersistentHugr`].
8284
pub type PersistentReplacement = hugr_core::SimpleReplacement<PatchNode>;

hugr-persistent/src/persistent_hugr.rs

Lines changed: 22 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
use std::{
2-
collections::{BTreeSet, HashMap, VecDeque},
2+
collections::{BTreeSet, HashMap},
33
mem, vec,
44
};
55

66
use delegate::delegate;
77
use derive_more::derive::From;
88
use hugr_core::{
9-
Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement,
9+
Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement,
1010
hugr::patch::{Patch, simple_replace},
1111
};
1212
use itertools::{Either, Itertools};
@@ -394,68 +394,14 @@ impl<R> PersistentHugr<R> {
394394
///
395395
/// Panics if `node` is not in `self` (in particular if it is deleted) or if
396396
/// `port` is not a value port in `node`.
397-
pub(crate) fn get_single_outgoing_port(
397+
pub(crate) fn single_outgoing_port(
398398
&self,
399399
node: PatchNode,
400400
port: impl Into<IncomingPort>,
401401
) -> (PatchNode, OutgoingPort) {
402-
let mut in_port = port.into();
403-
let PatchNode(commit_id, mut in_node) = node;
404-
405-
assert!(self.is_value_port(node, in_port), "not a dataflow wire");
406-
assert!(self.contains_node(node), "node not in self");
407-
408-
let hugr = self.commit_hugr(commit_id);
409-
let (mut out_node, mut out_port) = hugr
410-
.single_linked_output(in_node, in_port)
411-
.map(|(n, p)| (PatchNode(commit_id, n), p))
412-
.expect("invalid HUGR");
413-
414-
// invariant: (out_node, out_port) -> (in_node, in_port) is a boundary
415-
// edge, i.e. it never is the case that both are deleted by the same
416-
// child commit
417-
loop {
418-
let commit_id = out_node.0;
419-
420-
if let Some(deleted_by) = self.find_deleting_commit(out_node) {
421-
(out_node, out_port) = self
422-
.state_space
423-
.linked_child_output(PatchNode(commit_id, in_node), in_port, deleted_by)
424-
.expect("valid boundary edge");
425-
// update (in_node, in_port)
426-
(in_node, in_port) = {
427-
let new_commit_id = out_node.0;
428-
let hugr = self.commit_hugr(new_commit_id);
429-
hugr.linked_inputs(out_node.1, out_port)
430-
.find(|&(n, _)| {
431-
self.find_deleting_commit(PatchNode(commit_id, n)).is_none()
432-
})
433-
.expect("out_node is connected to output node (which is never deleted)")
434-
};
435-
} else if self
436-
.replacement(commit_id)
437-
.is_some_and(|repl| repl.get_replacement_io()[0] == out_node.1)
438-
{
439-
// out_node is an input node
440-
(out_node, out_port) = self
441-
.as_state_space()
442-
.linked_parent_input(PatchNode(commit_id, in_node), in_port);
443-
// update (in_node, in_port)
444-
(in_node, in_port) = {
445-
let new_commit_id = out_node.0;
446-
let hugr = self.commit_hugr(new_commit_id);
447-
hugr.linked_inputs(out_node.1, out_port)
448-
.find(|&(n, _)| {
449-
self.find_deleting_commit(PatchNode(new_commit_id, n))
450-
== Some(commit_id)
451-
})
452-
.expect("boundary edge must connect out_node to deleted node")
453-
};
454-
} else {
455-
// valid outgoing node!
456-
return (out_node, out_port);
457-
}
458-
}
402+
let w = self.get_wire(node, port.into());
403+
w.single_outgoing_port(self)
404+
.expect("found invalid dfg wire")
459405
}
460406

461407
/// All incoming ports that the given outgoing port is attached to.
@@ -464,99 +410,14 @@ impl<R> PersistentHugr<R> {
464410
///
465411
/// Panics if `out_node` is not in `self` (in particular if it is deleted)
466412
/// or if `out_port` is not a value port in `out_node`.
467-
pub(crate) fn get_all_incoming_ports(
413+
pub(crate) fn all_incoming_ports(
468414
&self,
469415
out_node: PatchNode,
470416
out_port: OutgoingPort,
471417
) -> impl Iterator<Item = (PatchNode, IncomingPort)> {
472-
assert!(
473-
self.is_value_port(out_node, out_port),
474-
"not a dataflow wire"
475-
);
476-
assert!(self.contains_node(out_node), "node not in self");
477-
478-
let mut visited = BTreeSet::new();
479-
// enqueue the outport and initialise the set of valid incoming ports
480-
// to the valid incoming ports in this commit
481-
let mut queue = VecDeque::from([(out_node, out_port)]);
482-
let mut valid_incoming_ports = BTreeSet::from_iter(
483-
self.commit_hugr(out_node.0)
484-
.linked_inputs(out_node.1, out_port)
485-
.map(|(in_node, in_port)| (PatchNode(out_node.0, in_node), in_port))
486-
.filter(|(in_node, _)| self.contains_node(*in_node)),
487-
);
488-
489-
// A simple BFS across the commit history to find all equivalent incoming ports.
490-
while let Some((out_node, out_port)) = queue.pop_front() {
491-
if !visited.insert((out_node, out_port)) {
492-
continue;
493-
}
494-
let commit_id = out_node.0;
495-
let hugr = self.commit_hugr(commit_id);
496-
let out_deleted_by = self.find_deleting_commit(out_node);
497-
let curr_repl_out = {
498-
let repl = self.replacement(commit_id);
499-
repl.map(|r| r.get_replacement_io()[1])
500-
};
501-
// incoming ports are of interest to us if
502-
// (i) they are connected to the output of a replacement (then there will be a
503-
// linked port in a parent commit), or
504-
// (ii) they are deleted by a child commit and are not equal to the out_node
505-
// (then there will be a linked port in a child commit)
506-
let is_linked_to_output = curr_repl_out.is_some_and(|curr_repl_out| {
507-
hugr.linked_inputs(out_node.1, out_port)
508-
.any(|(in_node, _)| in_node == curr_repl_out)
509-
});
510-
511-
let deleted_by_child: BTreeSet<_> = hugr
512-
.linked_inputs(out_node.1, out_port)
513-
.filter(|(in_node, _)| Some(in_node) != curr_repl_out.as_ref())
514-
.filter_map(|(in_node, _)| {
515-
self.find_deleting_commit(PatchNode(commit_id, in_node))
516-
.filter(|other_deleted_by|
517-
// (out_node, out_port) -> (in_node, in_port) is a boundary edge
518-
// into the child commit `other_deleted_by`
519-
(Some(other_deleted_by) != out_deleted_by.as_ref()))
520-
})
521-
.collect();
522-
523-
// Convert an incoming port to the unique outgoing port that it is linked to
524-
let to_outgoing_port = |(PatchNode(commit_id, in_node), in_port)| {
525-
let hugr = self.commit_hugr(commit_id);
526-
let (out_node, out_port) = hugr
527-
.single_linked_output(in_node, in_port)
528-
.expect("valid dfg wire");
529-
(PatchNode(commit_id, out_node), out_port)
530-
};
531-
532-
if is_linked_to_output {
533-
// Traverse boundary to parent(s)
534-
let new_ins = self
535-
.as_state_space()
536-
.linked_parent_outputs(out_node, out_port);
537-
for (in_node, in_port) in new_ins {
538-
if self.contains_node(in_node) {
539-
valid_incoming_ports.insert((in_node, in_port));
540-
}
541-
queue.push_back(to_outgoing_port((in_node, in_port)));
542-
}
543-
}
544-
545-
for child in deleted_by_child {
546-
// Traverse boundary to `child`
547-
let new_ins = self
548-
.as_state_space()
549-
.linked_child_inputs(out_node, out_port, child);
550-
for (in_node, in_port) in new_ins {
551-
if self.contains_node(in_node) {
552-
valid_incoming_ports.insert((in_node, in_port));
553-
}
554-
queue.push_back(to_outgoing_port((in_node, in_port)));
555-
}
556-
}
557-
}
558-
559-
valid_incoming_ports.into_iter()
418+
let w = self.get_wire(out_node, out_port);
419+
w.into_all_ports(self, Direction::Incoming)
420+
.map(|(node, port)| (node, port.as_incoming().unwrap()))
560421
}
561422

562423
delegate! {
@@ -578,7 +439,7 @@ impl<R> PersistentHugr<R> {
578439
/// All nodes will be PatchNodes with commit ID `commit_id`.
579440
pub fn inserted_nodes(&self, commit_id: CommitId) -> impl Iterator<Item = PatchNode> + '_;
580441
/// Get the replacement for `commit_id`.
581-
fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement<PatchNode>>;
442+
pub(crate) fn replacement(&self, commit_id: CommitId) -> Option<&SimpleReplacement<PatchNode>>;
582443
/// Get the Hugr inserted by `commit_id`.
583444
///
584445
/// This is either the replacement Hugr of a [`CommitData::Replacement`] or
@@ -628,14 +489,24 @@ impl<R> PersistentHugr<R> {
628489
.unique()
629490
}
630491

631-
fn find_deleting_commit(&self, node @ PatchNode(commit_id, _): PatchNode) -> Option<CommitId> {
492+
/// Get the child commit that deletes `node`.
493+
pub(crate) fn find_deleting_commit(
494+
&self,
495+
node @ PatchNode(commit_id, _): PatchNode,
496+
) -> Option<CommitId> {
632497
let mut children = self.state_space.children(commit_id);
633498
children.find(move |&child_id| {
634499
let child = self.get_commit(child_id);
635500
child.deleted_nodes().contains(&node)
636501
})
637502
}
638503

504+
/// Convert a node ID specific to a commit HUGR into a patch node in the
505+
/// [`PersistentHugr`].
506+
pub(crate) fn to_persistent_node(&self, node: Node, commit_id: CommitId) -> PatchNode {
507+
PatchNode(commit_id, node)
508+
}
509+
639510
/// Check if a patch node is in the PersistentHugr, that is, it belongs to
640511
/// a commit in the state space and is not deleted by any child commit.
641512
pub fn contains_node(&self, PatchNode(commit_id, node): PatchNode) -> bool {

0 commit comments

Comments
 (0)