@@ -36,8 +36,12 @@ def can_fuse(n):
36
36
if dag .in_degree (op2 ) != 1 :
37
37
return False
38
38
39
- # if input is used by another node then don't fuse
39
+ # if input is one of the arrays being computed then don't fuse
40
40
op2_input = next (dag .predecessors (op2 ))
41
+ if array_names is not None and op2_input in array_names :
42
+ return False
43
+
44
+ # if input is used by another node then don't fuse
41
45
if dag .out_degree (op2_input ) != 1 :
42
46
return False
43
47
@@ -143,6 +147,7 @@ def can_fuse_predecessors(
143
147
dag ,
144
148
name ,
145
149
* ,
150
+ array_names = None ,
146
151
max_total_source_arrays = 4 ,
147
152
max_total_num_input_blocks = None ,
148
153
always_fuse = None ,
@@ -164,6 +169,20 @@ def can_fuse_predecessors(
164
169
logger .debug ("can't fuse %s since no predecessor ops can be fused" , name )
165
170
return False
166
171
172
+ # if a predecessor op produces one of the arrays being computed, then don't fuse
173
+ if array_names is not None :
174
+ predecessor_array_names = set (
175
+ array_name for _ , array_name , _ in predecessor_ops_and_arrays (dag , name )
176
+ )
177
+ array_names_intersect = set (array_names ) & predecessor_array_names
178
+ if len (array_names_intersect ) > 0 :
179
+ logger .debug (
180
+ "can't fuse %s since predecessor ops produce one or more arrays being computed %s" ,
181
+ name ,
182
+ array_names_intersect ,
183
+ )
184
+ return False
185
+
167
186
# if node is in never_fuse or always_fuse list then it overrides logic below
168
187
if never_fuse is not None and name in never_fuse :
169
188
logger .debug ("can't fuse %s since it is in 'never_fuse'" , name )
@@ -217,6 +236,7 @@ def fuse_predecessors(
217
236
if not can_fuse_predecessors (
218
237
dag ,
219
238
name ,
239
+ array_names = array_names ,
220
240
max_total_source_arrays = max_total_source_arrays ,
221
241
max_total_num_input_blocks = max_total_num_input_blocks ,
222
242
always_fuse = always_fuse ,
0 commit comments