Skip to content

Commit 5788e0f

Browse files
committed
Include in DeadCode/DeadFunc elimination
1 parent d255a23 commit 5788e0f

File tree

3 files changed

+68
-13
lines changed

3 files changed

+68
-13
lines changed

hugr-core/src/ops/constant.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ impl Const {
5656
&self.value
5757
}
5858

59+
/// Whether the [`Const`] would be exported for linking
60+
/// (only applies if the parent is a [Module](super::Module))
61+
pub fn is_public(&self) -> bool {
62+
self.name.is_some()
63+
}
64+
5965
delegate! {
6066
to self.value {
6167
/// Returns the type of this constant.

hugr-passes/src/dead_code.rs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub struct DeadCodeElimPass {
1919
/// results are not used. Defaults to [PreserveNode::default_for].
2020
preserve_callback: Arc<PreserveCallback>,
2121
validation: ValidationLevel,
22+
include_exports: bool,
2223
}
2324

2425
impl Default for DeadCodeElimPass {
@@ -27,6 +28,7 @@ impl Default for DeadCodeElimPass {
2728
entry_points: Default::default(),
2829
preserve_callback: Arc::new(PreserveNode::default_for),
2930
validation: ValidationLevel::default(),
31+
include_exports: true,
3032
}
3133
}
3234
}
@@ -102,18 +104,32 @@ impl DeadCodeElimPass {
102104

103105
/// Mark some nodes as entry points to the Hugr, i.e. so we cannot eliminate any code
104106
/// used to evaluate these nodes.
105-
/// The root node is assumed to be an entry point;
106-
/// for Module roots the client will want to mark some of the FuncDefn children
107-
/// as entry points too.
107+
/// The root node is assumed to be an entry point; for Module roots, any public
108+
/// [FuncDefn](OpType::FuncDefn)s and [Const](OpType::Const)s are also considered entry points
109+
/// by default, but these can be removed by [Self::include_module_exports].
108110
pub fn with_entry_points(mut self, entry_points: impl IntoIterator<Item = Node>) -> Self {
109111
self.entry_points.extend(entry_points);
110112
self
111113
}
112114

115+
/// Sets whether, for Module-rooted Hugrs, the exported [FuncDefn](OpType::FuncDefn)s
116+
/// and [Const](OpType::Const)s are included as entry points (they are by default)
117+
pub fn include_module_exports(mut self, include: bool) -> Self {
118+
self.include_exports = include;
119+
self
120+
}
121+
113122
fn find_needed_nodes(&self, h: impl HugrView<Node = Node>) -> HashSet<Node> {
114123
let mut must_preserve = HashMap::new();
115124
let mut needed = HashSet::new();
116125
let mut q = VecDeque::from_iter(self.entry_points.iter().cloned());
126+
if self.include_exports && h.root_type().is_module() {
127+
q.extend(h.children(h.root()).filter(|ch| {
128+
let op = h.get_optype(*ch);
129+
op.as_func_defn().is_some_and(|fd| fd.public)
130+
|| op.as_const().is_some_and(|c| c.is_public())
131+
}))
132+
}
117133
q.push_front(h.root());
118134
while let Some(n) = q.pop_front() {
119135
if !needed.insert(n) {

hugr-passes/src/dead_funcs.rs

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,22 @@ fn reachable_funcs<'a, H: HugrView>(
6161
}))
6262
}
6363

64-
#[derive(Debug, Clone, Default)]
64+
#[derive(Debug, Clone)]
6565
/// A configuration for the Dead Function Removal pass.
6666
pub struct RemoveDeadFuncsPass {
6767
validation: ValidationLevel,
6868
entry_points: Vec<Node>,
69+
include_exports: bool,
70+
}
71+
72+
impl Default for RemoveDeadFuncsPass {
73+
fn default() -> Self {
74+
Self {
75+
validation: Default::default(),
76+
entry_points: Default::default(),
77+
include_exports: true,
78+
}
79+
}
6980
}
7081

7182
impl RemoveDeadFuncsPass {
@@ -88,10 +99,28 @@ impl RemoveDeadFuncsPass {
8899
self
89100
}
90101

102+
/// Sets whether the exported [FuncDefn](hugr_core::ops::FuncDefn) children of a
103+
/// [Module](hugr_core::ops::Module) are included as entry points (yes by default)
104+
pub fn include_module_exports(mut self, include: bool) -> Self {
105+
self.include_exports = include;
106+
self
107+
}
108+
91109
/// Runs the pass (see [remove_dead_funcs]) with this configuration
92110
pub fn run<H: HugrMut>(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
93111
self.validation.run_validated_pass(hugr, |hugr: &mut H, _| {
94-
remove_dead_funcs(hugr, self.entry_points.iter().cloned())
112+
let exports = if hugr.root_type().is_module() && self.include_exports {
113+
hugr.children(hugr.root())
114+
.filter(|ch| {
115+
hugr.get_optype(*ch)
116+
.as_func_defn()
117+
.is_some_and(|fd| fd.public)
118+
})
119+
.collect()
120+
} else {
121+
vec![]
122+
};
123+
remove_dead_funcs(hugr, self.entry_points.iter().cloned().chain(exports))
95124
})
96125
}
97126
}
@@ -145,26 +174,29 @@ mod test {
145174
use super::RemoveDeadFuncsPass;
146175

147176
#[rstest]
148-
#[case([], vec![])] // No entry_points removes everything!
149-
#[case(["main"], vec!["from_main", "main"])]
150-
#[case(["from_main"], vec!["from_main"])]
151-
#[case(["other1"], vec!["other1", "other2"])]
152-
#[case(["other2"], vec!["other2"])]
153-
#[case(["other1", "other2"], vec!["other1", "other2"])]
177+
#[case(false, [], vec![])] // No entry_points removes everything!
178+
#[case(false, ["main"], vec!["from_main", "main"])]
179+
#[case(false, ["from_main"], vec!["from_main"])]
180+
#[case(false, ["other1"], vec!["other1", "other2"])]
181+
#[case(false, ["other2"], vec!["other2"])]
182+
#[case(false, ["other1", "other2"], vec!["other1", "other2"])]
183+
#[case(true, [], vec!["from_main", "main", "other2"])]
184+
#[case(true, ["other1"], vec!["from_main", "main", "other1", "other2"])]
154185
fn remove_dead_funcs_entry_points(
186+
#[case] include_exports: bool,
155187
#[case] entry_points: impl IntoIterator<Item = &'static str>,
156188
#[case] retained_funcs: Vec<&'static str>,
157189
) -> Result<(), Box<dyn std::error::Error>> {
158190
let mut hb = ModuleBuilder::new();
159191
let o2 = hb.define_function("other2", Signature::new_endo(usize_t()))?;
160192
let o2inp = o2.input_wires();
161193
let o2 = o2.finish_with_outputs(o2inp)?;
162-
let mut o1 = hb.define_function("other1", Signature::new_endo(usize_t()))?;
194+
let mut o1 = hb.define_function_vis("other1", Signature::new_endo(usize_t()), false)?;
163195

164196
let o1c = o1.call(o2.handle(), &[], o1.input_wires())?;
165197
o1.finish_with_outputs(o1c.outputs())?;
166198

167-
let fm = hb.define_function("from_main", Signature::new_endo(usize_t()))?;
199+
let fm = hb.define_function_vis("from_main", Signature::new_endo(usize_t()), false)?;
168200
let f_inp = fm.input_wires();
169201
let fm = fm.finish_with_outputs(f_inp)?;
170202
let mut m = hb.define_function("main", Signature::new_endo(usize_t()))?;
@@ -183,6 +215,7 @@ mod test {
183215
.collect::<HashMap<_, _>>();
184216

185217
RemoveDeadFuncsPass::default()
218+
.include_module_exports(include_exports)
186219
.with_module_entry_points(
187220
entry_points
188221
.into_iter()

0 commit comments

Comments
 (0)