Skip to content

Commit c2d4944

Browse files
authored
[core] Fix race condition when canceling task that hasn't started yet (#52703)
See linked issue for the original symptom and reproduction. When a `CancelTask` RPC message is received, we need to handle 4 possible cases: 1. The `PushTask` RPC hasn't been received yet. 2. The `PushTask` RPC has been received but the task isn't executing yet. 3. The `PushTask` RPC has been received and the task is now executing. 4. The task finished executing and the `PushTask` RPC reply has been sent. The code currently handles (1) and (4) by relying on client-side retries: we return `success=False` and expect the client to retry the cancellation (unless the task has already finished in case (4), which it knows). However, there is a race condition between cases (2) and (3) where the task is no longer considered queued in the `OutOfOrderActorSchedulingQueue`, but it hasn't actually started executing yet and therefore there is no future to cancel. This can happen because: - We [erase the task ID](https://github.com/ray-project/ray/blob/master/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc#L240) from the pending map before actually executing the task. After this, `CancelTaskIfFound` will return false. - We then post the work to start running the request [to the io_service_](https://github.com/ray-project/ray/blob/master/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc#L245). - We post the `RunRequest` callback that eventually runs the task [to the fiber thread](https://github.com/ray-project/ray/blob/master/src/ray/core_worker/transport/out_of_order_actor_scheduling_queue.cc#L156). - The logic to cancel the task runs on the [task_execution_service_](https://github.com/ray-project/ray/blob/master/src/ray/core_worker/core_worker.cc#L4485). This means there is no guarantee that the task has actually started to execute when we call [cancel_async_task_](https://github.com/ray-project/ray/blob/master/src/ray/core_worker/core_worker.cc#L4462). This PR fixes the problem by extending the reliance on client retries: we return a boolean from `cancel_async_task_` that indicates if the task was cancelled. If not, it's up to the client to retry. The proper long-term fix would be to serialize the executions and cancellations inside of the scheduling queue / task executor, but that will require a lot of refactoring work. We need to simplify the concurrency model in these classes. Closes #52628 --------- Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
1 parent 0b65b4a commit c2d4944

File tree

8 files changed

+139
-191
lines changed

8 files changed

+139
-191
lines changed

python/ray/_raylet.pyx

+28-28
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ from typing import (
3636
)
3737

3838
import contextvars
39-
import concurrent
40-
from concurrent.futures import ThreadPoolExecutor
41-
from concurrent.futures import Future as ConcurrentFuture
39+
import concurrent.futures
4240

4341
from libc.stdint cimport (
4442
int32_t,
@@ -253,7 +251,7 @@ GRPC_STATUS_CODE_UNIMPLEMENTED = CGrpcStatusCode.UNIMPLEMENTED
253251

254252
logger = logging.getLogger(__name__)
255253

256-
# The currently executing task, if any. These are used to synchronize task
254+
# The currently running task, if any. These are used to synchronize task
257255
# interruption for ray.cancel.
258256
current_task_id = None
259257
current_task_id_lock = threading.Lock()
@@ -1108,7 +1106,7 @@ cdef store_task_errors(
11081106

11091107

11101108
cdef class StreamingGeneratorExecutionContext:
1111-
"""The context to execute streaming generator function.
1109+
"""The context to run a streaming generator function.
11121110
11131111
Make sure you always call `initialize` API before
11141112
accessing any fields.
@@ -2550,26 +2548,25 @@ cdef void delete_spilled_objects_handler(
25502548
job_id=None)
25512549

25522550

2553-
cdef void cancel_async_task(
2554-
const CTaskID &c_task_id,
2555-
const CRayFunction &ray_function,
2556-
const c_string c_name_of_concurrency_group_to_execute) nogil:
2551+
cdef c_bool cancel_async_actor_task(const CTaskID &c_task_id) nogil:
2552+
"""Attempt to cancel a task running in this asyncio actor.
2553+
2554+
Returns True if the task was currently running and was cancelled, else False.
2555+
2556+
Note that the underlying asyncio task may not actually have been cancelled: it
2557+
could already have completed or else might not gracefully handle cancellation.
2558+
The return value only indicates that the task was found and cancelled.
2559+
"""
25572560
with gil:
2558-
function_descriptor = CFunctionDescriptorToPython(
2559-
ray_function.GetFunctionDescriptor())
2560-
name_of_concurrency_group_to_execute = \
2561-
c_name_of_concurrency_group_to_execute.decode("ascii")
25622561
task_id = TaskID(c_task_id.Binary())
2563-
25642562
worker = ray._private.worker.global_worker
2565-
eventloop, _ = worker.core_worker.get_event_loop(
2566-
function_descriptor, name_of_concurrency_group_to_execute)
2567-
future = worker.core_worker.get_queued_future(task_id)
2568-
if future is not None:
2569-
future.cancel()
2570-
# else, the task is already finished. If the task
2571-
# wasn't finished (task is queued on a client or server side),
2572-
# this method shouldn't have been called.
2563+
fut = worker.core_worker.get_future_for_running_task(task_id)
2564+
if fut is None:
2565+
# Either the task hasn't started executing yet or already finished.
2566+
return False
2567+
2568+
fut.cancel()
2569+
return True
25732570

25742571

25752572
cdef void unhandled_exception_handler(const CRayObject& error) nogil:
@@ -2970,7 +2967,7 @@ cdef class CoreWorker:
29702967
options.restore_spilled_objects = restore_spilled_objects_handler
29712968
options.delete_spilled_objects = delete_spilled_objects_handler
29722969
options.unhandled_exception_handler = unhandled_exception_handler
2973-
options.cancel_async_task = cancel_async_task
2970+
options.cancel_async_actor_task = cancel_async_actor_task
29742971
options.get_lang_stack = get_py_stack
29752972
options.is_local_mode = local_mode
29762973
options.kill_main = kill_main_task
@@ -4461,15 +4458,15 @@ cdef class CoreWorker:
44614458
for fd in function_descriptors:
44624459
self.fd_to_cgname_dict[fd] = cg_name
44634460

4464-
def get_event_loop_executor(self) -> ThreadPoolExecutor:
4461+
def get_event_loop_executor(self) -> concurrent.futures.ThreadPoolExecutor:
44654462
if self.event_loop_executor is None:
44664463
# NOTE: We're deliberately allocating thread-pool executor with
44674464
# a single thread, provided that many of its use-cases are
44684465
# not thread-safe yet (for ex, reporting streaming generator output)
4469-
self.event_loop_executor = ThreadPoolExecutor(max_workers=1)
4466+
self.event_loop_executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
44704467
return self.event_loop_executor
44714468

4472-
def reset_event_loop_executor(self, executor: ThreadPoolExecutor):
4469+
def reset_event_loop_executor(self, executor: concurrent.futures.ThreadPoolExecutor):
44734470
self.event_loop_executor = executor
44744471

44754472
def get_event_loop(self, function_descriptor, specified_cgname):
@@ -4622,8 +4619,11 @@ cdef class CoreWorker:
46224619
return ActorID(CCoreWorkerProcess.GetCoreWorker().GetWorkerContext()
46234620
.GetRootDetachedActorID().Binary())
46244621

4625-
def get_queued_future(self, task_id: Optional[TaskID]) -> ConcurrentFuture:
4626-
"""Get a asyncio.Future that's queued in the event loop."""
4622+
def get_future_for_running_task(self, task_id: Optional[TaskID]) -> Optional[concurrent.futures.Future]:
4623+
"""Get the future corresponding to a running task (or None).
4624+
4625+
The underyling asyncio task might be queued, running, or completed.
4626+
"""
46274627
with self._task_id_to_future_lock:
46284628
return self._task_id_to_future.get(task_id)
46294629

python/ray/includes/libcoreworker.pxd

+1-5
Original file line numberDiff line numberDiff line change
@@ -420,11 +420,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
420420
const c_string&,
421421
const c_vector[c_string]&) nogil) run_on_util_worker_handler
422422
(void(const CRayObject&) nogil) unhandled_exception_handler
423-
(void(
424-
const CTaskID &c_task_id,
425-
const CRayFunction &ray_function,
426-
const c_string c_name_of_concurrency_group_to_execute
427-
) nogil) cancel_async_task
423+
(c_bool(const CTaskID &c_task_id) nogil) cancel_async_actor_task
428424
(void(c_string *stack_out) nogil) get_lang_stack
429425
c_bool is_local_mode
430426
int num_workers

python/ray/tests/test_actor_cancel.py

+35-92
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
import time
5+
import concurrent.futures
56
from collections import defaultdict
67

78
import pytest
@@ -134,70 +135,6 @@ def f():
134135
ray.get(ref_dep_not_resolved)
135136

136137

137-
@pytest.mark.skip(
138-
reason=("The guarantee in this case is too weak now. Need more work.")
139-
)
140-
def test_in_flight_queued_requests_canceled(shutdown_only, monkeypatch):
141-
"""
142-
When there are large input size in-flight actor tasks
143-
tasks are queued inside a RPC layer (core_worker_client.h)
144-
In this case, we don't cancel a request from a client side
145-
but wait until it is sent to the server side and cancel it.
146-
See SendRequests() inside core_worker_client.h
147-
"""
148-
# Currently the max bytes is
149-
# const int64_t kMaxBytesInFlight = 16 * 1024 * 1024.
150-
# See core_worker_client.h.
151-
input_arg = b"1" * 15 * 1024 # 15KB.
152-
# Tasks are queued when there are more than 1024 tasks.
153-
sig = SignalActor.remote()
154-
155-
@ray.remote
156-
class Actor:
157-
def __init__(self, signal_actor):
158-
self.signal_actor = signal_actor
159-
160-
def f(self, input_arg):
161-
ray.get(self.signal_actor.wait.remote())
162-
return True
163-
164-
a = Actor.remote(sig)
165-
refs = [a.f.remote(input_arg) for _ in range(5000)]
166-
167-
# Wait until the first task runs.
168-
wait_for_condition(
169-
lambda: len(list_tasks(filters=[("STATE", "=", "RUNNING")])) == 1
170-
)
171-
172-
# Cancel all tasks.
173-
for ref in refs:
174-
ray.cancel(ref)
175-
176-
# The first ref is in progress, so we pop it out
177-
first_ref = refs.pop(0)
178-
ray.get(sig.send.remote())
179-
180-
# Make sure all tasks that are queued (including queued
181-
# due to in-flight bytes) are canceled.
182-
canceled = 0
183-
for ref in refs:
184-
try:
185-
ray.get(ref)
186-
except TaskCancelledError:
187-
canceled += 1
188-
189-
# Verify at least half of tasks are canceled.
190-
# Currently, the guarantee is weak because we cannot
191-
# detect queued tasks due to inflight bytes limit.
192-
# TODO(sang): Move the in flight bytes logic into
193-
# actor submission queue instead of doing it inside
194-
# core worker client.
195-
assert canceled > 2500
196-
197-
# first ref shouldn't have been canceled.
198-
assert ray.get(first_ref)
199-
200-
201138
def test_async_actor_server_side_cancel(shutdown_only):
202139
"""
203140
Test Cancelation when a task is queued on a server side.
@@ -324,34 +261,6 @@ def f(refs):
324261
ray.get(sleep_ref)
325262

326263

327-
@pytest.mark.skip(reason=("Currently not passing. There's one edge case to fix."))
328-
def test_cancel_stress(shutdown_only):
329-
ray.init()
330-
331-
@ray.remote
332-
class Actor:
333-
async def sleep(self):
334-
await asyncio.sleep(1000)
335-
336-
actors = [Actor.remote() for _ in range(30)]
337-
338-
refs = []
339-
for _ in range(20):
340-
for actor in actors:
341-
for i in range(100):
342-
ref = actor.sleep.remote()
343-
refs.append(ref)
344-
if i % 2 == 0:
345-
ray.cancel(ref)
346-
347-
for ref in refs:
348-
ray.cancel(ref)
349-
350-
for ref in refs:
351-
with pytest.raises((ray.exceptions.TaskCancelledError, TaskCancelledError)):
352-
ray.get(ref)
353-
354-
355264
def test_cancel_recursive_tree(shutdown_only):
356265
"""Verify recursive cancel works for tree-nested tasks.
357266
@@ -529,6 +438,40 @@ def get_child_ref(self):
529438
ray.get(ref)
530439

531440

441+
def test_concurrent_submission_and_cancellation(shutdown_only):
442+
"""Test submitting and then cancelling many tasks concurrently.
443+
444+
This is a regression test for race conditions such as:
445+
https://github.com/ray-project/ray/issues/52628.
446+
"""
447+
NUM_TASKS = 2500
448+
449+
@ray.remote(num_cpus=0)
450+
class Worker:
451+
async def sleep(self, i: int):
452+
# NOTE: all tasks should be cancelled, so this won't actually sleep for the
453+
# full duration if the test is passing.
454+
await asyncio.sleep(30)
455+
456+
worker = Worker.remote()
457+
458+
# Submit many tasks in parallel to cause queueing on the caller and receiver.
459+
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_TASKS) as executor:
460+
futures = [executor.submit(worker.sleep.remote, i) for i in range(NUM_TASKS)]
461+
refs = [f.result() for f in concurrent.futures.as_completed(futures)]
462+
463+
# Cancel the tasks in reverse order of submission.
464+
for ref in reversed(refs):
465+
ray.cancel(ref)
466+
467+
# Check that all tasks were successfully cancelled (none ran to completion).
468+
for ref in refs:
469+
with pytest.raises(ray.exceptions.TaskCancelledError):
470+
ray.get(ref)
471+
472+
print(f"All {NUM_TASKS} tasks were cancelled successfully.")
473+
474+
532475
if __name__ == "__main__":
533476
if os.environ.get("PARALLEL_CI"):
534477
sys.exit(pytest.main(["-n", "auto", "--boxed", "-vs", __file__]))

0 commit comments

Comments
 (0)