Skip to content

Commit 31878c9

Browse files
authored
[train] train v1 export api (#51177)
This PR implements the export API for Ray Train V1 state. This builds on top of #50622, which implements the export API for Ray Train V2. ## Key Changes - Added `export.py` with conversion functions between Train V1 state models and Train (V2) state export protobuf - Updated `TrainRunInfo` and `TrainWorkerInfo` schemas with additional fields for compatibility: - Log file paths for controller and workers - Note that these point to the Ray worker stderr logs, rather than specific train logs. - Resource allocation information - Made worker status a required field - Note that it is always set as ACTIVE for now. Signed-off-by: Matthew Deng <matt@anyscale.com>
1 parent 2e9e63b commit 31878c9

File tree

9 files changed

+378
-7
lines changed

9 files changed

+378
-7
lines changed

python/ray/train/BUILD

+8
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,14 @@ py_test(
745745
],
746746
)
747747

748+
py_test(
749+
name = "test_state_export",
750+
size = "small",
751+
srcs = ["tests/test_state_export.py"],
752+
tags = ["team:ml", "exclusive"],
753+
deps = [":train_lib", ":conftest"]
754+
)
755+
748756
py_test(
749757
name = "test_tensorflow_checkpoint",
750758
size = "small",

python/ray/train/_internal/backend_executor.py

+5
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,9 @@ def initialize_session(
565565
from ray.train._internal.state.schema import RunStatusEnum
566566

567567
core_context = ray.runtime_context.get_runtime_context()
568+
controller_log_file_path = (
569+
ray._private.worker.global_worker.get_err_file_path()
570+
)
568571

569572
self.state_manager.register_train_run(
570573
run_id=self._trial_info.run_id,
@@ -575,6 +578,8 @@ def initialize_session(
575578
worker_group=self.worker_group,
576579
start_time_ms=self._start_time_ms,
577580
run_status=RunStatusEnum.RUNNING,
581+
controller_log_file_path=controller_log_file_path,
582+
resources=[self._resources_per_worker] * self._num_workers,
578583
)
579584

580585
# Run the training function asynchronously in its own thread.
+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from typing import Optional
2+
from ray.core.generated.export_train_state_pb2 import (
3+
ExportTrainRunEventData as ProtoTrainRun,
4+
ExportTrainRunAttemptEventData as ProtoTrainRunAttempt,
5+
)
6+
from ray.train._internal.state.schema import (
7+
TrainRunInfo,
8+
TrainWorkerInfo,
9+
RunStatusEnum,
10+
ActorStatusEnum,
11+
)
12+
13+
14+
TRAIN_SCHEMA_VERSION = 1
15+
16+
# Status mapping dictionaries
17+
_ACTOR_STATUS_MAP = {
18+
ActorStatusEnum.ALIVE: ProtoTrainRunAttempt.ActorStatus.ALIVE,
19+
ActorStatusEnum.DEAD: ProtoTrainRunAttempt.ActorStatus.DEAD,
20+
}
21+
22+
_RUN_ATTEMPT_STATUS_MAP = {
23+
RunStatusEnum.STARTED: ProtoTrainRunAttempt.RunAttemptStatus.PENDING,
24+
RunStatusEnum.RUNNING: ProtoTrainRunAttempt.RunAttemptStatus.RUNNING,
25+
RunStatusEnum.FINISHED: ProtoTrainRunAttempt.RunAttemptStatus.FINISHED,
26+
RunStatusEnum.ERRORED: ProtoTrainRunAttempt.RunAttemptStatus.ERRORED,
27+
RunStatusEnum.ABORTED: ProtoTrainRunAttempt.RunAttemptStatus.ABORTED,
28+
}
29+
30+
_RUN_STATUS_MAP = {
31+
RunStatusEnum.STARTED: ProtoTrainRun.RunStatus.INITIALIZING,
32+
RunStatusEnum.RUNNING: ProtoTrainRun.RunStatus.RUNNING,
33+
RunStatusEnum.FINISHED: ProtoTrainRun.RunStatus.FINISHED,
34+
RunStatusEnum.ERRORED: ProtoTrainRun.RunStatus.ERRORED,
35+
RunStatusEnum.ABORTED: ProtoTrainRun.RunStatus.ABORTED,
36+
}
37+
38+
39+
def _ms_to_ns(ms: Optional[int]) -> Optional[int]:
40+
if ms is None:
41+
return None
42+
return ms * 1000000
43+
44+
45+
# Helper conversion functions
46+
def _to_proto_resources(resources: dict) -> ProtoTrainRunAttempt.TrainResources:
47+
"""Convert resources dictionary to protobuf TrainResources."""
48+
return ProtoTrainRunAttempt.TrainResources(resources=resources)
49+
50+
51+
def _to_proto_worker(worker: TrainWorkerInfo) -> ProtoTrainRunAttempt.TrainWorker:
52+
"""Convert TrainWorker to protobuf format."""
53+
proto_worker = ProtoTrainRunAttempt.TrainWorker(
54+
world_rank=worker.world_rank,
55+
local_rank=worker.local_rank,
56+
node_rank=worker.node_rank,
57+
actor_id=bytes.fromhex(worker.actor_id),
58+
node_id=bytes.fromhex(worker.node_id),
59+
node_ip=worker.node_ip,
60+
pid=worker.pid,
61+
gpu_ids=worker.gpu_ids,
62+
status=_ACTOR_STATUS_MAP[worker.status],
63+
resources=_to_proto_resources(worker.resources),
64+
log_file_path=worker.worker_log_file_path,
65+
)
66+
67+
return proto_worker
68+
69+
70+
# Main conversion functions
71+
def train_run_info_to_proto_run(run_info: TrainRunInfo) -> ProtoTrainRun:
72+
"""Convert TrainRunInfo to TrainRun protobuf format."""
73+
proto_run = ProtoTrainRun(
74+
schema_version=TRAIN_SCHEMA_VERSION,
75+
id=run_info.id,
76+
name=run_info.name,
77+
job_id=bytes.fromhex(run_info.job_id),
78+
controller_actor_id=bytes.fromhex(run_info.controller_actor_id),
79+
status=_RUN_STATUS_MAP[run_info.run_status],
80+
status_detail=run_info.status_detail,
81+
start_time_ns=_ms_to_ns(run_info.start_time_ms),
82+
end_time_ns=_ms_to_ns(run_info.end_time_ms),
83+
controller_log_file_path=run_info.controller_log_file_path,
84+
)
85+
86+
return proto_run
87+
88+
89+
def train_run_info_to_proto_attempt(run_info: TrainRunInfo) -> ProtoTrainRunAttempt:
90+
"""Convert TrainRunInfo to TrainRunAttempt protobuf format."""
91+
92+
proto_attempt = ProtoTrainRunAttempt(
93+
schema_version=TRAIN_SCHEMA_VERSION,
94+
run_id=run_info.id,
95+
attempt_id=run_info.id, # Same as run_id
96+
status=_RUN_ATTEMPT_STATUS_MAP[run_info.run_status],
97+
status_detail=run_info.status_detail,
98+
start_time_ns=_ms_to_ns(run_info.start_time_ms),
99+
end_time_ns=_ms_to_ns(run_info.end_time_ms),
100+
resources=[_to_proto_resources(r) for r in run_info.resources],
101+
workers=[_to_proto_worker(worker) for worker in run_info.workers],
102+
)
103+
104+
return proto_attempt

python/ray/train/_internal/state/schema.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from enum import Enum
2-
from typing import List, Optional
2+
from typing import Dict, List, Optional
33

44
from ray._private.pydantic_compat import BaseModel, Field
55
from ray.dashboard.modules.job.pydantic_models import JobDetails
@@ -47,9 +47,13 @@ class TrainWorkerInfo(BaseModel):
4747
gpu_ids: List[int] = Field(
4848
description="A list of GPU ids allocated to that worker."
4949
)
50-
status: Optional[ActorStatusEnum] = Field(
50+
status: ActorStatusEnum = Field(
5151
description="The status of the train worker actor. It can be ALIVE or DEAD."
5252
)
53+
resources: Dict[str, float] = Field(
54+
description="The resources allocated to the worker."
55+
)
56+
worker_log_file_path: str = Field(description="The path to the worker log file.")
5357

5458

5559
@DeveloperAPI
@@ -139,6 +143,12 @@ class TrainRunInfo(BaseModel):
139143
description="The UNIX timestamp of the end time of this Train run. "
140144
"If null, the Train run has not ended yet."
141145
)
146+
controller_log_file_path: str = Field(
147+
description="The path to the controller log file."
148+
)
149+
resources: List[Dict[str, float]] = Field(
150+
description="The resources allocated to the worker."
151+
)
142152

143153

144154
@DeveloperAPI

python/ray/train/_internal/state/state_actor.py

+90
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
import logging
2+
import os
23
import threading
34
from typing import Dict, Optional
45

56
import ray
7+
from ray._private.event.export_event_logger import (
8+
EventLogType,
9+
get_export_event_logger,
10+
check_export_api_enabled,
11+
)
612
from ray.actor import ActorHandle
713
from ray.train._internal.state.schema import TrainRunInfo
814

@@ -14,10 +20,19 @@ class TrainStateActor:
1420
def __init__(self):
1521
self._run_infos: Dict[str, TrainRunInfo] = {}
1622

23+
(
24+
self._export_logger,
25+
self._is_train_run_export_api_enabled,
26+
self._is_train_run_attempt_export_api_enabled,
27+
) = self._init_export_logger()
28+
1729
def register_train_run(self, run_info: TrainRunInfo) -> None:
1830
# Register a new train run.
1931
self._run_infos[run_info.id] = run_info
2032

33+
self._maybe_export_train_run(run_info)
34+
self._maybe_export_train_run_attempt(run_info)
35+
2136
def get_train_run(self, run_id: str) -> Optional[TrainRunInfo]:
2237
# Retrieve a registered run with its id
2338
return self._run_infos.get(run_id, None)
@@ -26,6 +41,81 @@ def get_all_train_runs(self) -> Dict[str, TrainRunInfo]:
2641
# Retrieve all registered train runs
2742
return self._run_infos
2843

44+
# ============================
45+
# Export API
46+
# ============================
47+
48+
def is_export_api_enabled(self) -> bool:
49+
return self._export_logger is not None
50+
51+
def _init_export_logger(self) -> tuple[Optional[logging.Logger], bool, bool]:
52+
"""Initialize the export logger and check if the export API is enabled.
53+
54+
Returns:
55+
A tuple containing:
56+
- The export logger (or None if export API is not enabled).
57+
- A boolean indicating if the export API is enabled for train runs.
58+
- A boolean indicating if the export API is enabled for train run attempts.
59+
"""
60+
# Proto schemas should be imported within the scope of TrainStateActor to
61+
# prevent serialization errors.
62+
from ray.core.generated.export_event_pb2 import ExportEvent
63+
64+
is_train_run_export_api_enabled = check_export_api_enabled(
65+
ExportEvent.SourceType.EXPORT_TRAIN_RUN
66+
)
67+
is_train_run_attempt_export_api_enabled = check_export_api_enabled(
68+
ExportEvent.SourceType.EXPORT_TRAIN_RUN_ATTEMPT
69+
)
70+
export_api_enabled = (
71+
is_train_run_export_api_enabled or is_train_run_attempt_export_api_enabled
72+
)
73+
74+
if not export_api_enabled:
75+
return None, False, False
76+
77+
log_directory = os.path.join(
78+
ray._private.worker._global_node.get_session_dir_path(), "logs"
79+
)
80+
logger = None
81+
try:
82+
logger = get_export_event_logger(
83+
EventLogType.TRAIN_STATE,
84+
log_directory,
85+
)
86+
except Exception:
87+
logger.exception(
88+
"Unable to initialize the export event logger, so no Train export "
89+
"events will be written."
90+
)
91+
92+
if logger is None:
93+
return None, False, False
94+
95+
return (
96+
logger,
97+
is_train_run_export_api_enabled,
98+
is_train_run_attempt_export_api_enabled,
99+
)
100+
101+
def _maybe_export_train_run(self, run_info: TrainRunInfo) -> None:
102+
if not self._is_train_run_export_api_enabled:
103+
return
104+
105+
from ray.train._internal.state.export import train_run_info_to_proto_run
106+
107+
run_proto = train_run_info_to_proto_run(run_info)
108+
self._export_logger.send_event(run_proto)
109+
110+
def _maybe_export_train_run_attempt(self, run_info: TrainRunInfo) -> None:
111+
if not self._is_train_run_attempt_export_api_enabled:
112+
return
113+
114+
from ray.train._internal.state.export import train_run_info_to_proto_attempt
115+
116+
run_attempt_proto = train_run_info_to_proto_attempt(run_info)
117+
self._export_logger.send_event(run_attempt_proto)
118+
29119

30120
TRAIN_STATE_ACTOR_NAME = "train_state_actor"
31121
TRAIN_STATE_ACTOR_NAMESPACE = "_train_state_actor"

python/ray/train/_internal/state/state_manager.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import logging
22
import os
33
from collections import defaultdict
4-
from typing import Any, Dict
4+
from typing import Any, Dict, List
55

66
import ray
77
from ray.data import Dataset
88
from ray.train._internal.state.schema import (
9+
ActorStatusEnum,
910
RunStatusEnum,
1011
TrainDatasetInfo,
1112
TrainRunInfo,
@@ -37,6 +38,8 @@ def register_train_run(
3738
datasets: Dict[str, Dataset],
3839
worker_group: WorkerGroup,
3940
start_time_ms: float,
41+
controller_log_file_path: str,
42+
resources: List[Dict[str, float]],
4043
status_detail: str = "",
4144
) -> None:
4245
"""Collect Train Run Info and report to StateActor."""
@@ -50,7 +53,7 @@ def register_train_run(
5053
def collect_train_worker_info():
5154
train_context = ray.train.get_context()
5255
core_context = ray.runtime_context.get_runtime_context()
53-
56+
worker_log_file_path = ray._private.worker.global_worker.get_err_file_path()
5457
return TrainWorkerInfo(
5558
world_rank=train_context.get_world_rank(),
5659
local_rank=train_context.get_local_rank(),
@@ -60,6 +63,9 @@ def collect_train_worker_info():
6063
node_ip=ray.util.get_node_ip_address(),
6164
gpu_ids=ray.get_gpu_ids(),
6265
pid=os.getpid(),
66+
resources=resources[0],
67+
worker_log_file_path=worker_log_file_path,
68+
status=ActorStatusEnum.ALIVE,
6369
)
6470

6571
futures = [
@@ -97,6 +103,8 @@ def collect_train_worker_info():
97103
start_time_ms=start_time_ms,
98104
run_status=run_status,
99105
status_detail=status_detail,
106+
controller_log_file_path=controller_log_file_path,
107+
resources=resources,
100108
)
101109

102110
# Clear the cached info to avoid registering the same run twice

0 commit comments

Comments
 (0)