@@ -99,6 +99,23 @@ def predecessor_ops(dag, name):
99
99
yield pre_list [0 ]
100
100
101
101
102
+ def predecessor_ops_and_arrays (dag , name ):
103
+ # returns op predecessors, the arrays that they produce (only one since we don't support multiple outputs yet),
104
+ # and a flag indicating if the op can be fused with each predecessor, taking into account the number of dependents for the array
105
+ nodes = dict (dag .nodes (data = True ))
106
+ for input in nodes [name ]["primitive_op" ].source_array_names :
107
+ pre_list = list (predecessors_unordered (dag , input ))
108
+ assert len (pre_list ) == 1 # each array is produced by a single op
109
+ pre = pre_list [0 ]
110
+ can_fuse = is_primitive_op (nodes [pre ]) and out_degree_unique (dag , input ) == 1
111
+ yield pre , input , can_fuse
112
+
113
+
114
+ def out_degree_unique (dag , name ):
115
+ """Returns number of unique out edges"""
116
+ return len (set (post for _ , post in dag .out_edges (name )))
117
+
118
+
102
119
def is_primitive_op (node_dict ):
103
120
"""Return True if a node is a primitive op"""
104
121
return "primitive_op" in node_dict
@@ -142,7 +159,8 @@ def can_fuse_predecessors(
142
159
return False
143
160
144
161
# if no predecessor ops can be fused then there is nothing to fuse
145
- if all (not is_primitive_op (nodes [pre ]) for pre in predecessor_ops (dag , name )):
162
+ # (this may be because predecessor ops produce arrays with multiple dependents)
163
+ if all (not can_fuse for _ , _ , can_fuse in predecessor_ops_and_arrays (dag , name )):
146
164
logger .debug ("can't fuse %s since no predecessor ops can be fused" , name )
147
165
return False
148
166
@@ -158,8 +176,8 @@ def can_fuse_predecessors(
158
176
# the fused function would be more than an allowed maximum, then don't fuse
159
177
if len (list (predecessor_ops (dag , name ))) > 1 :
160
178
total_source_arrays = sum (
161
- num_source_arrays (dag , pre ) if is_primitive_op ( nodes [ pre ]) else 1
162
- for pre in predecessor_ops (dag , name )
179
+ num_source_arrays (dag , pre ) if can_fuse else 1
180
+ for pre , _ , can_fuse in predecessor_ops_and_arrays (dag , name )
163
181
)
164
182
if total_source_arrays > max_total_source_arrays :
165
183
logger .debug (
@@ -172,8 +190,8 @@ def can_fuse_predecessors(
172
190
173
191
predecessor_primitive_ops = [
174
192
nodes [pre ]["primitive_op" ]
175
- for pre in predecessor_ops (dag , name )
176
- if is_primitive_op ( nodes [ pre ])
193
+ for pre , _ , can_fuse in predecessor_ops_and_arrays (dag , name )
194
+ if can_fuse
177
195
]
178
196
return can_fuse_multiple_primitive_ops (
179
197
name ,
@@ -211,8 +229,8 @@ def fuse_predecessors(
211
229
212
230
# if a predecessor has no primitive op then just use None
213
231
predecessor_primitive_ops = [
214
- nodes [pre ]["primitive_op" ] if is_primitive_op ( nodes [ pre ]) else None
215
- for pre in predecessor_ops (dag , name )
232
+ nodes [pre ]["primitive_op" ] if can_fuse else None
233
+ for pre , _ , can_fuse in predecessor_ops_and_arrays (dag , name )
216
234
]
217
235
218
236
fused_primitive_op = fuse_multiple (primitive_op , * predecessor_primitive_ops )
@@ -224,28 +242,15 @@ def fuse_predecessors(
224
242
fused_nodes [name ]["pipeline" ] = fused_primitive_op .pipeline
225
243
226
244
# re-wire dag to remove predecessor nodes that have been fused
227
-
228
- # 1. update edges to change inputs
229
- for input in predecessors_unordered (dag , name ):
230
- pre = next (predecessors_unordered (dag , input ))
231
- if not is_primitive_op (fused_nodes [pre ]):
232
- # if a predecessor is not fusable then don't change the edge
233
- continue
234
- fused_dag .remove_edge (input , name )
235
- for pre in predecessor_ops (dag , name ):
236
- if not is_primitive_op (fused_nodes [pre ]):
237
- # if a predecessor is not fusable then don't change the edge
238
- continue
239
- for input in predecessors_unordered (dag , pre ):
240
- fused_dag .add_edge (input , name )
241
-
242
- # 2. remove predecessor nodes with no successors
243
- # (ones with successors are needed by other nodes)
244
- for input in predecessors_unordered (dag , name ):
245
- if fused_dag .out_degree (input ) == 0 :
246
- for pre in list (predecessors_unordered (fused_dag , input )):
245
+ for pre , input , can_fuse in predecessor_ops_and_arrays (dag , name ):
246
+ if can_fuse :
247
+ # check if already removed for repeated arguments
248
+ if input in fused_dag :
249
+ fused_dag .remove_node (input )
250
+ if pre in fused_dag :
247
251
fused_dag .remove_node (pre )
248
- fused_dag .remove_node (input )
252
+ for pre_input in predecessors_unordered (dag , pre ):
253
+ fused_dag .add_edge (pre_input , name )
249
254
250
255
return fused_dag
251
256
0 commit comments