Skip to content

Commit 71f4106

Browse files
committed
Move IncludeExports up to root of hugr-passes, use for const_fold, +deprecate
1 parent 2ea97f4 commit 71f4106

File tree

5 files changed

+68
-33
lines changed

5 files changed

+68
-33
lines changed

hugr-passes/src/const_fold.rs

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,23 @@ use hugr_core::{
1616
};
1717
use value_handle::ValueHandle;
1818

19-
use crate::dataflow::{
20-
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
21-
partial_from_const,
22-
};
2319
use crate::dead_code::{DeadCodeElimPass, PreserveNode};
2420
use crate::{ComposablePass, composable::validate_if_test};
21+
use crate::{
22+
IncludeExports,
23+
dataflow::{
24+
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
25+
partial_from_const,
26+
},
27+
};
2528

2629
#[derive(Debug, Clone, Default)]
2730
/// A configuration for the Constant Folding pass.
2831
pub struct ConstantFoldPass {
2932
allow_increase_termination: bool,
3033
/// Each outer key Node must be either:
3134
/// - a `FuncDefn` child of the root, if the root is a module; or
32-
/// - the root, if the root is not a Module
35+
/// - the entrypoint, if the entrypoint is not a Module
3336
inputs: HashMap<Node, HashMap<IncomingPort, Value>>,
3437
}
3538

@@ -185,25 +188,53 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
185188
}
186189
}
187190

191+
const NO_INPUTS: [(IncomingPort, Value); 0] = [];
192+
188193
/// Exhaustively apply constant folding to a HUGR.
189194
/// If the Hugr's entrypoint is its [`Module`], assumes all [`FuncDefn`] children are reachable.
190195
///
191196
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
192197
/// [`Module`]: hugr_core::ops::OpType::Module
198+
#[deprecated(note = "Use constant_fold_pass_pub, or manually configure ConstantFoldPass")]
193199
pub fn constant_fold_pass<H: HugrMut<Node = Node> + 'static>(mut h: impl AsMut<H>) {
194200
let h = h.as_mut();
195201
let c = ConstantFoldPass::default();
196202
let c = if h.get_optype(h.entrypoint()).is_module() {
197-
let no_inputs: [(IncomingPort, _); 0] = [];
198203
h.children(h.entrypoint())
199204
.filter(|n| h.get_optype(*n).is_func_defn())
200-
.fold(c, |c, n| c.with_inputs(n, no_inputs.iter().cloned()))
205+
.fold(c, |c, n| c.with_inputs(n, NO_INPUTS.clone()))
201206
} else {
202207
c
203208
};
204209
validate_if_test(c, h).unwrap();
205210
}
206211

212+
/// Exhaustively apply constant folding to a HUGR.
213+
/// Assumes that the Hugr's entrypoint is reachable (if it is not a [`Module`]).
214+
/// Also uses `policy` to determine which public [`FuncDefn`] children of the [`HugrView::module_root`] are reachable.
215+
///
216+
/// [`Module`]: hugr_core::ops::OpType::Module
217+
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
218+
pub fn constant_fold_pass_pub(
219+
h: &mut (impl HugrMut<Node = Node> + 'static),
220+
policy: IncludeExports,
221+
) {
222+
let mut funcs = Vec::new();
223+
if h.get_optype(h.entrypoint()).is_func_defn() {
224+
funcs.push(h.entrypoint());
225+
}
226+
if policy.for_hugr(&h) {
227+
funcs.extend(
228+
h.children(h.module_root())
229+
.filter(|n| h.get_optype(*n).is_func_defn()),
230+
)
231+
}
232+
let c = funcs.into_iter().fold(ConstantFoldPass::default(), |c, n| {
233+
c.with_inputs(n, NO_INPUTS.clone())
234+
});
235+
validate_if_test(c, h).unwrap();
236+
}
237+
207238
struct ConstFoldContext;
208239

209240
impl ConstLoader<ValueHandle<Node>> for ConstFoldContext {

hugr-passes/src/const_fold/test.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,14 @@ use hugr_core::std_extensions::logic::LogicOp;
2929
use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV};
3030
use hugr_core::{Hugr, HugrView, IncomingPort, Node, type_row};
3131

32-
use crate::ComposablePass as _;
3332
use crate::dataflow::{DFContext, PartialValue, partial_from_const};
33+
use crate::{ComposablePass as _, IncludeExports};
3434

35-
use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, constant_fold_pass};
35+
use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, constant_fold_pass_pub};
36+
37+
fn constant_fold_pass(h: &mut (impl HugrMut<Node = Node> + 'static)) {
38+
constant_fold_pass_pub(h, IncludeExports::Always);
39+
}
3640

3741
#[rstest]
3842
#[case(ConstInt::new_u(4, 2).unwrap(), true)]

hugr-passes/src/dead_funcs.rs

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use hugr_core::{
1111
use petgraph::visit::{Dfs, Walker};
1212

1313
use crate::{
14-
ComposablePass,
14+
ComposablePass, IncludeExports,
1515
composable::{ValidatePassError, validate_if_test},
1616
};
1717

@@ -31,14 +31,6 @@ pub enum RemoveDeadFuncsError<N = Node> {
3131
},
3232
}
3333

34-
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
35-
pub enum IncludeExports {
36-
Always,
37-
Never,
38-
#[default]
39-
OnlyIfEntrypointIsModuleRoot,
40-
}
41-
4234
fn reachable_funcs<'a, H: HugrView>(
4335
cg: &'a CallGraph<H::Node>,
4436
h: &'a H,
@@ -90,20 +82,7 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
9082
type Result = ();
9183
fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> {
9284
let mut entry_points = Vec::new();
93-
let include_exports = match self.include_exports {
94-
IncludeExports::Always => true,
95-
IncludeExports::Never => false,
96-
IncludeExports::OnlyIfEntrypointIsModuleRoot => hugr.entrypoint() == hugr.module_root(),
97-
};
98-
let include_exports2 = matches!(
99-
(
100-
self.include_exports,
101-
hugr.entrypoint() == hugr.module_root()
102-
),
103-
(IncludeExports::Always, _) | (IncludeExports::OnlyIfEntrypointIsModuleRoot, true)
104-
);
105-
assert_eq!(include_exports, include_exports2);
106-
if include_exports {
85+
if self.include_exports.for_hugr(hugr) {
10786
entry_points.extend(hugr.children(hugr.module_root()).filter(|ch| {
10887
hugr.get_optype(*ch)
10988
.as_func_defn()

hugr-passes/src/lib.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub use dead_funcs::{RemoveDeadFuncsError, RemoveDeadFuncsPass, remove_dead_func
1313
pub mod force_order;
1414
mod half_node;
1515
pub mod linearize_array;
16+
use hugr_core::HugrView;
1617
pub use linearize_array::LinearizeArrayPass;
1718
pub mod lower;
1819
pub mod merge_bbs;
@@ -28,3 +29,23 @@ pub use force_order::{force_order, force_order_by_key};
2829
pub use lower::{lower_ops, replace_many_ops};
2930
pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges};
3031
pub use untuple::UntuplePass;
32+
33+
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
34+
/// A policy for whether to include the public (exported) functions of a Hugr
35+
/// (typically, as starting points for analysis)
36+
pub enum IncludeExports {
37+
Always,
38+
Never,
39+
#[default]
40+
OnlyIfEntrypointIsModuleRoot,
41+
}
42+
43+
impl IncludeExports {
44+
/// Returns whether to include the public functions of a particular Hugr
45+
fn for_hugr(&self, h: &impl HugrView) -> bool {
46+
matches!(
47+
(self, h.entrypoint() == h.module_root()),
48+
(Self::Always, _) | (Self::OnlyIfEntrypointIsModuleRoot, true)
49+
)
50+
}
51+
}

hugr-passes/src/monomorphize.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ mod test {
290290
use hugr_core::{Hugr, HugrView, Node};
291291
use rstest::rstest;
292292

293-
use crate::dead_funcs::IncludeExports;
293+
use crate::IncludeExports;
294294
use crate::{ComposablePass, RemoveDeadFuncsPass, monomorphize, remove_dead_funcs};
295295

296296
use super::{is_polymorphic, mangle_name};

0 commit comments

Comments
 (0)