Skip to content

Commit b835edb

Browse files
authored
Added support to cancel all job runs (#346)
2 parents fa6f315 + 435741b commit b835edb

File tree

6 files changed

+99
-15
lines changed

6 files changed

+99
-15
lines changed

ads/jobs/ads_job.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33

44
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
import time
67
from typing import List, Union, Dict
78
from urllib.parse import urlparse
89

910
import fsspec
11+
import oci
1012
from ads.common.auth import default_signer
1113
from ads.jobs.builders.base import Builder
1214
from ads.jobs.builders.infrastructure.dataflow import DataFlow, DataFlowRun
13-
from ads.jobs.builders.infrastructure.dsc_job import DataScienceJob, DataScienceJobRun
15+
from ads.jobs.builders.infrastructure.dsc_job import (
16+
DataScienceJob,
17+
DataScienceJobRun,
18+
SLEEP_INTERVAL
19+
)
1420
from ads.jobs.builders.runtimes.pytorch_runtime import PyTorchDistributedRuntime
1521
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
1622
from ads.jobs.builders.runtimes.python_runtime import (
@@ -460,7 +466,29 @@ def run_list(self, **kwargs) -> list:
460466
A list of job run instances, the actual object type depends on the infrastructure.
461467
"""
462468
return self.infrastructure.run_list(**kwargs)
469+
470+
def cancel(self, wait_for_completion: bool = True) -> None:
471+
"""Cancels the runs of the job.
463472
473+
Parameters
474+
----------
475+
wait_for_completion: bool
476+
Whether to wait for run to be cancelled before proceeding.
477+
Defaults to True.
478+
"""
479+
runs = self.run_list()
480+
for run in runs:
481+
run.cancel(wait_for_completion=False)
482+
483+
if wait_for_completion:
484+
for run in runs:
485+
while (
486+
run.lifecycle_state !=
487+
oci.data_science.models.JobRun.LIFECYCLE_STATE_CANCELED
488+
):
489+
run.sync()
490+
time.sleep(SLEEP_INTERVAL)
491+
464492
def delete(self) -> None:
465493
"""Deletes the job from the infrastructure."""
466494
self.infrastructure.delete()

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -725,19 +725,28 @@ def stop_condition():
725725

726726
return self
727727

728-
def cancel(self) -> DataScienceJobRun:
728+
def cancel(self, wait_for_completion: bool = True) -> DataScienceJobRun:
729729
"""Cancels a job run
730-
This method will wait for the job run to be canceled before returning.
730+
731+
Parameters
732+
----------
733+
wait_for_completion: bool
734+
Whether to wait for job run to be cancelled before proceeding.
735+
Defaults to True.
731736
732737
Returns
733738
-------
734739
self
735740
The job run instance.
736741
"""
737742
self.client.cancel_job_run(self.id)
738-
while self.lifecycle_state != "CANCELED":
739-
self.sync()
740-
time.sleep(SLEEP_INTERVAL)
743+
if wait_for_completion:
744+
while (
745+
self.lifecycle_state !=
746+
oci.data_science.models.JobRun.LIFECYCLE_STATE_CANCELED
747+
):
748+
self.sync()
749+
time.sleep(SLEEP_INTERVAL)
741750
return self
742751

743752
def __repr__(self) -> str:

ads/opctl/backend/ads_ml_job.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,22 @@ def cancel(self):
217217
"""
218218
Cancel Job Run from OCID.
219219
"""
220-
run_id = self.config["execution"]["run_id"]
221220
with AuthContext(auth=self.auth_type, profile=self.profile):
222-
DataScienceJobRun.from_ocid(run_id).cancel()
223-
print(f"Job run {run_id} has been cancelled.")
221+
wait_for_completion = self.config["execution"].get("wait_for_completion")
222+
if self.config["execution"].get("id"):
223+
id = self.config["execution"]["id"]
224+
Job.from_datascience_job(id).cancel(
225+
wait_for_completion=wait_for_completion
226+
)
227+
if wait_for_completion:
228+
print(f"All job runs under {id} have been cancelled.")
229+
elif self.config["execution"].get("run_id"):
230+
run_id = self.config["execution"]["run_id"]
231+
DataScienceJobRun.from_ocid(run_id).cancel(
232+
wait_for_completion=wait_for_completion
233+
)
234+
if wait_for_completion:
235+
print(f"Job run {run_id} has been cancelled.")
224236

225237
def watch(self):
226238
"""

ads/opctl/cmds.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def delete(**kwargs) -> None:
374374
):
375375
kwargs["id"] = kwargs.pop("ocid")
376376
else:
377-
raise ValueError(f"{kwargs['ocid']} is valid or supported.")
377+
raise ValueError(f"{kwargs['ocid']} is invalid or not supported.")
378378

379379
p = ConfigProcessor().step(ConfigMerger, **kwargs)
380380
return _BackendFactory(p.config).backend.delete()
@@ -388,13 +388,24 @@ def cancel(**kwargs) -> None:
388388
----------
389389
kwargs: dict
390390
keyword argument, stores command line args
391+
391392
Returns
392393
-------
393394
None
394395
"""
395-
kwargs["run_id"] = kwargs.pop("ocid")
396-
if not kwargs.get("backend"):
397-
kwargs["backend"] = _get_backend_from_run_id(kwargs["run_id"])
396+
kwargs["backend"] = _get_backend_from_ocid(kwargs["ocid"])
397+
if (
398+
DataScienceResourceRun.JOB_RUN in kwargs["ocid"]
399+
or DataScienceResourceRun.DATAFLOW_RUN in kwargs["ocid"]
400+
or DataScienceResourceRun.PIPELINE_RUN in kwargs["ocid"]
401+
):
402+
kwargs["run_id"] = kwargs.pop("ocid")
403+
elif (
404+
DataScienceResource.JOB in kwargs["ocid"]
405+
):
406+
kwargs["id"] = kwargs.pop("ocid")
407+
else:
408+
raise ValueError(f"{kwargs['ocid']} is invalid or not supported.")
398409
p = ConfigProcessor().step(ConfigMerger, **kwargs)
399410
return _BackendFactory(p.config).backend.cancel()
400411

tests/unitary/default_setup/jobs/test_jobs_base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ScriptRuntime,
2121
NotebookRuntime,
2222
)
23+
from ads.jobs.builders.infrastructure.dsc_job import DataScienceJobRun
2324
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
2425
CondaRuntimeHandler,
2526
ScriptRuntimeHandler,
@@ -603,3 +604,25 @@ def test_run_details_link_fail(self, mock_extract_region):
603604
test_run_instance = RunInstance()
604605
test_result = test_run_instance.run_details_link
605606
assert test_result == ""
607+
608+
609+
class DataScienceJobMethodTest(DataScienceJobPayloadTest):
610+
611+
@patch("ads.jobs.builders.infrastructure.dsc_job.DataScienceJobRun.cancel")
612+
@patch("ads.jobs.ads_job.Job.run_list")
613+
def test_job_cancel(self, mock_run_list, mock_cancel):
614+
mock_run_list.return_value = [
615+
DataScienceJobRun(
616+
lifecycle_state="CANCELED"
617+
)
618+
] * 3
619+
620+
job = (
621+
Job(name="test")
622+
.with_infrastructure(infrastructure.DataScienceJob())
623+
.with_runtime(ScriptRuntime().with_script(self.SCRIPT_URI))
624+
)
625+
626+
job.cancel()
627+
mock_run_list.assert_called()
628+
mock_cancel.assert_called_with(wait_for_completion=False)

tests/unitary/with_extras/opctl/test_opctl_cmds.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,9 @@ def test_cancel(self, job_cancel_func, pipeline_cancel_func, monkeypatch):
181181
monkeypatch.delenv("NB_SESSION_OCID", raising=False)
182182
cancel(ocid="...datasciencejobrun...")
183183
job_cancel_func.assert_called()
184-
with pytest.raises(ValueError):
185-
cancel(ocid="....datasciencejob....")
184+
185+
cancel(ocid="....datasciencejob....")
186+
job_cancel_func.assert_called()
186187

187188
cancel(ocid="...datasciencepipelinerun...")
188189
pipeline_cancel_func.assert_called()

0 commit comments

Comments
 (0)