Skip to content

Commit 19f10bc

Browse files
committed
Added defined tags for data flow
1 parent eb56fd5 commit 19f10bc

File tree

2 files changed

+82
-3
lines changed

2 files changed

+82
-3
lines changed

ads/jobs/builders/infrastructure/dataflow.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@ class DataFlow(Infrastructure):
384384
CONST_OCPUS = "ocpus"
385385
CONST_ID = "id"
386386
CONST_PRIVATE_ENDPOINT_ID = "private_endpoint_id"
387+
CONST_FREEFORM_TAGS = "freeform_tags"
388+
CONST_DEFINED_TAGS = "defined_tags"
387389

388390
attribute_map = {
389391
CONST_COMPARTMENT_ID: "compartmentId",
@@ -402,6 +404,8 @@ class DataFlow(Infrastructure):
402404
CONST_OCPUS: CONST_OCPUS,
403405
CONST_ID: CONST_ID,
404406
CONST_PRIVATE_ENDPOINT_ID: "privateEndpointId",
407+
CONST_FREEFORM_TAGS: "freeformTags",
408+
CONST_DEFINED_TAGS: "definedTags"
405409
}
406410

407411
def __init__(self, spec: dict = None, **kwargs):
@@ -414,7 +418,9 @@ def __init__(self, spec: dict = None, **kwargs):
414418
spec = {
415419
k: v
416420
for k, v in spec.items()
417-
if f"with_{camel_to_snake(k)}" in self.__dir__() and v is not None
421+
if (f"with_{camel_to_snake(k)}" in self.__dir__()
422+
or (k == "defined_tags" or "freeform_tags"))
423+
and v is not None
418424
}
419425
defaults.update(spec)
420426
super().__init__(defaults, **kwargs)
@@ -775,6 +781,36 @@ def with_private_endpoint_id(self, private_endpoint_id: str) -> "DataFlow":
775781
the Data Flow instance itself
776782
"""
777783
return self.set_spec(self.CONST_PRIVATE_ENDPOINT_ID, private_endpoint_id)
784+
785+
def with_freeform_tag(self, **kwargs) -> "DataFlow":
786+
"""Sets freeform tags
787+
788+
Returns
789+
-------
790+
DataFlow
791+
The DataFlow instance (self)
792+
"""
793+
return self.set_spec(self.CONST_FREEFORM_TAGS, kwargs)
794+
795+
def with_defined_tag(self, **kwargs) -> "DataFlow":
796+
"""Sets defined tags
797+
798+
Returns
799+
-------
800+
DataFlow
801+
The DataFlow instance (self)
802+
"""
803+
return self.set_spec(self.CONST_DEFINED_TAGS, kwargs)
804+
805+
@property
806+
def freeform_tags(self) -> dict:
807+
"""Freeform tags"""
808+
return self.get_spec(self.CONST_FREEFORM_TAGS, {})
809+
810+
@property
811+
def defined_tags(self) -> dict:
812+
"""Defined tags"""
813+
return self.get_spec(self.CONST_DEFINED_TAGS, {})
778814

779815
def __getattr__(self, item):
780816
if f"with_{item}" in self.__dir__():
@@ -849,7 +885,8 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
849885
{
850886
"display_name": self.name,
851887
"file_uri": runtime.script_uri,
852-
"freeform_tags": runtime.freeform_tags,
888+
"freeform_tags": runtime.freeform_tags or self.freeform_tags,
889+
"defined_tags": runtime.defined_tags or self.defined_tags,
853890
"archive_uri": runtime.archive_uri,
854891
"configuration": runtime.configuration,
855892
}
@@ -915,6 +952,7 @@ def run(
915952
args: List[str] = None,
916953
env_vars: Dict[str, str] = None,
917954
freeform_tags: Dict[str, str] = None,
955+
defined_tags: Dict[str, Dict[str, object]] = None,
918956
wait: bool = False,
919957
**kwargs,
920958
) -> DataFlowRun:
@@ -932,6 +970,8 @@ def run(
932970
dictionary of environment variables (not used for data flow)
933971
freeform_tags: Dict[str, str], optional
934972
freeform tags
973+
defined_tags: Dict[str, Dict[str, object]], optional
974+
defined tags
935975
wait: bool, optional
936976
whether to wait for a run to terminate
937977
kwargs
@@ -950,7 +990,8 @@ def run(
950990
# Set default display_name if not specified - randomly generated easy to remember name generated
951991
payload["display_name"] = name if name else utils.get_random_name_for_resource()
952992
payload["arguments"] = args if args and len(args) > 0 else None
953-
payload["freeform_tags"] = freeform_tags
993+
payload["freeform_tags"] = freeform_tags or self.freeform_tags
994+
payload["defined_tags"] = defined_tags or self.defined_tags
954995
payload.pop("spark_version", None)
955996
logger.debug(f"Creating a DataFlow Run with payload {payload}")
956997
run = DataFlowRun(**payload).create()

tests/unitary/default_setup/jobs/test_jobs_dataflow.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,27 @@ def df(self):
320320
2
321321
).with_private_endpoint_id(
322322
"test_private_endpoint"
323+
).with_freeform_tag(
324+
test_freeform_tags_key="test_freeform_tags_value",
325+
).with_defined_tag(
326+
test_defined_tags_namespace={
327+
"test_defined_tags_key": "test_defined_tags_value"
328+
}
323329
)
324330
return df
325331

326332
def test_create_with_builder_pattern(self, mock_to_dict, mock_client, df):
327333
assert df.language == "PYTHON"
328334
assert df.spark_version == "3.2.1"
329335
assert df.num_executors == 2
336+
assert df.freeform_tags == {
337+
"test_freeform_tags_key": "test_freeform_tags_value"
338+
}
339+
assert df.defined_tags == {
340+
"test_defined_tags_namespace": {
341+
"test_defined_tags_key": "test_defined_tags_value"
342+
}
343+
}
330344

331345
rt = (
332346
DataFlowRuntime()
@@ -335,9 +349,25 @@ def test_create_with_builder_pattern(self, mock_to_dict, mock_client, df):
335349
.with_custom_conda(
336350
"oci://my_bucket@my_namespace/conda_environments/cpu/PySpark 3.0 and Data Flow/5.0/pyspark30_p37_cpu_v5"
337351
)
352+
.with_freeform_tag(
353+
test_freeform_tags_runtime_key="test_freeform_tags_runtime_value"
354+
)
355+
.with_defined_tag(
356+
test_defined_tags_namespace={
357+
"test_defined_tags_runtime_key": "test_defined_tags_runtime_value"
358+
}
359+
)
338360
.with_overwrite(True)
339361
)
340362
assert rt.overwrite == True
363+
assert rt.freeform_tags == {
364+
"test_freeform_tags_runtime_key": "test_freeform_tags_runtime_value"
365+
}
366+
assert rt.defined_tags == {
367+
"test_defined_tags_namespace": {
368+
"test_defined_tags_runtime_key": "test_defined_tags_runtime_value"
369+
}
370+
}
341371

342372
with patch.object(DataFlowApp, "client", mock_client):
343373
with patch.object(DataFlowApp, "to_dict", mock_to_dict):
@@ -429,6 +459,14 @@ def test_to_and_from_dict(self, df):
429459
assert df_dict["spec"]["privateEndpointId"] == "test_private_endpoint"
430460
assert df_dict["spec"]["driverShapeConfig"] == {"memoryInGBs": 1, "ocpus": 16}
431461
assert df_dict["spec"]["executorShapeConfig"] == {"memoryInGBs": 1, "ocpus": 16}
462+
assert df_dict["spec"]["freeformTags"] == {
463+
"test_freeform_tags_key": "test_freeform_tags_value"
464+
}
465+
assert df_dict["spec"]["definedTags"] == {
466+
"test_defined_tags_namespace": {
467+
"test_defined_tags_key": "test_defined_tags_value"
468+
}
469+
}
432470

433471
df_dict["spec"].pop("language")
434472
df_dict["spec"].pop("numExecutors")

0 commit comments

Comments
 (0)