Skip to content

Commit e13ab90

Browse files
committed
address comments
1 parent c20176b commit e13ab90

File tree

5 files changed

+70
-69
lines changed

5 files changed

+70
-69
lines changed

hugr-core/src/hugr/patch/simple_replace.rs

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -126,16 +126,16 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
126126
/// of `self`.
127127
///
128128
/// The returned port will be in `replacement`, unless the wire in the
129-
/// replacement is empty and `return_invalid` is
130-
/// [`BoundaryMode::SnapToHost`] (the default), in which case it
131-
/// will be another `host` port. If [`BoundaryMode::IncludeIO`] is
132-
/// passed, the returned port will always be in `replacement` even if it
133-
/// is invalid (i.e. it is an IO node in the replacement).
129+
/// replacement is empty and `boundary` is [`BoundaryMode::SnapToHost`] (the
130+
/// default), in which case it will be another `host` port. If
131+
/// [`BoundaryMode::IncludeIO`] is passed, the returned port will always
132+
/// be in `replacement` even if it is invalid (i.e. it is an IO node in
133+
/// the replacement).
134134
pub fn linked_replacement_output(
135135
&self,
136136
port: impl Into<HostPort<HostNode, IncomingPort>>,
137137
host: &impl HugrView<Node = HostNode>,
138-
return_invalid: BoundaryMode,
138+
boundary: BoundaryMode,
139139
) -> Option<BoundaryPort<HostNode, OutgoingPort>> {
140140
let HostPort(node, port) = port.into();
141141
let pos = self
@@ -144,7 +144,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
144144
.iter()
145145
.position(move |&(n, p)| host.linked_inputs(n, p).contains(&(node, port)))?;
146146

147-
Some(self.linked_replacement_output_by_position(pos, host, return_invalid))
147+
Some(self.linked_replacement_output_by_position(pos, host, boundary))
148148
}
149149

150150
/// The outgoing port linked to the i-th output boundary edge of `subgraph`.
@@ -155,7 +155,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
155155
&self,
156156
pos: usize,
157157
host: &impl HugrView<Node = HostNode>,
158-
return_invalid: BoundaryMode,
158+
boundary: BoundaryMode,
159159
) -> BoundaryPort<HostNode, OutgoingPort> {
160160
debug_assert!(pos < self.subgraph().signature(host).output_count());
161161

@@ -166,7 +166,7 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
166166
.single_linked_output(repl_out, pos)
167167
.expect("valid dfg wire");
168168

169-
if out_node != repl_inp || return_invalid == BoundaryMode::IncludeIO {
169+
if out_node != repl_inp || boundary == BoundaryMode::IncludeIO {
170170
BoundaryPort::Replacement(out_node, out_port)
171171
} else {
172172
let (in_node, in_port) = *self.subgraph.incoming_ports()[out_port.index()]
@@ -213,17 +213,16 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
213213
/// of `self`.
214214
///
215215
/// The returned ports will be in `replacement`, unless the wires in the
216-
/// replacement are empty and `return_invalid` is
217-
/// [`BoundaryMode::SnapToHost`] (the default), in which case they
218-
/// will be other `host` ports. If [`BoundaryMode::IncludeIO`] is
219-
/// passed, the returned ports will always be in
220-
/// `replacement` even if they are invalid (i.e. they are an IO node in
221-
/// the replacement).
216+
/// replacement are empty and `boundary` is [`BoundaryMode::SnapToHost`]
217+
/// (the default), in which case they will be other `host` ports. If
218+
/// [`BoundaryMode::IncludeIO`] is passed, the returned ports will
219+
/// always be in `replacement` even if they are invalid (i.e. they are
220+
/// an IO node in the replacement).
222221
pub fn linked_replacement_inputs<'a>(
223222
&'a self,
224223
port: impl Into<HostPort<HostNode, OutgoingPort>>,
225224
host: &'a impl HugrView<Node = HostNode>,
226-
return_invalid: BoundaryMode,
225+
boundary: BoundaryMode,
227226
) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> + 'a {
228227
let HostPort(node, port) = port.into();
229228
let positions = self
@@ -235,25 +234,24 @@ impl<HostNode: HugrNode> SimpleReplacement<HostNode> {
235234
host.single_linked_output(n, p).expect("valid dfg wire") == (node, port)
236235
});
237236

238-
positions.flat_map(move |pos| {
239-
self.linked_replacement_inputs_by_position(pos, host, return_invalid)
240-
})
237+
positions
238+
.flat_map(move |pos| self.linked_replacement_inputs_by_position(pos, host, boundary))
241239
}
242240

243241
/// The incoming ports linked to the i-th input boundary edge of `subgraph`.
244242
fn linked_replacement_inputs_by_position(
245243
&self,
246244
pos: usize,
247245
host: &impl HugrView<Node = HostNode>,
248-
return_invalid: BoundaryMode,
246+
boundary: BoundaryMode,
249247
) -> impl Iterator<Item = BoundaryPort<HostNode, IncomingPort>> {
250248
debug_assert!(pos < self.subgraph().signature(host).input_count());
251249

252250
let [repl_inp, repl_out] = self.get_replacement_io();
253251
self.replacement
254252
.linked_inputs(repl_inp, pos)
255253
.flat_map(move |(in_node, in_port)| {
256-
if in_node != repl_out || return_invalid == BoundaryMode::IncludeIO {
254+
if in_node != repl_out || boundary == BoundaryMode::IncludeIO {
257255
Either::Left(std::iter::once(BoundaryPort::Replacement(in_node, in_port)))
258256
} else {
259257
let (out_node, out_port) = self.subgraph.outgoing_ports()[in_port.index()];
@@ -552,8 +550,8 @@ impl<HostNode: HugrNode> PatchVerification for SimpleReplacement<HostNode> {
552550
}
553551
}
554552

555-
/// In [`SimpleReplacement`], some nodes in the replacement may not be valid
556-
/// after the replacement is applied.
553+
/// In [`SimpleReplacement::replacement`], IO nodes marking the boundary will
554+
/// not be valid nodes in the host after the replacement is applied.
557555
///
558556
/// This enum allows specifying whether these invalid nodes on the boundary
559557
/// should be returned or should be resolved to valid nodes in the host.

hugr-persistent/src/state_space.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ impl<R> CommitStateSpace<R> {
256256
/// Check whether `commit_id` exists and return it.
257257
pub fn try_get_commit(&self, commit_id: CommitId) -> Option<&Commit> {
258258
self.contains_id(commit_id)
259-
.then(|| self.graph.get_node(commit_id).into())
259+
.then(|| self.get_commit(commit_id))
260260
}
261261

262262
/// Get an iterator over all commit IDs in the state space.
@@ -337,9 +337,9 @@ impl<R> CommitStateSpace<R> {
337337
/// Get the boundary inputs linked to `(node, port)` in `child`.
338338
///
339339
/// The returned ports will be ports on successors of the input node in the
340-
/// `child` commit, unless (node, port) is connected to an empty wire in
341-
/// `child` (i.e. a wire from input node to output node), in which case
342-
/// they will be in one of the parents of `child`.
340+
/// `child` commit, unless (node, port) is connected to a passthrough wire
341+
/// in `child` (i.e. a wire from input node to output node), in which
342+
/// case they will be in one of the parents of `child`.
343343
///
344344
/// `child` should be a child commit of the owner of `node`.
345345
///
@@ -372,10 +372,10 @@ impl<R> CommitStateSpace<R> {
372372

373373
/// Get the single boundary output linked to `(node, port)` in `child`.
374374
///
375-
/// The returned port will be ports on predecessors of the output node in
376-
/// the `child` commit, unless (node, port) is connected to an empty wire
377-
/// in `child` (i.e. a wire from input node to output node), in which
378-
/// case it will be in one of the parents of `child`.
375+
/// The returned port will be a port on a predecessor of the output node in
376+
/// the `child` commit, unless (node, port) is connected to a passthrough
377+
/// wire in `child` (i.e. a wire from input node to output node), in
378+
/// which case it will be in one of the parents of `child`.
379379
///
380380
/// `child` should be a child commit of the owner of `node` (or `None` will
381381
/// be returned).
@@ -425,8 +425,11 @@ impl<R> CommitStateSpace<R> {
425425
}
426426
}
427427

428-
/// Get the single output boundary port linked to `(node, port)` in a
429-
/// parent of the commit of `node`.
428+
/// Get the single output port linked to `(node, port)` in a parent of the
429+
/// commit of `node`.
430+
///
431+
/// The returned port belongs to the input boundary of the subgraph in
432+
/// parent.
430433
///
431434
/// ## Panics
432435
///
@@ -450,8 +453,11 @@ impl<R> CommitStateSpace<R> {
450453
repl.linked_host_input((node, port), &parent_hugrs).into()
451454
}
452455

453-
/// Get the input boundary ports linked to `(node, port)` in a
454-
/// parent of the commit of `node`.
456+
/// Get the input ports linked to `(node, port)` in a parent of the commit
457+
/// of `node`.
458+
///
459+
/// The returned ports belong to the output boundary of the subgraph in
460+
/// parent.
455461
///
456462
/// ## Panics
457463
///

hugr-persistent/src/tests.rs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,15 +300,14 @@ pub(super) fn persistent_hugr_empty_child() -> (PersistentHugr, [CommitId; 2], [
300300
let (triple_not_hugr, not_nodes) = {
301301
let mut dfg_builder = DFGBuilder::new(endo_sig(bool_t())).unwrap();
302302
let [mut w] = dfg_builder.input_wires_arr();
303-
let mut not_nodes = Vec::with_capacity(3);
304-
for _ in 0..3 {
303+
let not_nodes = [(); 3].map(|()| {
305304
let handle = dfg_builder.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
306305
[w] = handle.outputs_arr();
307-
not_nodes.push(handle.node());
308-
}
306+
handle.node()
307+
});
309308
(
310309
dfg_builder.finish_hugr_with_outputs([w]).unwrap(),
311-
not_nodes.into_iter().collect_array::<3>().unwrap(),
310+
not_nodes,
312311
)
313312
};
314313
let mut hugr = PersistentHugr::with_base(triple_not_hugr);

hugr-persistent/src/walker.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ impl<'a, R: Resolver> Walker<'a, R> {
344344
let mut repl = repl.try_into_checked().expect("replacement is not DFG");
345345
let new_inputs = incoming
346346
.iter()
347-
.flatten()
347+
.flatten() // because of singleton-vec wrapping above
348348
.map(|&(n, p)| {
349349
map_boundary(n, p.into())
350350
.as_outgoing()
@@ -450,7 +450,7 @@ impl<R: Clone> Walker<'_, R> {
450450
impl<R: Resolver> Walker<'_, R> {
451451
// Check walker equality by comparing pointers to the state space and
452452
// other fields. Only for testing purposes.
453-
fn ptr_eq(&self, other: &Self) -> bool {
453+
fn component_wise_ptr_eq(&self, other: &Self) -> bool {
454454
std::ptr::eq(self.state_space.as_ref(), other.state_space.as_ref())
455455
&& self.pinned_nodes == other.pinned_nodes
456456
&& BTreeSet::from_iter(self.selected_commits.all_commit_ids())
@@ -463,7 +463,7 @@ impl<R: Resolver> Walker<'_, R> {
463463
let Some([new_walker]) = self.expand(wire, dir).collect_array() else {
464464
return false;
465465
};
466-
new_walker.ptr_eq(self)
466+
new_walker.component_wise_ptr_eq(self)
467467
}
468468
}
469469

hugr-persistent/tests/persistent_walker_example.rs

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -271,32 +271,31 @@ fn create_commit(wire: PersistentWire, walker: &Walker) -> Option<Commit> {
271271
// - the first qubit is the one that is shared between the two CZ gates
272272
// - the second qubit only touches the first CZ (out_node)
273273
// - the third qubit only touches the second CZ (in_node)
274-
match (port.as_directed(), node == out_node) {
275-
(Either::Left(incoming), true) if incoming.index() == shared_qb_out => {
276-
// out_node on the shared qubit -> port 0
277-
OutgoingPort::from(0).into()
274+
match port.as_directed() {
275+
Either::Left(incoming) => {
276+
let in_boundary: [(_, IncomingPort); 3] = [
277+
(out_node, shared_qb_out.into()),
278+
(out_node, (1 - shared_qb_out).into()),
279+
(in_node, (1 - shared_qb_in).into()),
280+
];
281+
let out_index = in_boundary
282+
.iter()
283+
.position(|&(n, p)| n == node && p == incoming)
284+
.expect("invalid input port");
285+
OutgoingPort::from(out_index).into()
278286
}
279-
(Either::Left(incoming), true) if incoming.index() == 1 - shared_qb_out => {
280-
// out_node on the not shared qubit -> port 1
281-
OutgoingPort::from(1).into()
287+
Either::Right(outgoing) => {
288+
let out_boundary: [(_, OutgoingPort); 3] = [
289+
(in_node, shared_qb_in.into()),
290+
(out_node, (1 - shared_qb_out).into()),
291+
(in_node, (1 - shared_qb_in).into()),
292+
];
293+
let in_index = out_boundary
294+
.iter()
295+
.position(|&(n, p)| n == node && p == outgoing)
296+
.expect("invalid output port");
297+
IncomingPort::from(in_index).into()
282298
}
283-
(Either::Left(incoming), false) if incoming.index() == 1 - shared_qb_in => {
284-
// in_node on the not shared qubit -> port 2
285-
OutgoingPort::from(2).into()
286-
}
287-
(Either::Right(outgoing), false) if outgoing.index() == shared_qb_in => {
288-
// in_node on the shared qubit -> port 0
289-
IncomingPort::from(0).into()
290-
}
291-
(Either::Right(outgoing), true) if outgoing.index() == 1 - shared_qb_out => {
292-
// out_node on the not shared qubit -> port 1
293-
IncomingPort::from(1).into()
294-
}
295-
(Either::Right(outgoing), false) if outgoing.index() == 1 - shared_qb_in => {
296-
// in_node on the not shared qubit -> port 2
297-
IncomingPort::from(2).into()
298-
}
299-
_ => panic!("unexpected boundary port"),
300299
}
301300
})
302301
}
@@ -305,7 +304,7 @@ fn create_commit(wire: PersistentWire, walker: &Walker) -> Option<Commit> {
305304
.ok()
306305
}
307306

308-
#[ignore = "takes 10s (to be optimised)"]
307+
#[ignore = "takes 10s (todo: optimise)"]
309308
#[test]
310309
fn walker_example() {
311310
let state_space = build_state_space();
@@ -340,7 +339,6 @@ fn walker_example() {
340339
// and such that the resulting HUGR is empty
341340
let mut empty_hugr = None;
342341
for cs in empty_commits.iter().combinations(3) {
343-
// for cs in empty_commits.iter().combinations(2) {
344342
let cs = cs.into_iter().copied();
345343
if let Ok(hugr) = state_space.try_extract_hugr(cs) {
346344
empty_hugr = Some(hugr);

0 commit comments

Comments
 (0)