Skip to content

Commit f20f10d

Browse files
committed
Add IncludeReplacementNodes enum
1 parent c37fe20 commit f20f10d

File tree

7 files changed

+730
-200
lines changed

7 files changed

+730
-200
lines changed

hugr-core/src/hugr/patch/simple_replace.rs

Lines changed: 88 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,16 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
126126
/// of `self`.
127127
///
128128
/// The returned port will be in `replacement`, unless the wire in the
129-
/// replacement is empty, in which case it will another `host` port.
129+
/// replacement is empty and `return_invalid` is
130+
/// [`IncludeReplacementNodes::Valid`] (the default), in which case it
131+
/// will be another `host` port. If [`IncludeReplacementNodes::All`] is
132+
/// passed, the returned port will always be in `replacement` even if it
133+
/// is invalid (i.e. it is an IO node in the replacement).
130134
pub fn linked_replacement_output(
131135
&self,
132136
port: impl Into<HostPort<HostNode, IncomingPort>>,
133137
host: &impl HugrView<Node = HostNode>,
138+
return_invalid: IncludeReplacementNodes,
134139
) -> Option<BoundaryPort<HostNode, OutgoingPort>> {
135140
let HostPort(node, port) = port.into();
136141
let pos = self
@@ -139,7 +144,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
139144
.iter()
140145
.position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?;
141146

142-
Some(self.linked_replacement_output_by_position(pos, host))
147+
Some(self.linked_replacement_output_by_position(pos, host, return_invalid))
143148
}
144149

145150
/// The outgoing port linked to the i-th output boundary edge of `subgraph`.
@@ -150,6 +155,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
150155
&self,
151156
pos: usize,
152157
host: &impl HugrView<Node = HostNode>,
158+
return_invalid: IncludeReplacementNodes,
153159
) -> BoundaryPort<HostNode, OutgoingPort> {
154160
debug_assert!(pos < self.subgraph().signature(host).output_count());
155161

@@ -160,7 +166,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
160166
.single_linked_output(repl_out, pos)
161167
.expect("valid dfg wire");
162168

163-
if out_node != repl_inp {
169+
if out_node != repl_inp || return_invalid == IncludeReplacementNodes::All {
164170
BoundaryPort::Replacement(out_node, out_port)
165171
} else {
166172
let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
@@ -207,11 +213,17 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
207213
/// of `self`.
208214
///
209215
/// The returned ports will be in `replacement`, unless the wires in the
210-
/// replacement are empty, in which case they are other `host` ports.
216+
/// replacement are empty and `return_invalid` is
217+
/// [`IncludeReplacementNodes::Valid`] (the default), in which case they
218+
/// will be other `host` ports. If [`IncludeReplacementNodes::All`] is
219+
/// passed, the returned ports will always be in
220+
/// `replacement` even if they are invalid (i.e. they are an IO node in
221+
/// the replacement).
211222
pub fn linked_replacement_inputs<'a>(
212223
&'a self,
213224
port: impl Into<HostPort<HostNode, OutgoingPort>>,
214225
host: &'a impl HugrView<Node = HostNode>,
226+
return_invalid: IncludeReplacementNodes,
215227
) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> + 'a {
216228
let HostPort(node, port) = port.into();
217229
let positions = self
@@ -223,26 +235,25 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
223235
host.single_linked_output(n, p).expect("valid dfg wire") == (node, port)
224236
});
225237

226-
positions.flat_map(|pos| self.linked_replacement_inputs_by_position(pos, host))
238+
positions.flat_map(move |pos| {
239+
self.linked_replacement_inputs_by_position(pos, host, return_invalid)
240+
})
227241
}
228242

229243
/// The incoming ports linked to the i-th input boundary edge of `subgraph`.
230-
///
231-
/// The ports will be in `replacement` for all endpoints of the i-th input
232-
/// wire that are not the output node of `replacement` and be in `host`
233-
/// otherwise.
234244
fn linked_replacement_inputs_by_position(
235245
&self,
236246
pos: usize,
237247
host: &impl HugrView<Node = HostNode>,
248+
return_invalid: IncludeReplacementNodes,
238249
) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> {
239250
debug_assert!(pos < self.subgraph().signature(host).input_count());
240251

241252
let [repl_inp, repl_out] = self.get_replacement_io();
242253
self.replacement
243254
.linked_inputs(repl_inp, pos)
244255
.flat_map(move |(in_node, in_port)| {
245-
if in_node != repl_out {
256+
if in_node != repl_out || return_invalid == IncludeReplacementNodes::All {
246257
Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port)))
247258
} else {
248259
let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()];
@@ -316,8 +327,12 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
316327
subgraph_outgoing_ports
317328
.enumerate()
318329
.flat_map(|(pos, subg_np)| {
319-
self.linked_replacement_inputs_by_position(pos, host)
320-
.filter_map(move |np| Some((np.as_replacement()?, subg_np)))
330+
self.linked_replacement_inputs_by_position(
331+
pos,
332+
host,
333+
IncludeReplacementNodes::Valid,
334+
)
335+
.filter_map(move |np| Some((np.as_replacement()?, subg_np)))
321336
})
322337
.map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
323338
(
@@ -359,7 +374,11 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
359374
.enumerate()
360375
.filter_map(|(pos, subg_all)| {
361376
let np = self
362-
.linked_replacement_output_by_position(pos, host)
377+
.linked_replacement_output_by_position(
378+
pos,
379+
host,
380+
IncludeReplacementNodes::Valid,
381+
)
363382
.as_replacement()?;
364383
Some((np, subg_all))
365384
})
@@ -406,8 +425,12 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
406425
.enumerate()
407426
.filter_map(|(pos, subg_all)| {
408427
Some((
409-
self.linked_replacement_output_by_position(pos, host)
410-
.as_host()?,
428+
self.linked_replacement_output_by_position(
429+
pos,
430+
host,
431+
IncludeReplacementNodes::Valid,
432+
)
433+
.as_host()?,
411434
subg_all,
412435
))
413436
})
@@ -517,7 +540,8 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
517540
SimpleReplacement::try_new(subgraph, new_host, replacement.clone())
518541
}
519542

520-
/// Allows to get the [Self::invalidated_nodes] without requiring a [HugrView].
543+
/// Allows to get the [Self::invalidated_nodes] without requiring a
544+
/// [HugrView].
521545
pub fn invalidation_set(&self) -> impl Iterator<Item = HostNode> {
522546
self.subgraph.nodes().iter().copied()
523547
}
@@ -540,6 +564,20 @@ impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
540564
}
541565
}
542566

567+
/// In [`SimpleReplacement`], some nodes in the replacement may not be valid
568+
/// after the replacement is applied.
569+
///
570+
/// This enum allows to filter out such nodes.
571+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
572+
#[non_exhaustive]
573+
pub enum IncludeReplacementNodes {
574+
/// Only consider nodes that are valid after the replacement is applied.
575+
#[default]
576+
Valid,
577+
/// Include all nodes, including potentially invalid ones.
578+
All,
579+
}
580+
543581
/// Result of applying a [`SimpleReplacement`].
544582
pub struct Outcome<HostNode = Node> {
545583
/// Map from Node in replacement to corresponding Node in the result Hugr
@@ -652,7 +690,7 @@ pub(in crate::hugr::patch) mod test {
652690
ModuleBuilder, endo_sig, inout_sig,
653691
};
654692
use crate::extension::prelude::{bool_t, qb_t};
655-
use crate::hugr::patch::simple_replace::Outcome;
693+
use crate::hugr::patch::simple_replace::{IncludeReplacementNodes, Outcome};
656694
use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort};
657695
use crate::hugr::views::{HugrView, SiblingSubgraph};
658696
use crate::hugr::{Hugr, HugrMut, Patch};
@@ -1145,7 +1183,11 @@ pub(in crate::hugr::patch) mod test {
11451183

11461184
// Test linked_replacement_inputs with empty replacement
11471185
let replacement_inputs: Vec<_> = repl
1148-
.linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr)
1186+
.linked_replacement_inputs(
1187+
(inp, OutgoingPort::from(0)),
1188+
&hugr,
1189+
IncludeReplacementNodes::Valid,
1190+
)
11491191
.collect();
11501192

11511193
assert_eq!(
@@ -1158,8 +1200,12 @@ pub(in crate::hugr::patch) mod test {
11581200
// Test linked_replacement_output with empty replacement
11591201
let replacement_output = (0..4)
11601202
.map(|i| {
1161-
repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr)
1162-
.unwrap()
1203+
repl.linked_replacement_output(
1204+
(out, IncomingPort::from(i)),
1205+
&hugr,
1206+
IncludeReplacementNodes::Valid,
1207+
)
1208+
.unwrap()
11631209
})
11641210
.collect_vec();
11651211

@@ -1191,7 +1237,11 @@ pub(in crate::hugr::patch) mod test {
11911237
};
11921238

11931239
let replacement_inputs: Vec<_> = repl
1194-
.linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr)
1240+
.linked_replacement_inputs(
1241+
(inp, OutgoingPort::from(0)),
1242+
&hugr,
1243+
IncludeReplacementNodes::Valid,
1244+
)
11951245
.collect();
11961246

11971247
assert_eq!(
@@ -1203,8 +1253,12 @@ pub(in crate::hugr::patch) mod test {
12031253

12041254
let replacement_output = (0..4)
12051255
.map(|i| {
1206-
repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr)
1207-
.unwrap()
1256+
repl.linked_replacement_output(
1257+
(out, IncomingPort::from(i)),
1258+
&hugr,
1259+
IncludeReplacementNodes::Valid,
1260+
)
1261+
.unwrap()
12081262
})
12091263
.collect_vec();
12101264

@@ -1241,7 +1295,11 @@ pub(in crate::hugr::patch) mod test {
12411295
};
12421296

12431297
let replacement_inputs: Vec<_> = repl
1244-
.linked_replacement_inputs((inp, OutgoingPort::from(0)), &hugr)
1298+
.linked_replacement_inputs(
1299+
(inp, OutgoingPort::from(0)),
1300+
&hugr,
1301+
IncludeReplacementNodes::Valid,
1302+
)
12451303
.collect();
12461304

12471305
assert_eq!(
@@ -1257,8 +1315,12 @@ pub(in crate::hugr::patch) mod test {
12571315

12581316
let replacement_output = (0..4)
12591317
.map(|i| {
1260-
repl.linked_replacement_output((out, IncomingPort::from(i)), &hugr)
1261-
.unwrap()
1318+
repl.linked_replacement_output(
1319+
(out, IncomingPort::from(i)),
1320+
&hugr,
1321+
IncludeReplacementNodes::Valid,
1322+
)
1323+
.unwrap()
12621324
})
12631325
.collect_vec();
12641326

hugr-persistent/src/persistent_hugr.rs

Lines changed: 74 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ impl Commit {
3838
/// Requires a reference to the commit state space that the nodes in
3939
/// `replacement` refer to.
4040
///
41+
/// Use [`Self::try_new`] instead if the parents of the commit cannot be
42+
/// inferred from the invalidation set of `replacement` alone.
43+
///
4144
/// The replacement must act on a non-empty subgraph, otherwise this
4245
/// function will return an [`InvalidCommit::EmptyReplacement`] error.
4346
///
@@ -47,20 +50,37 @@ impl Commit {
4750
pub fn try_from_replacement<R>(
4851
replacement: PersistentReplacement,
4952
graph: &CommitStateSpace<R>,
53+
) -> Result<Commit, InvalidCommit> {
54+
Self::try_new(replacement, [], graph)
55+
}
56+
57+
/// Create a new commit
58+
///
59+
/// Requires a reference to the commit state space that the nodes in
60+
/// `replacement` refer to.
61+
///
62+
/// The returned commit will correspond to the application of `replacement`
63+
/// and will be the child of the commits in `parents` as well as of all
64+
/// the commits in the invalidation set of `replacement`.
65+
///
66+
/// The replacement must act on a non-empty subgraph, otherwise this
67+
/// function will return an [`InvalidCommit::EmptyReplacement`] error.
68+
/// If any of the parents of the replacement are not in the commit state
69+
/// space, this function will return an [`InvalidCommit::UnknownParent`]
70+
/// error.
71+
pub fn try_new<R>(
72+
replacement: PersistentReplacement,
73+
parents: impl IntoIterator<Item = Commit>,
74+
graph: &CommitStateSpace<R>,
5075
) -> Result<Commit, InvalidCommit> {
5176
if replacement.subgraph().nodes().is_empty() {
5277
return Err(InvalidCommit::EmptyReplacement);
5378
}
54-
let parent_ids = replacement.invalidation_set().map(|n| n.0).unique();
55-
let parents = parent_ids
56-
.map(|id| {
57-
if graph.contains_id(id) {
58-
Ok(graph.get_commit(id).clone())
59-
} else {
60-
Err(InvalidCommit::UnknownParent(id))
61-
}
62-
})
63-
.collect::<Result<Vec<_>, _>>()?;
79+
let repl_parents = get_parent_commits(&replacement, graph)?;
80+
let parents = parents
81+
.into_iter()
82+
.chain(repl_parents)
83+
.unique_by(|p| p.as_ptr());
6484
let rc = RelRc::with_parents(
6585
replacement.into(),
6686
parents.into_iter().map(|p| (p.into(), ())),
@@ -529,6 +549,32 @@ impl<R> PersistentHugr<R> {
529549
.expect("invalid port")
530550
.is_value()
531551
}
552+
553+
pub(super) fn value_ports(
554+
&self,
555+
patch_node @ PatchNode(commit_id, node): PatchNode,
556+
dir: Direction,
557+
) -> impl Iterator<Item = (PatchNode, Port)> + '_ {
558+
let hugr = self.commit_hugr(commit_id);
559+
let ports = hugr.node_ports(node, dir);
560+
ports.filter_map(move |p| self.is_value_port(patch_node, p).then_some((patch_node, p)))
561+
}
562+
563+
pub(super) fn output_value_ports(
564+
&self,
565+
patch_node: PatchNode,
566+
) -> impl Iterator<Item = (PatchNode, OutgoingPort)> + '_ {
567+
self.value_ports(patch_node, Direction::Outgoing)
568+
.map(|(n, p)| (n, p.as_outgoing().expect("unexpected port direction")))
569+
}
570+
571+
pub(super) fn input_value_ports(
572+
&self,
573+
patch_node: PatchNode,
574+
) -> impl Iterator<Item = (PatchNode, IncomingPort)> + '_ {
575+
self.value_ports(patch_node, Direction::Incoming)
576+
.map(|(n, p)| (n, p.as_incoming().expect("unexpected port direction")))
577+
}
532578
}
533579

534580
impl<R> IntoIterator for PersistentHugr<R> {
@@ -549,11 +595,11 @@ impl<R> IntoIterator for PersistentHugr<R> {
549595
/// among `children`.
550596
pub(crate) fn find_conflicting_node<'a>(
551597
commit_id: CommitId,
552-
mut children: impl Iterator<Item = &'a Commit>,
598+
children: impl IntoIterator<Item = &'a Commit>,
553599
) -> Option<Node> {
554600
let mut all_invalidated = BTreeSet::new();
555601

556-
children.find_map(|child| {
602+
children.into_iter().find_map(|child| {
557603
let mut new_invalidated =
558604
child
559605
.invalidation_set()
@@ -567,3 +613,19 @@ pub(crate) fn find_conflicting_node<'a>(
567613
new_invalidated.find(|&n| !all_invalidated.insert(n))
568614
})
569615
}
616+
617+
fn get_parent_commits<R>(
618+
replacement: &PersistentReplacement,
619+
graph: &CommitStateSpace<R>,
620+
) -> Result<Vec<Commit>, InvalidCommit> {
621+
let parent_ids = replacement.invalidation_set().map(|n| n.0).unique();
622+
parent_ids
623+
.map(|id| {
624+
if graph.contains_id(id) {
625+
Ok(graph.get_commit(id).clone())
626+
} else {
627+
Err(InvalidCommit::UnknownParent(id))
628+
}
629+
})
630+
.collect()
631+
}

0 commit comments

Comments
 (0)