Skip to content

Commit 3b92fdc

Browse files
committed
address comments
1 parent f20f10d commit 3b92fdc

File tree

5 files changed

+72
-66
lines changed

5 files changed

+72
-66
lines changed

hugr-core/src/hugr/patch/simple_replace.rs

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -127,15 +127,15 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
127127
///
128128
/// The returned port will be in `replacement`, unless the wire in the
129129
/// replacement is empty and `return_invalid` is
130-
/// [`IncludeReplacementNodes::Valid`] (the default), in which case it
131-
/// will be another `host` port. If [`IncludeReplacementNodes::All`] is
130+
/// [`BoundaryMode::SnapToHost`] (the default), in which case it
131+
/// will be another `host` port. If [`BoundaryMode::IncludeIO`] is
132132
/// passed, the returned port will always be in `replacement` even if it
133133
/// is invalid (i.e. it is an IO node in the replacement).
134134
pub fn linked_replacement_output(
135135
&self,
136136
port: impl Into<HostPort<HostNode, IncomingPort>>,
137137
host: &impl HugrView<Node = HostNode>,
138-
return_invalid: IncludeReplacementNodes,
138+
return_invalid: BoundaryMode,
139139
) -> Option<BoundaryPort<HostNode, OutgoingPort>> {
140140
let HostPort(node, port) = port.into();
141141
let pos = self
@@ -155,7 +155,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
155155
&self,
156156
pos: usize,
157157
host: &impl HugrView<Node = HostNode>,
158-
return_invalid: IncludeReplacementNodes,
158+
return_invalid: BoundaryMode,
159159
) -> BoundaryPort<HostNode, OutgoingPort> {
160160
debug_assert!(pos < self.subgraph().signature(host).output_count());
161161

@@ -166,7 +166,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
166166
.single_linked_output(repl_out, pos)
167167
.expect("valid dfg wire");
168168

169-
if out_node != repl_inp || return_invalid == IncludeReplacementNodes::All {
169+
if out_node != repl_inp || return_invalid == BoundaryMode::IncludeIO {
170170
BoundaryPort::Replacement(out_node, out_port)
171171
} else {
172172
let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
@@ -214,16 +214,16 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
214214
///
215215
/// The returned ports will be in `replacement`, unless the wires in the
216216
/// replacement are empty and `return_invalid` is
217-
/// [`IncludeReplacementNodes::Valid`] (the default), in which case they
218-
/// will be other `host` ports. If [`IncludeReplacementNodes::All`] is
217+
/// [`BoundaryMode::SnapToHost`] (the default), in which case they
218+
/// will be other `host` ports. If [`BoundaryMode::IncludeIO`] is
219219
/// passed, the returned ports will always be in
220220
/// `replacement` even if they are invalid (i.e. they are an IO node in
221221
/// the replacement).
222222
pub fn linked_replacement_inputs<'a>(
223223
&'a self,
224224
port: impl Into<HostPort<HostNode, OutgoingPort>>,
225225
host: &'a impl HugrView<Node = HostNode>,
226-
return_invalid: IncludeReplacementNodes,
226+
return_invalid: BoundaryMode,
227227
) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> + 'a {
228228
let HostPort(node, port) = port.into();
229229
let positions = self
@@ -245,15 +245,15 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
245245
&self,
246246
pos: usize,
247247
host: &impl HugrView<Node = HostNode>,
248-
return_invalid: IncludeReplacementNodes,
248+
return_invalid: BoundaryMode,
249249
) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> {
250250
debug_assert!(pos < self.subgraph().signature(host).input_count());
251251

252252
let [repl_inp, repl_out] = self.get_replacement_io();
253253
self.replacement
254254
.linked_inputs(repl_inp, pos)
255255
.flat_map(move |(in_node, in_port)| {
256-
if in_node != repl_out || return_invalid == IncludeReplacementNodes::All {
256+
if in_node != repl_out || return_invalid == BoundaryMode::IncludeIO {
257257
Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port)))
258258
} else {
259259
let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()];
@@ -327,12 +327,8 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
327327
subgraph_outgoing_ports
328328
.enumerate()
329329
.flat_map(|(pos, subg_np)| {
330-
self.linked_replacement_inputs_by_position(
331-
pos,
332-
host,
333-
IncludeReplacementNodes::Valid,
334-
)
335-
.filter_map(move |np| Some((np.as_replacement()?, subg_np)))
330+
self.linked_replacement_inputs_by_position(pos, host, BoundaryMode::SnapToHost)
331+
.filter_map(move |np| Some((np.as_replacement()?, subg_np)))
336332
})
337333
.map(|((repl_node, repl_port), (subgraph_node, subgraph_port))| {
338334
(
@@ -374,11 +370,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
374370
.enumerate()
375371
.filter_map(|(pos, subg_all)| {
376372
let np = self
377-
.linked_replacement_output_by_position(
378-
pos,
379-
host,
380-
IncludeReplacementNodes::Valid,
381-
)
373+
.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
382374
.as_replacement()?;
383375
Some((np, subg_all))
384376
})
@@ -425,12 +417,8 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
425417
.enumerate()
426418
.filter_map(|(pos, subg_all)| {
427419
Some((
428-
self.linked_replacement_output_by_position(
429-
pos,
430-
host,
431-
IncludeReplacementNodes::Valid,
432-
)
433-
.as_host()?,
420+
self.linked_replacement_output_by_position(pos, host, BoundaryMode::SnapToHost)
421+
.as_host()?,
434422
subg_all,
435423
))
436424
})
@@ -567,15 +555,19 @@ impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
567555
/// In [`SimpleReplacement`], some nodes in the replacement may not be valid
568556
/// after the replacement is applied.
569557
///
570-
/// This enum allows to filter out such nodes.
558+
/// This enum allows specifying whether these invalid nodes on the boundary
559+
/// should be returned or should be resolved to valid nodes in the host.
571560
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
572-
#[non_exhaustive]
573-
pub enum IncludeReplacementNodes {
561+
pub enum BoundaryMode {
574562
/// Only consider nodes that are valid after the replacement is applied.
563+
///
564+
/// This means that nodes in hosts may be returned in places where nodes in
565+
/// the replacement would be typically expected.
575566
#[default]
576-
Valid,
577-
/// Include all nodes, including potentially invalid ones.
578-
All,
567+
SnapToHost,
568+
/// Include all nodes, including potentially invalid ones (inputs and
569+
/// outputs of replacements).
570+
IncludeIO,
579571
}
580572

581573
/// Result of applying a [`SimpleReplacement`].
@@ -690,7 +682,7 @@ pub(in crate::hugr::patch) mod test {
690682
ModuleBuilder, endo_sig, inout_sig,
691683
};
692684
use crate::extension::prelude::{bool_t, qb_t};
693-
use crate::hugr::patch::simple_replace::{IncludeReplacementNodes, Outcome};
685+
use crate::hugr::patch::simple_replace::{BoundaryMode, Outcome};
694686
use crate::hugr::patch::{BoundaryPort, HostPort, PatchVerification, ReplacementPort};
695687
use crate::hugr::views::{HugrView, SiblingSubgraph};
696688
use crate::hugr::{Hugr, HugrMut, Patch};
@@ -1186,7 +1178,7 @@ pub(in crate::hugr::patch) mod test {
11861178
.linked_replacement_inputs(
11871179
(inp, OutgoingPort::from(0)),
11881180
&hugr,
1189-
IncludeReplacementNodes::Valid,
1181+
BoundaryMode::SnapToHost,
11901182
)
11911183
.collect();
11921184

@@ -1203,7 +1195,7 @@ pub(in crate::hugr::patch) mod test {
12031195
repl.linked_replacement_output(
12041196
(out, IncomingPort::from(i)),
12051197
&hugr,
1206-
IncludeReplacementNodes::Valid,
1198+
BoundaryMode::SnapToHost,
12071199
)
12081200
.unwrap()
12091201
})
@@ -1240,7 +1232,7 @@ pub(in crate::hugr::patch) mod test {
12401232
.linked_replacement_inputs(
12411233
(inp, OutgoingPort::from(0)),
12421234
&hugr,
1243-
IncludeReplacementNodes::Valid,
1235+
BoundaryMode::SnapToHost,
12441236
)
12451237
.collect();
12461238

@@ -1256,7 +1248,7 @@ pub(in crate::hugr::patch) mod test {
12561248
repl.linked_replacement_output(
12571249
(out, IncomingPort::from(i)),
12581250
&hugr,
1259-
IncludeReplacementNodes::Valid,
1251+
BoundaryMode::SnapToHost,
12601252
)
12611253
.unwrap()
12621254
})
@@ -1298,7 +1290,7 @@ pub(in crate::hugr::patch) mod test {
12981290
.linked_replacement_inputs(
12991291
(inp, OutgoingPort::from(0)),
13001292
&hugr,
1301-
IncludeReplacementNodes::Valid,
1293+
BoundaryMode::SnapToHost,
13021294
)
13031295
.collect();
13041296

@@ -1318,7 +1310,7 @@ pub(in crate::hugr::patch) mod test {
13181310
repl.linked_replacement_output(
13191311
(out, IncomingPort::from(i)),
13201312
&hugr,
1321-
IncludeReplacementNodes::Valid,
1313+
BoundaryMode::SnapToHost,
13221314
)
13231315
.unwrap()
13241316
})

hugr-persistent/src/persistent_hugr.rs

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,14 +618,13 @@ fn get_parent_commits<R>(
618618
replacement: &PersistentReplacement,
619619
graph: &CommitStateSpace<R>,
620620
) -> Result<Vec<Commit>, InvalidCommit> {
621-
let parent_ids = replacement.invalidation_set().map(|n| n.0).unique();
621+
let parent_ids = replacement.invalidation_set().map(|n| n.owner()).unique();
622622
parent_ids
623623
.map(|id| {
624-
if graph.contains_id(id) {
625-
Ok(graph.get_commit(id).clone())
626-
} else {
627-
Err(InvalidCommit::UnknownParent(id))
628-
}
624+
graph
625+
.try_get_commit(id)
626+
.cloned()
627+
.ok_or(InvalidCommit::UnknownParent(id))
629628
})
630629
.collect()
631630
}

hugr-persistent/src/state_space.rs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use hugr_core::{
1111
internal::HugrInternals,
1212
patch::{
1313
BoundaryPort,
14-
simple_replace::{IncludeReplacementNodes, InvalidReplacement},
14+
simple_replace::{BoundaryMode, InvalidReplacement},
1515
},
1616
views::{InvalidSignature, sibling_subgraph::InvalidSubgraph},
1717
},
@@ -253,6 +253,12 @@ impl<R> CommitStateSpace<R> {
253253
self.graph.get_node(commit_id).into()
254254
}
255255

256+
/// Check whether `commit_id` exists and return it.
257+
pub fn try_get_commit(&self, commit_id: CommitId) -> Option<&Commit> {
258+
self.contains_id(commit_id)
259+
.then(|| self.graph.get_node(commit_id).into())
260+
}
261+
256262
/// Get an iterator over all commit IDs in the state space.
257263
pub fn all_commit_ids(&self) -> impl Iterator<Item = CommitId> + Clone + '_ {
258264
let vec = self.graph.all_node_ids().collect_vec();
@@ -330,8 +336,10 @@ impl<R> CommitStateSpace<R> {
330336

331337
/// Get the boundary inputs linked to `(node, port)` in `child`.
332338
///
333-
/// The returned ports will be in the `child` commit unless the child commit
334-
/// is empty, in which case they will be in one of the parents of `child`.
339+
/// The returned ports will be ports on successors of the input node in the
340+
/// `child` commit, unless (node, port) is connected to an empty wire in
341+
/// `child` (i.e. a wire from input node to output node), in which case
342+
/// they will be in one of the parents of `child`.
335343
///
336344
/// `child` should be a child commit of the owner of `node`.
337345
///
@@ -344,7 +352,7 @@ impl<R> CommitStateSpace<R> {
344352
node: PatchNode,
345353
port: OutgoingPort,
346354
child: CommitId,
347-
return_invalid: IncludeReplacementNodes,
355+
return_invalid: BoundaryMode,
348356
) -> impl Iterator<Item = (PatchNode, IncomingPort)> + '_ {
349357
assert!(
350358
self.is_boundary_edge(node, port, child),
@@ -364,8 +372,10 @@ impl<R> CommitStateSpace<R> {
364372

365373
/// Get the single boundary output linked to `(node, port)` in `child`.
366374
///
367-
/// The returned port will be in the `child` commit unless the child commit
368-
/// is empty, in which case it will be in one of the parents of `child`.
375+
/// The returned port will be ports on predecessors of the output node in
376+
/// the `child` commit, unless (node, port) is connected to an empty wire
377+
/// in `child` (i.e. a wire from input node to output node), in which
378+
/// case it will be in one of the parents of `child`.
369379
///
370380
/// `child` should be a child commit of the owner of `node` (or `None` will
371381
/// be returned).
@@ -378,7 +388,7 @@ impl<R> CommitStateSpace<R> {
378388
node: PatchNode,
379389
port: IncomingPort,
380390
child: CommitId,
381-
return_invalid: IncludeReplacementNodes,
391+
return_invalid: BoundaryMode,
382392
) -> Option<(PatchNode, OutgoingPort)> {
383393
let parent_hugrs = ParentsView::from_commit(child, self);
384394
let repl = self.replacement(child)?;
@@ -400,7 +410,7 @@ impl<R> CommitStateSpace<R> {
400410
node: PatchNode,
401411
port: impl Into<Port>,
402412
child: CommitId,
403-
return_invalid: IncludeReplacementNodes,
413+
return_invalid: BoundaryMode,
404414
) -> impl Iterator<Item = (PatchNode, Port)> + '_ {
405415
match port.into().as_directed() {
406416
Either::Left(incoming) => Either::Left(

hugr-persistent/src/walker.rs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ use thiserror::Error;
6666
use hugr_core::{
6767
Direction, Hugr, HugrView, Port, PortIndex,
6868
hugr::{
69-
patch::simple_replace::IncludeReplacementNodes,
69+
patch::simple_replace::BoundaryMode,
7070
views::{RootCheckable, SiblingSubgraph},
7171
},
7272
ops::handle::DfgID,
@@ -303,11 +303,12 @@ impl<'a, R: Resolver> Walker<'a, R> {
303303
wire_ports_incoming.extend(w.all_incoming_ports(self.as_hugr_view()));
304304
wire_ports_outgoing.extend(w.single_outgoing_port(self.as_hugr_view()));
305305
for id in w.owners() {
306-
if self.state_space.contains_id(id) {
307-
additional_parents.insert(id, self.state_space.get_commit(id).clone());
308-
} else {
309-
return Err(InvalidCommit::UnknownParent(id));
310-
}
306+
let commit = self
307+
.state_space
308+
.try_get_commit(id)
309+
.ok_or(InvalidCommit::UnknownParent(id))?
310+
.clone();
311+
additional_parents.insert(id, commit);
311312
}
312313
}
313314

@@ -421,7 +422,7 @@ impl<R: Clone> Walker<'_, R> {
421422
opp_node,
422423
opp_port,
423424
child_id,
424-
IncludeReplacementNodes::Valid,
425+
BoundaryMode::SnapToHost,
425426
) {
426427
let mut empty_commits = empty_commits.clone();
427428
if node.0 != child_id {
@@ -836,7 +837,5 @@ mod tests {
836837
BTreeSet::from_iter([base_commit, empty_commit])
837838
])
838839
);
839-
840-
panic!("explicit")
841840
}
842841
}

hugr-persistent/src/wire.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ use std::collections::{BTreeSet, VecDeque};
22

33
use hugr_core::{
44
Direction, HugrView, IncomingPort, OutgoingPort, Port, Wire,
5-
hugr::patch::simple_replace::IncludeReplacementNodes,
5+
hugr::patch::simple_replace::BoundaryMode,
66
};
7-
use itertools::{Either, Itertools};
7+
use itertools::Itertools;
88

99
use crate::{CommitId, PatchNode, PersistentHugr, Resolver, Walker};
1010

@@ -52,6 +52,12 @@ impl CommitWire {
5252
fn commit_id(&self) -> CommitId {
5353
self.0.node().0
5454
}
55+
56+
delegate::delegate! {
57+
to self.0 {
58+
fn node(&self) -> PatchNode;
59+
}
60+
}
5561
}
5662

5763
/// A node in a commit of a [`PersistentHugr`] is either a valid node of the
@@ -120,7 +126,7 @@ impl PersistentWire {
120126
opp_node,
121127
opp_port,
122128
deleted_by,
123-
IncludeReplacementNodes::All,
129+
BoundaryMode::IncludeIO,
124130
)
125131
{
126132
debug_assert_eq!(child_node.owner(), deleted_by);

0 commit comments

Comments
 (0)