Skip to content

Commit 5da112b

Browse files
authored
feat: Add Root checked methods to DataflowParentID (#2382)
Previously, methods such as `get_io` and `map_function_type` were only defined on root checked Hugrs with root tag `DfgId`. I've turned this into a macro and implemented the types for `DataflowParentId` as well.
1 parent eaa7dfe commit 5da112b

File tree

1 file changed

+177
-121
lines changed
  • hugr-core/src/hugr/views/root_checked

1 file changed

+177
-121
lines changed

hugr-core/src/hugr/views/root_checked/dfg.rs

Lines changed: 177 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -8,138 +8,149 @@ use thiserror::Error;
88
use crate::{
99
IncomingPort, OutgoingPort, PortIndex,
1010
hugr::HugrMut,
11-
ops::{DFG, FuncDefn, Input, OpTrait, OpType, Output, dataflow::IOTrait, handle::DfgID},
11+
ops::{
12+
DFG, FuncDefn, Input, OpTrait, OpType, Output,
13+
dataflow::IOTrait,
14+
handle::{DataflowParentID, DfgID},
15+
},
1216
types::{NoRV, Signature, TypeBase},
1317
};
1418

1519
use super::RootChecked;
1620

17-
impl<H: HugrMut> RootChecked<H, DfgID<H::Node>> {
18-
/// Get the input and output nodes of the DFG at the entrypoint node.
19-
pub fn get_io(&self) -> [H::Node; 2] {
20-
self.hugr()
21-
.get_io(self.hugr().entrypoint())
22-
.expect("valid DFG graph")
23-
}
24-
25-
/// Rewire the inputs and outputs of the DFG to modify its signature.
26-
///
27-
/// Reorder the outgoing resp. incoming wires at the input resp. output
28-
/// node of the DFG to modify the signature of the DFG HUGR. This will
29-
/// recursively update the signatures of all ancestors of the entrypoint.
30-
///
31-
/// ### Arguments
32-
///
33-
/// * `new_inputs`: The new input signature. After the map, the i-th input
34-
/// wire will be connected to the ports connected to the
35-
/// `new_inputs[i]`-th input of the old DFG.
36-
/// * `new_outputs`: The new output signature. After the map, the i-th
37-
/// output wire will be connected to the ports connected to the
38-
/// `new_outputs[i]`-th output of the old DFG.
39-
///
40-
/// Returns an `InvalidSignature` error if the new_inputs and new_outputs
41-
/// map are not valid signatures.
42-
///
43-
/// ### Panics
44-
///
45-
/// Panics if the DFG is not trivially nested, i.e. if there is an ancestor
46-
/// DFG of the entrypoint that has more than one inner DFG.
47-
pub fn map_function_type(
48-
&mut self,
49-
new_inputs: &[usize],
50-
new_outputs: &[usize],
51-
) -> Result<(), InvalidSignature> {
52-
let [inp, out] = self.get_io();
53-
let Self(hugr, _) = self;
54-
55-
// Record the old connections from and to the input and output nodes
56-
let old_inputs_incoming = hugr
57-
.node_outputs(inp)
58-
.map(|p| hugr.linked_inputs(inp, p).collect_vec())
59-
.collect_vec();
60-
let old_outputs_outgoing = hugr
61-
.node_inputs(out)
62-
.map(|p| hugr.linked_outputs(out, p).collect_vec())
63-
.collect_vec();
64-
65-
// The old signature types
66-
let old_inp_sig = hugr
67-
.get_optype(inp)
68-
.dataflow_signature()
69-
.expect("input has signature");
70-
let old_inp_sig = old_inp_sig.output_types();
71-
let old_out_sig = hugr
72-
.get_optype(out)
73-
.dataflow_signature()
74-
.expect("output has signature");
75-
let old_out_sig = old_out_sig.input_types();
76-
77-
// Check if the signature map is valid
78-
check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?;
79-
check_valid_outputs(old_out_sig, new_outputs)?;
80-
81-
// The new signature types
82-
let new_inp_sig = new_inputs
83-
.iter()
84-
.map(|&i| old_inp_sig[i].clone())
85-
.collect_vec();
86-
let new_out_sig = new_outputs
87-
.iter()
88-
.map(|&i| old_out_sig[i].clone())
89-
.collect_vec();
90-
let new_sig = Signature::new(new_inp_sig, new_out_sig);
91-
92-
// Remove all edges of the input and output nodes
93-
disconnect_all(hugr, inp);
94-
disconnect_all(hugr, out);
95-
96-
// Update the signatures of the IO and their ancestors
97-
let mut is_ancestor = false;
98-
let mut node = hugr.entrypoint();
99-
while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
100-
let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
101-
for node in [node, inner_inp, inner_out] {
102-
update_signature(hugr, node, &new_sig);
21+
macro_rules! impl_dataflow_parent_methods {
22+
($handle_type:ident) => {
23+
impl<H: HugrMut> RootChecked<H, $handle_type<H::Node>> {
24+
/// Get the input and output nodes of the DFG at the entrypoint node.
25+
pub fn get_io(&self) -> [H::Node; 2] {
26+
self.hugr()
27+
.get_io(self.hugr().entrypoint())
28+
.expect("valid DFG graph")
10329
}
104-
if is_ancestor {
105-
update_inner_dfg_links(hugr, node);
106-
}
107-
if let Some(parent) = hugr.get_parent(node) {
108-
node = parent;
109-
is_ancestor = true;
110-
} else {
111-
break;
112-
}
113-
}
11430

115-
// Insert the new edges at the input
116-
let mut old_output_to_new_input = BTreeMap::<IncomingPort, OutgoingPort>::new();
117-
for (inp_pos, &old_pos) in new_inputs.iter().enumerate() {
118-
for &(node, port) in &old_inputs_incoming[old_pos] {
119-
if node != out {
120-
hugr.connect(inp, inp_pos, node, port);
121-
} else {
122-
old_output_to_new_input.insert(port, inp_pos.into());
31+
/// Rewire the inputs and outputs of the DFG to modify its signature.
32+
///
33+
/// Reorder the outgoing resp. incoming wires at the input resp. output
34+
/// node of the DFG to modify the signature of the DFG HUGR. This will
35+
/// recursively update the signatures of all ancestors of the entrypoint.
36+
///
37+
/// ### Arguments
38+
///
39+
/// * `new_inputs`: The new input signature. After the map, the i-th input
40+
/// wire will be connected to the ports connected to the
41+
/// `new_inputs[i]`-th input of the old DFG.
42+
/// * `new_outputs`: The new output signature. After the map, the i-th
43+
/// output wire will be connected to the ports connected to the
44+
/// `new_outputs[i]`-th output of the old DFG.
45+
///
46+
/// Returns an `InvalidSignature` error if the new_inputs and new_outputs
47+
/// map are not valid signatures.
48+
///
49+
/// ### Panics
50+
///
51+
/// Panics if the DFG is not trivially nested, i.e. if there is an ancestor
52+
/// DFG of the entrypoint that has more than one inner DFG.
53+
pub fn map_function_type(
54+
&mut self,
55+
new_inputs: &[usize],
56+
new_outputs: &[usize],
57+
) -> Result<(), InvalidSignature> {
58+
let [inp, out] = self.get_io();
59+
let Self(hugr, _) = self;
60+
61+
// Record the old connections from and to the input and output nodes
62+
let old_inputs_incoming = hugr
63+
.node_outputs(inp)
64+
.map(|p| hugr.linked_inputs(inp, p).collect_vec())
65+
.collect_vec();
66+
let old_outputs_outgoing = hugr
67+
.node_inputs(out)
68+
.map(|p| hugr.linked_outputs(out, p).collect_vec())
69+
.collect_vec();
70+
71+
// The old signature types
72+
let old_inp_sig = hugr
73+
.get_optype(inp)
74+
.dataflow_signature()
75+
.expect("input has signature");
76+
let old_inp_sig = old_inp_sig.output_types();
77+
let old_out_sig = hugr
78+
.get_optype(out)
79+
.dataflow_signature()
80+
.expect("output has signature");
81+
let old_out_sig = old_out_sig.input_types();
82+
83+
// Check if the signature map is valid
84+
check_valid_inputs(&old_inputs_incoming, old_inp_sig, new_inputs)?;
85+
check_valid_outputs(old_out_sig, new_outputs)?;
86+
87+
// The new signature types
88+
let new_inp_sig = new_inputs
89+
.iter()
90+
.map(|&i| old_inp_sig[i].clone())
91+
.collect_vec();
92+
let new_out_sig = new_outputs
93+
.iter()
94+
.map(|&i| old_out_sig[i].clone())
95+
.collect_vec();
96+
let new_sig = Signature::new(new_inp_sig, new_out_sig);
97+
98+
// Remove all edges of the input and output nodes
99+
disconnect_all(hugr, inp);
100+
disconnect_all(hugr, out);
101+
102+
// Update the signatures of the IO and their ancestors
103+
let mut is_ancestor = false;
104+
let mut node = hugr.entrypoint();
105+
while matches!(hugr.get_optype(node), OpType::FuncDefn(_) | OpType::DFG(_)) {
106+
let [inner_inp, inner_out] = hugr.get_io(node).expect("valid DFG graph");
107+
for node in [node, inner_inp, inner_out] {
108+
update_signature(hugr, node, &new_sig);
109+
}
110+
if is_ancestor {
111+
update_inner_dfg_links(hugr, node);
112+
}
113+
if let Some(parent) = hugr.get_parent(node) {
114+
node = parent;
115+
is_ancestor = true;
116+
} else {
117+
break;
118+
}
119+
}
120+
121+
// Insert the new edges at the input
122+
let mut old_output_to_new_input = BTreeMap::<IncomingPort, OutgoingPort>::new();
123+
for (inp_pos, &old_pos) in new_inputs.iter().enumerate() {
124+
for &(node, port) in &old_inputs_incoming[old_pos] {
125+
if node != out {
126+
hugr.connect(inp, inp_pos, node, port);
127+
} else {
128+
old_output_to_new_input.insert(port, inp_pos.into());
129+
}
130+
}
123131
}
124-
}
125-
}
126132

127-
// Insert the new edges at the output
128-
for (out_pos, &old_pos) in new_outputs.iter().enumerate() {
129-
for &(node, port) in &old_outputs_outgoing[old_pos] {
130-
if node != inp {
131-
hugr.connect(node, port, out, out_pos);
132-
} else {
133-
let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap();
134-
hugr.connect(inp, inp_pos, out, out_pos);
133+
// Insert the new edges at the output
134+
for (out_pos, &old_pos) in new_outputs.iter().enumerate() {
135+
for &(node, port) in &old_outputs_outgoing[old_pos] {
136+
if node != inp {
137+
hugr.connect(node, port, out, out_pos);
138+
} else {
139+
let &inp_pos = old_output_to_new_input.get(&old_pos.into()).unwrap();
140+
hugr.connect(inp, inp_pos, out, out_pos);
141+
}
142+
}
135143
}
144+
145+
Ok(())
136146
}
137147
}
138-
139-
Ok(())
140-
}
148+
};
141149
}
142150

151+
impl_dataflow_parent_methods!(DataflowParentID);
152+
impl_dataflow_parent_methods!(DfgID);
153+
143154
/// Panics if the DFG within `node` is not a single inner DFG.
144155
fn update_inner_dfg_links<H: HugrMut>(hugr: &mut H, node: H::Node) {
145156
// connect all edges of the inner DFG to the input and output nodes
@@ -272,7 +283,7 @@ mod test {
272283
};
273284
use crate::extension::prelude::{bool_t, qb_t};
274285
use crate::hugr::views::root_checked::RootChecked;
275-
use crate::ops::handle::{DfgID, NodeHandle};
286+
use crate::ops::handle::NodeHandle;
276287
use crate::ops::{NamedOp, OpParent};
277288
use crate::types::Signature;
278289
use crate::utils::test_quantum_extension::cx_gate;
@@ -290,6 +301,51 @@ mod test {
290301
let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
291302
let mut hugr = new_empty_dfg(sig);
292303

304+
// Wrap in RootChecked
305+
let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
306+
307+
// Test mapping inputs: [0,1] -> [1,0]
308+
let input_map = vec![1, 0];
309+
let output_map = vec![0, 1];
310+
311+
// Map the I/O
312+
dfg_view.map_function_type(&input_map, &output_map).unwrap();
313+
314+
// Verify the new signature
315+
let dfg_hugr = dfg_view.hugr();
316+
let new_sig = dfg_hugr
317+
.get_optype(dfg_hugr.entrypoint())
318+
.dataflow_signature()
319+
.unwrap();
320+
assert_eq!(new_sig.input_count(), 2);
321+
assert_eq!(new_sig.output_count(), 2);
322+
323+
// Test invalid mapping - missing input
324+
let invalid_input_map = vec![0, 0];
325+
let err = dfg_view.map_function_type(&invalid_input_map, &output_map);
326+
assert!(matches!(err, Err(InvalidSignature::MissingIO(1, "input"))));
327+
328+
// Test invalid mapping - duplicate input
329+
let invalid_input_map = vec![0, 0, 1];
330+
assert!(matches!(
331+
dfg_view.map_function_type(&invalid_input_map, &output_map),
332+
Err(InvalidSignature::DuplicateInput(0))
333+
));
334+
335+
// Test invalid mapping - unknown output
336+
let invalid_output_map = vec![0, 2];
337+
assert!(matches!(
338+
dfg_view.map_function_type(&input_map, &invalid_output_map),
339+
Err(InvalidSignature::UnknownIO(2, "output"))
340+
));
341+
}
342+
343+
#[test]
344+
fn test_map_io_dfg_id() {
345+
// Create a DFG with 2 inputs and 2 outputs
346+
let sig = Signature::new_endo(vec![qb_t(), qb_t()]);
347+
let mut hugr = new_empty_dfg(sig);
348+
293349
// Wrap in RootChecked
294350
let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
295351

@@ -337,7 +393,7 @@ mod test {
337393
let mut hugr = new_empty_dfg(sig);
338394

339395
// Wrap in RootChecked
340-
let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
396+
let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
341397

342398
// Test mapping outputs: [0] -> [0,0] (duplicating the output)
343399
let input_map = vec![0];
@@ -377,7 +433,7 @@ mod test {
377433
.unwrap();
378434

379435
// Wrap in RootChecked
380-
let mut dfg_view = RootChecked::<&mut Hugr, DfgID>::try_new(&mut hugr).unwrap();
436+
let mut dfg_view = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut hugr).unwrap();
381437

382438
// Test mapping inputs: [0,1] -> [1,0] (swapping inputs)
383439
let input_map = vec![1, 0];

0 commit comments

Comments
 (0)