Skip to content

Commit 0ac1e47

Browse files
committed
Separate result setting RPC call
1 parent 5e80e46 commit 0ac1e47

File tree

5 files changed

+61
-20
lines changed

5 files changed

+61
-20
lines changed

mars/deploy/oscar/tests/test_local.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@
9393
"serialization": {},
9494
"most_calls": DICT_NOT_EMPTY,
9595
"slow_calls": DICT_NOT_EMPTY,
96-
"band_subtasks": DICT_NOT_EMPTY,
97-
"slow_subtasks": DICT_NOT_EMPTY,
96+
# "band_subtasks": DICT_NOT_EMPTY,
97+
# "slow_subtasks": DICT_NOT_EMPTY,
9898
}
9999
}
100100
EXPECT_PROFILING_STRUCTURE_NO_SLOW = copy.deepcopy(EXPECT_PROFILING_STRUCTURE)

mars/services/scheduling/supervisor/manager.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,13 @@ async def _get_execution_ref(self, band: BandType):
172172

173173
return await mo.actor_ref(SubtaskExecutionActor.default_uid(), address=band[0])
174174

175-
async def _handle_subtask_result(
176-
self, info: SubtaskScheduleInfo, result: SubtaskResult, band: BandType
175+
async def set_subtask_result(
176+
self, result: SubtaskResult, band: BandType
177177
):
178+
info = self._subtask_infos[result.subtask_id]
178179
subtask_id = info.subtask.subtask_id
180+
notify_task_service = True
181+
179182
async with redirect_subtask_errors(self, [info.subtask], reraise=False):
180183
try:
181184
info.band_futures[band].set_result(result)
@@ -199,6 +202,7 @@ async def _handle_subtask_result(
199202
[info.subtask.priority or tuple()],
200203
exclude_bands=set(info.band_futures.keys()),
201204
)
205+
notify_task_service = False
202206
else:
203207
raise ex
204208
except asyncio.CancelledError:
@@ -236,6 +240,10 @@ async def _handle_subtask_result(
236240
if info.num_reschedules > 0:
237241
await self._queueing_ref.submit_subtasks.tell()
238242

243+
if notify_task_service:
244+
task_api = await self._get_task_api()
245+
await task_api.set_subtask_result(result)
246+
239247
async def finish_subtasks(
240248
self,
241249
subtask_results: List[SubtaskResult],
@@ -251,11 +259,6 @@ async def finish_subtasks(
251259
subtask_info = self._subtask_infos.get(subtask_id, None)
252260

253261
if subtask_info is not None:
254-
if subtask_band is not None:
255-
await self._handle_subtask_result(
256-
subtask_info, result, subtask_band
257-
)
258-
259262
self._finished_subtask_count.record(
260263
1,
261264
{
@@ -273,7 +276,7 @@ async def finish_subtasks(
273276
# Cancel subtask on other bands.
274277
aio_task = subtask_info.band_futures.pop(subtask_band, None)
275278
if aio_task:
276-
await aio_task
279+
yield aio_task
277280
if schedule_next:
278281
band_tasks[subtask_band] += 1
279282
if subtask_info.band_futures:

mars/services/scheduling/supervisor/tests/test_manager.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from .....typing import BandType
2424
from ....cluster import MockClusterAPI
2525
from ....subtask import Subtask, SubtaskResult, SubtaskStatus
26-
from ....task import TaskAPI
2726
from ....task.supervisor.manager import TaskManagerActor
2827
from ...supervisor import (
2928
SubtaskQueueingActor,
@@ -91,7 +90,10 @@ async def run_subtask(
9190
self._run_subtask_events[subtask.subtask_id].set()
9291

9392
async def task_fun():
94-
task_api = await TaskAPI.create(subtask.session_id, supervisor_address)
93+
manager_ref = await mo.actor_ref(
94+
uid=SubtaskManagerActor.gen_uid(subtask.session_id),
95+
address=supervisor_address,
96+
)
9597
result = SubtaskResult(
9698
subtask_id=subtask.subtask_id,
9799
session_id=subtask.session_id,
@@ -107,12 +109,12 @@ async def task_fun():
107109
result.status = SubtaskStatus.cancelled
108110
result.error = ex
109111
result.traceback = ex.__traceback__
110-
await task_api.set_subtask_result(result)
112+
await manager_ref.set_subtask_result.tell(result, (self.address, band_name))
111113
raise
112114
else:
113115
result.status = SubtaskStatus.succeeded
114116
result.execution_end_time = time.time()
115-
await task_api.set_subtask_result(result)
117+
await manager_ref.set_subtask_result.tell(result, (self.address, band_name))
116118

117119
self._subtask_aiotasks[subtask.subtask_id][band_name] = asyncio.create_task(
118120
task_fun()

mars/services/scheduling/worker/execution.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from ...meta import MetaAPI
3737
from ...storage import StorageAPI
3838
from ...subtask import Subtask, SubtaskAPI, SubtaskResult, SubtaskStatus
39-
from ...task import TaskAPI
4039
from .quota import QuotaActor
4140
from .workerslot import BandSlotManagerActor
4241

@@ -178,6 +177,17 @@ async def _get_slot_manager_ref(
178177
BandSlotManagerActor.gen_uid(band), address=self.address
179178
)
180179

180+
@classmethod
181+
@alru_cache(cache_exceptions=False)
182+
async def _get_manager_ref(
183+
cls, session_id: str, supervisor_address: str
184+
) -> mo.ActorRefType[BandSlotManagerActor]:
185+
from ..supervisor import SubtaskManagerActor
186+
187+
return await mo.actor_ref(
188+
SubtaskManagerActor.gen_uid(session_id), address=supervisor_address
189+
)
190+
181191
@alru_cache(cache_exceptions=False)
182192
async def _get_band_quota_ref(self, band: str) -> mo.ActorRefType[QuotaActor]:
183193
return await mo.actor_ref(QuotaActor.gen_uid(band), address=self.address)
@@ -415,10 +425,12 @@ async def internal_run_subtask(self, subtask: Subtask, band_name: str):
415425
# pop the subtask info at the end is to cancel the job.
416426
self._subtask_info.pop(subtask.subtask_id, None)
417427

418-
task_api = await TaskAPI.create(
428+
manager_ref = await self._get_manager_ref(
419429
subtask.session_id, subtask_info.supervisor_address
420430
)
421-
await task_api.set_subtask_result(subtask_info.result)
431+
await manager_ref.set_subtask_result.tell(
432+
subtask_info.result, (self.address, subtask_info.band_name)
433+
)
422434
return subtask_info.result
423435

424436
async def _retry_run_subtask(
@@ -557,8 +569,10 @@ async def subtask_caller():
557569
)
558570
_fill_subtask_result_with_exception(subtask, band_name, res)
559571

560-
task_api = await TaskAPI.create(subtask.session_id, supervisor_address)
561-
await task_api.set_subtask_result(res)
572+
manager_ref = await self._get_manager_ref(
573+
subtask.session_id, supervisor_address
574+
)
575+
await manager_ref.set_subtask_result.tell(res, (self.address, band_name))
562576
finally:
563577
self._subtask_info.pop(subtask_id, None)
564578
self._finished_subtask_count.record(1, {"band": self.address})

mars/services/scheduling/worker/tests/test_execution.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .....resource import Resource
3838
from .....tensor.fetch import TensorFetch
3939
from .....tensor.arithmetic import TensorTreeAdd
40+
from .....typing import BandType
4041
from .....utils import Timer
4142
from ....cluster import MockClusterAPI
4243
from ....lifecycle import MockLifecycleAPI
@@ -47,7 +48,7 @@
4748
from ....subtask import MockSubtaskAPI, Subtask, SubtaskStatus, SubtaskResult
4849
from ....task.supervisor.manager import TaskManagerActor
4950
from ....mutable import MockMutableAPI
50-
from ...supervisor import GlobalResourceManagerActor
51+
from ...supervisor import GlobalResourceManagerActor, SubtaskManagerActor
5152
from ...worker import SubtaskExecutionActor, QuotaActor, BandSlotManagerActor
5253

5354

@@ -155,6 +156,19 @@ def get_results(self):
155156
return list(self._results.values())
156157

157158

159+
class MockSubtaskManagerActor(mo.Actor):
160+
def __init__(self, session_id: str):
161+
self._session_id = session_id
162+
163+
async def __post_create__(self):
164+
self._task_manager_ref = await mo.actor_ref(
165+
uid=TaskManagerActor.gen_uid(self._session_id), address=self.address
166+
)
167+
168+
async def set_subtask_result(self, result: SubtaskResult, band: BandType):
169+
await self._task_manager_ref.set_subtask_result.tell(result)
170+
171+
158172
@pytest.fixture
159173
async def actor_pool(request):
160174
n_slots, enable_kill = request.param
@@ -221,9 +235,17 @@ async def actor_pool(request):
221235
address=pool.external_address,
222236
)
223237

238+
subtask_manager_ref = await mo.create_actor(
239+
MockSubtaskManagerActor,
240+
session_id,
241+
uid=SubtaskManagerActor.gen_uid(session_id),
242+
address=pool.external_address,
243+
)
244+
224245
try:
225246
yield pool, session_id, meta_api, worker_meta_api, storage_api, execution_ref
226247
finally:
248+
await mo.destroy_actor(subtask_manager_ref)
227249
await mo.destroy_actor(task_manager_ref)
228250
await mo.destroy_actor(band_slot_ref)
229251
await mo.destroy_actor(global_resource_ref)

0 commit comments

Comments
 (0)