Skip to content

Commit 5049b0e

Browse files
committed
Update job.run() to support defined tags.
1 parent 6a5c14d commit 5049b0e

File tree

3 files changed

+44
-21
lines changed

3 files changed

+44
-21
lines changed

ads/jobs/builders/infrastructure/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def run(
6161
args: str = None,
6262
env_var: dict = None,
6363
freeform_tags: dict = None,
64+
defined_tags: dict = None,
6465
wait: bool = False,
6566
):
6667
"""Runs a job on the infrastructure.
@@ -72,9 +73,11 @@ def run(
7273
args : str, optional
7374
Command line arguments for the job run, by default None.
7475
env_var : dict, optional
75-
Environment variable for the job run, by default None
76+
Environment variable for the job run, by default None.
7677
freeform_tags : dict, optional
77-
Freeform tags for the job run, by default None
78+
Freeform tags for the job run, by default None.
79+
defined_tags : dict, optional
80+
Defined tags for the job run, by default None.
7881
wait : bool, optional
7982
Indicate if this method should wait for the run to finish before it returns, by default False.
8083
"""

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@
3535
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
3636
from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime
3737

38-
from ads.common.dsc_file_system import (
39-
OCIFileStorage,
40-
DSCFileSystemManager
41-
)
38+
from ads.common.dsc_file_system import OCIFileStorage, DSCFileSystemManager
4239

4340
logger = logging.getLogger(__name__)
4441

@@ -445,6 +442,8 @@ def run(self, **kwargs) -> DataScienceJobRun:
445442
* command_line_arguments: str
446443
* maximum_runtime_in_minutes: int
447444
* display_name: str
445+
* freeform_tags: dict(str, str)
446+
* defined_tags: dict(str, object)
448447
449448
If display_name is not specified, it will be generated as "<JOB_NAME>-run-<TIMESTAMP>".
450449
@@ -845,7 +844,7 @@ class DataScienceJob(Infrastructure):
845844
.with_storage_mount(
846845
{
847846
"src" : "<mount_target_ip_address>:<export_path>",
848-
"dest" : "<destination_directory_name>"
847+
"dest" : "<destination_directory_name>"
849848
}
850849
)
851850
)
@@ -1231,9 +1230,7 @@ def log_group_id(self) -> str:
12311230
"""
12321231
return self.get_spec(self.CONST_LOG_GROUP_ID)
12331232

1234-
def with_storage_mount(
1235-
self, *storage_mount: List[dict]
1236-
) -> DataScienceJob:
1233+
def with_storage_mount(self, *storage_mount: List[dict]) -> DataScienceJob:
12371234
"""Sets the file systems to be mounted for the data science job.
12381235
A maximum number of 5 file systems are allowed to be mounted for a single data science job.
12391236
@@ -1425,7 +1422,8 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
14251422
"Storage mount hasn't been supported in the current OCI SDK installed."
14261423
)
14271424
dsc_job.job_storage_mount_configuration_details_list = [
1428-
DSCFileSystemManager.initialize(file_system) for file_system in self.storage_mount
1425+
DSCFileSystemManager.initialize(file_system)
1426+
for file_system in self.storage_mount
14291427
]
14301428
return self
14311429

@@ -1477,7 +1475,13 @@ def create(self, runtime, **kwargs) -> DataScienceJob:
14771475
return self
14781476

14791477
def run(
1480-
self, name=None, args=None, env_var=None, freeform_tags=None, wait=False
1478+
self,
1479+
name=None,
1480+
args=None,
1481+
env_var=None,
1482+
freeform_tags=None,
1483+
defined_tags=None,
1484+
wait=False,
14811485
) -> DataScienceJobRun:
14821486
"""Runs a job on OCI Data Science job
14831487
@@ -1491,6 +1495,8 @@ def run(
14911495
Environment variable for the job run, by default None
14921496
freeform_tags : dict, optional
14931497
Freeform tags for the job run, by default None
1498+
defined_tags : dict, optional
1499+
Defined tags for the job run, by default None
14941500
wait : bool, optional
14951501
Indicate if this method should wait for the run to finish before it returns, by default False.
14961502
@@ -1505,11 +1511,18 @@ def run(
15051511
raise RuntimeError(
15061512
"Job is not created. Call create() to create the job first."
15071513
)
1508-
tags = self.runtime.freeform_tags
1509-
if not tags:
1510-
tags = {}
1511-
if freeform_tags:
1512-
tags.update(freeform_tags)
1514+
1515+
if not freeform_tags:
1516+
freeform_tags = {}
1517+
runtime_freeform_tags = self.runtime.freeform_tags
1518+
if runtime_freeform_tags:
1519+
freeform_tags.update(runtime_freeform_tags)
1520+
1521+
if not defined_tags:
1522+
defined_tags = {}
1523+
runtime_defined_tags = self.runtime.defined_tags
1524+
if runtime_defined_tags:
1525+
defined_tags.update(runtime_defined_tags)
15131526

15141527
if name:
15151528
envs = self.runtime.envs
@@ -1521,7 +1534,8 @@ def run(
15211534
display_name=name,
15221535
command_line_arguments=args,
15231536
environment_variables=env_var,
1524-
freeform_tags=tags,
1537+
freeform_tags=freeform_tags,
1538+
defined_tags=defined_tags,
15251539
wait=wait,
15261540
)
15271541

ads/jobs/builders/infrastructure/dsc_job_runtime.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from ads.jobs.builders.infrastructure.utils import get_value
3939

4040

41-
class IncompatibleRuntime(Exception): # pragma: no cover
41+
class IncompatibleRuntime(Exception): # pragma: no cover
4242
"""Represents an exception when runtime is not compatible with the OCI data science job configuration.
4343
This exception is designed to be raised during the extraction of a runtime from OCI data science job.
4444
The data science job does not explicitly contain information of the type of the ADS runtime.
@@ -104,6 +104,8 @@ def translate(self, runtime: Runtime) -> dict:
104104
payload["job_configuration_details"] = self._translate_config(runtime)
105105
if runtime.freeform_tags:
106106
payload["freeform_tags"] = runtime.freeform_tags
107+
if runtime.defined_tags:
108+
payload["defined_tags"] = runtime.defined_tags
107109
self.data_science_job.runtime = runtime
108110
return payload
109111

@@ -353,10 +355,14 @@ def _extract_tags(self, dsc_job):
353355
dict
354356
A runtime specification dictionary for initializing a runtime.
355357
"""
358+
tags = {}
356359
value = get_value(dsc_job, "freeform_tags")
357360
if value:
358-
return {Runtime.CONST_TAG: value}
359-
return {}
361+
tags[Runtime.CONST_FREEFORM_TAGS] = value
362+
value = get_value(dsc_job, "defined_tags")
363+
if value:
364+
tags[Runtime.CONST_DEFINED_TAGS] = value
365+
return tags
360366

361367
def _extract_artifact(self, dsc_job):
362368
"""Extract the job artifact from data science job.

0 commit comments

Comments
 (0)