42
42
DEFAULT_SPARK_VERSION = "3.2.1"
43
43
DEFAULT_NUM_EXECUTORS = 1
44
44
DEFAULT_SHAPE = "VM.Standard.E3.Flex"
45
+ DATAFLOW_SHAPE_FAMILY = [
46
+ "Standard.E3" ,
47
+ "Standard.E4" ,
48
+ "Standard3" ,
49
+ "Standard.A1" ,
50
+ "Standard2"
51
+ ]
45
52
46
53
47
54
def conda_pack_name_to_dataflow_config (conda_uri ):
@@ -860,6 +867,15 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
860
867
raise ValueError (
861
868
"Compartment id is required. Specify compartment id via 'with_compartment_id()'."
862
869
)
870
+ self ._validate_shapes (payload )
871
+ payload .pop ("id" , None )
872
+ logger .debug (f"Creating a DataFlow Application with payload { payload } " )
873
+ self .df_app = DataFlowApp (** payload ).create ()
874
+ self .with_id (self .df_app .id )
875
+ return self
876
+
877
+ @staticmethod
878
+ def _validate_shapes (payload : Dict ):
863
879
if "executor_shape" not in payload :
864
880
payload ["executor_shape" ] = DEFAULT_SHAPE
865
881
if "driver_shape" not in payload :
@@ -868,15 +884,22 @@ def create(self, runtime: DataFlowRuntime, **kwargs) -> "DataFlow":
868
884
executor_shape_config = payload .get ("executor_shape_config" , {})
869
885
driver_shape = payload ["driver_shape" ]
870
886
driver_shape_config = payload .get ("driver_shape_config" , {})
871
- if executor_shape != driver_shape :
872
- raise ValueError ("`executor_shape` and `driver_shape` must be from the same shape family." )
873
- if (not executor_shape .endswith ("Flex" ) and executor_shape_config ) or (not driver_shape .endswith ("Flex" ) and driver_shape_config ):
874
- raise ValueError ("Shape config is not required for non flex shape from user end." )
875
- payload .pop ("id" , None )
876
- logger .debug (f"Creating a DataFlow Application with payload { payload } " )
877
- self .df_app = DataFlowApp (** payload ).create ()
878
- self .with_id (self .df_app .id )
879
- return self
887
+ same_shape_family = False
888
+ for shape in DATAFLOW_SHAPE_FAMILY :
889
+ if shape in executor_shape and shape in driver_shape :
890
+ same_shape_family = True
891
+ break
892
+ if not same_shape_family :
893
+ raise ValueError (
894
+ "`executor_shape` and `driver_shape` must be from the same shape family."
895
+ )
896
+ if (
897
+ (not executor_shape .endswith ("Flex" ) and executor_shape_config )
898
+ or (not driver_shape .endswith ("Flex" ) and driver_shape_config )
899
+ ):
900
+ raise ValueError (
901
+ "Shape config is not required for non flex shape from user end."
902
+ )
880
903
881
904
@staticmethod
882
905
def _upload_file (local_path , bucket , overwrite = False ):
0 commit comments