@@ -21,7 +21,7 @@ use super::call_graph::{CallGraph, CallGraphNode};
21
21
#[ non_exhaustive]
22
22
/// Errors produced by [`RemoveDeadFuncsPass`].
23
23
pub enum RemoveDeadFuncsError < N = Node > {
24
- /// The specified entry point is not a `FuncDefn` node or is not a child of the root.
24
+ /// The specified entry point is not a `FuncDefn` node
25
25
#[ error(
26
26
"Entrypoint for RemoveDeadFuncsPass {node} was not a function definition in the root module"
27
27
) ]
@@ -35,30 +35,17 @@ fn reachable_funcs<'a, H: HugrView>(
35
35
cg : & ' a CallGraph < H :: Node > ,
36
36
h : & ' a H ,
37
37
entry_points : impl IntoIterator < Item = H :: Node > ,
38
- ) -> Result < impl Iterator < Item = H :: Node > + ' a , RemoveDeadFuncsError < H :: Node > > {
38
+ ) -> impl Iterator < Item = H :: Node > + ' a {
39
39
let g = cg. graph ( ) ;
40
- let mut entry_points = entry_points. into_iter ( ) ;
41
- let searcher = if h. get_optype ( h. entrypoint ( ) ) . is_module ( ) {
42
- let mut d = Dfs :: new ( g, 0 . into ( ) ) ;
43
- d. stack . clear ( ) ;
44
- for n in entry_points {
45
- if !h. get_optype ( n) . is_func_defn ( ) || h. get_parent ( n) != Some ( h. entrypoint ( ) ) {
46
- return Err ( RemoveDeadFuncsError :: InvalidEntryPoint { node : n } ) ;
47
- }
48
- d. stack . push ( cg. node_index ( n) . unwrap ( ) ) ;
49
- }
50
- d
51
- } else {
52
- if let Some ( n) = entry_points. next ( ) {
53
- // Can't be a child of the module root as there isn't a module root!
54
- return Err ( RemoveDeadFuncsError :: InvalidEntryPoint { node : n } ) ;
55
- }
56
- Dfs :: new ( g, cg. node_index ( h. entrypoint ( ) ) . unwrap ( ) )
57
- } ;
58
- Ok ( searcher. iter ( g) . map ( |i| match g. node_weight ( i) . unwrap ( ) {
40
+ let mut d = Dfs :: new ( g, 0 . into ( ) ) ;
41
+ d. stack . clear ( ) ; // Remove the fake 0
42
+ for n in entry_points {
43
+ d. stack . push ( cg. node_index ( n) . unwrap ( ) ) ;
44
+ }
45
+ d. iter ( g) . map ( |i| match g. node_weight ( i) . unwrap ( ) {
59
46
CallGraphNode :: FuncDefn ( n) | CallGraphNode :: FuncDecl ( n) => * n,
60
47
CallGraphNode :: NonFuncRoot => h. entrypoint ( ) ,
61
- } ) )
48
+ } )
62
49
}
63
50
64
51
#[ derive( Debug , Clone , Default ) ]
@@ -86,14 +73,31 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
86
73
type Error = RemoveDeadFuncsError ;
87
74
type Result = ( ) ;
88
75
fn run ( & self , hugr : & mut H ) -> Result < ( ) , RemoveDeadFuncsError > {
89
- let reachable = reachable_funcs (
90
- & CallGraph :: new ( hugr) ,
91
- hugr,
92
- self . entry_points . iter ( ) . copied ( ) ,
93
- ) ?
94
- . collect :: < HashSet < _ > > ( ) ;
76
+ let mut entry_points = Vec :: new ( ) ;
77
+ for & n in self . entry_points . iter ( ) {
78
+ if !hugr. get_optype ( n) . is_func_defn ( ) {
79
+ return Err ( RemoveDeadFuncsError :: InvalidEntryPoint { node : n } ) ;
80
+ }
81
+ debug_assert_eq ! ( hugr. get_parent( n) , Some ( hugr. module_root( ) ) ) ;
82
+ entry_points. push ( n) ;
83
+ }
84
+ if hugr. entrypoint ( ) != hugr. module_root ( ) {
85
+ entry_points. push ( hugr. entrypoint ( ) )
86
+ }
87
+
88
+ let mut reachable =
89
+ reachable_funcs ( & CallGraph :: new ( hugr) , hugr, entry_points) . collect :: < HashSet < _ > > ( ) ;
90
+ // Also prevent removing the entrypoint itself
91
+ let mut n = Some ( hugr. entrypoint ( ) ) ;
92
+ while let Some ( n2) = n {
93
+ n = hugr. get_parent ( n2) ;
94
+ if n == Some ( hugr. module_root ( ) ) {
95
+ reachable. insert ( n2) ;
96
+ }
97
+ }
98
+
95
99
let unreachable = hugr
96
- . entry_descendants ( )
100
+ . children ( hugr . module_root ( ) )
97
101
. filter ( |n| {
98
102
OpTag :: Function . is_superset ( hugr. get_optype ( * n) . tag ( ) ) && !reachable. contains ( n)
99
103
} )
@@ -108,17 +112,13 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for RemoveDeadFuncsPass {
108
112
/// Deletes from the Hugr any functions that are not used by either [`Call`] or
109
113
/// [`LoadFunction`] nodes in reachable parts.
110
114
///
111
- /// For [`Module`]-rooted Hugrs, `entry_points` may provide a list of entry points,
112
- /// which must be children of the root. Note that if `entry_points` is empty, this will
113
- /// result in all functions in the module being removed.
114
- ///
115
- /// For non-[`Module`]-rooted Hugrs, `entry_points` must be empty; the root node is used.
115
+ /// `entry_points` may provide a list of entry points, which must be [`FuncDefn`]s (children of the root).
116
+ /// The [HugrView::entrypoint] will also be used unless it is the [HugrView::module_root].
117
+ /// Note that for a [`Module`]-rooted Hugr with no `entry_points` provided, this will remove
118
+ /// all functions from the module.
116
119
///
117
120
/// # Errors
118
- /// * If there are any `entry_points` but the root of the hugr is not a [`Module`]
119
- /// * If any node in `entry_points` is
120
- /// * not a [`FuncDefn`], or
121
- /// * not a child of the root
121
+ /// * If any node in `entry_points` is not a [`FuncDefn`]
122
122
///
123
123
/// [`Call`]: hugr_core::ops::OpType::Call
124
124
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
@@ -138,22 +138,26 @@ pub fn remove_dead_funcs(
138
138
mod test {
139
139
use std:: collections:: HashMap ;
140
140
141
+ use hugr_core:: ops:: handle:: NodeHandle ;
141
142
use itertools:: Itertools ;
142
143
use rstest:: rstest;
143
144
144
145
use hugr_core:: builder:: { Dataflow , DataflowSubContainer , HugrBuilder , ModuleBuilder } ;
146
+ use hugr_core:: hugr:: hugrmut:: HugrMut ;
145
147
use hugr_core:: { HugrView , extension:: prelude:: usize_t, types:: Signature } ;
146
148
147
149
use super :: remove_dead_funcs;
148
150
149
151
#[ rstest]
150
- #[ case( [ ] , vec![ ] ) ] // No entry_points removes everything!
151
- #[ case( [ "main" ] , vec![ "from_main" , "main" ] ) ]
152
- #[ case( [ "from_main" ] , vec![ "from_main" ] ) ]
153
- #[ case( [ "other1" ] , vec![ "other1" , "other2" ] ) ]
154
- #[ case( [ "other2" ] , vec![ "other2" ] ) ]
155
- #[ case( [ "other1" , "other2" ] , vec![ "other1" , "other2" ] ) ]
152
+ #[ case( false , [ ] , vec![ ] ) ] // No entry_points removes everything!
153
+ #[ case( true , [ ] , vec![ "from_main" , "main" ] ) ]
154
+ #[ case( false , [ "main" ] , vec![ "from_main" , "main" ] ) ]
155
+ #[ case( false , [ "from_main" ] , vec![ "from_main" ] ) ]
156
+ #[ case( false , [ "other1" ] , vec![ "other1" , "other2" ] ) ]
157
+ #[ case( true , [ "other2" ] , vec![ "from_main" , "main" , "other2" ] ) ]
158
+ #[ case( false , [ "other1" , "other2" ] , vec![ "other1" , "other2" ] ) ]
156
159
fn remove_dead_funcs_entry_points (
160
+ #[ case] use_hugr_entrypoint : bool ,
157
161
#[ case] entry_points : impl IntoIterator < Item = & ' static str > ,
158
162
#[ case] retained_funcs : Vec < & ' static str > ,
159
163
) -> Result < ( ) , Box < dyn std:: error:: Error > > {
@@ -171,12 +175,15 @@ mod test {
171
175
let fm = fm. finish_with_outputs ( f_inp) ?;
172
176
let mut m = hb. define_function ( "main" , Signature :: new_endo ( usize_t ( ) ) ) ?;
173
177
let mc = m. call ( fm. handle ( ) , & [ ] , m. input_wires ( ) ) ?;
174
- m. finish_with_outputs ( mc. outputs ( ) ) ?;
178
+ let m = m. finish_with_outputs ( mc. outputs ( ) ) ?;
175
179
176
180
let mut hugr = hb. finish_hugr ( ) ?;
181
+ if use_hugr_entrypoint {
182
+ hugr. set_entrypoint ( m. node ( ) ) ;
183
+ }
177
184
178
185
let avail_funcs = hugr
179
- . entry_descendants ( )
186
+ . children ( hugr . module_root ( ) )
180
187
. filter_map ( |n| {
181
188
hugr. get_optype ( n)
182
189
. as_func_defn ( )
0 commit comments