Skip to content

Commit 6ed266e

Browse files
committed
Updated pr.
1 parent 1433210 commit 6ed266e

File tree

2 files changed

+34
-11
lines changed

2 files changed

+34
-11
lines changed

ads/jobs/builders/infrastructure/dataflow.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@
4242
DEFAULT_SPARK_VERSION = "3.2.1"
4343
DEFAULT_NUM_EXECUTORS = 1
4444
DEFAULT_SHAPE = "VM.Standard.E3.Flex"
45+
DATAFLOW_SHAPE_FAMILY = [
46+
"Standard.E3",
47+
"Standard.E4",
48+
"Standard3",
49+
"Standard.A1",
50+
"Standard2"
51+
]
4552

4653

4754
def conda_pack_name_to_dataflow_config(conda_uri):
@@ -860,6 +867,15 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
860867
raise ValueError(
861868
"Compartment id is required. Specify compartment id via 'with_compartment_id()'."
862869
)
870+
self._validate_shapes(payload)
871+
payload.pop("id", None)
872+
logger.debug(f"Creating a DataFlow Application with payload {payload}")
873+
self.df_app = DataFlowApp(**payload).create()
874+
self.with_id(self.df_app.id)
875+
return self
876+
877+
@staticmethod
878+
def _validate_shapes(payload: Dict):
863879
if "executor_shape" not in payload:
864880
payload["executor_shape"] = DEFAULT_SHAPE
865881
if "driver_shape" not in payload:
@@ -868,15 +884,22 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
868884
executor_shape_config = payload.get("executor_shape_config", {})
869885
driver_shape = payload["driver_shape"]
870886
driver_shape_config = payload.get("driver_shape_config", {})
871-
if executor_shape != driver_shape:
872-
raise ValueError("`executor_shape` and `driver_shape` must be from the same shape family.")
873-
if (not executor_shape.endswith("Flex") and executor_shape_config) or (not driver_shape.endswith("Flex") and driver_shape_config):
874-
raise ValueError("Shape config is not required for non flex shape from user end.")
875-
payload.pop("id", None)
876-
logger.debug(f"Creating a DataFlow Application with payload {payload}")
877-
self.df_app = DataFlowApp(**payload).create()
878-
self.with_id(self.df_app.id)
879-
return self
887+
same_shape_family = False
888+
for shape in DATAFLOW_SHAPE_FAMILY:
889+
if shape in executor_shape and shape in driver_shape:
890+
same_shape_family = True
891+
break
892+
if not same_shape_family:
893+
raise ValueError(
894+
"`executor_shape` and `driver_shape` must be from the same shape family."
895+
)
896+
if (
897+
(not executor_shape.endswith("Flex") and executor_shape_config)
898+
or (not driver_shape.endswith("Flex") and driver_shape_config)
899+
):
900+
raise ValueError(
901+
"Shape config is not required for non flex shape from user end."
902+
)
880903

881904
@staticmethod
882905
def _upload_file(local_path, bucket, overwrite=False):

tests/unitary/default_setup/jobs/test_jobs_dataflow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def test_shape_and_details(self, mock_to_dict, mock_client, df):
448448
df.with_driver_shape(
449449
"VM.Standard2.1"
450450
).with_executor_shape(
451-
"VM.Standard2.8"
451+
"VM.Standard.E4.Flex"
452452
)
453453

454454
rt = (
@@ -475,7 +475,7 @@ def test_shape_and_details(self, mock_to_dict, mock_client, df):
475475
memory_in_gbs=SAMPLE_PAYLOAD["driver_shape_config"]["memory_in_gbs"],
476476
ocpus=SAMPLE_PAYLOAD["driver_shape_config"]["ocpus"],
477477
).with_executor_shape(
478-
"VM.Standard2.1"
478+
"VM.Standard2.16"
479479
).with_executor_shape_config(
480480
memory_in_gbs=SAMPLE_PAYLOAD["executor_shape_config"]["memory_in_gbs"],
481481
ocpus=SAMPLE_PAYLOAD["executor_shape_config"]["ocpus"],

0 commit comments

Comments
 (0)