@@ -219,6 +219,8 @@ def _execute_operand(self, op):
219
219
results = self ._chunk_results
220
220
ref_counts = self ._chunk_key_ref_counts
221
221
op_keys = self ._executed_op_keys
222
+ executed_chunk_keys = set ()
223
+ deleted_chunk_keys = set ()
222
224
try :
223
225
ops = list (self ._op_key_to_ops [op .key ])
224
226
if not self ._mock :
@@ -227,11 +229,15 @@ def _execute_operand(self, op):
227
229
# so we pass the first operand's first output to Executor.handle
228
230
first_op = ops [0 ]
229
231
Executor .handle (first_op .outputs [0 ], results )
232
+ executed_chunk_keys .update ([c .key for c in first_op .outputs ])
230
233
op_keys .add (first_op .key )
231
234
# handle other operands
232
235
for rest_op in ops [1 :]:
233
236
for op_output , rest_op_output in zip (first_op .outputs , rest_op .outputs ):
234
- results [rest_op_output .key ] = results [op_output .key ]
237
+ # if the op's outputs have been stored,
238
+ # other same key ops' results will be the same
239
+ if rest_op_output .key not in executed_chunk_keys :
240
+ results [rest_op_output .key ] = results [op_output .key ]
235
241
else :
236
242
sparse_percent = self ._sparse_mock_percent if op .sparse else 1.0
237
243
for output in op .outputs :
@@ -245,7 +251,10 @@ def _execute_operand(self, op):
245
251
# in case that operand has multiple outputs
246
252
# and some of the output not in result keys, delete them
247
253
if ref_counts .get (output .key ) == 0 :
248
- del results [output .key ]
254
+ # if the result has been deleted, it should be skipped
255
+ if output .key not in deleted_chunk_keys :
256
+ deleted_chunk_keys .add (output .key )
257
+ del results [output .key ]
249
258
250
259
# clean the predecessors' results if ref counts equals 0
251
260
for pred_chunk in self ._graph .iter_predecessors (output ):
0 commit comments