Skip to content

Commit e71e398

Browse files
authored
Added validation for driver and executor shape details. (#211)
2 parents 589d93a + 6ed266e commit e71e398

File tree

2 files changed

+86
-6
lines changed

2 files changed

+86
-6
lines changed

ads/jobs/builders/infrastructure/dataflow.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
DEFAULT_SPARK_VERSION = "3.2.1"
4343
DEFAULT_NUM_EXECUTORS = 1
4444
DEFAULT_SHAPE = "VM.Standard.E3.Flex"
45+
DATAFLOW_SHAPE_FAMILY = [
46+
"Standard.E3",
47+
"Standard.E4",
48+
"Standard3",
49+
"Standard.A1",
50+
"Standard2"
51+
]
4552

4653

4754
def conda_pack_name_to_dataflow_config(conda_uri):
@@ -889,11 +896,39 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
889896
raise ValueError(
890897
"Compartment id is required. Specify compartment id via 'with_compartment_id()'."
891898
)
899+
self._validate_shapes(payload)
892900
payload.pop("id", None)
893901
logger.debug(f"Creating a DataFlow Application with payload {payload}")
894902
self.df_app = DataFlowApp(**payload).create()
895903
self.with_id(self.df_app.id)
896904
return self
905+
906+
@staticmethod
907+
def _validate_shapes(payload: Dict):
908+
if "executor_shape" not in payload:
909+
payload["executor_shape"] = DEFAULT_SHAPE
910+
if "driver_shape" not in payload:
911+
payload["driver_shape"] = DEFAULT_SHAPE
912+
executor_shape = payload["executor_shape"]
913+
executor_shape_config = payload.get("executor_shape_config", {})
914+
driver_shape = payload["driver_shape"]
915+
driver_shape_config = payload.get("driver_shape_config", {})
916+
same_shape_family = False
917+
for shape in DATAFLOW_SHAPE_FAMILY:
918+
if shape in executor_shape and shape in driver_shape:
919+
same_shape_family = True
920+
break
921+
if not same_shape_family:
922+
raise ValueError(
923+
"`executor_shape` and `driver_shape` must be from the same shape family."
924+
)
925+
if (
926+
(not executor_shape.endswith("Flex") and executor_shape_config)
927+
or (not driver_shape.endswith("Flex") and driver_shape_config)
928+
):
929+
raise ValueError(
930+
"Shape config is not required for non flex shape from user end."
931+
)
897932

898933
@staticmethod
899934
def _upload_file(local_path, bucket, overwrite=False):

tests/unitary/default_setup/jobs/test_jobs_dataflow.py

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@
3737
arguments=["test-df"],
3838
compartment_id="ocid1.compartment.oc1..<unique_ocid>",
3939
display_name="test-df",
40-
driver_shape="VM.Standard2.1",
40+
driver_shape="VM.Standard.E4.Flex",
4141
driver_shape_config={"memory_in_gbs": 1, "ocpus": 16},
42-
executor_shape="VM.Standard2.1",
42+
executor_shape="VM.Standard.E4.Flex",
4343
executor_shape_config={"memory_in_gbs": 1, "ocpus": 16},
4444
file_uri="oci://test_bucket@test_namespace/test-dataflow/test-dataflow.py",
4545
num_executors=1,
@@ -124,7 +124,7 @@ def test_create_delete(self, mock_to_dict, mock_client):
124124
df.lifecycle_state
125125
== oci.data_flow.models.Application.LIFECYCLE_STATE_DELETED
126126
)
127-
assert len(df.to_yaml()) == 557
127+
assert len(df.to_yaml()) == 567
128128

129129
def test_create_df_app_with_default_display_name(
130130
self,
@@ -433,8 +433,8 @@ def test_create_from_id(self, mock_from_ocid):
433433
mock_from_ocid.return_value = Application(**SAMPLE_PAYLOAD)
434434
df = DataFlow.from_id("ocid1.datasciencejob.oc1.iad.<unique_ocid>")
435435
assert df.name == "test-df"
436-
assert df.driver_shape == "VM.Standard2.1"
437-
assert df.executor_shape == "VM.Standard2.1"
436+
assert df.driver_shape == "VM.Standard.E4.Flex"
437+
assert df.executor_shape == "VM.Standard.E4.Flex"
438438
assert df.private_endpoint_id == "test_private_endpoint"
439439

440440
assert (
@@ -454,7 +454,7 @@ def test_create_from_id(self, mock_from_ocid):
454454
def test_to_and_from_dict(self, df):
455455
df_dict = df.to_dict()
456456
assert df_dict["spec"]["numExecutors"] == 2
457-
assert df_dict["spec"]["driverShape"] == "VM.Standard2.1"
457+
assert df_dict["spec"]["driverShape"] == "VM.Standard.E4.Flex"
458458
assert df_dict["spec"]["logsBucketUri"] == "oci://test_bucket@test_namespace/"
459459
assert df_dict["spec"]["privateEndpointId"] == "test_private_endpoint"
460460
assert df_dict["spec"]["driverShapeConfig"] == {"memoryInGBs": 1, "ocpus": 16}
@@ -482,6 +482,51 @@ def test_to_and_from_dict(self, df):
482482
assert df3_dict["spec"]["sparkVersion"] == "3.2.1"
483483
assert df3_dict["spec"]["numExecutors"] == 2
484484

485+
def test_shape_and_details(self, mock_to_dict, mock_client, df):
486+
df.with_driver_shape(
487+
"VM.Standard2.1"
488+
).with_executor_shape(
489+
"VM.Standard.E4.Flex"
490+
)
491+
492+
rt = (
493+
DataFlowRuntime()
494+
.with_script_uri(SAMPLE_PAYLOAD["file_uri"])
495+
.with_archive_uri(SAMPLE_PAYLOAD["archive_uri"])
496+
.with_custom_conda(
497+
"oci://my_bucket@my_namespace/conda_environments/cpu/PySpark 3.0 and Data Flow/5.0/pyspark30_p37_cpu_v5"
498+
)
499+
.with_overwrite(True)
500+
)
501+
502+
with pytest.raises(
503+
ValueError,
504+
match="`executor_shape` and `driver_shape` must be from the same shape family."
505+
):
506+
with patch.object(DataFlowApp, "client", mock_client):
507+
with patch.object(DataFlowApp, "to_dict", mock_to_dict):
508+
df.create(rt)
509+
510+
df.with_driver_shape(
511+
"VM.Standard2.1"
512+
).with_driver_shape_config(
513+
memory_in_gbs=SAMPLE_PAYLOAD["driver_shape_config"]["memory_in_gbs"],
514+
ocpus=SAMPLE_PAYLOAD["driver_shape_config"]["ocpus"],
515+
).with_executor_shape(
516+
"VM.Standard2.16"
517+
).with_executor_shape_config(
518+
memory_in_gbs=SAMPLE_PAYLOAD["executor_shape_config"]["memory_in_gbs"],
519+
ocpus=SAMPLE_PAYLOAD["executor_shape_config"]["ocpus"],
520+
)
521+
522+
with pytest.raises(
523+
ValueError,
524+
match="Shape config is not required for non flex shape from user end."
525+
):
526+
with patch.object(DataFlowApp, "client", mock_client):
527+
with patch.object(DataFlowApp, "to_dict", mock_to_dict):
528+
df.create(rt)
529+
485530

486531
class TestDataFlowNotebookRuntime:
487532
@pytest.mark.skipif(

0 commit comments

Comments
 (0)