Skip to content

Commit 0b18dfe

Browse files
authored
Support setting "defined tags" in data science jobs. (#167)
2 parents a34675e + 7d6f475 commit 0b18dfe

File tree

7 files changed

+177
-28
lines changed

7 files changed

+177
-28
lines changed

ads/jobs/ads_job.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class Job(Builder):
7474
.with_python_path("other_packages")
7575
# Copy files in "code_dir/output" to object storage after job finishes.
7676
.with_output("output", "oci://bucket_name@namespace/path/to/dir")
77+
# Tags
78+
.with_freeform_tag(my_tag="my_value")
79+
.with_defined_tag(**{"Operations": {"CostCenter": "42"}})
7780
)
7881
)
7982
# Create and Run the job

ads/jobs/builders/infrastructure/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def run(
6161
args: str = None,
6262
env_var: dict = None,
6363
freeform_tags: dict = None,
64+
defined_tags: dict = None,
6465
wait: bool = False,
6566
):
6667
"""Runs a job on the infrastructure.
@@ -72,9 +73,11 @@ def run(
7273
args : str, optional
7374
Command line arguments for the job run, by default None.
7475
env_var : dict, optional
75-
Environment variable for the job run, by default None
76+
Environment variable for the job run, by default None.
7677
freeform_tags : dict, optional
77-
Freeform tags for the job run, by default None
78+
Freeform tags for the job run, by default None.
79+
defined_tags : dict, optional
80+
Defined tags for the job run, by default None.
7881
wait : bool, optional
7982
Indicate if this method should wait for the run to finish before it returns, by default False.
8083
"""

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@
3535
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
3636
from ads.jobs.builders.runtimes.python_runtime import GitPythonRuntime
3737

38-
from ads.common.dsc_file_system import (
39-
OCIFileStorage,
40-
DSCFileSystemManager
41-
)
38+
from ads.common.dsc_file_system import OCIFileStorage, DSCFileSystemManager
4239

4340
logger = logging.getLogger(__name__)
4441

@@ -445,6 +442,8 @@ def run(self, **kwargs) -> DataScienceJobRun:
445442
* command_line_arguments: str
446443
* maximum_runtime_in_minutes: int
447444
* display_name: str
445+
* freeform_tags: dict(str, str)
446+
* defined_tags: dict(str, dict(str, object))
448447
449448
If display_name is not specified, it will be generated as "<JOB_NAME>-run-<TIMESTAMP>".
450449
@@ -845,9 +844,12 @@ class DataScienceJob(Infrastructure):
845844
.with_storage_mount(
846845
{
847846
"src" : "<mount_target_ip_address>:<export_path>",
848-
"dest" : "<destination_directory_name>"
847+
"dest" : "<destination_directory_name>"
849848
}
850849
)
850+
# Tags
851+
.with_freeform_tag(my_tag="my_value")
852+
.with_defined_tag(**{"Operations": {"CostCenter": "42"}})
851853
)
852854
853855
"""
@@ -866,6 +868,8 @@ class DataScienceJob(Infrastructure):
866868
CONST_LOG_ID = "logId"
867869
CONST_LOG_GROUP_ID = "logGroupId"
868870
CONST_STORAGE_MOUNT = "storageMount"
871+
CONST_FREEFORM_TAGS = "freeformTags"
872+
CONST_DEFINED_TAGS = "definedTags"
869873

870874
attribute_map = {
871875
CONST_PROJECT_ID: "project_id",
@@ -880,6 +884,8 @@ class DataScienceJob(Infrastructure):
880884
CONST_LOG_ID: "log_id",
881885
CONST_LOG_GROUP_ID: "log_group_id",
882886
CONST_STORAGE_MOUNT: "storage_mount",
887+
CONST_FREEFORM_TAGS: "freeform_tags",
888+
CONST_DEFINED_TAGS: "defined_tags",
883889
}
884890

885891
shape_config_details_attribute_map = {
@@ -1231,9 +1237,7 @@ def log_group_id(self) -> str:
12311237
"""
12321238
return self.get_spec(self.CONST_LOG_GROUP_ID)
12331239

1234-
def with_storage_mount(
1235-
self, *storage_mount: List[dict]
1236-
) -> DataScienceJob:
1240+
def with_storage_mount(self, *storage_mount: List[dict]) -> DataScienceJob:
12371241
"""Sets the file systems to be mounted for the data science job.
12381242
A maximum number of 5 file systems are allowed to be mounted for a single data science job.
12391243
@@ -1271,6 +1275,36 @@ def storage_mount(self) -> List[dict]:
12711275
"""
12721276
return self.get_spec(self.CONST_STORAGE_MOUNT, [])
12731277

1278+
def with_freeform_tag(self, **kwargs) -> DataScienceJob:
1279+
"""Sets freeform tags
1280+
1281+
Returns
1282+
-------
1283+
DataScienceJob
1284+
The DataScienceJob instance (self)
1285+
"""
1286+
return self.set_spec(self.CONST_FREEFORM_TAGS, kwargs)
1287+
1288+
def with_defined_tag(self, **kwargs) -> DataScienceJob:
1289+
"""Sets defined tags
1290+
1291+
Returns
1292+
-------
1293+
DataScienceJob
1294+
The DataScienceJob instance (self)
1295+
"""
1296+
return self.set_spec(self.CONST_DEFINED_TAGS, kwargs)
1297+
1298+
@property
1299+
def freeform_tags(self) -> dict:
1300+
"""Freeform tags"""
1301+
return self.get_spec(self.CONST_FREEFORM_TAGS, {})
1302+
1303+
@property
1304+
def defined_tags(self) -> dict:
1305+
"""Defined tags"""
1306+
return self.get_spec(self.CONST_DEFINED_TAGS, {})
1307+
12741308
def _prepare_log_config(self) -> dict:
12751309
if not self.log_group_id and not self.log_id:
12761310
return None
@@ -1425,7 +1459,8 @@ def _update_job_infra(self, dsc_job: DSCJob) -> DataScienceJob:
14251459
"Storage mount hasn't been supported in the current OCI SDK installed."
14261460
)
14271461
dsc_job.job_storage_mount_configuration_details_list = [
1428-
DSCFileSystemManager.initialize(file_system) for file_system in self.storage_mount
1462+
DSCFileSystemManager.initialize(file_system)
1463+
for file_system in self.storage_mount
14291464
]
14301465
return self
14311466

@@ -1467,6 +1502,10 @@ def create(self, runtime, **kwargs) -> DataScienceJob:
14671502

14681503
payload["display_name"] = display_name
14691504
payload["job_log_configuration_details"] = self._prepare_log_config()
1505+
if not payload.get("freeform_tags"):
1506+
payload["freeform_tags"] = self.freeform_tags
1507+
if not payload.get("defined_tags"):
1508+
payload["defined_tags"] = self.defined_tags
14701509

14711510
self.dsc_job = DSCJob(**payload)
14721511
# Set Job infra to user values after DSCJob initialized the defaults
@@ -1477,7 +1516,13 @@ def create(self, runtime, **kwargs) -> DataScienceJob:
14771516
return self
14781517

14791518
def run(
1480-
self, name=None, args=None, env_var=None, freeform_tags=None, wait=False
1519+
self,
1520+
name=None,
1521+
args=None,
1522+
env_var=None,
1523+
freeform_tags=None,
1524+
defined_tags=None,
1525+
wait=False,
14811526
) -> DataScienceJobRun:
14821527
"""Runs a job on OCI Data Science job
14831528
@@ -1491,6 +1536,8 @@ def run(
14911536
Environment variable for the job run, by default None
14921537
freeform_tags : dict, optional
14931538
Freeform tags for the job run, by default None
1539+
defined_tags : dict, optional
1540+
Defined tags for the job run, by default None
14941541
wait : bool, optional
14951542
Indicate if this method should wait for the run to finish before it returns, by default False.
14961543
@@ -1505,11 +1552,18 @@ def run(
15051552
raise RuntimeError(
15061553
"Job is not created. Call create() to create the job first."
15071554
)
1508-
tags = self.runtime.freeform_tags
1509-
if not tags:
1510-
tags = {}
1511-
if freeform_tags:
1512-
tags.update(freeform_tags)
1555+
1556+
if not freeform_tags:
1557+
freeform_tags = {}
1558+
runtime_freeform_tags = self.runtime.freeform_tags
1559+
if runtime_freeform_tags:
1560+
freeform_tags.update(runtime_freeform_tags)
1561+
1562+
if not defined_tags:
1563+
defined_tags = {}
1564+
runtime_defined_tags = self.runtime.defined_tags
1565+
if runtime_defined_tags:
1566+
defined_tags.update(runtime_defined_tags)
15131567

15141568
if name:
15151569
envs = self.runtime.envs
@@ -1521,7 +1575,8 @@ def run(
15211575
display_name=name,
15221576
command_line_arguments=args,
15231577
environment_variables=env_var,
1524-
freeform_tags=tags,
1578+
freeform_tags=freeform_tags,
1579+
defined_tags=defined_tags,
15251580
wait=wait,
15261581
)
15271582

ads/jobs/builders/infrastructure/dsc_job_runtime.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from ads.jobs.builders.infrastructure.utils import get_value
3939

4040

41-
class IncompatibleRuntime(Exception): # pragma: no cover
41+
class IncompatibleRuntime(Exception): # pragma: no cover
4242
"""Represents an exception when runtime is not compatible with the OCI data science job configuration.
4343
This exception is designed to be raised during the extraction of a runtime from OCI data science job.
4444
The data science job does not explicitly contain information of the type of the ADS runtime.
@@ -104,6 +104,8 @@ def translate(self, runtime: Runtime) -> dict:
104104
payload["job_configuration_details"] = self._translate_config(runtime)
105105
if runtime.freeform_tags:
106106
payload["freeform_tags"] = runtime.freeform_tags
107+
if runtime.defined_tags:
108+
payload["defined_tags"] = runtime.defined_tags
107109
self.data_science_job.runtime = runtime
108110
return payload
109111

@@ -353,10 +355,14 @@ def _extract_tags(self, dsc_job):
353355
dict
354356
A runtime specification dictionary for initializing a runtime.
355357
"""
358+
tags = {}
356359
value = get_value(dsc_job, "freeform_tags")
357360
if value:
358-
return {Runtime.CONST_TAG: value}
359-
return {}
361+
tags[Runtime.CONST_FREEFORM_TAGS] = value
362+
value = get_value(dsc_job, "defined_tags")
363+
if value:
364+
tags[Runtime.CONST_DEFINED_TAGS] = value
365+
return tags
360366

361367
def _extract_artifact(self, dsc_job):
362368
"""Extract the job artifact from data science job.

ads/jobs/builders/runtimes/base.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ class Runtime(Builder):
2424
CONST_ENV_VAR = "env"
2525
CONST_ARGS = "args"
2626
CONST_MAXIMUM_RUNTIME_IN_MINUTES = "maximumRuntimeInMinutes"
27-
CONST_TAG = "freeformTags"
27+
CONST_FREEFORM_TAGS = "freeformTags"
28+
CONST_DEFINED_TAGS = "definedTags"
2829

2930
attribute_map = {
30-
CONST_TAG: "freeform_tags",
31+
CONST_FREEFORM_TAGS: "freeform_tags",
32+
CONST_DEFINED_TAGS: "defined_tags",
3133
CONST_ENV_VAR: CONST_ENV_VAR,
3234
}
3335

@@ -164,14 +166,24 @@ def with_environment_variable(self: Self, **kwargs) -> Self:
164166
return self.set_spec(self.CONST_ENV_VAR, envs)
165167

166168
def with_freeform_tag(self: Self, **kwargs) -> Self:
167-
"""Sets freeform tag
169+
"""Sets freeform tags
168170
169171
Returns
170172
-------
171173
Self
172174
This method returns self to support chaining methods.
173175
"""
174-
return self.set_spec(self.CONST_TAG, kwargs)
176+
return self.set_spec(self.CONST_FREEFORM_TAGS, kwargs)
177+
178+
def with_defined_tag(self: Self, **kwargs) -> Self:
179+
"""Sets defined tags
180+
181+
Returns
182+
-------
183+
Self
184+
This method returns self to support chaining methods.
185+
"""
186+
return self.set_spec(self.CONST_DEFINED_TAGS, kwargs)
175187

176188
def with_maximum_runtime_in_minutes(
177189
self: Self, maximum_runtime_in_minutes: int
@@ -209,8 +221,13 @@ def envs(self) -> dict:
209221

210222
@property
211223
def freeform_tags(self) -> dict:
212-
"""freeform_tags"""
213-
return self.get_spec(self.CONST_TAG, {})
224+
"""Freeform tags"""
225+
return self.get_spec(self.CONST_FREEFORM_TAGS, {})
226+
227+
@property
228+
def defined_tags(self) -> dict:
229+
"""Defined tags"""
230+
return self.get_spec(self.CONST_DEFINED_TAGS, {})
214231

215232
@property
216233
def args(self) -> list:

tests/unitary/default_setup/jobs/test_jobs_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import unittest
1111
from unittest import mock
1212
from zipfile import ZipFile
13-
from unittest.mock import PropertyMock, patch
13+
from unittest.mock import patch
1414
import pytest
1515

1616

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from ads.jobs import Job, DataScienceJob, ContainerRuntime
2+
from tests.unitary.default_setup.jobs.test_jobs_base import DataScienceJobPayloadTest
3+
4+
5+
class JobTagTestCase(DataScienceJobPayloadTest):
6+
@staticmethod
7+
def runtime() -> ContainerRuntime:
8+
return ContainerRuntime().with_image(
9+
"iad.ocir.io/my_namespace/my_ubuntu_image",
10+
entrypoint="/bin/sh",
11+
cmd="-c,echo Hello World",
12+
)
13+
14+
@staticmethod
15+
def infra() -> DataScienceJob:
16+
return (
17+
DataScienceJob()
18+
.with_compartment_id("ocid1.compartment.oc1..<unique_ocid>")
19+
.with_project_id("ocid1.datascienceproject.oc1.iad.<unique_ocid>")
20+
)
21+
22+
def create_job(self, infra, runtime) -> dict:
23+
job = (
24+
Job(name=self.__class__.__name__)
25+
.with_infrastructure(infra)
26+
.with_runtime(runtime)
27+
)
28+
job = self.mock_create_job(job)
29+
return job.infrastructure.dsc_job.to_dict()
30+
31+
def test_create_job_with_runtime_tags(self):
32+
runtime = (
33+
self.runtime()
34+
.with_freeform_tag(freeform_tag="freeform_tag_val")
35+
.with_defined_tag(Operations={"CostCenter": "42"})
36+
)
37+
payload = self.create_job(self.infra(), runtime)
38+
self.assertEqual(payload["freeformTags"], dict(freeform_tag="freeform_tag_val"))
39+
self.assertEqual(payload["definedTags"], {"Operations": {"CostCenter": "42"}})
40+
41+
def test_create_job_with_infra_tags(self):
42+
infra = (
43+
self.infra()
44+
.with_freeform_tag(freeform_tag="freeform_tag_val")
45+
.with_defined_tag(Operations={"CostCenter": "42"})
46+
)
47+
payload = self.create_job(infra, self.runtime())
48+
self.assertEqual(payload["freeformTags"], dict(freeform_tag="freeform_tag_val"))
49+
self.assertEqual(payload["definedTags"], {"Operations": {"CostCenter": "42"}})
50+
51+
def test_create_job_with_infra_and_runtime_tags(self):
52+
# Tags defined in runtime will have higher priority
53+
infra = (
54+
self.infra()
55+
.with_freeform_tag(freeform_tag="freeform_tag_val")
56+
.with_defined_tag(Operations={"CostCenter": "41"})
57+
)
58+
runtime = (
59+
self.runtime()
60+
.with_freeform_tag(freeform_tag="freeform_tag_val")
61+
.with_defined_tag(Operations={"CostCenter": "42"})
62+
)
63+
payload = self.create_job(infra, runtime)
64+
self.assertEqual(payload["freeformTags"], dict(freeform_tag="freeform_tag_val"))
65+
self.assertEqual(payload["definedTags"], {"Operations": {"CostCenter": "42"}})

0 commit comments

Comments
 (0)