22
22
from .... import oscar as mo
23
23
from ....lib .aio import alru_cache
24
24
from ....metrics import Metrics
25
- from ....oscar .backends .context import ProfilingContext
26
25
from ....oscar .errors import MarsError
27
- from ....oscar .profiling import ProfilingData , MARS_ENABLE_PROFILING
28
26
from ....typing import BandType
29
- from ....utils import dataslots , Timer
27
+ from ....utils import dataslots
30
28
from ...subtask import Subtask , SubtaskResult , SubtaskStatus
31
29
from ...task import TaskAPI
32
30
from ..core import SubtaskScheduleSummary
@@ -127,14 +125,6 @@ async def __post_create__(self):
127
125
)
128
126
await self ._speculation_execution_scheduler .start ()
129
127
130
- async def dump_running ():
131
- while True :
132
- if self ._subtask_infos :
133
- logger .warning ("RUNNING: %r" , list (self ._subtask_infos ))
134
- await asyncio .sleep (5 )
135
-
136
- asyncio .create_task (dump_running ())
137
-
138
128
async def __pre_destroy__ (self ):
139
129
await self ._speculation_execution_scheduler .stop ()
140
130
@@ -186,7 +176,7 @@ async def _handle_subtask_result(
186
176
self , info : SubtaskScheduleInfo , result : SubtaskResult , band : BandType
187
177
):
188
178
subtask_id = info .subtask .subtask_id
189
- async with redirect_subtask_errors (self , [info .subtask ]):
179
+ async with redirect_subtask_errors (self , [info .subtask ], reraise = False ):
190
180
try :
191
181
info .band_futures [band ].set_result (result )
192
182
if result .error is not None :
@@ -262,9 +252,9 @@ async def finish_subtasks(
262
252
263
253
if subtask_info is not None :
264
254
if subtask_band is not None :
265
- logger . warning ( "BEFORE await self._handle_subtask_result(subtask_info, result, subtask_band)" )
266
- await self . _handle_subtask_result ( subtask_info , result , subtask_band )
267
- logger . warning ( "AFTER await self._handle_subtask_result(subtask_info, result, subtask_band)" )
255
+ await self ._handle_subtask_result (
256
+ subtask_info , result , subtask_band
257
+ )
268
258
269
259
self ._finished_subtask_count .record (
270
260
1 ,
@@ -275,16 +265,15 @@ async def finish_subtasks(
275
265
},
276
266
)
277
267
self ._subtask_summaries [subtask_id ] = subtask_info .to_summary (
278
- is_finished = True , is_cancelled = result .status == SubtaskStatus .cancelled
268
+ is_finished = True ,
269
+ is_cancelled = result .status == SubtaskStatus .cancelled ,
279
270
)
280
271
subtask_info .end_time = time .time ()
281
272
self ._speculation_execution_scheduler .finish_subtask (subtask_info )
282
273
# Cancel subtask on other bands.
283
274
aio_task = subtask_info .band_futures .pop (subtask_band , None )
284
275
if aio_task :
285
- logger .warning ("BEFORE await aio_task" )
286
276
await aio_task
287
- logger .warning ("AFTER await aio_task" )
288
277
if schedule_next :
289
278
band_tasks [subtask_band ] += 1
290
279
if subtask_info .band_futures :
@@ -304,7 +293,6 @@ async def finish_subtasks(
304
293
if schedule_next :
305
294
for band in subtask_info .band_futures .keys ():
306
295
band_tasks [band ] += 1
307
- # await self._queueing_ref.remove_queued_subtasks(subtask_ids)
308
296
if band_tasks :
309
297
await self ._queueing_ref .submit_subtasks .tell (dict (band_tasks ))
310
298
@@ -345,7 +333,9 @@ async def batch_submit_subtask_to_band(self, args_list, kwargs_list):
345
333
band_to_subtask_ids [band ].append (subtask_id )
346
334
347
335
if res_release_delays :
348
- await self ._global_resource_ref .release_subtask_resource .batch (* res_release_delays )
336
+ await self ._global_resource_ref .release_subtask_resource .batch (
337
+ * res_release_delays
338
+ )
349
339
350
340
for band , subtask_ids in band_to_subtask_ids .items ():
351
341
asyncio .create_task (self ._submit_subtasks_to_band (band , subtask_ids ))
@@ -386,29 +376,22 @@ async def cancel_subtasks(
386
376
subtask_ids ,
387
377
kill_timeout ,
388
378
)
389
- queued_subtask_ids = []
390
- single_cancel_tasks = []
391
379
392
380
task_api = await self ._get_task_api ()
393
381
394
- async def cancel_single_task (subtask , raw_tasks , cancel_tasks ):
395
- if cancel_tasks :
396
- await asyncio .wait (cancel_tasks )
397
- if raw_tasks :
398
- dones , _ = await asyncio .wait (raw_tasks )
399
- else :
400
- dones = []
401
- if not dones or all (fut .cancelled () for fut in dones ):
402
- await task_api .set_subtask_result (
403
- SubtaskResult (
404
- subtask_id = subtask .subtask_id ,
405
- session_id = subtask .session_id ,
406
- task_id = subtask .task_id ,
407
- stage_id = subtask .stage_id ,
408
- status = SubtaskStatus .cancelled ,
409
- )
410
- )
382
+ async def cancel_task_in_band (band ):
383
+ cancel_delays = band_to_cancel_delays .get (band ) or []
384
+ execution_ref = await self ._get_execution_ref (band )
385
+ if cancel_delays :
386
+ await execution_ref .cancel_subtask .batch (* cancel_delays )
387
+ band_futures = band_to_futures .get (band )
388
+ if band_futures :
389
+ await asyncio .wait (band_futures )
411
390
391
+ queued_subtask_ids = []
392
+ cancel_tasks = []
393
+ band_to_cancel_delays = defaultdict (list )
394
+ band_to_futures = defaultdict (list )
412
395
for subtask_id in subtask_ids :
413
396
if subtask_id not in self ._subtask_infos :
414
397
# subtask may already finished or not submitted at all
@@ -423,35 +406,33 @@ async def cancel_single_task(subtask, raw_tasks, cancel_tasks):
423
406
raw_tasks_to_cancel = list (info .band_futures .values ())
424
407
425
408
if not raw_tasks_to_cancel :
426
- queued_subtask_ids .append (subtask_id )
427
- single_cancel_tasks .append (
428
- asyncio .create_task (
429
- cancel_single_task (info .subtask , [], [])
430
- )
409
+ # not submitted yet: mark subtasks as cancelled
410
+ result = SubtaskResult (
411
+ subtask_id = info .subtask .subtask_id ,
412
+ session_id = info .subtask .session_id ,
413
+ task_id = info .subtask .task_id ,
414
+ stage_id = info .subtask .stage_id ,
415
+ status = SubtaskStatus .cancelled ,
431
416
)
417
+ cancel_tasks .append (task_api .set_subtask_result (result ))
418
+ queued_subtask_ids .append (subtask_id )
432
419
else :
433
- cancel_tasks = []
434
- for band in info .band_futures .keys ():
420
+ for band , future in info .band_futures .items ():
435
421
execution_ref = await self ._get_execution_ref (band )
436
- cancel_tasks .append (
437
- asyncio .create_task (
438
- execution_ref .cancel_subtask (
439
- subtask_id , kill_timeout = kill_timeout
440
- )
441
- )
422
+ band_to_cancel_delays [band ].append (
423
+ execution_ref .cancel_subtask .delay (subtask_id , kill_timeout )
442
424
)
443
- single_cancel_tasks .append (
444
- asyncio .create_task (
445
- cancel_single_task (
446
- info .subtask , raw_tasks_to_cancel , cancel_tasks
447
- )
448
- )
449
- )
425
+ band_to_futures [band ].append (future )
426
+
427
+ for band in band_to_futures :
428
+ cancel_tasks .append (asyncio .create_task (cancel_task_in_band (band )))
429
+
450
430
if queued_subtask_ids :
451
431
# Don't use `finish_subtasks` because it may remove queued
452
432
await self ._queueing_ref .remove_queued_subtasks (queued_subtask_ids )
453
- if single_cancel_tasks :
454
- yield asyncio .wait (single_cancel_tasks )
433
+
434
+ if cancel_tasks :
435
+ yield asyncio .gather (* cancel_tasks )
455
436
456
437
for subtask_id in subtask_ids :
457
438
info = self ._subtask_infos .pop (subtask_id , None )
0 commit comments