Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 96 additions & 32 deletions python/kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import queue
import random
import string
import time
import uuid
from typing import Dict, List, Optional, Union, Set

Expand Down Expand Up @@ -159,6 +160,71 @@ def get_runtime(self, name: str) -> types.Runtime:

return self.__get_runtime_from_crd(runtime) # type: ignore

def get_runtime_packages(self, runtime: types.Runtime):
"""
Print the installed Python packages for the given Runtime. If Runtime has GPUs it also
prints available GPUs on the single training node.
Args:
runtime: Reference to one of existing Runtimes.
Raises:
ValueError: Input arguments are invalid.
RuntimeError: Failed to get Runtime.
"""

if runtime.trainer.trainer_type == types.TrainerType.BUILTIN_TRAINER:
raise ValueError("Cannot get Runtime packages for BuiltinTrainer")

# Run mpirun only within the single process.
if runtime.trainer.command[0] == "mpirun":
mpi_command = list(constants.MPI_COMMAND)
mpi_command[1:3] = ["-np", "1"]
runtime.trainer.set_command(tuple(mpi_command))

def print_packages():
import subprocess
import shutil
import sys

# Print Python version.
print(f"Python: {sys.version}")

# Print Python packages.
if shutil.which("pip"):
pip_list = subprocess.run(
["pip", "list"], capture_output=True, text=True
)
print(pip_list.stdout)
else:
print("Unable to get installed packages: pip command not found")

# Print nvidia-smi if GPUs are available.
if shutil.which("nvidia-smi"):
print("Available GPUs on the single training node")
nvidia_smi = subprocess.run(
["nvidia-smi"], capture_output=True, text=True
)
print(nvidia_smi.stdout)

# Create the TrainJob and wait until it completes.
# If Runtime trainer has GPU resources use them, otherwise run TrainJob with 1 CPU.
job_name = self.train(
runtime=runtime,
trainer=types.CustomTrainer(
func=print_packages,
num_nodes=1,
resources_per_node=(
{"cpu": 1} if runtime.trainer.device != "gpu" else None
),
),
)

self.wait_for_job_status(job_name)
print(self.get_job_logs(job_name)["node-0"])
self.delete_job(job_name)

def train(
self,
runtime: Optional[types.Runtime] = None,
Expand All @@ -174,11 +240,11 @@ def train(
the post-training logic, requiring only parameter adjustments, e.g. `BuiltinTrainer`.
Args:
runtime (`types.Runtime`): Reference to one of existing Runtimes. By default the
runtime: Reference to one of existing Runtimes. By default the
torch-distributed Runtime is used.
initializer (`Optional[types.Initializer]`):
initializer:
Configuration for the dataset and model initializers.
trainer (`Optional[types.CustomTrainer, types.BuiltinTrainer]`):
trainer:
Configuration for Custom Training Task or Config-driven Task with Builtin Trainer.
Returns:
Expand Down Expand Up @@ -460,6 +526,7 @@ def wait_for_job_status(
name: str,
status: Set[str] = {constants.TRAINJOB_COMPLETE},
timeout: int = 600,
polling_interval: int = 2,
) -> types.TrainJob:
"""Wait for TrainJob to reach the desired status
Expand All @@ -468,6 +535,7 @@ def wait_for_job_status(
status: Set of expected statuses. It must be subset of Created, Running, Complete, and
Failed statuses.
timeout: How many seconds to wait until TrainJob reaches one of the expected conditions.
polling_interval: The polling interval in seconds to check TrainJob status.
Returns:
TrainJob: The training job that reaches the desired status.
Expand All @@ -489,36 +557,28 @@ def wait_for_job_status(
f"Expected status {status} must be a subset of {job_statuses}"
)

# Use Kubernetes watch API to monitor the TrainJob's Pods.
w = watch.Watch()
try:
for event in w.stream(
self.core_api.list_namespaced_pod,
self.namespace,
label_selector=constants.POD_LABEL_SELECTOR.format(trainjob_name=name),
timeout_seconds=timeout,
):
# Check the status after event is generated for the TrainJob's Pods.
trainjob = self.get_job(name)
logger.debug(f"TrainJob {name}, status {trainjob.status}")
if polling_interval > timeout:
Copy link
Member Author

@andreyvelich andreyvelich Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@astefanutti I have to refactor the wait_for_job_status() API to perform polling as before.
The problem that I saw is when Pods are succeeded too fast and TrainJob controller doesn't add the Complete condition to the .status.conditions.

Since we only watch for Pod events, we can't catch this event, and TrainJob is stuck in Running condition.

Alternatively, we can watch both for TrainJob + Pods with two Python threads, but I am not sure if it worths it.

What do you think @astefanutti ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreyvelich I agree with you. This problem should be addressed when we'll have comprehensive TrainJob conditions. During the interim, better keep things simple in the SDK and refactor it once we'll have the new TrainJob conditions.

raise ValueError(
f"Polling interval {polling_interval} must be less than timeout: {timeout}"
)

# Raise an error if TrainJob is Failed and it is not the expected status.
if (
constants.TRAINJOB_FAILED not in status
and trainjob.status == constants.TRAINJOB_FAILED
):
raise RuntimeError(f"TrainJob {name} is Failed")
for _ in range(round(timeout / polling_interval)):
# Check the status after event is generated for the TrainJob's Pods.
trainjob = self.get_job(name)
logger.debug(f"TrainJob {name}, status {trainjob.status}")

# Return the TrainJob if it reaches the expected status.
if trainjob.status in status:
return trainjob
# Raise an error if TrainJob is Failed and it is not the expected status.
if (
constants.TRAINJOB_FAILED not in status
and trainjob.status == constants.TRAINJOB_FAILED
):
raise RuntimeError(f"TrainJob {name} is Failed")

except TimeoutError:
raise TimeoutError(f"Timeout to get the TrainJob {name}")
except Exception:
raise RuntimeError(f"Failed to watch Pods for TrainJob {name}")
finally:
w.stop()
# Return the TrainJob if it reaches the expected status.
if trainjob.status in status:
return trainjob

time.sleep(polling_interval)

raise TimeoutError(
f"Timeout waiting for TrainJob {name} to reach status: {status} status"
Expand Down Expand Up @@ -691,12 +751,16 @@ def __get_trainjob_from_crd(
elif c.type == constants.TRAINJOB_FAILED and c.status == "True":
trainjob.status = c.type
else:
# The TrainJob running status is defined when all training node (e.g. Pods) are running.
# The TrainJob running status is defined when all training node (e.g. Pods) are
# running or succeeded.
num_running_nodes = sum(
1
for step in trainjob.steps
if step.name.startswith(constants.NODE)
and step.status == constants.TRAINJOB_RUNNING
and (
step.status == constants.TRAINJOB_RUNNING
or step.status == constants.POD_SUCCEEDED
)
)

if trainjob.num_nodes == num_running_nodes:
Expand Down
67 changes: 59 additions & 8 deletions python/kubeflow/trainer/api/trainer_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ class TestCase:
# In all tests runtime name is equal to the framework name.
TORCH_RUNTIME = "torch"
TORCH_TUNE_RUNTIME = "torchtune"

# 2 nodes * 2 nproc
RUNTIME_DEVICES = "4"

FAIL_LOGS = "fail_logs"
LIST_RUNTIMES = "list_runtimes"
BASIC_TRAIN_JOB_NAME = "basic-job"
Expand Down Expand Up @@ -95,11 +99,6 @@ def trainer_client(request):
list_namespaced_pod=Mock(side_effect=list_namespaced_pod_response),
read_namespaced_pod_log=Mock(side_effect=mock_read_namespaced_pod_log),
),
), patch(
"kubernetes.watch.Watch",
return_value=Mock(
stream=Mock(side_effect=mock_watch),
),
):
yield TrainerClient()

Expand Down Expand Up @@ -509,7 +508,8 @@ def create_runtime_type(
trainer_type=types.TrainerType.CUSTOM_TRAINER,
framework=name,
num_nodes=2,
accelerator_count=4,
device="gpu",
device_count=RUNTIME_DEVICES,
)
trainer.set_command(constants.TORCH_COMMAND)
return types.Runtime(
Expand All @@ -528,7 +528,8 @@ def get_train_job_data_type(
trainer = types.RuntimeTrainer(
trainer_type=types.TrainerType.CUSTOM_TRAINER,
framework=runtime_name,
accelerator_count=4,
device="gpu",
device_count=RUNTIME_DEVICES,
num_nodes=2,
)
trainer.set_command(constants.TORCH_COMMAND)
Expand Down Expand Up @@ -644,6 +645,45 @@ def test_list_runtimes(trainer_client, test_case):
print("test execution complete")


@pytest.mark.parametrize(
"test_case",
[
TestCase(
name="valid flow with custom trainer runtime",
expected_status=SUCCESS,
config={"runtime": create_runtime_type(name=TORCH_RUNTIME)},
),
TestCase(
name="value error with builtin trainer runtime",
expected_status=FAILED,
config={
"runtime": types.Runtime(
name="torchtune-runtime",
trainer=types.RuntimeTrainer(
trainer_type=types.TrainerType.BUILTIN_TRAINER,
framework="torchtune",
num_nodes=1,
device="cpu",
device_count="1",
),
)
},
expected_error=ValueError,
),
],
)
def test_get_runtime_packages(trainer_client, test_case):
"""Test TrainerClient.get_runtime_packages with basic success path."""
print("Executing test:", test_case.name)

try:
trainer_client.get_runtime_packages(**test_case.config)
except Exception as e:
assert type(e) is test_case.expected_error

print("test execution complete")


@pytest.mark.parametrize(
"test_case",
[
Expand Down Expand Up @@ -944,6 +984,16 @@ def test_get_job_logs(trainer_client, test_case):
},
expected_error=ValueError,
),
TestCase(
name="polling interval is more than timeout error",
expected_status=FAILED,
config={
"name": BASIC_TRAIN_JOB_NAME,
"timeout": 1,
"polling_interval": 2,
},
expected_error=ValueError,
),
TestCase(
name="job failed when not expected",
expected_status=FAILED,
Expand All @@ -959,7 +1009,8 @@ def test_get_job_logs(trainer_client, test_case):
config={
"name": BASIC_TRAIN_JOB_NAME,
"status": {constants.TRAINJOB_FAILED},
"timeout": 1,
"polling_interval": 1,
"timeout": 2,
},
expected_error=TimeoutError,
),
Expand Down
6 changes: 5 additions & 1 deletion python/kubeflow/trainer/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
# The default status for the TrainJob once users create it.
TRAINJOB_CREATED = "Created"

# The running status of the TrainJob, defined when all training node (e.g. Pods) are running.
# The running status of the TrainJob, defined when all training node (e.g. Pods) are
# running or succeeded.
TRAINJOB_RUNNING = "Running"

# The complete status of the TrainJob, defined when TrainJob CR has complete condition.
Expand All @@ -50,6 +51,9 @@
# The failed status of the TrainJob, defined when TrainJob CR has failed condition.
TRAINJOB_FAILED = "Failed"

# The succeeded phase of the Pod.
POD_SUCCEEDED = "Succeeded"

# The label key to identify the relationship between TrainJob and Pod template in the runtime.
# For example, what PodTemplate must be overridden by TrainJob's .spec.trainer APIs.
TRAINJOB_ANCESTOR_LABEL = "trainer.kubeflow.org/trainjob-ancestor-step"
Expand Down
7 changes: 4 additions & 3 deletions python/kubeflow/trainer/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, Optional

from kubeflow.trainer.constants import constants

Expand Down Expand Up @@ -168,7 +168,8 @@ class RuntimeTrainer:
trainer_type: TrainerType
framework: str
num_nodes: int = 1 # The default value is set in the APIs.
accelerator_count: Union[str, float, int] = constants.UNKNOWN
device: str = constants.UNKNOWN
device_count: str = constants.UNKNOWN
Comment on lines +171 to +172
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I make it consistent with Step device and device_count.
I think, it looks better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great. We should also update notebooks in trainer examples

__command: tuple[str, ...] = field(init=False, repr=False)

@property
Expand All @@ -194,7 +195,7 @@ class Step:
status: Optional[str]
pod_name: str
device: str = constants.UNKNOWN
device_count: Union[str, int] = constants.UNKNOWN
device_count: str = constants.UNKNOWN


# Representation for the TrainJob.
Expand Down
10 changes: 5 additions & 5 deletions python/kubeflow/trainer/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,19 @@ def get_runtime_trainer(

# Get the container devices.
if devices := get_container_devices(trainer_container.resources):
_, trainer.accelerator_count = devices
trainer.device, trainer.device_count = devices

# Torch and MPI plugins override accelerator count.
if ml_policy.torch and ml_policy.torch.num_proc_per_node:
num_proc = ml_policy.torch.num_proc_per_node.actual_instance
if isinstance(num_proc, int):
trainer.accelerator_count = num_proc
trainer.device_count = str(num_proc)
elif ml_policy.mpi and ml_policy.mpi.num_proc_per_node:
trainer.accelerator_count = ml_policy.mpi.num_proc_per_node
trainer.device_count = str(ml_policy.mpi.num_proc_per_node)

# Multiply accelerator_count by the number of nodes.
if isinstance(trainer.accelerator_count, (int, float)) and ml_policy.num_nodes:
trainer.accelerator_count *= ml_policy.num_nodes
if trainer.device_count.isdigit() and ml_policy.num_nodes:
trainer.device_count = str(int(trainer.device_count) * ml_policy.num_nodes)

# Add number of training nodes.
if ml_policy.num_nodes:
Expand Down