|
| 1 | +from typing import Optional |
| 2 | +from ray.core.generated.export_train_state_pb2 import ( |
| 3 | + ExportTrainRunEventData as ProtoTrainRun, |
| 4 | + ExportTrainRunAttemptEventData as ProtoTrainRunAttempt, |
| 5 | +) |
| 6 | +from ray.train._internal.state.schema import ( |
| 7 | + TrainRunInfo, |
| 8 | + TrainWorkerInfo, |
| 9 | + RunStatusEnum, |
| 10 | + ActorStatusEnum, |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +TRAIN_SCHEMA_VERSION = 1 |
| 15 | + |
| 16 | +# Status mapping dictionaries |
| 17 | +_ACTOR_STATUS_MAP = { |
| 18 | + ActorStatusEnum.ALIVE: ProtoTrainRunAttempt.ActorStatus.ALIVE, |
| 19 | + ActorStatusEnum.DEAD: ProtoTrainRunAttempt.ActorStatus.DEAD, |
| 20 | +} |
| 21 | + |
| 22 | +_RUN_ATTEMPT_STATUS_MAP = { |
| 23 | + RunStatusEnum.STARTED: ProtoTrainRunAttempt.RunAttemptStatus.PENDING, |
| 24 | + RunStatusEnum.RUNNING: ProtoTrainRunAttempt.RunAttemptStatus.RUNNING, |
| 25 | + RunStatusEnum.FINISHED: ProtoTrainRunAttempt.RunAttemptStatus.FINISHED, |
| 26 | + RunStatusEnum.ERRORED: ProtoTrainRunAttempt.RunAttemptStatus.ERRORED, |
| 27 | + RunStatusEnum.ABORTED: ProtoTrainRunAttempt.RunAttemptStatus.ABORTED, |
| 28 | +} |
| 29 | + |
| 30 | +_RUN_STATUS_MAP = { |
| 31 | + RunStatusEnum.STARTED: ProtoTrainRun.RunStatus.INITIALIZING, |
| 32 | + RunStatusEnum.RUNNING: ProtoTrainRun.RunStatus.RUNNING, |
| 33 | + RunStatusEnum.FINISHED: ProtoTrainRun.RunStatus.FINISHED, |
| 34 | + RunStatusEnum.ERRORED: ProtoTrainRun.RunStatus.ERRORED, |
| 35 | + RunStatusEnum.ABORTED: ProtoTrainRun.RunStatus.ABORTED, |
| 36 | +} |
| 37 | + |
| 38 | + |
| 39 | +def _ms_to_ns(ms: Optional[int]) -> Optional[int]: |
| 40 | + if ms is None: |
| 41 | + return None |
| 42 | + return ms * 1000000 |
| 43 | + |
| 44 | + |
| 45 | +# Helper conversion functions |
| 46 | +def _to_proto_resources(resources: dict) -> ProtoTrainRunAttempt.TrainResources: |
| 47 | + """Convert resources dictionary to protobuf TrainResources.""" |
| 48 | + return ProtoTrainRunAttempt.TrainResources(resources=resources) |
| 49 | + |
| 50 | + |
| 51 | +def _to_proto_worker(worker: TrainWorkerInfo) -> ProtoTrainRunAttempt.TrainWorker: |
| 52 | + """Convert TrainWorker to protobuf format.""" |
| 53 | + proto_worker = ProtoTrainRunAttempt.TrainWorker( |
| 54 | + world_rank=worker.world_rank, |
| 55 | + local_rank=worker.local_rank, |
| 56 | + node_rank=worker.node_rank, |
| 57 | + actor_id=bytes.fromhex(worker.actor_id), |
| 58 | + node_id=bytes.fromhex(worker.node_id), |
| 59 | + node_ip=worker.node_ip, |
| 60 | + pid=worker.pid, |
| 61 | + gpu_ids=worker.gpu_ids, |
| 62 | + status=_ACTOR_STATUS_MAP[worker.status], |
| 63 | + resources=_to_proto_resources(worker.resources), |
| 64 | + log_file_path=worker.worker_log_file_path, |
| 65 | + ) |
| 66 | + |
| 67 | + return proto_worker |
| 68 | + |
| 69 | + |
| 70 | +# Main conversion functions |
| 71 | +def train_run_info_to_proto_run(run_info: TrainRunInfo) -> ProtoTrainRun: |
| 72 | + """Convert TrainRunInfo to TrainRun protobuf format.""" |
| 73 | + proto_run = ProtoTrainRun( |
| 74 | + schema_version=TRAIN_SCHEMA_VERSION, |
| 75 | + id=run_info.id, |
| 76 | + name=run_info.name, |
| 77 | + job_id=bytes.fromhex(run_info.job_id), |
| 78 | + controller_actor_id=bytes.fromhex(run_info.controller_actor_id), |
| 79 | + status=_RUN_STATUS_MAP[run_info.run_status], |
| 80 | + status_detail=run_info.status_detail, |
| 81 | + start_time_ns=_ms_to_ns(run_info.start_time_ms), |
| 82 | + end_time_ns=_ms_to_ns(run_info.end_time_ms), |
| 83 | + controller_log_file_path=run_info.controller_log_file_path, |
| 84 | + ) |
| 85 | + |
| 86 | + return proto_run |
| 87 | + |
| 88 | + |
| 89 | +def train_run_info_to_proto_attempt(run_info: TrainRunInfo) -> ProtoTrainRunAttempt: |
| 90 | + """Convert TrainRunInfo to TrainRunAttempt protobuf format.""" |
| 91 | + |
| 92 | + proto_attempt = ProtoTrainRunAttempt( |
| 93 | + schema_version=TRAIN_SCHEMA_VERSION, |
| 94 | + run_id=run_info.id, |
| 95 | + attempt_id=run_info.id, # Same as run_id |
| 96 | + status=_RUN_ATTEMPT_STATUS_MAP[run_info.run_status], |
| 97 | + status_detail=run_info.status_detail, |
| 98 | + start_time_ns=_ms_to_ns(run_info.start_time_ms), |
| 99 | + end_time_ns=_ms_to_ns(run_info.end_time_ms), |
| 100 | + resources=[_to_proto_resources(r) for r in run_info.resources], |
| 101 | + workers=[_to_proto_worker(worker) for worker in run_info.workers], |
| 102 | + ) |
| 103 | + |
| 104 | + return proto_attempt |
0 commit comments