@@ -61,11 +61,22 @@ fn reachable_funcs<'a, H: HugrView>(
61
61
} ) )
62
62
}
63
63
64
- #[ derive( Debug , Clone , Default ) ]
64
+ #[ derive( Debug , Clone ) ]
65
65
/// A configuration for the Dead Function Removal pass.
66
66
pub struct RemoveDeadFuncsPass {
67
67
validation : ValidationLevel ,
68
68
entry_points : Vec < Node > ,
69
+ include_exports : bool ,
70
+ }
71
+
72
+ impl Default for RemoveDeadFuncsPass {
73
+ fn default ( ) -> Self {
74
+ Self {
75
+ validation : Default :: default ( ) ,
76
+ entry_points : Default :: default ( ) ,
77
+ include_exports : true ,
78
+ }
79
+ }
69
80
}
70
81
71
82
impl RemoveDeadFuncsPass {
@@ -88,10 +99,28 @@ impl RemoveDeadFuncsPass {
88
99
self
89
100
}
90
101
102
+ /// Sets whether the exported [FuncDefn](hugr_core::ops::FuncDefn) children of a
103
+ /// [Module](hugr_core::ops::Module) are included as entry points (yes by default)
104
+ pub fn include_module_exports ( mut self , include : bool ) -> Self {
105
+ self . include_exports = include;
106
+ self
107
+ }
108
+
91
109
/// Runs the pass (see [remove_dead_funcs]) with this configuration
92
110
pub fn run < H : HugrMut > ( & self , hugr : & mut H ) -> Result < ( ) , RemoveDeadFuncsError > {
93
111
self . validation . run_validated_pass ( hugr, |hugr : & mut H , _| {
94
- remove_dead_funcs ( hugr, self . entry_points . iter ( ) . cloned ( ) )
112
+ let exports = if hugr. root_type ( ) . is_module ( ) && self . include_exports {
113
+ hugr. children ( hugr. root ( ) )
114
+ . filter ( |ch| {
115
+ hugr. get_optype ( * ch)
116
+ . as_func_defn ( )
117
+ . is_some_and ( |fd| fd. public )
118
+ } )
119
+ . collect ( )
120
+ } else {
121
+ vec ! [ ]
122
+ } ;
123
+ remove_dead_funcs ( hugr, self . entry_points . iter ( ) . cloned ( ) . chain ( exports) )
95
124
} )
96
125
}
97
126
}
@@ -145,26 +174,29 @@ mod test {
145
174
use super :: RemoveDeadFuncsPass ;
146
175
147
176
#[ rstest]
148
- #[ case( [ ] , vec![ ] ) ] // No entry_points removes everything!
149
- #[ case( [ "main" ] , vec![ "from_main" , "main" ] ) ]
150
- #[ case( [ "from_main" ] , vec![ "from_main" ] ) ]
151
- #[ case( [ "other1" ] , vec![ "other1" , "other2" ] ) ]
152
- #[ case( [ "other2" ] , vec![ "other2" ] ) ]
153
- #[ case( [ "other1" , "other2" ] , vec![ "other1" , "other2" ] ) ]
177
+ #[ case( false , [ ] , vec![ ] ) ] // No entry_points removes everything!
178
+ #[ case( false , [ "main" ] , vec![ "from_main" , "main" ] ) ]
179
+ #[ case( false , [ "from_main" ] , vec![ "from_main" ] ) ]
180
+ #[ case( false , [ "other1" ] , vec![ "other1" , "other2" ] ) ]
181
+ #[ case( false , [ "other2" ] , vec![ "other2" ] ) ]
182
+ #[ case( false , [ "other1" , "other2" ] , vec![ "other1" , "other2" ] ) ]
183
+ #[ case( true , [ ] , vec![ "from_main" , "main" , "other2" ] ) ]
184
+ #[ case( true , [ "other1" ] , vec![ "from_main" , "main" , "other1" , "other2" ] ) ]
154
185
fn remove_dead_funcs_entry_points (
186
+ #[ case] include_exports : bool ,
155
187
#[ case] entry_points : impl IntoIterator < Item = & ' static str > ,
156
188
#[ case] retained_funcs : Vec < & ' static str > ,
157
189
) -> Result < ( ) , Box < dyn std:: error:: Error > > {
158
190
let mut hb = ModuleBuilder :: new ( ) ;
159
191
let o2 = hb. define_function ( "other2" , Signature :: new_endo ( usize_t ( ) ) ) ?;
160
192
let o2inp = o2. input_wires ( ) ;
161
193
let o2 = o2. finish_with_outputs ( o2inp) ?;
162
- let mut o1 = hb. define_function ( "other1" , Signature :: new_endo ( usize_t ( ) ) ) ?;
194
+ let mut o1 = hb. define_function_vis ( "other1" , Signature :: new_endo ( usize_t ( ) ) , false ) ?;
163
195
164
196
let o1c = o1. call ( o2. handle ( ) , & [ ] , o1. input_wires ( ) ) ?;
165
197
o1. finish_with_outputs ( o1c. outputs ( ) ) ?;
166
198
167
- let fm = hb. define_function ( "from_main" , Signature :: new_endo ( usize_t ( ) ) ) ?;
199
+ let fm = hb. define_function_vis ( "from_main" , Signature :: new_endo ( usize_t ( ) ) , false ) ?;
168
200
let f_inp = fm. input_wires ( ) ;
169
201
let fm = fm. finish_with_outputs ( f_inp) ?;
170
202
let mut m = hb. define_function ( "main" , Signature :: new_endo ( usize_t ( ) ) ) ?;
@@ -183,6 +215,7 @@ mod test {
183
215
. collect :: < HashMap < _ , _ > > ( ) ;
184
216
185
217
RemoveDeadFuncsPass :: default ( )
218
+ . include_module_exports ( include_exports)
186
219
. with_module_entry_points (
187
220
entry_points
188
221
. into_iter ( )
0 commit comments