Skip to content

Commit 187cc8b

Browse files
committed
Cherry-pick call_graph + dead_funcs
1 parent 83ace81 commit 187cc8b

File tree

2 files changed

+64
-54
lines changed

2 files changed

+64
-54
lines changed

hugr-passes/src/call_graph.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,17 @@ pub enum CallGraphNode<N = Node> {
2626
}
2727

2828
/// Details the [`Call`]s and [`LoadFunction`]s in a Hugr.
29-
/// Each node in the `CallGraph` corresponds to a [`FuncDefn`] in the Hugr; each edge corresponds
30-
/// to a [`Call`]/[`LoadFunction`] of the edge's target, contained in the edge's source.
29+
///
30+
/// Each node in the `CallGraph` corresponds to a [`FuncDefn`] or [`FuncDecl`] in the Hugr;
31+
/// each edge corresponds to a [`Call`]/[`LoadFunction`] of the edge's target, contained in
32+
/// the edge's source.
3133
///
32-
/// For Hugrs whose root is neither a [Module](OpType::Module) nor a [`FuncDefn`], the call graph
33-
/// will have an additional [`CallGraphNode::NonFuncRoot`] corresponding to the Hugr's root, with no incoming edges.
34+
/// For Hugrs whose entrypoint is neither a [Module](OpType::Module) nor a [`FuncDefn`], the
35+
/// call graph will have an additional [`CallGraphNode::NonFuncRoot`] corresponding to the Hugr's
36+
/// entrypoint, with no incoming edges.
3437
///
3538
/// [`Call`]: OpType::Call
39+
/// [`FuncDecl`]: OpType::FuncDecl
3640
/// [`FuncDefn`]: OpType::FuncDefn
3741
/// [`LoadFunction`]: OpType::LoadFunction
3842
pub struct CallGraph<N = Node> {
@@ -41,14 +45,13 @@ pub struct CallGraph<N = Node> {
4145
}
4246

4347
impl<N: HugrNode> CallGraph<N> {
44-
/// Makes a new `CallGraph` for a specified (subview) of a Hugr.
45-
/// Calls to functions outside the view will be dropped.
48+
/// Makes a new `CallGraph` for a Hugr.
4649
pub fn new(hugr: &impl HugrView<Node = N>) -> Self {
4750
let mut g = Graph::default();
4851
let non_func_root =
4952
(!hugr.get_optype(hugr.entrypoint()).is_module()).then_some(hugr.entrypoint());
5053
let node_to_g = hugr
51-
.entry_descendants()
54+
.children(hugr.module_root())
5255
.filter_map(|n| {
5356
let weight = match hugr.get_optype(n) {
5457
OpType::FuncDecl(_) => CallGraphNode::FuncDecl(n),
@@ -94,7 +97,7 @@ impl<N: HugrNode> CallGraph<N> {
9497

9598
/// Convert a Hugr [Node] into a petgraph node index.
9699
/// Result will be `None` if `n` is not a [`FuncDefn`](OpType::FuncDefn),
97-
/// [`FuncDecl`](OpType::FuncDecl) or the hugr root.
100+
/// [`FuncDecl`](OpType::FuncDecl) or the [HugrView::entrypoint].
98101
pub fn node_index(&self, n: N) -> Option<petgraph::graph::NodeIndex<u32>> {
99102
self.node_to_g.get(&n).copied()
100103
}

hugr-passes/src/dead_funcs.rs

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use super::call_graph::{CallGraph, CallGraphNode};
2121
#[non_exhaustive]
2222
/// Errors produced by [`RemoveDeadFuncsPass`].
2323
pub enum RemoveDeadFuncsError<N = Node> {
24-
/// The specified entry point is not a `FuncDefn` node or is not a child of the root.
24+
/// The specified entry point is not a `FuncDefn` node
2525
#[error(
2626
"Entrypoint for RemoveDeadFuncsPass {node} was not a function definition in the root module"
2727
)]
@@ -35,30 +35,17 @@ fn reachable_funcs<'a, H: HugrView>(
3535
cg: &'a CallGraph<H::Node>,
3636
h: &'a H,
3737
entry_points: impl IntoIterator<Item = H::Node>,
38-
) -> Result<impl Iterator<Item = H::Node> + 'a, RemoveDeadFuncsError<H::Node>> {
38+
) -> impl Iterator<Item = H::Node> + 'a {
3939
let g = cg.graph();
40-
let mut entry_points = entry_points.into_iter();
41-
let searcher = if h.get_optype(h.entrypoint()).is_module() {
42-
let mut d = Dfs::new(g, 0.into());
43-
d.stack.clear();
44-
for n in entry_points {
45-
if !h.get_optype(n).is_func_defn() || h.get_parent(n) != Some(h.entrypoint()) {
46-
return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n });
47-
}
48-
d.stack.push(cg.node_index(n).unwrap());
49-
}
50-
d
51-
} else {
52-
if let Some(n) = entry_points.next() {
53-
// Can't be a child of the module root as there isn't a module root!
54-
return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n });
55-
}
56-
Dfs::new(g, cg.node_index(h.entrypoint()).unwrap())
57-
};
58-
Ok(searcher.iter(g).map(|i| match g.node_weight(i).unwrap() {
40+
let mut d = Dfs::new(g, 0.into());
41+
d.stack.clear(); // Remove the fake 0
42+
for n in entry_points {
43+
d.stack.push(cg.node_index(n).unwrap());
44+
}
45+
d.iter(g).map(|i| match g.node_weight(i).unwrap() {
5946
CallGraphNode::FuncDefn(n) | CallGraphNode::FuncDecl(n) => *n,
6047
CallGraphNode::NonFuncRoot => h.entrypoint(),
61-
}))
48+
})
6249
}
6350

6451
#[derive(Debug, Clone, Default)]
@@ -86,14 +73,31 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
8673
type Error = RemoveDeadFuncsError;
8774
type Result = ();
8875
fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
89-
let reachable = reachable_funcs(
90-
&CallGraph::new(hugr),
91-
hugr,
92-
self.entry_points.iter().copied(),
93-
)?
94-
.collect::<HashSet<_>>();
76+
let mut entry_points = Vec::new();
77+
for &n in self.entry_points.iter() {
78+
if !hugr.get_optype(n).is_func_defn() {
79+
return Err(RemoveDeadFuncsError::InvalidEntryPoint { node: n });
80+
}
81+
debug_assert_eq!(hugr.get_parent(n), Some(hugr.module_root()));
82+
entry_points.push(n);
83+
}
84+
if hugr.entrypoint() != hugr.module_root() {
85+
entry_points.push(hugr.entrypoint())
86+
}
87+
88+
let mut reachable =
89+
reachable_funcs(&CallGraph::new(hugr), hugr, entry_points).collect::<HashSet<_>>();
90+
// Also prevent removing the entrypoint itself
91+
let mut n = Some(hugr.entrypoint());
92+
while let Some(n2) = n {
93+
n = hugr.get_parent(n2);
94+
if n == Some(hugr.module_root()) {
95+
reachable.insert(n2);
96+
}
97+
}
98+
9599
let unreachable = hugr
96-
.entry_descendants()
100+
.children(hugr.module_root())
97101
.filter(|n| {
98102
OpTag::Function.is_superset(hugr.get_optype(*n).tag()) && !reachable.contains(n)
99103
})
@@ -108,17 +112,13 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
108112
/// Deletes from the Hugr any functions that are not used by either [`Call`] or
109113
/// [`LoadFunction`] nodes in reachable parts.
110114
///
111-
/// For [`Module`]-rooted Hugrs, `entry_points` may provide a list of entry points,
112-
/// which must be children of the root. Note that if `entry_points` is empty, this will
113-
/// result in all functions in the module being removed.
114-
///
115-
/// For non-[`Module`]-rooted Hugrs, `entry_points` must be empty; the root node is used.
115+
/// `entry_points` may provide a list of entry points, which must be [`FuncDefn`]s (children of the root).
116+
/// The [HugrView::entrypoint] will also be used unless it is the [HugrView::module_root].
117+
/// Note that for a [`Module`]-rooted Hugr with no `entry_points` provided, this will remove
118+
/// all functions from the module.
116119
///
117120
/// # Errors
118-
/// * If there are any `entry_points` but the root of the hugr is not a [`Module`]
119-
/// * If any node in `entry_points` is
120-
/// * not a [`FuncDefn`], or
121-
/// * not a child of the root
121+
/// * If any node in `entry_points` is not a [`FuncDefn`]
122122
///
123123
/// [`Call`]: hugr_core::ops::OpType::Call
124124
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
@@ -138,22 +138,26 @@ pub fn remove_dead_funcs(
138138
mod test {
139139
use std::collections::HashMap;
140140

141+
use hugr_core::ops::handle::NodeHandle;
141142
use itertools::Itertools;
142143
use rstest::rstest;
143144

144145
use hugr_core::builder::{Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
146+
use hugr_core::hugr::hugrmut::HugrMut;
145147
use hugr_core::{HugrView, extension::prelude::usize_t, types::Signature};
146148

147149
use super::remove_dead_funcs;
148150

149151
#[rstest]
150-
#[case([], vec![])] // No entry_points removes everything!
151-
#[case(["main"], vec!["from_main", "main"])]
152-
#[case(["from_main"], vec!["from_main"])]
153-
#[case(["other1"], vec!["other1", "other2"])]
154-
#[case(["other2"], vec!["other2"])]
155-
#[case(["other1", "other2"], vec!["other1", "other2"])]
152+
#[case(false, [], vec![])] // No entry_points removes everything!
153+
#[case(true, [], vec!["from_main", "main"])]
154+
#[case(false, ["main"], vec!["from_main", "main"])]
155+
#[case(false, ["from_main"], vec!["from_main"])]
156+
#[case(false, ["other1"], vec!["other1", "other2"])]
157+
#[case(true, ["other2"], vec!["from_main", "main", "other2"])]
158+
#[case(false, ["other1", "other2"], vec!["other1", "other2"])]
156159
fn remove_dead_funcs_entry_points(
160+
#[case] use_hugr_entrypoint: bool,
157161
#[case] entry_points: impl IntoIterator<Item = &'static str>,
158162
#[case] retained_funcs: Vec<&'static str>,
159163
) -> Result<(), Box<dyn std::error::Error>> {
@@ -171,12 +175,15 @@ mod test {
171175
let fm = fm.finish_with_outputs(f_inp)?;
172176
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?;
173177
let mc = m.call(fm.handle(), &[], m.input_wires())?;
174-
m.finish_with_outputs(mc.outputs())?;
178+
let m = m.finish_with_outputs(mc.outputs())?;
175179

176180
let mut hugr = hb.finish_hugr()?;
181+
if use_hugr_entrypoint {
182+
hugr.set_entrypoint(m.node());
183+
}
177184

178185
let avail_funcs = hugr
179-
.entry_descendants()
186+
.children(hugr.module_root())
180187
.filter_map(|n| {
181188
hugr.get_optype(n)
182189
.as_func_defn()

0 commit comments

Comments
 (0)