-
Notifications
You must be signed in to change notification settings - Fork 41
feat(trainer): Add wait_for_job_status() API
#52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
35da385
76fd9a2
dc1b011
476b1b2
ddd2655
675f101
65b8a37
a4f0bcd
81e43fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,7 +18,7 @@ | |||||||||||||||||
| import random | ||||||||||||||||||
| import string | ||||||||||||||||||
| import uuid | ||||||||||||||||||
| from typing import Dict, List, Optional, Union | ||||||||||||||||||
| from typing import Dict, List, Optional, Union, Set | ||||||||||||||||||
|
|
||||||||||||||||||
| from kubeflow.trainer.constants import constants | ||||||||||||||||||
| from kubeflow.trainer.types import types | ||||||||||||||||||
|
|
@@ -433,6 +433,75 @@ def get_job_logs( | |||||||||||||||||
|
|
||||||||||||||||||
| return logs_dict | ||||||||||||||||||
|
|
||||||||||||||||||
| def wait_for_job_status( | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would that make sense to be a bit more generic and make it something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking about it, however I am not sure if we want to expose complexity of CR condition to the SDK user. Thus, I suggest that we just expose single TrainJob status to make it easier to read. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good, that makes sense. We can always add a "lower" level method later if needed. |
||||||||||||||||||
| self, | ||||||||||||||||||
| name: str, | ||||||||||||||||||
| status: Set[str] = {constants.TRAINJOB_COMPLETE}, | ||||||||||||||||||
| timeout: int = 600, | ||||||||||||||||||
| ) -> types.TrainJob: | ||||||||||||||||||
| """Wait for TrainJob to reach the desired status | ||||||||||||||||||
| Args: | ||||||||||||||||||
| name: Name of the TrainJob. | ||||||||||||||||||
| status: Set of expected statuses. It must be subset of Created, Running, Complete, and | ||||||||||||||||||
| Failed statuses. | ||||||||||||||||||
| timeout: How many seconds to wait until TrainJob reaches one of the expected conditions. | ||||||||||||||||||
| Returns: | ||||||||||||||||||
| TrainJob: The training job that reaches the desired status. | ||||||||||||||||||
| Raises: | ||||||||||||||||||
| ValueError: The input values are incorrect. | ||||||||||||||||||
| RuntimeError: Failed to get TrainJob or TrainJob reaches unexpected Failed status. | ||||||||||||||||||
| TimeoutError: Timeout to wait for TrainJob status. | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| job_statuses = { | ||||||||||||||||||
| constants.TRAINJOB_CREATED, | ||||||||||||||||||
| constants.TRAINJOB_RUNNING, | ||||||||||||||||||
| constants.TRAINJOB_COMPLETE, | ||||||||||||||||||
| constants.TRAINJOB_FAILED, | ||||||||||||||||||
| } | ||||||||||||||||||
| if not status.issubset(job_statuses): | ||||||||||||||||||
| raise ValueError( | ||||||||||||||||||
| f"Expected status {status} must be a subset of {job_statuses}" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Use Kubernetes watch API to monitor the TrainJob's Pods. | ||||||||||||||||||
| w = watch.Watch() | ||||||||||||||||||
| try: | ||||||||||||||||||
| for event in w.stream( | ||||||||||||||||||
| self.core_api.list_namespaced_pod, | ||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @astefanutti I had to watch for Pod events since we don't push all events to TrainJob, but I think that would be fine for now. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @andreyvelich I'm not sure I understand why it's needed to watch for pods since the logic then gets the TrainJob with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @astefanutti The trick is that we don't expose events of running pods to the TrainJob. sdk/python/kubeflow/trainer/api/trainer_client.py Lines 661 to 668 in 81e43fd
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @andreyvelich Ah I see. I missed this. That makes sense then. |
||||||||||||||||||
| self.namespace, | ||||||||||||||||||
| label_selector=constants.POD_LABEL_SELECTOR.format(trainjob_name=name), | ||||||||||||||||||
| timeout_seconds=timeout, | ||||||||||||||||||
| ): | ||||||||||||||||||
| # Check the status after event is generated for the TrainJob's Pods. | ||||||||||||||||||
| trainjob = self.get_job(name) | ||||||||||||||||||
| logger.debug(f"TrainJob {name}, status {trainjob.status}") | ||||||||||||||||||
|
|
||||||||||||||||||
| # Raise an error if TrainJob is Failed and it is not the expected status. | ||||||||||||||||||
| if ( | ||||||||||||||||||
| constants.TRAINJOB_FAILED not in status | ||||||||||||||||||
| and trainjob.status == constants.TRAINJOB_FAILED | ||||||||||||||||||
| ): | ||||||||||||||||||
| raise RuntimeError(f"TrainJob {name} is Failed") | ||||||||||||||||||
|
|
||||||||||||||||||
| # Return the TrainJob if it reaches the expected status. | ||||||||||||||||||
| if trainjob.status in status: | ||||||||||||||||||
| return trainjob | ||||||||||||||||||
|
|
||||||||||||||||||
| except TimeoutError: | ||||||||||||||||||
| raise TimeoutError(f"Timeout to get the TrainJob {name}") | ||||||||||||||||||
| except Exception: | ||||||||||||||||||
| raise RuntimeError(f"Failed to watch Pods for TrainJob {name}") | ||||||||||||||||||
| finally: | ||||||||||||||||||
| w.stop() | ||||||||||||||||||
|
|
||||||||||||||||||
| raise TimeoutError( | ||||||||||||||||||
| f"Timeout waiting for TrainJob {name} to reach status: {status} status" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| def delete_job(self, name: str): | ||||||||||||||||||
| """Delete the TrainJob. | ||||||||||||||||||
|
|
@@ -485,7 +554,6 @@ def __get_runtime_from_crd( | |||||||||||||||||
| trainer=utils.get_runtime_trainer( | ||||||||||||||||||
| runtime_crd.spec.template.spec.replicated_jobs, | ||||||||||||||||||
| runtime_crd.spec.ml_policy, | ||||||||||||||||||
| runtime_crd.metadata, | ||||||||||||||||||
| ), | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -506,43 +574,28 @@ def __get_trainjob_from_crd( | |||||||||||||||||
| name = trainjob_crd.metadata.name | ||||||||||||||||||
| namespace = trainjob_crd.metadata.namespace | ||||||||||||||||||
|
|
||||||||||||||||||
| runtime = self.get_runtime(trainjob_crd.spec.runtime_ref.name) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Construct the TrainJob from the CRD. | ||||||||||||||||||
| trainjob = types.TrainJob( | ||||||||||||||||||
| name=name, | ||||||||||||||||||
| creation_timestamp=trainjob_crd.metadata.creation_timestamp, | ||||||||||||||||||
| runtime=self.get_runtime(trainjob_crd.spec.runtime_ref.name), | ||||||||||||||||||
| runtime=runtime, | ||||||||||||||||||
| steps=[], | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Add the TrainJob status. | ||||||||||||||||||
| # TODO (andreyvelich): Discuss how we should show TrainJob status to SDK users. | ||||||||||||||||||
| # The TrainJob exists at that stage so its status can safely default to Created | ||||||||||||||||||
| trainjob.status = constants.TRAINJOB_CREATED | ||||||||||||||||||
| # Then it can be read from the TrainJob conditions if any | ||||||||||||||||||
| if trainjob_crd.status and trainjob_crd.status.conditions: | ||||||||||||||||||
| for c in trainjob_crd.status.conditions: | ||||||||||||||||||
| if c.type == "Complete" and c.status == "True": | ||||||||||||||||||
| trainjob.status = "Succeeded" | ||||||||||||||||||
| elif c.type == "Failed" and c.status == "True": | ||||||||||||||||||
| trainjob.status = "Failed" | ||||||||||||||||||
|
|
||||||||||||||||||
| # Select Pods created by the appropriate JobSet. It checks the following ReplicatedJob.name: | ||||||||||||||||||
| # dataset-initializer, model-initializer, launcher, node. | ||||||||||||||||||
| label_selector = "{}={},{} in ({}, {}, {}, {})".format( | ||||||||||||||||||
| constants.JOBSET_NAME_LABEL, | ||||||||||||||||||
| name, | ||||||||||||||||||
| constants.JOBSET_RJOB_NAME_LABEL, | ||||||||||||||||||
| constants.DATASET_INITIALIZER, | ||||||||||||||||||
| constants.MODEL_INITIALIZER, | ||||||||||||||||||
| constants.LAUNCHER, | ||||||||||||||||||
| constants.NODE, | ||||||||||||||||||
| # Number of nodes is taken from TrainJob or TrainingRuntime | ||||||||||||||||||
| num_nodes=( | ||||||||||||||||||
| trainjob_crd.spec.trainer.num_nodes | ||||||||||||||||||
| if trainjob_crd.spec.trainer and trainjob_crd.spec.trainer.num_nodes | ||||||||||||||||||
| else runtime.trainer.num_nodes | ||||||||||||||||||
| ), | ||||||||||||||||||
| status=constants.TRAINJOB_CREATED, # The default TrainJob status. | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Add the TrainJob components, e.g. trainer nodes and initializer. | ||||||||||||||||||
| try: | ||||||||||||||||||
| response = self.core_api.list_namespaced_pod( | ||||||||||||||||||
| namespace, | ||||||||||||||||||
| label_selector=label_selector, | ||||||||||||||||||
| label_selector=constants.POD_LABEL_SELECTOR.format(trainjob_name=name), | ||||||||||||||||||
| async_req=True, | ||||||||||||||||||
| ).get(constants.DEFAULT_TIMEOUT) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -567,26 +620,28 @@ def __get_trainjob_from_crd( | |||||||||||||||||
| constants.DATASET_INITIALIZER, | ||||||||||||||||||
| constants.MODEL_INITIALIZER, | ||||||||||||||||||
| }: | ||||||||||||||||||
| step = utils.get_trainjob_initializer_step( | ||||||||||||||||||
| pod.metadata.name, | ||||||||||||||||||
| pod.spec, | ||||||||||||||||||
| pod.status, | ||||||||||||||||||
| trainjob.steps.append( | ||||||||||||||||||
| utils.get_trainjob_initializer_step( | ||||||||||||||||||
| pod.metadata.name, | ||||||||||||||||||
| pod.spec, | ||||||||||||||||||
| pod.status, | ||||||||||||||||||
| ) | ||||||||||||||||||
| ) | ||||||||||||||||||
| # Get the Node step. | ||||||||||||||||||
| elif pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL] in { | ||||||||||||||||||
| constants.LAUNCHER, | ||||||||||||||||||
| constants.NODE, | ||||||||||||||||||
| }: | ||||||||||||||||||
| step = utils.get_trainjob_node_step( | ||||||||||||||||||
| pod.metadata.name, | ||||||||||||||||||
| pod.spec, | ||||||||||||||||||
| pod.status, | ||||||||||||||||||
| trainjob.runtime, | ||||||||||||||||||
| pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL], | ||||||||||||||||||
| int(pod.metadata.labels[constants.JOB_INDEX_LABEL]), | ||||||||||||||||||
| trainjob.steps.append( | ||||||||||||||||||
| utils.get_trainjob_node_step( | ||||||||||||||||||
| pod.metadata.name, | ||||||||||||||||||
| pod.spec, | ||||||||||||||||||
| pod.status, | ||||||||||||||||||
| trainjob.runtime, | ||||||||||||||||||
| pod.metadata.labels[constants.JOBSET_RJOB_NAME_LABEL], | ||||||||||||||||||
| int(pod.metadata.labels[constants.JOB_INDEX_LABEL]), | ||||||||||||||||||
| ) | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| trainjob.steps.append(step) | ||||||||||||||||||
| except multiprocessing.TimeoutError: | ||||||||||||||||||
| raise TimeoutError( | ||||||||||||||||||
| f"Timeout to list {constants.TRAINJOB_KIND}'s steps: {namespace}/{name}" | ||||||||||||||||||
|
|
@@ -596,4 +651,23 @@ def __get_trainjob_from_crd( | |||||||||||||||||
| f"Failed to list {constants.TRAINJOB_KIND}'s steps: {namespace}/{name}" | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # Update the TrainJob status from its conditions. | ||||||||||||||||||
| if trainjob_crd.status and trainjob_crd.status.conditions: | ||||||||||||||||||
| for c in trainjob_crd.status.conditions: | ||||||||||||||||||
| if c.type == constants.TRAINJOB_COMPLETE and c.status == "True": | ||||||||||||||||||
| trainjob.status = c.type | ||||||||||||||||||
| elif c.type == constants.TRAINJOB_FAILED and c.status == "True": | ||||||||||||||||||
| trainjob.status = c.type | ||||||||||||||||||
| else: | ||||||||||||||||||
| # The TrainJob running status is defined when all training node (e.g. Pods) are running. | ||||||||||||||||||
| num_running_nodes = sum( | ||||||||||||||||||
| 1 | ||||||||||||||||||
| for step in trainjob.steps | ||||||||||||||||||
| if step.name.startswith(constants.NODE) | ||||||||||||||||||
| and step.status == constants.TRAINJOB_RUNNING | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| if trainjob.num_nodes == num_running_nodes: | ||||||||||||||||||
| trainjob.status = constants.TRAINJOB_RUNNING | ||||||||||||||||||
|
|
||||||||||||||||||
| return trainjob | ||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shall we also accept
namespaceparameter?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, since the namespace is controlled by TrainerClient: https://github.com/kubeflow/sdk/blob/main/python/kubeflow/trainer/api/trainer_client.py#L38
We don't allow APIs (e.g.
get_job()) to override it.Ideally, we should abstract the namespace context from the SDK user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see, makes sense now!