Skip to content

Commit 5901054

Browse files
committed
Fix for multiple_inputs_optimize_dag
Fix for simple_optimize_dag
1 parent 381aa7d commit 5901054

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

cubed/core/optimization.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@ def can_fuse(n):
3636
if dag.in_degree(op2) != 1:
3737
return False
3838

39-
# if input is used by another node then don't fuse
39+
# if input is one of the arrays being computed then don't fuse
4040
op2_input = next(dag.predecessors(op2))
41+
if array_names is not None and op2_input in array_names:
42+
return False
43+
44+
# if input is used by another node then don't fuse
4145
if dag.out_degree(op2_input) != 1:
4246
return False
4347

@@ -143,6 +147,7 @@ def can_fuse_predecessors(
143147
dag,
144148
name,
145149
*,
150+
array_names=None,
146151
max_total_source_arrays=4,
147152
max_total_num_input_blocks=None,
148153
always_fuse=None,
@@ -164,6 +169,20 @@ def can_fuse_predecessors(
164169
logger.debug("can't fuse %s since no predecessor ops can be fused", name)
165170
return False
166171

172+
# if a predecessor op produces one of the arrays being computed, then don't fuse
173+
if array_names is not None:
174+
predecessor_array_names = set(
175+
array_name for _, array_name, _ in predecessor_ops_and_arrays(dag, name)
176+
)
177+
array_names_intersect = set(array_names) & predecessor_array_names
178+
if len(array_names_intersect) > 0:
179+
logger.debug(
180+
"can't fuse %s since predecessor ops produce one or more arrays being computed %s",
181+
name,
182+
array_names_intersect,
183+
)
184+
return False
185+
167186
# if node is in never_fuse or always_fuse list then it overrides logic below
168187
if never_fuse is not None and name in never_fuse:
169188
logger.debug("can't fuse %s since it is in 'never_fuse'", name)
@@ -217,6 +236,7 @@ def fuse_predecessors(
217236
if not can_fuse_predecessors(
218237
dag,
219238
name,
239+
array_names=array_names,
220240
max_total_source_arrays=max_total_source_arrays,
221241
max_total_num_input_blocks=max_total_num_input_blocks,
222242
always_fuse=always_fuse,

0 commit comments

Comments
 (0)