Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions python/kubeflow/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

# Import the Kubeflow Trainer client.
from kubeflow.trainer.api.trainer_client import TrainerClient # noqa: F401

# Import the Kubeflow Trainer constants.
from kubeflow.trainer.constants.constants import DATASET_PATH, MODEL_PATH # noqa: F401

# Import the Kubeflow Trainer types.
from kubeflow.trainer.types.types import (
BuiltinTrainer,
Expand All @@ -35,13 +37,26 @@
Runtime,
TorchTuneConfig,
TorchTuneInstructDataset,
Trainer,
RuntimeTrainer,
TrainerType,
)

__all__ = [
"BuiltinTrainer", "CustomTrainer", "DataFormat", "DATASET_PATH", "DataType", "Framework",
"HuggingFaceDatasetInitializer", "HuggingFaceModelInitializer", "Initializer", "Loss",
"MODEL_PATH", "Runtime", "TorchTuneConfig", "TorchTuneInstructDataset", "Trainer",
"TrainerClient", "TrainerType"
"BuiltinTrainer",
"CustomTrainer",
"DataFormat",
"DATASET_PATH",
"DataType",
"Framework",
"HuggingFaceDatasetInitializer",
"HuggingFaceModelInitializer",
"Initializer",
"Loss",
"MODEL_PATH",
"Runtime",
"TorchTuneConfig",
"TorchTuneInstructDataset",
"RuntimeTrainer",
"TrainerClient",
"TrainerType",
]
129 changes: 101 additions & 28 deletions python/kubeflow/trainer/api/trainer_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import random
import string
import uuid
from typing import Dict, List, Optional, Union
import time
from typing import Dict, List, Optional, Union, Set

from kubeflow.trainer.constants import constants
from kubeflow.trainer.types import types
Expand Down Expand Up @@ -433,6 +434,61 @@ def get_job_logs(

return logs_dict

def wait_for_job_status(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also accept namespace parameter?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, since the namespace is controlled by TrainerClient: https://github.com/kubeflow/sdk/blob/main/python/kubeflow/trainer/api/trainer_client.py#L38
We don't allow APIs (e.g. get_job()) to override it.

Ideally, we should abstract the namespace context from the SDK user.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, makes sense now!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would that make sense to be a bit more generic and make it something like wait_for_job_condition, similar to what the kubectl wait command provides?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about it, however I am not sure if we want to expose complexity of CR condition to the SDK user.
For example, condition has type, status, and reason, etc. API that we don't really need to expose to the user (at least for now).

Thus, I suggest that we just expose single TrainJob status to make it easier to read.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, that makes sense. We can always add a "lower" level method later if needed.

self,
name: str,
status: Set[str] = {constants.TRAINJOB_COMPLETE},
timeout: int = 600,
polling_interval: int = 5,
) -> types.TrainJob:
"""Wait for TrainJob to reach the desired status

Args:
name: Name of the TrainJob.
status: Set of expected statuses. It must be subset of Created, Running, Complete, and
Failed statuses.
timeout: How many seconds to wait until TrainJob reaches one of the expected conditions.
polling_interval: The polling interval in seconds to check TrainJob status.

Returns:
TrainJob: The training job that reaches desired status.

Raises:
ValueError: The input values are incorrect.
RuntimeError: Failed to get TrainJob or TrainJob reaches unexpected Failed status.
TimeoutError: Timeout to wait for TrainJob status.
"""

job_statuses = {
constants.TRAINJOB_CREATED,
constants.TRAINJOB_RUNNING,
constants.TRAINJOB_COMPLETE,
constants.TRAINJOB_FAILED,
}
if not status.issubset(job_statuses):
raise ValueError(
f"Expected status {status} must be a subset of {job_statuses}"
)
for _ in range(round(timeout / polling_interval)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think a watch request be used instead of polling?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me try that!

trainjob = self.get_job(name)

# Raise an error if TrainJob is Failed and it is not the expected status.
if (
constants.TRAINJOB_FAILED not in status
and trainjob.status == constants.TRAINJOB_FAILED
):
raise RuntimeError(f"TrainJob {name} is Failed")

# Return the TrainJob if it reaches the expected status.
if trainjob.status in status:
return trainjob

time.sleep(polling_interval)

raise TimeoutError(
f"Timeout waiting for TrainJob {name} to reach {status} Status"
)

def delete_job(self, name: str):
"""Delete the TrainJob.

Expand Down Expand Up @@ -485,7 +541,6 @@ def __get_runtime_from_crd(
trainer=utils.get_runtime_trainer(
runtime_crd.spec.template.spec.replicated_jobs,
runtime_crd.spec.ml_policy,
runtime_crd.metadata,
),
)

Expand All @@ -506,26 +561,23 @@ def __get_trainjob_from_crd(
name = trainjob_crd.metadata.name
namespace = trainjob_crd.metadata.namespace

runtime = self.get_runtime(trainjob_crd.spec.runtime_ref.name)

# Construct the TrainJob from the CRD.
trainjob = types.TrainJob(
name=name,
creation_timestamp=trainjob_crd.metadata.creation_timestamp,
runtime=self.get_runtime(trainjob_crd.spec.runtime_ref.name),
runtime=runtime,
steps=[],
# Number of nodes is taken from TrainJob or TrainingRuntime
num_nodes=(
trainjob_crd.spec.trainer.num_nodes
if trainjob_crd.spec.trainer and trainjob_crd.spec.trainer.num_nodes
else runtime.trainer.num_nodes
),
status=constants.TRAINJOB_CREATED, # The default TrainJob status.
)

# Add the TrainJob status.
# TODO (andreyvelich): Discuss how we should show TrainJob status to SDK users.
# The TrainJob exists at that stage so its status can safely default to Created
trainjob.status = constants.TRAINJOB_CREATED
# Then it can be read from the TrainJob conditions if any
if trainjob_crd.status and trainjob_crd.status.conditions:
for c in trainjob_crd.status.conditions:
if c.type == "Complete" and c.status == "True":
trainjob.status = "Succeeded"
elif c.type == "Failed" and c.status == "True":
trainjob.status = "Failed"

# Select Pods created by the appropriate JobSet. It checks the following ReplicatedJob.name:
# dataset-initializer, model-initializer, launcher, node.
label_selector = "{}={},{} in ({}, {}, {}, {})".format(
Expand Down Expand Up @@ -567,26 +619,28 @@ def __get_trainjob_from_crd(
constants.DATASET_INITIALIZER,
constants.MODEL_INITIALIZER,
}:
step = utils.get_trainjob_initializer_step(
pod.metadata.name,
pod.spec,
pod.status,
trainjob.steps.append(
utils.get_trainjob_initializer_step(
pod.metadata.name,
pod.spec,
pod.status,
)
)
# Get the Node step.
elif pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL] in {
constants.LAUNCHER,
constants.NODE,
}:
step = utils.get_trainjob_node_step(
pod.metadata.name,
pod.spec,
pod.status,
trainjob.runtime,
pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL],
int(pod.metadata.labels[constants.JOB_INDEX_LABEL]),
trainjob.steps.append(
utils.get_trainjob_node_step(
pod.metadata.name,
pod.spec,
pod.status,
trainjob.runtime,
pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL],
int(pod.metadata.labels[constants.JOB_INDEX_LABEL]),
)
)

trainjob.steps.append(step)
except multiprocessing.TimeoutError:
raise TimeoutError(
f"Timeout to list {constants.TRAINJOB_KIND}'s steps: {namespace}/{name}"
Expand All @@ -596,4 +650,23 @@ def __get_trainjob_from_crd(
f"Failed to list {constants.TRAINJOB_KIND}'s steps: {namespace}/{name}"
)

# Update the TrainJob status from its conditions.
if trainjob_crd.status and trainjob_crd.status.conditions:
for c in trainjob_crd.status.conditions:
if c.type == constants.TRAINJOB_COMPLETE and c.status == "True":
trainjob.status = c.type
elif c.type == constants.TRAINJOB_FAILED and c.status == "True":
trainjob.status = c.type
else:
# The TrainJob running status is defined when all training node (e.g. Pods) are running.
num_running_nodes = sum(
1
for step in trainjob.steps
if step.name.startswith(constants.NODE)
and step.status == constants.TRAINJOB_RUNNING
)

if trainjob.num_nodes == num_running_nodes:
trainjob.status = constants.TRAINJOB_RUNNING

return trainjob
Loading