Skip to content

Commit d74d17c

Browse files
authored
[core] Remove worker_context_ dependency from the task receiver (#52740)
This was only used for one very minor thing: setting the current actor ID. Better to lift that logic into `core_worker.cc` and avoid this dependency entirely. --------- Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
1 parent c6c338c commit d74d17c

File tree

4 files changed

+24
-37
lines changed

4 files changed

+24
-37
lines changed

src/ray/core_worker/core_worker.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,6 @@ CoreWorker::CoreWorker(CoreWorkerOptions options, const WorkerID &worker_id)
432432
std::placeholders::_7,
433433
std::placeholders::_8);
434434
task_receiver_ = std::make_unique<TaskReceiver>(
435-
worker_context_,
436435
task_execution_service_,
437436
*task_event_buffer_,
438437
execute_task,
@@ -3835,12 +3834,31 @@ void CoreWorker::HandlePushTask(rpc::PushTaskRequest request,
38353834
send_reply_callback)) {
38363835
return;
38373836
}
3837+
3838+
// Set actor info in the worker context.
3839+
if (request.task_spec().type() == TaskType::ACTOR_CREATION_TASK) {
3840+
auto actor_id =
3841+
ActorID::FromBinary(request.task_spec().actor_creation_task_spec().actor_id());
3842+
3843+
// Handle duplicate actor creation tasks that might be sent from the GCS on restart.
3844+
// Ignore the message and reply OK.
3845+
if (worker_context_.GetCurrentActorID() == actor_id) {
3846+
RAY_LOG(INFO) << "Ignoring duplicate actor creation task for actor " << actor_id
3847+
<< ". This is likely due to a GCS server restart.";
3848+
send_reply_callback(Status::OK(), nullptr, nullptr);
3849+
return;
3850+
}
3851+
worker_context_.SetCurrentActorId(actor_id);
3852+
}
3853+
3854+
// Set job info in the worker context.
38383855
if (request.task_spec().type() == TaskType::ACTOR_CREATION_TASK ||
38393856
request.task_spec().type() == TaskType::NORMAL_TASK) {
38403857
auto job_id = JobID::FromBinary(request.task_spec().job_id());
38413858
worker_context_.MaybeInitializeJobInfo(job_id, request.task_spec().job_config());
38423859
task_counter_.SetJobId(job_id);
38433860
}
3861+
38443862
// Increment the task_queue_length and per function counter.
38453863
task_queue_length_ += 1;
38463864
std::string func_name =

src/ray/core_worker/test/direct_actor_transport_test.cc

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -772,14 +772,6 @@ class MockDependencyWaiter : public DependencyWaiter {
772772
virtual ~MockDependencyWaiter() {}
773773
};
774774

775-
class MockWorkerContext : public WorkerContext {
776-
public:
777-
MockWorkerContext(WorkerType worker_type, const JobID &job_id)
778-
: WorkerContext(worker_type, WorkerID::FromRandom(), job_id) {
779-
current_actor_is_direct_call_ = true;
780-
}
781-
};
782-
783775
class MockTaskEventBuffer : public worker::TaskEventBuffer {
784776
public:
785777
void AddTaskEvent(std::unique_ptr<worker::TaskEvent> task_event) override {}
@@ -797,14 +789,12 @@ class MockTaskEventBuffer : public worker::TaskEventBuffer {
797789

798790
class MockTaskReceiver : public TaskReceiver {
799791
public:
800-
MockTaskReceiver(WorkerContext &worker_context,
801-
instrumented_io_context &task_execution_service,
792+
MockTaskReceiver(instrumented_io_context &task_execution_service,
802793
worker::TaskEventBuffer &task_event_buffer,
803794
const TaskHandler &task_handler,
804795
std::function<std::function<void()>()> initialize_thread_callback,
805796
const OnActorCreationTaskDone &actor_creation_task_done_)
806-
: TaskReceiver(worker_context,
807-
task_execution_service,
797+
: TaskReceiver(task_execution_service,
808798
task_event_buffer,
809799
task_handler,
810800
initialize_thread_callback,
@@ -819,8 +809,7 @@ class MockTaskReceiver : public TaskReceiver {
819809
class TaskReceiverTest : public ::testing::Test {
820810
public:
821811
TaskReceiverTest()
822-
: worker_context_(WorkerType::WORKER, JobID::FromInt(0)),
823-
worker_client_(std::make_shared<MockWorkerClient>()),
812+
: worker_client_(std::make_shared<MockWorkerClient>()),
824813
dependency_waiter_(std::make_unique<MockDependencyWaiter>()) {
825814
auto execute_task = std::bind(&TaskReceiverTest::MockExecuteTask,
826815
this,
@@ -831,7 +820,6 @@ class TaskReceiverTest : public ::testing::Test {
831820
std::placeholders::_5,
832821
std::placeholders::_6);
833822
receiver_ = std::make_unique<MockTaskReceiver>(
834-
worker_context_,
835823
task_execution_service_,
836824
task_event_buffer_,
837825
execute_task,
@@ -867,7 +855,6 @@ class TaskReceiverTest : public ::testing::Test {
867855

868856
private:
869857
rpc::Address rpc_address_;
870-
MockWorkerContext worker_context_;
871858
instrumented_io_context task_execution_service_;
872859
MockTaskEventBuffer task_event_buffer_;
873860
std::shared_ptr<MockWorkerClient> worker_client_;

src/ray/core_worker/transport/task_receiver.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,7 @@ void TaskReceiver::HandleTask(rpc::PushTaskRequest request,
3838
RAY_CHECK(waiter_ != nullptr) << "Must call init() prior to use";
3939
TaskSpecification task_spec(std::move(*request.mutable_task_spec()));
4040

41-
// If GCS server is restarted after sending an actor creation task to this core worker,
42-
// the restarted GCS server will send the same actor creation task to the core worker
43-
// again. We just need to ignore it and reply ok.
44-
if (task_spec.IsActorCreationTask() &&
45-
worker_context_.GetCurrentActorID() == task_spec.ActorCreationId()) {
46-
send_reply_callback(Status::OK(), nullptr, nullptr);
47-
RAY_LOG(INFO) << "Ignoring duplicate actor creation task for actor "
48-
<< task_spec.ActorCreationId()
49-
<< ". This is likely due to a GCS server restart.";
50-
return;
51-
}
52-
5341
if (task_spec.IsActorCreationTask()) {
54-
worker_context_.SetCurrentActorId(task_spec.ActorCreationId());
5542
SetupActor(task_spec.IsAsyncioActor(),
5643
task_spec.MaxActorConcurrency(),
5744
task_spec.ExecuteOutOfOrder());

src/ray/core_worker/transport/task_receiver.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include "ray/core_worker/actor_creator.h"
3333
#include "ray/core_worker/actor_handle.h"
3434
#include "ray/core_worker/common.h"
35-
#include "ray/core_worker/context.h"
3635
#include "ray/core_worker/fiber.h"
3736
#include "ray/core_worker/store_provider/memory_store/memory_store.h"
3837
#include "ray/core_worker/transport/actor_scheduling_queue.h"
@@ -63,14 +62,12 @@ class TaskReceiver {
6362

6463
using OnActorCreationTaskDone = std::function<Status()>;
6564

66-
TaskReceiver(WorkerContext &worker_context,
67-
instrumented_io_context &task_execution_service,
65+
TaskReceiver(instrumented_io_context &task_execution_service,
6866
worker::TaskEventBuffer &task_event_buffer,
6967
TaskHandler task_handler,
7068
std::function<std::function<void()>()> initialize_thread_callback,
7169
const OnActorCreationTaskDone &actor_creation_task_done)
72-
: worker_context_(worker_context),
73-
task_handler_(std::move(task_handler)),
70+
: task_handler_(std::move(task_handler)),
7471
task_execution_service_(task_execution_service),
7572
task_event_buffer_(task_event_buffer),
7673
initialize_thread_callback_(std::move(initialize_thread_callback)),
@@ -124,8 +121,6 @@ class TaskReceiver {
124121
absl::flat_hash_map<ActorID, std::vector<ConcurrencyGroup>> concurrency_groups_cache_;
125122

126123
private:
127-
// Worker context.
128-
WorkerContext &worker_context_;
129124
/// The callback function to process a task.
130125
TaskHandler task_handler_;
131126
/// The event loop for running tasks on.

0 commit comments

Comments
 (0)