Skip to content

Commit 1433210

Browse files
committed
Added validation for driver and executor shape details.
1 parent eb56fd5 commit 1433210

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

ads/jobs/builders/infrastructure/dataflow.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,18 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
860860
raise ValueError(
861861
"Compartment id is required. Specify compartment id via 'with_compartment_id()'."
862862
)
863+
if "executor_shape" not in payload:
864+
payload["executor_shape"] = DEFAULT_SHAPE
865+
if "driver_shape" not in payload:
866+
payload["driver_shape"] = DEFAULT_SHAPE
867+
executor_shape = payload["executor_shape"]
868+
executor_shape_config = payload.get("executor_shape_config", {})
869+
driver_shape = payload["driver_shape"]
870+
driver_shape_config = payload.get("driver_shape_config", {})
871+
if executor_shape != driver_shape:
872+
raise ValueError("`executor_shape` and `driver_shape` must be from the same shape family.")
873+
if (not executor_shape.endswith("Flex") and executor_shape_config) or (not driver_shape.endswith("Flex") and driver_shape_config):
874+
raise ValueError("Shape config is not required for non flex shape from user end.")
863875
payload.pop("id", None)
864876
logger.debug(f"Creating a DataFlow Application with payload {payload}")
865877
self.df_app = DataFlowApp(**payload).create()

tests/unitary/default_setup/jobs/test_jobs_dataflow.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
arguments=["test-df"],
3838
compartment_id="ocid1.compartment.oc1..<unique_ocid>",
3939
display_name="test-df",
40-
driver_shape="VM.Standard2.1",
40+
driver_shape="VM.Standard.E4.Flex",
4141
driver_shape_config={"memory_in_gbs": 1, "ocpus": 16},
42-
executor_shape="VM.Standard2.1",
42+
executor_shape="VM.Standard.E4.Flex",
4343
executor_shape_config={"memory_in_gbs": 1, "ocpus": 16},
4444
file_uri="oci://test_bucket@test_namespace/test-dataflow/test-dataflow.py",
4545
num_executors=1,
@@ -124,7 +124,7 @@ def test_create_delete(self, mock_to_dict, mock_client):
124124
df.lifecycle_state
125125
== oci.data_flow.models.Application.LIFECYCLE_STATE_DELETED
126126
)
127-
assert len(df.to_yaml()) == 557
127+
assert len(df.to_yaml()) == 567
128128

129129
def test_create_df_app_with_default_display_name(
130130
self,
@@ -403,8 +403,8 @@ def test_create_from_id(self, mock_from_ocid):
403403
mock_from_ocid.return_value = Application(**SAMPLE_PAYLOAD)
404404
df = DataFlow.from_id("ocid1.datasciencejob.oc1.iad.<unique_ocid>")
405405
assert df.name == "test-df"
406-
assert df.driver_shape == "VM.Standard2.1"
407-
assert df.executor_shape == "VM.Standard2.1"
406+
assert df.driver_shape == "VM.Standard.E4.Flex"
407+
assert df.executor_shape == "VM.Standard.E4.Flex"
408408
assert df.private_endpoint_id == "test_private_endpoint"
409409

410410
assert (
@@ -424,7 +424,7 @@ def test_create_from_id(self, mock_from_ocid):
424424
def test_to_and_from_dict(self, df):
425425
df_dict = df.to_dict()
426426
assert df_dict["spec"]["numExecutors"] == 2
427-
assert df_dict["spec"]["driverShape"] == "VM.Standard2.1"
427+
assert df_dict["spec"]["driverShape"] == "VM.Standard.E4.Flex"
428428
assert df_dict["spec"]["logsBucketUri"] == "oci://test_bucket@test_namespace/"
429429
assert df_dict["spec"]["privateEndpointId"] == "test_private_endpoint"
430430
assert df_dict["spec"]["driverShapeConfig"] == {"memoryInGBs": 1, "ocpus": 16}
@@ -444,6 +444,51 @@ def test_to_and_from_dict(self, df):
444444
assert df3_dict["spec"]["sparkVersion"] == "3.2.1"
445445
assert df3_dict["spec"]["numExecutors"] == 2
446446

447+
def test_shape_and_details(self, mock_to_dict, mock_client, df):
448+
df.with_driver_shape(
449+
"VM.Standard2.1"
450+
).with_executor_shape(
451+
"VM.Standard2.8"
452+
)
453+
454+
rt = (
455+
DataFlowRuntime()
456+
.with_script_uri(SAMPLE_PAYLOAD["file_uri"])
457+
.with_archive_uri(SAMPLE_PAYLOAD["archive_uri"])
458+
.with_custom_conda(
459+
"oci://my_bucket@my_namespace/conda_environments/cpu/PySpark 3.0 and Data Flow/5.0/pyspark30_p37_cpu_v5"
460+
)
461+
.with_overwrite(True)
462+
)
463+
464+
with pytest.raises(
465+
ValueError,
466+
match="`executor_shape` and `driver_shape` must be from the same shape family."
467+
):
468+
with patch.object(DataFlowApp, "client", mock_client):
469+
with patch.object(DataFlowApp, "to_dict", mock_to_dict):
470+
df.create(rt)
471+
472+
df.with_driver_shape(
473+
"VM.Standard2.1"
474+
).with_driver_shape_config(
475+
memory_in_gbs=SAMPLE_PAYLOAD["driver_shape_config"]["memory_in_gbs"],
476+
ocpus=SAMPLE_PAYLOAD["driver_shape_config"]["ocpus"],
477+
).with_executor_shape(
478+
"VM.Standard2.1"
479+
).with_executor_shape_config(
480+
memory_in_gbs=SAMPLE_PAYLOAD["executor_shape_config"]["memory_in_gbs"],
481+
ocpus=SAMPLE_PAYLOAD["executor_shape_config"]["ocpus"],
482+
)
483+
484+
with pytest.raises(
485+
ValueError,
486+
match="Shape config is not required for non flex shape from user end."
487+
):
488+
with patch.object(DataFlowApp, "client", mock_client):
489+
with patch.object(DataFlowApp, "to_dict", mock_to_dict):
490+
df.create(rt)
491+
447492

448493
class TestDataFlowNotebookRuntime:
449494
@pytest.mark.skipif(

0 commit comments

Comments
 (0)