Skip to content

Commit a31ccbc

Browse files
zrhoaborgna-q
andauthored
fix: Order hints on input and output nodes. (#2422)
Allows order hints involving `core` input and output nodes. Closes #2399. --------- Co-authored-by: Agustín Borgna <121866228+aborgna-q@users.noreply.github.com>
1 parent b4b9433 commit a31ccbc

File tree

6 files changed

+157
-85
lines changed

6 files changed

+157
-85
lines changed

hugr-core/src/export.rs

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -628,41 +628,47 @@ impl<'a> Context<'a> {
628628
let children = self.hugr.children(node);
629629
let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump);
630630

631-
let mut output_node = None;
632-
633631
for child in children {
634632
match self.hugr.get_optype(child) {
635633
OpType::Input(input) => {
636634
sources = self.make_ports(child, Direction::Outgoing, input.types.len());
637635
input_types = Some(&input.types);
636+
637+
if has_order_edges(self.hugr, child) {
638+
let key = self.make_term(model::Literal::Nat(child.index() as u64).into());
639+
meta.push(self.make_term_apply(model::ORDER_HINT_INPUT_KEY, &[key]));
640+
}
638641
}
639642
OpType::Output(output) => {
640643
targets = self.make_ports(child, Direction::Incoming, output.types.len());
641644
output_types = Some(&output.types);
642-
output_node = Some(child);
645+
646+
if has_order_edges(self.hugr, child) {
647+
let key = self.make_term(model::Literal::Nat(child.index() as u64).into());
648+
meta.push(self.make_term_apply(model::ORDER_HINT_OUTPUT_KEY, &[key]));
649+
}
643650
}
644-
child_optype => {
651+
_ => {
645652
if let Some(child_id) = self.export_node_shallow(child) {
646653
region_children.push(child_id);
647-
648-
// Record all order edges that originate from this node in metadata.
649-
let successors = child_optype
650-
.other_output_port()
651-
.into_iter()
652-
.flat_map(|port| self.hugr.linked_inputs(child, port))
653-
.map(|(successor, _)| successor)
654-
.filter(|successor| Some(*successor) != output_node);
655-
656-
for successor in successors {
657-
let a =
658-
self.make_term(model::Literal::Nat(child.index() as u64).into());
659-
let b = self
660-
.make_term(model::Literal::Nat(successor.index() as u64).into());
661-
meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b]));
662-
}
663654
}
664655
}
665656
}
657+
658+
// Record all order edges that originate from this node in metadata.
659+
let successors = self
660+
.hugr
661+
.get_optype(child)
662+
.other_output_port()
663+
.into_iter()
664+
.flat_map(|port| self.hugr.linked_inputs(child, port))
665+
.map(|(successor, _)| successor);
666+
667+
for successor in successors {
668+
let a = self.make_term(model::Literal::Nat(child.index() as u64).into());
669+
let b = self.make_term(model::Literal::Nat(successor.index() as u64).into());
670+
meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b]));
671+
}
666672
}
667673

668674
for child_id in &region_children {
@@ -1103,21 +1109,7 @@ impl<'a> Context<'a> {
11031109
}
11041110

11051111
fn export_node_order_metadata(&mut self, node: Node, meta: &mut Vec<table::TermId>) {
1106-
fn is_relevant_node(hugr: &Hugr, node: Node) -> bool {
1107-
let optype = hugr.get_optype(node);
1108-
!optype.is_input() && !optype.is_output()
1109-
}
1110-
1111-
let optype = self.hugr.get_optype(node);
1112-
1113-
let has_order_edges = Direction::BOTH
1114-
.iter()
1115-
.filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder))
1116-
.filter_map(|dir| optype.other_port(*dir))
1117-
.flat_map(|port| self.hugr.linked_ports(node, port))
1118-
.any(|(other, _)| is_relevant_node(self.hugr, other));
1119-
1120-
if has_order_edges {
1112+
if has_order_edges(self.hugr, node) {
11211113
let key = self.make_term(model::Literal::Nat(node.index() as u64).into());
11221114
meta.push(self.make_term_apply(model::ORDER_HINT_KEY, &[key]));
11231115
}
@@ -1232,6 +1224,18 @@ impl Links {
12321224
}
12331225
}
12341226

1227+
/// Returns `true` if a node has any incident order edges.
1228+
fn has_order_edges(hugr: &Hugr, node: Node) -> bool {
1229+
let optype = hugr.get_optype(node);
1230+
Direction::BOTH
1231+
.iter()
1232+
.filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder))
1233+
.filter_map(|dir| optype.other_port(*dir))
1234+
.flat_map(|port| hugr.linked_ports(node, port))
1235+
.next()
1236+
.is_some()
1237+
}
1238+
12351239
#[cfg(test)]
12361240
mod test {
12371241
use rstest::{fixture, rstest};

hugr-core/src/import.rs

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ impl From<ExtensionError> for ImportError {
114114
enum OrderHintError {
115115
/// Duplicate order hint key in the same region.
116116
#[error("duplicate order hint key {0}")]
117-
DuplicateKey(table::NodeId, u64),
117+
DuplicateKey(table::RegionId, u64),
118118
/// Order hint including a key not defined in the region.
119119
#[error("order hint with unknown key {0}")]
120120
UnknownKey(u64),
@@ -608,7 +608,7 @@ impl<'a> Context<'a> {
608608
self.import_node(*child, node)?;
609609
}
610610

611-
self.create_order_edges(region)?;
611+
self.create_order_edges(region, input, output)?;
612612

613613
for meta_item in region_data.meta {
614614
self.import_node_metadata(node, *meta_item)?;
@@ -622,13 +622,18 @@ impl<'a> Context<'a> {
622622
/// Create order edges between nodes of a dataflow region based on order hint metadata.
623623
///
624624
/// This method assumes that the nodes for the children of the region have already been imported.
625-
fn create_order_edges(&mut self, region_id: table::RegionId) -> Result<(), ImportError> {
625+
fn create_order_edges(
626+
&mut self,
627+
region_id: table::RegionId,
628+
input: Node,
629+
output: Node,
630+
) -> Result<(), ImportError> {
626631
let region_data = self.get_region(region_id)?;
627632
debug_assert_eq!(region_data.kind, model::RegionKind::DataFlow);
628633

629634
// Collect order hint keys
630635
// PERFORMANCE: It might be worthwhile to reuse the map to avoid allocations.
631-
let mut order_keys = FxHashMap::<u64, table::NodeId>::default();
636+
let mut order_keys = FxHashMap::<u64, Node>::default();
632637

633638
for child_id in region_data.children {
634639
let child_data = self.get_node(*child_id)?;
@@ -642,8 +647,42 @@ impl<'a> Context<'a> {
642647
continue;
643648
};
644649

645-
if order_keys.insert(*key, *child_id).is_some() {
646-
return Err(OrderHintError::DuplicateKey(*child_id, *key).into());
650+
// NOTE: The lookups here are expected to succeed since we only
651+
// process the order metadata after we have imported the nodes.
652+
let child_node = self.nodes[child_id];
653+
let child_optype = self.hugr.get_optype(child_node);
654+
655+
// Check that the node has order ports.
656+
// NOTE: This assumes that a node has an input order port iff it has an output one.
657+
if child_optype.other_output_port().is_none() {
658+
return Err(OrderHintError::NoOrderPort(*child_id).into());
659+
}
660+
661+
if order_keys.insert(*key, child_node).is_some() {
662+
return Err(OrderHintError::DuplicateKey(region_id, *key).into());
663+
}
664+
}
665+
}
666+
667+
// Collect the order hint keys for the input and output nodes
668+
for meta_id in region_data.meta {
669+
if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_INPUT_KEY)? {
670+
let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else {
671+
continue;
672+
};
673+
674+
if order_keys.insert(*key, input).is_some() {
675+
return Err(OrderHintError::DuplicateKey(region_id, *key).into());
676+
}
677+
}
678+
679+
if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_OUTPUT_KEY)? {
680+
let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else {
681+
continue;
682+
};
683+
684+
if order_keys.insert(*key, output).is_some() {
685+
return Err(OrderHintError::DuplicateKey(region_id, *key).into());
647686
}
648687
}
649688
}
@@ -665,24 +704,13 @@ impl<'a> Context<'a> {
665704
let a = order_keys.get(a).ok_or(OrderHintError::UnknownKey(*a))?;
666705
let b = order_keys.get(b).ok_or(OrderHintError::UnknownKey(*b))?;
667706

668-
// NOTE: The lookups here are expected to succeed since we only
669-
// process the order metadata after we have imported the nodes.
670-
let a_node = self.nodes[a];
671-
let b_node = self.nodes[b];
672-
673-
let a_port = self
674-
.hugr
675-
.get_optype(a_node)
676-
.other_output_port()
677-
.ok_or(OrderHintError::NoOrderPort(*a))?;
678-
679-
let b_port = self
680-
.hugr
681-
.get_optype(b_node)
682-
.other_input_port()
683-
.ok_or(OrderHintError::NoOrderPort(*b))?;
707+
// NOTE: The unwrap here must succeed:
708+
// - For all ordinary nodes we checked that they have an order port.
709+
// - Input and output nodes always have an order port.
710+
let a_port = self.hugr.get_optype(*a).other_output_port().unwrap();
711+
let b_port = self.hugr.get_optype(*b).other_input_port().unwrap();
684712

685-
self.hugr.connect(a_node, a_port, b_node, b_port);
713+
self.hugr.connect(*a, a_port, *b, b_port);
686714
}
687715

688716
Ok(())

hugr-core/tests/snapshots/model__roundtrip_order.snap

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,19 @@ expression: ast
88

99
(import core.meta.description)
1010

11+
(import core.order_hint.input_key)
12+
13+
(import core.order_hint.order)
14+
15+
(import arithmetic.int.types.int)
16+
1117
(import core.nat)
1218

1319
(import core.order_hint.key)
1420

15-
(import core.fn)
16-
17-
(import core.order_hint.order)
21+
(import core.order_hint.output_key)
1822

19-
(import arithmetic.int.types.int)
23+
(import core.fn)
2024

2125
(declare-operation
2226
arithmetic.int.ineg
@@ -49,9 +53,14 @@ expression: ast
4953
(arithmetic.int.types.int 6)
5054
(arithmetic.int.types.int 6)
5155
(arithmetic.int.types.int 6)]))
56+
(meta (core.order_hint.input_key 2))
57+
(meta (core.order_hint.order 2 4))
58+
(meta (core.order_hint.order 2 3))
59+
(meta (core.order_hint.output_key 3))
5260
(meta (core.order_hint.order 4 7))
5361
(meta (core.order_hint.order 5 6))
5462
(meta (core.order_hint.order 5 4))
63+
(meta (core.order_hint.order 5 3))
5564
(meta (core.order_hint.order 6 7))
5665
((arithmetic.int.ineg 6) [%0] [%4]
5766
(signature

hugr-model/src/v0/mod.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,26 @@ pub const COMPAT_CONST_JSON: &str = "compat.const_json";
287287
/// - **Result:** `core.meta`
288288
pub const ORDER_HINT_KEY: &str = "core.order_hint.key";
289289

290+
/// Metadata constructor for order hint keys on input nodes.
291+
///
292+
/// When the sources of a dataflow region are represented by an input operation
293+
/// within the region, this metadata can be attached the region to give the
294+
/// input node an order hint key.
295+
///
296+
/// - **Parameter:** `?key : core.nat`
297+
/// - **Result:** `core.meta`
298+
pub const ORDER_HINT_INPUT_KEY: &str = "core.order_hint.input_key";
299+
300+
/// Metadata constructor for order hint keys on output nodes.
301+
///
302+
/// When the targets of a dataflow region are represented by an output operation
303+
/// within the region, this metadata can be attached the region to give the
304+
/// output node an order hint key.
305+
///
306+
/// - **Parameter:** `?key : core.nat`
307+
/// - **Result:** `core.meta`
308+
pub const ORDER_HINT_OUTPUT_KEY: &str = "core.order_hint.output_key";
309+
290310
/// Metadata constructor for order hints.
291311
///
292312
/// When this metadata is attached to a dataflow region, it can indicate a

hugr-model/tests/fixtures/model-order.edn

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
(meta (core.order_hint.order 1 0))
2929
(meta (core.order_hint.order 2 3))
3030
(meta (core.order_hint.order 0 3))
31+
(meta (core.order_hint.input_key 4))
32+
(meta (core.order_hint.order 4 0))
33+
(meta (core.order_hint.order 4 5))
34+
(meta (core.order_hint.order 1 5))
35+
(meta (core.order_hint.output_key 5))
3136

3237
((arithmetic.int.ineg 6)
3338
[%0] [%4]

hugr-py/src/hugr/model/export.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def export_node(
7979
meta = self.export_json_meta(node)
8080

8181
# Add an order hint key to the node if necessary
82-
if _needs_order_key(self.hugr, node):
82+
if _has_order_links(self.hugr, node):
8383
meta.append(model.Apply("core.order_hint.key", [model.Literal(node.idx)]))
8484

8585
match node_data.op:
@@ -427,13 +427,27 @@ def export_region_dfg(self, node: Node) -> model.Region:
427427
for i in range(child_data._num_outs)
428428
]
429429

430+
if _has_order_links(self.hugr, child):
431+
meta.append(
432+
model.Apply(
433+
"core.order_hint.input_key", [model.Literal(child.idx)]
434+
)
435+
)
436+
430437
case Output() as op:
431438
target_types = model.List([type.to_model() for type in op.types])
432439
targets = [
433440
self.link_name(InPort(child, i))
434441
for i in range(child_data._num_inps)
435442
]
436443

444+
if _has_order_links(self.hugr, child):
445+
meta.append(
446+
model.Apply(
447+
"core.order_hint.output_key", [model.Literal(child.idx)]
448+
)
449+
)
450+
437451
case _:
438452
child_node = self.export_node(child)
439453

@@ -442,14 +456,13 @@ def export_region_dfg(self, node: Node) -> model.Region:
442456

443457
children.append(child_node)
444458

445-
meta += [
446-
model.Apply(
447-
"core.order_hint.order",
448-
[model.Literal(child.idx), model.Literal(successor.idx)],
449-
)
450-
for successor in self.hugr.outgoing_order_links(child)
451-
if not isinstance(self.hugr[successor].op, Output)
452-
]
459+
meta += [
460+
model.Apply(
461+
"core.order_hint.order",
462+
[model.Literal(child.idx), model.Literal(successor.idx)],
463+
)
464+
for successor in self.hugr.outgoing_order_links(child)
465+
]
453466

454467
signature = model.Apply("core.fn", [source_types, target_types])
455468

@@ -639,19 +652,12 @@ def union(self, a: T, b: T):
639652
self.sizes[a] += self.sizes[b]
640653

641654

642-
def _needs_order_key(hugr: Hugr, node: Node) -> bool:
643-
"""Checks whether the node has any order links for the purposes of
644-
exporting order hint metadata. Order links to `Input` or `Output`
645-
operations are ignored, since they are not present in the model format.
646-
"""
647-
for succ in hugr.outgoing_order_links(node):
648-
succ_op = hugr[succ].op
649-
if not isinstance(succ_op, Output):
650-
return True
651-
652-
for pred in hugr.incoming_order_links(node):
653-
pred_op = hugr[pred].op
654-
if not isinstance(pred_op, Input):
655-
return True
655+
def _has_order_links(hugr: Hugr, node: Node) -> bool:
656+
"""Checks whether the node has any order links."""
657+
for _succ in hugr.outgoing_order_links(node):
658+
return True
659+
660+
for _pred in hugr.incoming_order_links(node):
661+
return True
656662

657663
return False

0 commit comments

Comments
 (0)