Skip to content

Commit 1ba9ef6

Browse files
authored
feat: Add PinnedSubgraph (#2402)
I've factored out the logic that computes the subgraph of a `PersistentHugr` when constructing a `SimpleReplacement` into a new type `PinnedSubgraph`. This is required so that commit factories in `tket2` are able to compute the matched subgraph and extract the matched subHUGRs. This is very similar to `SiblingSubgraph`, with the nuance that the same subgraph can be applied to multiple `PersistentHugr`s, as long as they all share the same pinned nodes (in particular when the different `PersistentHugr`s stem from the different expansions of the same `Walker`). Conversions to and from `SiblingSubgraph`s are provided.
1 parent 01c20d8 commit 1ba9ef6

File tree

5 files changed

+296
-106
lines changed

5 files changed

+296
-106
lines changed

hugr-persistent/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,15 @@ mod parents_view;
7070
mod persistent_hugr;
7171
mod resolver;
7272
pub mod state_space;
73+
pub mod subgraph;
7374
mod trait_impls;
7475
pub mod walker;
7576
mod wire;
7677

7778
pub use persistent_hugr::{Commit, PersistentHugr};
7879
pub use resolver::{PointerEqResolver, Resolver, SerdeHashResolver};
7980
pub use state_space::{CommitId, CommitStateSpace, InvalidCommit, PatchNode};
81+
pub use subgraph::PinnedSubgraph;
8082
pub use walker::Walker;
8183
pub use wire::PersistentWire;
8284

hugr-persistent/src/state_space.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use hugr_core::{
1313
BoundaryPort,
1414
simple_replace::{BoundaryMode, InvalidReplacement},
1515
},
16-
views::{InvalidSignature, sibling_subgraph::InvalidSubgraph},
16+
views::InvalidSignature,
1717
},
1818
ops::OpType,
1919
};
@@ -23,7 +23,7 @@ use thiserror::Error;
2323

2424
use crate::{
2525
Commit, PersistentHugr, PersistentReplacement, PointerEqResolver, Resolver,
26-
find_conflicting_node, parents_view::ParentsView,
26+
find_conflicting_node, parents_view::ParentsView, subgraph::InvalidPinnedSubgraph,
2727
};
2828

2929
pub mod serial;
@@ -618,7 +618,7 @@ pub enum InvalidCommit {
618618

619619
#[error("Invalid subgraph: {0}")]
620620
/// The subgraph of the replacement is not convex.
621-
InvalidSubgraph(#[from] InvalidSubgraph<PatchNode>),
621+
InvalidSubgraph(#[from] InvalidPinnedSubgraph),
622622

623623
/// The replacement of the commit is invalid.
624624
#[error("Invalid replacement: {0}")]

hugr-persistent/src/subgraph.rs

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
use std::collections::BTreeSet;
2+
3+
use hugr_core::{
4+
IncomingPort, OutgoingPort,
5+
hugr::views::{
6+
SiblingSubgraph,
7+
sibling_subgraph::{IncomingPorts, InvalidSubgraph, OutgoingPorts},
8+
},
9+
};
10+
use itertools::Itertools;
11+
use thiserror::Error;
12+
13+
use crate::{CommitId, PatchNode, PersistentHugr, PersistentWire, Resolver, Walker};
14+
15+
/// A set of pinned nodes and wires between them, along with a fixed input
16+
/// and output boundary, simmilar to [`SiblingSubgraph`].
17+
///
18+
/// Unlike [`SiblingSubgraph`], subgraph validity (in particular convexity) is
19+
/// not checked (and cannot be checked), as the same [`PinnedSubgraph`] may
20+
/// represent [`SiblingSubgraph`]s in different HUGRs.
21+
///
22+
/// Obtain a valid [`SiblingSubgraph`] for a specific [`PersistentHugr`] by
23+
/// calling [`PinnedSubgraph::to_sibling_subgraph`].
24+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
25+
pub struct PinnedSubgraph {
26+
/// The nodes of the induced subgraph.
27+
nodes: BTreeSet<PatchNode>,
28+
/// The input ports of the subgraph.
29+
///
30+
/// Grouped by input parameter. Each port must be unique and belong to a
31+
/// node in `nodes`.
32+
inputs: Vec<Vec<(PatchNode, IncomingPort)>>,
33+
/// The output ports of the subgraph.
34+
///
35+
/// Repeated ports are allowed and correspond to copying the output. Every
36+
/// port must belong to a node in `nodes`.
37+
outputs: Vec<(PatchNode, OutgoingPort)>,
38+
/// The commits that must be selected in the host for the subgraph to be
39+
/// valid.
40+
selected_commits: BTreeSet<CommitId>,
41+
}
42+
43+
impl From<SiblingSubgraph<PatchNode>> for PinnedSubgraph {
44+
fn from(subgraph: SiblingSubgraph<PatchNode>) -> Self {
45+
Self {
46+
inputs: subgraph.incoming_ports().clone(),
47+
outputs: subgraph.outgoing_ports().clone(),
48+
nodes: BTreeSet::from_iter(subgraph.nodes().iter().copied()),
49+
selected_commits: BTreeSet::new(),
50+
}
51+
}
52+
}
53+
54+
impl PinnedSubgraph {
55+
/// Create a new subgraph from a set of pinned nodes and wires.
56+
///
57+
/// All nodes must be pinned and all wires must be complete in the given
58+
/// `walker`.
59+
///
60+
/// Nodes that are not isolated, i.e. are attached to at least one wire in
61+
/// `wires` will be added implicitly to the graph and do not need to be
62+
/// explicitly listed in `nodes`.
63+
pub fn try_from_pinned<R: Resolver>(
64+
nodes: impl IntoIterator<Item = PatchNode>,
65+
wires: impl IntoIterator<Item = PersistentWire>,
66+
walker: &Walker<R>,
67+
) -> Result<Self, InvalidPinnedSubgraph> {
68+
let mut selected_commits = BTreeSet::new();
69+
let host = walker.as_hugr_view();
70+
let wires = wires.into_iter().collect_vec();
71+
let nodes = nodes.into_iter().collect_vec();
72+
73+
for w in wires.iter() {
74+
if !walker.is_complete(w, None) {
75+
return Err(InvalidPinnedSubgraph::IncompleteWire(w.clone()));
76+
}
77+
for id in w.owners() {
78+
if host.contains_id(id) {
79+
selected_commits.insert(id);
80+
} else {
81+
return Err(InvalidPinnedSubgraph::InvalidCommit(id));
82+
}
83+
}
84+
}
85+
86+
if let Some(&unpinned) = nodes.iter().find(|&&n| !walker.is_pinned(n)) {
87+
return Err(InvalidPinnedSubgraph::UnpinnedNode(unpinned));
88+
}
89+
90+
let (inputs, outputs, all_nodes) = Self::compute_io_ports(nodes, wires, host);
91+
92+
Ok(Self {
93+
selected_commits,
94+
nodes: all_nodes,
95+
inputs,
96+
outputs,
97+
})
98+
}
99+
100+
/// Create a new subgraph from a set of complete wires in `walker`.
101+
pub fn try_from_wires<R: Resolver>(
102+
wires: impl IntoIterator<Item = PersistentWire>,
103+
walker: &Walker<R>,
104+
) -> Result<Self, InvalidPinnedSubgraph> {
105+
Self::try_from_pinned(std::iter::empty(), wires, walker)
106+
}
107+
108+
/// Compute the input and output ports for the given pinned nodes and wires.
109+
///
110+
/// Return the input boundary ports, output boundary ports as well as the
111+
/// set of all nodes in the subgraph.
112+
pub fn compute_io_ports<R: Resolver>(
113+
nodes: impl IntoIterator<Item = PatchNode>,
114+
wires: impl IntoIterator<Item = PersistentWire>,
115+
host: &PersistentHugr<R>,
116+
) -> (
117+
IncomingPorts<PatchNode>,
118+
OutgoingPorts<PatchNode>,
119+
BTreeSet<PatchNode>,
120+
) {
121+
let mut wire_ports_incoming = BTreeSet::new();
122+
let mut wire_ports_outgoing = BTreeSet::new();
123+
124+
for w in wires {
125+
wire_ports_incoming.extend(w.all_incoming_ports(host));
126+
wire_ports_outgoing.extend(w.single_outgoing_port(host));
127+
}
128+
129+
let mut all_nodes = BTreeSet::from_iter(nodes);
130+
all_nodes.extend(wire_ports_incoming.iter().map(|&(n, _)| n));
131+
all_nodes.extend(wire_ports_outgoing.iter().map(|&(n, _)| n));
132+
133+
// (in/out) boundary: all in/out ports on the nodes of the wire, minus ports
134+
// that are part of the wires
135+
let inputs = all_nodes
136+
.iter()
137+
.flat_map(|&n| host.input_value_ports(n))
138+
.filter(|node_port| !wire_ports_incoming.contains(node_port))
139+
.map(|np| vec![np])
140+
.collect_vec();
141+
let outputs = all_nodes
142+
.iter()
143+
.flat_map(|&n| host.output_value_ports(n))
144+
.filter(|node_port| !wire_ports_outgoing.contains(node_port))
145+
.collect_vec();
146+
147+
(inputs, outputs, all_nodes)
148+
}
149+
150+
/// Convert the pinned subgraph to a [`SiblingSubgraph`] for the given
151+
/// `host`.
152+
///
153+
/// This will fail if any of the required selected commits are not in the
154+
/// host, if any of the nodes are invalid in the host (e.g. deleted by
155+
/// another commit in host), or if the subgraph is not convex.
156+
pub fn to_sibling_subgraph<R>(
157+
&self,
158+
host: &PersistentHugr<R>,
159+
) -> Result<SiblingSubgraph<PatchNode>, InvalidPinnedSubgraph> {
160+
if let Some(&unselected) = self
161+
.selected_commits
162+
.iter()
163+
.find(|&&id| !host.contains_id(id))
164+
{
165+
return Err(InvalidPinnedSubgraph::InvalidCommit(unselected));
166+
}
167+
168+
if let Some(invalid) = self.nodes.iter().find(|&&n| !host.contains_node(n)) {
169+
return Err(InvalidPinnedSubgraph::InvalidNode(*invalid));
170+
}
171+
172+
Ok(SiblingSubgraph::try_new(
173+
self.inputs.clone(),
174+
self.outputs.clone(),
175+
host,
176+
)?)
177+
}
178+
179+
/// Iterate over all the commits required by this pinned subgraph.
180+
pub fn selected_commits(&self) -> impl Iterator<Item = CommitId> + '_ {
181+
self.selected_commits.iter().copied()
182+
}
183+
184+
/// Iterate over all the nodes in this pinned subgraph.
185+
pub fn nodes(&self) -> impl Iterator<Item = PatchNode> + '_ {
186+
self.nodes.iter().copied()
187+
}
188+
189+
/// Returns the computed [`IncomingPorts`] of the subgraph.
190+
#[must_use]
191+
pub fn incoming_ports(&self) -> &IncomingPorts<PatchNode> {
192+
&self.inputs
193+
}
194+
195+
/// Returns the computed [`OutgoingPorts`] of the subgraph.
196+
#[must_use]
197+
pub fn outgoing_ports(&self) -> &OutgoingPorts<PatchNode> {
198+
&self.outputs
199+
}
200+
}
201+
202+
#[derive(Debug, Clone, Error)]
203+
#[non_exhaustive]
204+
pub enum InvalidPinnedSubgraph {
205+
#[error("Invalid subgraph: {0}")]
206+
InvalidSubgraph(#[from] InvalidSubgraph<PatchNode>),
207+
#[error("Invalid commit in host: {0}")]
208+
InvalidCommit(CommitId),
209+
#[error("Wire is not complete: {0:?}")]
210+
IncompleteWire(PersistentWire),
211+
#[error("Node is not pinned: {0}")]
212+
UnpinnedNode(PatchNode),
213+
#[error("Invalid node in host: {0}")]
214+
InvalidNode(PatchNode),
215+
}

hugr-persistent/src/walker.rs

Lines changed: 29 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -55,24 +55,16 @@
5555
//! versions of the graph simultaneously, without having to materialize
5656
//! each version separately.
5757
58-
use std::{
59-
borrow::Cow,
60-
collections::{BTreeMap, BTreeSet},
61-
};
58+
use std::{borrow::Cow, collections::BTreeSet};
6259

60+
use hugr_core::hugr::patch::simple_replace::BoundaryMode;
61+
use hugr_core::ops::handle::DataflowParentID;
6362
use itertools::{Either, Itertools};
6463
use thiserror::Error;
6564

66-
use hugr_core::{
67-
Direction, Hugr, HugrView, Port, PortIndex,
68-
hugr::{
69-
patch::simple_replace::BoundaryMode,
70-
views::{RootCheckable, SiblingSubgraph},
71-
},
72-
ops::handle::DfgID,
73-
};
65+
use hugr_core::{Direction, Hugr, HugrView, Port, PortIndex, hugr::views::RootCheckable};
7466

75-
use crate::{Commit, PersistentReplacement};
67+
use crate::{Commit, PersistentReplacement, PinnedSubgraph};
7668

7769
use crate::{PersistentWire, PointerEqResolver, resolver::Resolver};
7870

@@ -298,51 +290,20 @@ impl<'a, R: Resolver> Walker<'a, R> {
298290
/// This will panic if repl is not a DFG graph.
299291
pub fn try_create_commit(
300292
&self,
301-
wires: impl IntoIterator<Item = PersistentWire>,
302-
repl: impl RootCheckable<Hugr, DfgID>,
293+
subgraph: impl Into<PinnedSubgraph>,
294+
repl: impl RootCheckable<Hugr, DataflowParentID>,
303295
map_boundary: impl Fn(PatchNode, Port) -> Port,
304296
) -> Result<Commit, InvalidCommit> {
305-
let mut wire_ports_incoming = BTreeSet::new();
306-
let mut wire_ports_outgoing = BTreeSet::new();
307-
let mut additional_parents = BTreeMap::new();
308-
309-
for w in wires {
310-
if let Some((n, p)) = self.wire_unpinned_ports(&w, None).next() {
311-
return Err(InvalidCommit::IncompleteWire(n, p));
312-
}
313-
wire_ports_incoming.extend(w.all_incoming_ports(self.as_hugr_view()));
314-
wire_ports_outgoing.extend(w.single_outgoing_port(self.as_hugr_view()));
315-
for id in w.owners() {
316-
let commit = self
317-
.state_space
318-
.try_get_commit(id)
319-
.ok_or(InvalidCommit::UnknownParent(id))?
320-
.clone();
321-
additional_parents.insert(id, commit);
322-
}
323-
}
324-
325-
let mut all_nodes = BTreeSet::new();
326-
all_nodes.extend(wire_ports_incoming.iter().map(|&(n, _)| n));
327-
all_nodes.extend(wire_ports_outgoing.iter().map(|&(n, _)| n));
328-
329-
// (in/out) boundary: all in/out ports on the nodes of the wire, minus ports
330-
// that are part of the wires
331-
let incoming = all_nodes
332-
.iter()
333-
.flat_map(|&n| self.as_hugr_view().input_value_ports(n))
334-
.filter(|node_port| !wire_ports_incoming.contains(node_port))
335-
.map(|np| vec![np])
336-
.collect_vec();
337-
let outgoing = all_nodes
338-
.iter()
339-
.flat_map(|&n| self.as_hugr_view().output_value_ports(n))
340-
.filter(|node_port| !wire_ports_outgoing.contains(node_port))
341-
.collect_vec();
297+
let pinned_subgraph = subgraph.into();
298+
let subgraph = pinned_subgraph.to_sibling_subgraph(self.as_hugr_view())?;
299+
let selected_commits = pinned_subgraph
300+
.selected_commits()
301+
.map(|id| self.state_space.get_commit(id).clone());
342302

343303
let repl = {
344304
let mut repl = repl.try_into_checked().expect("replacement is not DFG");
345-
let new_inputs = incoming
305+
let new_inputs = subgraph
306+
.incoming_ports()
346307
.iter()
347308
.flatten() // because of singleton-vec wrapping above
348309
.map(|&(n, p)| {
@@ -352,7 +313,8 @@ impl<'a, R: Resolver> Walker<'a, R> {
352313
.index()
353314
})
354315
.collect_vec();
355-
let new_outputs = outgoing
316+
let new_outputs = subgraph
317+
.outgoing_ports()
356318
.iter()
357319
.map(|&(n, p)| {
358320
map_boundary(n, p.into())
@@ -362,11 +324,10 @@ impl<'a, R: Resolver> Walker<'a, R> {
362324
})
363325
.collect_vec();
364326
repl.map_function_type(&new_inputs, &new_outputs)?;
365-
let subgraph = SiblingSubgraph::try_new(incoming, outgoing, self.as_hugr_view())?;
366327
PersistentReplacement::try_new(subgraph, self.as_hugr_view(), repl.into_hugr())?
367328
};
368329

369-
Commit::try_new(repl, additional_parents.into_values(), &self.state_space)
330+
Commit::try_new(repl, selected_commits, &self.state_space)
370331
}
371332
}
372333

@@ -807,14 +768,18 @@ mod tests {
807768
dfg_builder.finish_hugr_with_outputs(inputs).unwrap()
808769
};
809770
let commit = walker
810-
.try_create_commit(vec![wire], empty_hugr, |node, port| {
811-
assert_eq!(port.index(), 0);
812-
assert!([not0, not2].contains(&node));
813-
match port.direction() {
814-
Direction::Incoming => OutgoingPort::from(0).into(),
815-
Direction::Outgoing => IncomingPort::from(0).into(),
816-
}
817-
})
771+
.try_create_commit(
772+
PinnedSubgraph::try_from_pinned(std::iter::empty(), [wire], &walker).unwrap(),
773+
empty_hugr,
774+
|node, port| {
775+
assert_eq!(port.index(), 0);
776+
assert!([not0, not2].contains(&node));
777+
match port.direction() {
778+
Direction::Incoming => OutgoingPort::from(0).into(),
779+
Direction::Outgoing => IncomingPort::from(0).into(),
780+
}
781+
},
782+
)
818783
.unwrap();
819784

820785
let mut new_state_space = hugr.as_state_space().to_owned();

0 commit comments

Comments
 (0)