Skip to content

Commit 5cee462

Browse files
committed
Only fuse predecessor ops that produce arrays with single dependent op
1 parent fcd4d21 commit 5cee462

File tree

2 files changed

+46
-46
lines changed

2 files changed

+46
-46
lines changed

cubed/core/optimization.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,23 @@ def predecessor_ops(dag, name):
9999
yield pre_list[0]
100100

101101

102+
def predecessor_ops_and_arrays(dag, name):
103+
# returns op predecessors, the arrays that they produce (only one since we don't support multiple outputs yet),
104+
# and a flag indicating if the op can be fused with each predecessor, taking into account the number of dependents for the array
105+
nodes = dict(dag.nodes(data=True))
106+
for input in nodes[name]["primitive_op"].source_array_names:
107+
pre_list = list(predecessors_unordered(dag, input))
108+
assert len(pre_list) == 1 # each array is produced by a single op
109+
pre = pre_list[0]
110+
can_fuse = is_primitive_op(nodes[pre]) and out_degree_unique(dag, input) == 1
111+
yield pre, input, can_fuse
112+
113+
114+
def out_degree_unique(dag, name):
115+
"""Returns number of unique out edges"""
116+
return len(set(post for _, post in dag.out_edges(name)))
117+
118+
102119
def is_primitive_op(node_dict):
103120
"""Return True if a node is a primitive op"""
104121
return "primitive_op" in node_dict
@@ -142,7 +159,8 @@ def can_fuse_predecessors(
142159
return False
143160

144161
# if no predecessor ops can be fused then there is nothing to fuse
145-
if all(not is_primitive_op(nodes[pre]) for pre in predecessor_ops(dag, name)):
162+
# (this may be because predecessor ops produce arrays with multiple dependents)
163+
if all(not can_fuse for _, _, can_fuse in predecessor_ops_and_arrays(dag, name)):
146164
logger.debug("can't fuse %s since no predecessor ops can be fused", name)
147165
return False
148166

@@ -158,8 +176,8 @@ def can_fuse_predecessors(
158176
# the fused function would be more than an allowed maximum, then don't fuse
159177
if len(list(predecessor_ops(dag, name))) > 1:
160178
total_source_arrays = sum(
161-
num_source_arrays(dag, pre) if is_primitive_op(nodes[pre]) else 1
162-
for pre in predecessor_ops(dag, name)
179+
num_source_arrays(dag, pre) if can_fuse else 1
180+
for pre, _, can_fuse in predecessor_ops_and_arrays(dag, name)
163181
)
164182
if total_source_arrays > max_total_source_arrays:
165183
logger.debug(
@@ -172,8 +190,8 @@ def can_fuse_predecessors(
172190

173191
predecessor_primitive_ops = [
174192
nodes[pre]["primitive_op"]
175-
for pre in predecessor_ops(dag, name)
176-
if is_primitive_op(nodes[pre])
193+
for pre, _, can_fuse in predecessor_ops_and_arrays(dag, name)
194+
if can_fuse
177195
]
178196
return can_fuse_multiple_primitive_ops(
179197
name,
@@ -211,8 +229,8 @@ def fuse_predecessors(
211229

212230
# if a predecessor has no primitive op then just use None
213231
predecessor_primitive_ops = [
214-
nodes[pre]["primitive_op"] if is_primitive_op(nodes[pre]) else None
215-
for pre in predecessor_ops(dag, name)
232+
nodes[pre]["primitive_op"] if can_fuse else None
233+
for pre, _, can_fuse in predecessor_ops_and_arrays(dag, name)
216234
]
217235

218236
fused_primitive_op = fuse_multiple(primitive_op, *predecessor_primitive_ops)
@@ -224,28 +242,15 @@ def fuse_predecessors(
224242
fused_nodes[name]["pipeline"] = fused_primitive_op.pipeline
225243

226244
# re-wire dag to remove predecessor nodes that have been fused
227-
228-
# 1. update edges to change inputs
229-
for input in predecessors_unordered(dag, name):
230-
pre = next(predecessors_unordered(dag, input))
231-
if not is_primitive_op(fused_nodes[pre]):
232-
# if a predecessor is not fusable then don't change the edge
233-
continue
234-
fused_dag.remove_edge(input, name)
235-
for pre in predecessor_ops(dag, name):
236-
if not is_primitive_op(fused_nodes[pre]):
237-
# if a predecessor is not fusable then don't change the edge
238-
continue
239-
for input in predecessors_unordered(dag, pre):
240-
fused_dag.add_edge(input, name)
241-
242-
# 2. remove predecessor nodes with no successors
243-
# (ones with successors are needed by other nodes)
244-
for input in predecessors_unordered(dag, name):
245-
if fused_dag.out_degree(input) == 0:
246-
for pre in list(predecessors_unordered(fused_dag, input)):
245+
for pre, input, can_fuse in predecessor_ops_and_arrays(dag, name):
246+
if can_fuse:
247+
# check if already removed for repeated arguments
248+
if input in fused_dag:
249+
fused_dag.remove_node(input)
250+
if pre in fused_dag:
247251
fused_dag.remove_node(pre)
248-
fused_dag.remove_node(input)
252+
for pre_input in predecessors_unordered(dag, pre):
253+
fused_dag.add_edge(pre_input, name)
249254

250255
return fused_dag
251256

cubed/tests/test_optimization.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -448,9 +448,9 @@ def test_fuse_diamond(spec):
448448
# from https://github.com/cubed-dev/cubed/issues/126
449449
#
450450
# a -> a
451-
# | /|
452-
# b b |
453-
# /| \|
451+
# | |
452+
# b b
453+
# /|
454454
# c | d
455455
# \|
456456
# d
@@ -469,7 +469,7 @@ def test_fuse_mixed_levels_and_diamond(spec):
469469
expected_fused_dag = create_dag()
470470
add_placeholder_op(expected_fused_dag, (), (a,))
471471
add_placeholder_op(expected_fused_dag, (a,), (b,))
472-
add_placeholder_op(expected_fused_dag, (a, b), (d,))
472+
add_placeholder_op(expected_fused_dag, (b, b), (d,))
473473
optimized_dag = d.plan.optimize(optimize_function=opt_fn).dag
474474
assert structurally_equivalent(optimized_dag, expected_fused_dag)
475475
assert get_num_input_blocks(d.plan.dag, d.name) == (1, 1)
@@ -535,35 +535,30 @@ def test_fuse_repeated_argument(spec):
535535
assert_array_equal(result, -2 * np.ones((2, 2)))
536536

537537

538-
# other dependents
538+
# other dependents - no optimization is made in this case (cf previously)
539539
#
540-
# a -> a
541-
# | / \
542-
# b c b
543-
# / \ |
544-
# c d d
540+
# a
541+
# |
542+
# b
543+
# / \
544+
# c d
545545
#
546546
def test_fuse_other_dependents(spec):
547547
a = xp.ones((2, 2), chunks=(2, 2), spec=spec)
548548
b = xp.negative(a)
549549
c = xp.negative(b)
550550
d = xp.negative(b)
551551

552-
# only fuse c; leave d unfused
552+
# try to fuse c; leave d unfused
553553
opt_fn = fuse_one_level(c)
554554

555555
# note multi-arg forms of visualize and compute below
556556
cubed.visualize(c, d, optimize_function=opt_fn)
557557

558-
# check structure of optimized dag
559-
expected_fused_dag = create_dag()
560-
add_placeholder_op(expected_fused_dag, (), (a,))
561-
add_placeholder_op(expected_fused_dag, (a,), (b,))
562-
add_placeholder_op(expected_fused_dag, (a,), (c,))
563-
add_placeholder_op(expected_fused_dag, (b,), (d,))
558+
# optimization does nothing
564559
plan = arrays_to_plan(c, d)
565560
optimized_dag = plan.optimize(optimize_function=opt_fn).dag
566-
assert structurally_equivalent(optimized_dag, expected_fused_dag)
561+
assert structurally_equivalent(optimized_dag, plan.dag)
567562
assert get_num_input_blocks(c.plan.dag, c.name) == (1,)
568563
assert get_num_input_blocks(optimized_dag, c.name) == (1,)
569564

0 commit comments

Comments
 (0)