Skip to content

Commit 7db3a10

Browse files
committed
Update dsc_job.py to support defining tags on infrastructure.
1 parent a16d479 commit 7db3a10

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

ads/jobs/builders/infrastructure/dsc_job.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def run(self, **kwargs) -> DataScienceJobRun:
443443
* maximum_runtime_in_minutes: int
444444
* display_name: str
445445
* freeform_tags: dict(str, str)
446-
* defined_tags: dict(str, object)
446+
* defined_tags: dict(str, dict(str, object))
447447
448448
If display_name is not specified, it will be generated as "<JOB_NAME>-run-<TIMESTAMP>".
449449
@@ -865,6 +865,8 @@ class DataScienceJob(Infrastructure):
865865
CONST_LOG_ID = "logId"
866866
CONST_LOG_GROUP_ID = "logGroupId"
867867
CONST_STORAGE_MOUNT = "storageMount"
868+
CONST_FREEFORM_TAGS = "freeformTags"
869+
CONST_DEFINED_TAGS = "definedTags"
868870

869871
attribute_map = {
870872
CONST_PROJECT_ID: "project_id",
@@ -879,6 +881,8 @@ class DataScienceJob(Infrastructure):
879881
CONST_LOG_ID: "log_id",
880882
CONST_LOG_GROUP_ID: "log_group_id",
881883
CONST_STORAGE_MOUNT: "storage_mount",
884+
CONST_FREEFORM_TAGS: "freeform_tags",
885+
CONST_DEFINED_TAGS: "defined_tags",
882886
}
883887

884888
shape_config_details_attribute_map = {
@@ -1268,6 +1272,36 @@ def storage_mount(self) -> List[dict]:
12681272
"""
12691273
return self.get_spec(self.CONST_STORAGE_MOUNT, [])
12701274

1275+
def with_freeform_tag(self, **kwargs) -> DataScienceJob:
1276+
"""Sets freeform tags
1277+
1278+
Returns
1279+
-------
1280+
DataScienceJob
1281+
The DataScienceJob instance (self)
1282+
"""
1283+
return self.set_spec(self.CONST_FREEFORM_TAGS, kwargs)
1284+
1285+
def with_defined_tag(self, **kwargs) -> DataScienceJob:
1286+
"""Sets defined tags
1287+
1288+
Returns
1289+
-------
1290+
DataScienceJob
1291+
The DataScienceJob instance (self)
1292+
"""
1293+
return self.set_spec(self.CONST_DEFINED_TAGS, kwargs)
1294+
1295+
@property
1296+
def freeform_tags(self) -> dict:
1297+
"""Freeform tags"""
1298+
return self.get_spec(self.CONST_FREEFORM_TAGS, {})
1299+
1300+
@property
1301+
def defined_tags(self) -> dict:
1302+
"""Defined tags"""
1303+
return self.get_spec(self.CONST_DEFINED_TAGS, {})
1304+
12711305
def _prepare_log_config(self) -> dict:
12721306
if not self.log_group_id and not self.log_id:
12731307
return None
@@ -1465,6 +1499,10 @@ def create(self, runtime, **kwargs) -> DataScienceJob:
14651499

14661500
payload["display_name"] = display_name
14671501
payload["job_log_configuration_details"] = self._prepare_log_config()
1502+
if not payload.get("freeform_tags"):
1503+
payload["freeform_tags"] = self.freeform_tags
1504+
if not payload.get("defined_tags"):
1505+
payload["defined_tags"] = self.defined_tags
14681506

14691507
self.dsc_job = DSCJob(**payload)
14701508
# Set Job infra to user values after DSCJob initialized the defaults

0 commit comments

Comments
 (0)